#| export
# core

> remote python execution for SolveIt

## Requirements
- Allows me to connect to either my machine or rented GPU machines
    - Can either directly put in the ssh command or individual name, host and port info
- Runs the SolveIt cells on the remote machine and print the output back in SolveIt so the AI can see it without too much copy-pasting
    - Work with virtual env (if any) on the remote machine
- Tiny pip-installable python package so it can be easily reused in my various dialogs
    - Expose: `connect`, `disconnect` and `is_connected`
    - If `is_connected` is true the user should be able to always use the connection
- Give reliable error messages when it can't do something
- Exposes a cell magic for ease of use and so that the user doesn't have to write code as strings

In [None]:
#| export
#| default_exp core

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
import json
import os
import re
import subprocess
import sys
from jupyter_client import BlockingKernelClient
from dataclasses import dataclass
from IPython.core.magic import register_cell_magic

In [None]:
#| export
@dataclass(frozen=True)
class Client:
    """
    Represents a connection to a remote machine.
    """
    user:str
    host:str
    port:int
    kernel:BlockingKernelClient
    py:str

In [None]:
#| export
connected_clients = {}

In [None]:
#| export
SETUP_SCRIPT_TEMPLATE=f"""
import os
import signal
import subprocess
import sys
import time

def safe(fn):
    try: fn()
    except: pass

py = os.path.expanduser("{{py}}")
pid_path = os.path.expanduser("~/ipykernel.pid")
json_config_path = os.path.expanduser("~/ipykernel.json")
error_log_path = os.path.expanduser("~/ipykernel_error.log")

# Clear up old state if any
if os.path.exists(pid_path):
    with open(pid_path, encoding="utf-8") as f:
        pid = int(f.read())
        safe(lambda: os.kill(pid, signal.SIGTERM))
    safe(lambda: os.remove(pid_path))
safe(lambda: os.remove(json_config_path))
safe(lambda: os.remove(error_log_path))

stderr_file = open(error_log_path, "w")

process = subprocess.Popen([
        py,
        "-m",
        "ipykernel_launcher",
        "--shell=51001",
        "--iopub=51002",
        "--stdin=51003",
        "--control=51004",
        "--hb=51005",
        "-f="+json_config_path,
    ],
    stdin=subprocess.DEVNULL,
    stdout=subprocess.DEVNULL,
    stderr=stderr_file,
    start_new_session=True,
)
time.sleep(1)
stderr_file.close()

if process.poll() is not None:
    with open(error_log_path, "r") as f:
        print(f.read(), file=sys.stderr, end="")
    sys.exit(1)

with open(json_config_path, "r") as f:
    print(f.read(), file=sys.stdout, end="")

with open(pid_path, "w") as f:
    f.write(str(process.pid))
"""

In [None]:
#| export
CLEANUP_SCRIPT=f"""
import os
import signal

def safe(fn):
    try: fn()
    except: pass

pid_path = os.path.expanduser("~/ipykernel.pid")
json_config_path = os.path.expanduser("~/ipykernel.json")
error_log_path = os.path.expanduser("~/ipykernel_error.log")

if os.path.exists(pid_path):
    with open(pid_path, encoding="utf-8") as f:
        pid = int(f.read())
        safe(lambda: os.kill(pid, signal.SIGTERM))
    safe(lambda: os.remove(pid_path))
safe(lambda: os.remove(json_config_path))
safe(lambda: os.remove(error_log_path))
"""

In [None]:
#| export
CONTROL_SOCKET = "/tmp/remoteipy.sock"

In [None]:
#| export
def remote_python(code:str, user:str, host:str, port:int, py:str = "python3") -> str:
    result = subprocess.run(f"ssh -S {CONTROL_SOCKET} {user}@{host} -p {port} {py}".split(), input=code, capture_output=True, text=True)
    if result.returncode != 0:
        raise RuntimeError(result.stderr)
    return result.stdout

