# core

> Fill in a module description here

In [None]:
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import subprocess
import json
from jupyter_client import BlockingKernelClient
import sys
from dataclasses import dataclass
from IPython.core.magic import register_cell_magic
from queue import Empty

In [None]:
#| export
@dataclass(frozen=True)
class Client:
    user:str
    host:str
    port:int
    kernel_ports:dict
    kernel:BlockingKernelClient
    python_path:str

In [None]:
#| export
connected_clients = {}

In [None]:
def remote_python(cmd, user:str, host:str, port:int, python_path="python3"):
    result = subprocess.run(["ssh", f"{user}@{host}", "-p", str(port), python_path], input=cmd, capture_output=True, text=True)
    return result

In [None]:
#| export
def connect(
    user:str,
    host:str,
    port:int,
    shell_port:int = 50001,
    iopub_port:int = 50002,
    stdin_port:int = 50003,
    control_port:int = 50004,
    heartbeat_port:int = 50005,
    python_path:str = "python3",
) -> Client:
    key = (user, host, port, shell_port, iopub_port, stdin_port, control_port, heartbeat_port)
    if key in connected_clients:
        client = connected_clients[key]
        # TODO(achal): Check integrity of the client
        return client
    
    setup_script=f"""
import os
import signal
import subprocess
import sys
import time

def safe(fn):
    try: fn()
    except: pass

python_path = os.path.expanduser("{python_path}")
pid_path = os.path.expanduser("~/ipykernel.pid")
json_config_path = os.path.expanduser("~/ipykernel.json")
error_log_path = os.path.expanduser("~/ipykernel_error.log")

# Clear up old state if any
if os.path.exists(pid_path):
    with open(pid_path, encoding="utf-8") as f:
        pid = int(f.read())
        safe(lambda: os.kill(pid, signal.SIGTERM))
    safe(lambda: os.remove(pid_path))
safe(lambda: os.remove(json_config_path))
safe(lambda: os.remove(error_log_path))

stderr_file = open(error_log_path, "w")

process = subprocess.Popen([
        python_path,
        "-m",
        "ipykernel_launcher",
        "--shell={shell_port}",
        "--iopub={iopub_port}",
        "--stdin={stdin_port}",
        "--control={control_port}",
        "--hb={heartbeat_port}",
        "-f="+json_config_path,
    ],
    stdin=subprocess.DEVNULL,
    stdout=subprocess.DEVNULL,
    stderr=stderr_file,
    start_new_session=True,
)
time.sleep(1)
stderr_file.close()

if process.poll() is not None:
    with open(error_log_path, "r") as f:
        print(f.read(), file=sys.stderr, end="")
    sys.exit(1)

with open(json_config_path, "r") as f:
    print(f.read(), file=sys.stdout, end="")

with open(pid_path, "w") as f:
    f.write(str(process.pid))
    """
    result = remote_python(setup_script, user, host, port, python_path)
    if result.returncode != 0:
        raise RuntimeError(f"Failed to connect to the IPython kernel:\n{result.stderr}")

    # NOTE(achal): We use ssh control sockets for port forwarding.
    # `-f` should fork into background process only when authentication
    # and forwarding is successful.
    # Launching a background subprocess directly i.e. with a subprocess.Popen
    # runs into race condition with BlockingKernelClient.start_channels because
    # the port forwarding needs to be set up before we call this function.
    # During disconnect we will just use the socket we specified to kill the process.
    cmd = "ssh -N -f -M -S /tmp/ipykernel.sock "
    cmd += " ".join(f"-L {p}:localhost:{p} " for p in [shell_port, iopub_port, stdin_port, control_port, heartbeat_port])
    cmd += f" {user}@{host} -p {port}"
    subprocess.run(cmd.split())

    config_json = json.loads(result.stdout)
    
    # TODO(achal): Error handling
    kernel = BlockingKernelClient()
    kernel.load_connection_info(config_json)
    kernel.start_channels()
    kernel.wait_for_ready(timeout=10)

    print(f"Connected to IPython kernel at {user}@{host}:{port}")

    kernel_ports = {
        "shell_port": shell_port,
        "iopub_port": iopub_port,
        "stdin_port": stdin_port,
        "control_port": control_port,
        "heartbeat_port": heartbeat_port,
    }
    client = Client(user=user, host=host, port=port, kernel_ports=kernel_ports, python_path=python_path, kernel=kernel)
    connected_clients[key] = client
    
    # TODO(achal): Check before re-registering if connect is called twice?    
    @register_cell_magic
    def remote(line, cell):
        run_remote_blocking(client, cell)
    return client

In [None]:
#| export
def run_remote_blocking(client:Client, code:str):
    client.kernel.execute(code)
    while True:
        msg = client.kernel.get_iopub_msg(timeout=5)
        msg_type = msg["msg_type"]
        if msg_type == "status":
            if msg["content"]["execution_state"] == "idle": break
        elif msg_type == "stream":
            content = msg["content"]
            print(content["text"], file=getattr(sys, content["name"]), end='')
        elif msg_type == "error":
            print('\n'.join(msg["content"]["traceback"]), file=sys.stderr)
        elif msg_type == "execute_result":
            print(msg["content"]["data"]["text/plain"])
        else:
            # Do nothing for now
            pass

In [None]:
def _debug_eat_pending_messages(client:Client):
    while True:
        try:
            msg = client.kernel.get_iopub_msg(timeout=1)
            print(msg['msg_type'])
        except Empty:
            print("Queue cleared!")
            break

In [None]:
#| export
def disconnect(client:Client):
    client.kernel.stop_channels()
    subprocess.run(["ssh", "-S", "/tmp/ipykernel.sock", "-O", "exit", f"{client.user}@{client.host}", "-p", str(client.port)])
    cleanup_script=f"""
import os
import signal

def safe(fn):
    try: fn()
    except: pass

pid_path = os.path.expanduser("~/ipykernel.pid")
json_config_path = os.path.expanduser("~/ipykernel.json")
error_log_path = os.path.expanduser("~/ipykernel_error.log")

if os.path.exists(pid_path):
    with open(pid_path, encoding="utf-8") as f:
        pid = int(f.read())
        safe(lambda: os.kill(pid, signal.SIGTERM))
    safe(lambda: os.remove(pid_path))
safe(lambda: os.remove(json_config_path))
safe(lambda: os.remove(error_log_path))
"""
    result = remote_python(cleanup_script, client.user, client.host, client.port, python_path=client.python_path)
    if result.returncode != 0:
        print(result.stderr)
    else:
        print(result.stdout)

    key = (
        client.user,
        client.host,
        client.port,
        client.kernel_ports["shell_port"],
        client.kernel_ports["iopub_port"],
        client.kernel_ports["stdin_port"],
        client.kernel_ports["control_port"],
        client.kernel_ports["heartbeat_port"]
    )
    if key not in connected_clients:
        print("Warning: client was not in connected clients")
    del connected_clients[key]

In [None]:
client:Client = connect(user="achal", host="91.99.226.117", port=8000, python_path="~/venv/bin/python3")

bind [127.0.0.1]:50004: Address already in use


Connected to IPython kernel at achal@91.99.226.117:8000


In [None]:
%%remote

print("Hello")

Hello


In [None]:
%%remote
import sys
print(sys.version)

In [None]:
disconnect(client)

Exit request sent.



