In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

import nest_asyncio


sys.path.insert(0, os.path.abspath('..'))
nest_asyncio.apply()

In [None]:
from asyncio import run, sleep

from asyncssh import connect

In [None]:
from math_rag.application.enums import (
    ApptainerBuildStatus,
    ApptainerOverlayCreateStatus,
)
from math_rag.infrastructure.clients import ApptainerClient


apptainer_client = ApptainerClient()

build_id = await apptainer_client.build(...)
max_retries = 3
retry_count = 0
poll_interval = 3

while True:
    status = await apptainer_client.build_status(build_id)

    if status == ApptainerBuildStatus.DONE:
        break

    if status == ApptainerBuildStatus.FAILED:
        if retry_count < max_retries:
            build_id = await apptainer_client.build(...)
            retry_count += 1

        else:
            raise Exception('Max retries reached')

    await sleep(poll_interval)


async for chunk in apptainer_client.build_result(build_id):
    pass

In [None]:
# less nesting
from contextlib import AsyncExitStack
from typing import AsyncGenerator


HPC_USERNAME = ...
HPC_HOSTNAME = ...


class LocalHPCClient:
    def __init__(self, host: str, username: str):
        self.host = host
        self.username = username

    async def scp(
        self,
        source_stream: AsyncGenerator[bytes, None],
        target_path: Path,
    ):
        async with AsyncExitStack() as stack:
            conn = await stack.enter_async_context(
                connect(self.host, username=self.username)
            )
            sftp = await stack.enter_async_context(conn.start_sftp_client())
            file = await stack.enter_async_context(sftp.open(str(target_path), 'wb'))

            async for chunk in source_stream:
                await file.write(chunk)


class RemoteHPCClient:
    def __init__(self, host: str, username: str):
        self.host = host
        self.username = username

    async def qstat(self):
        async with connect(self.host, username=self.username) as connection:
            result = await connection.run('qstat', check=True)
            stdout = result.stdout.strip()

        return stdout