In [None]:
#| export
def connect(ssh_cmd:str, py:str = "python3") -> Client:
    """
    Starts an IPython kernel on a remote machine using ssh and connect to it,
    to send commands. Also registers a cell magic %%remote.
    Use the python interpreter path to use virtual environment on the remote machine. 

    ssh_cmd: ssh command to connect to the remote machine.
    py: path to the python interpreter.
    """ 
    user = None
    host = None
    port = None
    cmd_options = ssh_cmd.split()
    for i, s in enumerate(cmd_options):
        if s == "-p" or s == "--port":
            try:
                port = int(cmd_options[i+1])
            except:
                raise RuntimeError(f"Invalid ssh command: {ssh_cmd}")
        elif m := re.search("(\\w+)@([\\w.]+)", s):
            user = m.group(1)
            host = m.group(2)
    
    if not user: raise RuntimeError("ssh_cmd: user not found")
    if not host: raise RuntimeError("ssh_cmd: host not found")
    if not port: raise RuntimeError("ssh_cmd: port not found")

    key = (user, host, port)
    if key in connected_clients:
        client = connected_clients[key]
        if is_connected(client):
            print(f"Already connected to {user}@{host}:{port}")
            return client
        else: del connected_clients[key]
    
    cmd = f"{ssh_cmd} -fNMS {CONTROL_SOCKET} -o StrictHostKeyChecking=accept-new"
    subprocess.run(cmd.split())

    setup_script = SETUP_SCRIPT_TEMPLATE.format(py=py)
    result = remote_python(setup_script, user, host, port, py)
    config_json = json.loads(result)

    # port forwarding
    cmd = f"ssh -S {CONTROL_SOCKET} -O forward " + " ".join(f"-L {p}:localhost:{p}" for p in range(51001, 51006)) + f" {user}@{host} -p {port}"
    subprocess.run(cmd.split())
    
    kernel = BlockingKernelClient()
    kernel.load_connection_info(config_json)
    kernel.start_channels()

    timeout = 5
    while True:
        try:
            kernel.wait_for_ready(timeout=timeout)
            break
        except RuntimeError:
            print(f"Connection to {user}@{host}:{port} timed out. Trying again with {timeout} seconds")
            timeout *= 2

    print(f"Connected to IPython kernel at {user}@{host}:{port}")

    client = Client(user=user, host=host, port=port, py=py, kernel=kernel)
    connected_clients[key] = client
    
    @register_cell_magic
    def remote(line, cell):
        run_remote_blocking(client, cell)
    return client

In [None]:
#| export
def is_connected(client):
    """
    Check if the client is currently connected.
    A client is considered connected if both of the following are true:
    1. The port forwarding ssh control socket is alive.
    2. I/O channels to the IPython kernel are alive.
    """
    cs_exists = os.path.exists(CONTROL_SOCKET)
    return cs_exists and client.kernel.is_alive()

In [None]:
#| export
def run_remote_blocking(client:Client, code:str):
    """
    Run code on the remote machine associated with this client.
    Use the cell magic instead of directly calling this function.
    """
    client.kernel.execute(code)
    while True:
        msg = client.kernel.get_iopub_msg(timeout=5)
        msg_type = msg["msg_type"]
        if msg_type == "status":
            if msg["content"]["execution_state"] == "idle": break
        elif msg_type == "stream":
            content = msg["content"]
            print(content["text"], file=getattr(sys, content["name"]), end='')
        elif msg_type == "error":
            print('\n'.join(msg["content"]["traceback"]), file=sys.stderr)
        elif msg_type == "execute_result":
            print(msg["content"]["data"]["text/plain"])
        else:
            # Do nothing for now
            pass

In [None]:
from queue import Empty

def _debug_eat_pending_messages(client:Client):
    while True:
        try:
            msg = client.kernel.get_iopub_msg(timeout=1)
            print(msg['msg_type'])
        except Empty:
            print("Queue cleared!")
            break

In [None]:
#| export
def disconnect(client:Client):
    """
    Disconnect the client.
    """
    if client.kernel.is_alive(): client.kernel.stop_channels()

    remote_python(CLEANUP_SCRIPT, client.user, client.host, client.port, client.py)
    
    cmd = f"ssh -S /tmp/remoteipy.sock -O exit {client.user}@{client.host} -p {client.port}"
    subprocess.run(cmd.split())

    key = (client.user, client.host, client.port)
    if connected_clients.pop(key, None) is None:
        print("Warning: client was not found in connected clients")