In [None]:
#| default_exp core

# ConKernelClient source

> Concurrent-safe Jupyter KernelClient

## Imports

In [None]:
#| export
from queue import Empty
from asyncio import create_task
from jupyter_client import KernelClient, AsyncKernelClient
from jupyter_client.session import Session
from jupyter_client.channels import AsyncZMQSocketChannel
from zmq.error import ZMQError
from jupyter_client.kernelspec import KernelSpec
from jupyter_client import AsyncKernelManager
from traitlets import Type
import asyncio, zmq.asyncio

In [None]:
from fastcore.test import test_eq
from fastcore.utils import patch

## Setup

In [None]:
#| export
if not hasattr(Session, '_orig_send'): Session._orig_send = Session.send

def _send(self, stream, msg_or_type, content=None, parent=None, ident=None,
        buffers=None, track=False, header=None, metadata=None):
    msg = self._orig_send(stream, msg_or_type, content=content, parent=parent,
                         ident=ident, buffers=buffers, track=track, header=header, metadata=metadata)
    # Force a sync, ensuring the send is fully registered internally
    # Avoids a race where the lock releases, another thread immediately calls send(),
    # and now 2 threads are interacting with the internal state before I/O thread has caught up
    if stream and hasattr(stream, 'io_thread'): stream.io_thread.socket.get(zmq.EVENTS)
    return msg

Session.send = _send

In [None]:
#| export
class ConKernelClient(AsyncKernelClient):
    async def control(self, cmd, **kw):
        msg = self.session.msg(cmd, content=kw)
        self.control_channel.send(msg)
        return dict2obj(await self.get_control_msg(timeout=2))['content']

    async def create_subshell(self): return await self.control("create_subshell_request")
    async def list_subshells(self): return await self.control("list_subshell_request")
    async def delete_subshell(self, subsh_id:str): return await self.control("delete_subshell_request", subshell_id=subsh_id)

    async def start_channels(self, shell:bool=True, iopub:bool=True, stdin:bool=True, hb:bool=True, control:bool=True):
        super().start_channels(shell=shell, iopub=iopub, stdin=stdin, hb=hb, control=control)
        await self.wait_for_ready()
        self._pending = {}
        async def _reader():
            while True:
                try: reply = await self.get_shell_msg(timeout=None)
                except Exception as e:
                    for q in self._pending.values(): await q.put(e)
                    break
                q = self._pending.get(reply["parent_header"].get("msg_id"))
                if q: await q.put(reply)
        self._shell_reader_task = asyncio.create_task(_reader())
        return self

    def stop_channels(self):
        super().stop_channels()
        if (tk := getattr(self, '_shell_reader_task', None)):
            tk.cancel()
            self._shell_reader_task = None

    async def _async_recv_reply(self, msg_id, timeout=None, channel="shell"):
        if channel == "control": return await self._async_get_control_msg(timeout=timeout)
        q = self._pending[msg_id]
        try:
            res = await asyncio.wait_for(q.get(), timeout)
            if isinstance(res, Exception): raise res
            return res
        except asyncio.TimeoutError as e: raise TimeoutError("Timeout waiting for reply") from e
        finally: self._pending.pop(msg_id, None)

    def execute(self, code, user_expressions=None, allow_stdin=None, reply=False, subsh_id=None,
                cts_typ='code', timeout=60, msg_id=None, **kw):
        if user_expressions is None: user_expressions = {}
        if allow_stdin is None: allow_stdin = self.allow_stdin
        content = dict(user_expressions=user_expressions, allow_stdin=allow_stdin, subsh_id=subsh_id, **kw)
        content[cts_typ] = code
        msg = self.session.msg("execute_request", content)
        if msg_id is not None: msg["header"]["msg_id"] = msg_id
        if subsh_id is not None: msg["header"]["subshell_id"] = subsh_id
        msg_id = msg["header"]["msg_id"]
        if reply: self._pending[msg_id] = asyncio.Queue(maxsize=1)
        self.shell_channel.send(msg)
        if not reply: return msg_id
        return self._async_recv_reply(msg_id, timeout=timeout)

In [None]:
#| export
class ConKernelManager(AsyncKernelManager): client_class,client_factory = ConKernelClient,Type(ConKernelClient)

In [None]:
km = ConKernelManager(session=Session(key=b'x'))
await km.start_kernel()
await km.is_alive()

True

In [None]:
kc = await km.client().start_channels()
await kc.is_alive()

True

In [None]:
mid = kc.execute('1+1', reply=False)
mid

'da1daba0-04dbd1f288df4b48f666fff6_65859_1'

In [None]:
@patch
async def get_pubs(self:KernelClient, timeout=0.2):
    "Retrieve all outstanding iopub messages"
    res = []
    try:
        while msg := await self.get_iopub_msg(timeout=timeout): res.append(msg)
    except Empty: pass
    return res

In [None]:
pubs = await kc.get_pubs()
[(o['msg_type'],o['content']) for o in pubs]

[('status', {'execution_state': 'busy'}),
 ('execute_input', {'code': '1+1', 'execution_count': 1}),
 ('execute_result',
  {'data': {'text/plain': '2'}, 'metadata': {}, 'execution_count': 1}),
 ('status', {'execution_state': 'idle'})]

In [None]:
pubs[0]['parent_header']

{'msg_id': 'da1daba0-04dbd1f288df4b48f666fff6_65859_1',
 'msg_type': 'execute_request',
 'username': 'jhoward',
 'session': 'da1daba0-04dbd1f288df4b48f666fff6',
 'date': datetime.datetime(2026, 2, 25, 23, 27, 50, 974152, tzinfo=tzutc()),
 'version': '5.4'}

In [None]:
kc.stop_channels()

In [None]:
kc = await km.client().start_channels()

In [None]:
r = await kc.execute('2+1', timeout=1, reply=True)
r

{'header': {'msg_id': 'ef2cb125-7690f6c7ec100946c4a9ac71_65865_21',
  'msg_type': 'execute_reply',
  'username': 'jhoward',
  'session': 'ef2cb125-7690f6c7ec100946c4a9ac71',
  'date': datetime.datetime(2026, 2, 25, 23, 27, 51, 502259, tzinfo=tzutc()),
  'version': '5.4'},
 'msg_id': 'ef2cb125-7690f6c7ec100946c4a9ac71_65865_21',
 'msg_type': 'execute_reply',
 'parent_header': {'msg_id': 'da1daba0-04dbd1f288df4b48f666fff6_65859_1',
  'msg_type': 'execute_request',
  'username': 'jhoward',
  'session': 'da1daba0-04dbd1f288df4b48f666fff6',
  'date': datetime.datetime(2026, 2, 25, 23, 27, 51, 499371, tzinfo=tzutc()),
  'version': '5.4'},
 'metadata': {'started': '2026-02-25T23:27:51.500099Z',
  'dependencies_met': True,
  'engine': 'd5073d74-29f2-4b60-a29c-25d4107c3680',
  'status': 'ok'},
 'content': {'status': 'ok',
  'execution_count': 2,
  'user_expressions': {},
  'payload': []},
 'buffers': []}

In [None]:
await kc.get_pubs()
kc.execute('print("orphan")')
await asyncio.sleep(0.3)
slow, fast = await asyncio.gather(
    kc.execute('import time; time.sleep(0.3)', timeout=5, reply=True),
    kc.execute('1+1', timeout=5, reply=True),
    return_exceptions=True)
test_eq(type(slow), dict)
test_eq(type(fast), dict)

In [None]:
a = kc.execute('x=2', reply=True)
b = kc.execute('y=3', reply=True)

In [None]:
r = await asyncio.wait_for(asyncio.gather(a,b), timeout=2)
test_eq(len(r), 2)
r[0]['parent_header']['msg_id']

'da1daba0-04dbd1f288df4b48f666fff6_65859_5'

In [None]:
if await km.is_alive():
    kc.stop_channels()
    await km.shutdown_kernel()

## export -

In [None]:
#| hide
from nbdev import nbdev_export
nbdev_export()