# core

> Fill in a module description here

In [None]:
#| default_exp inference.multinode_from_aiop_tool

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from multiprocessing import Value, Queue
import subprocess
import signal
from threading import Thread
import tqdm
import multiprocessing
import os
import time
import sys
import uuid
import re
import select
from functools import reduce
from colorama import Fore, Style

queue = multiprocessing.Queue()


def print_status(current=None, total=None, description=None):
    """
    This function can be used by a python script used in HPC_job to share its progress with
    the DistributeHPC class. It will be used to update the progress bar.
    """
    if total is None:
        total = ""
    if description is None:
        description = ""
    if current is None:
        current = ""
    print(f"STATUS, {current}, {total}, {description}")


class HPC_Job:
    """
    Represents a High-Performance Computing (HPC) job with attributes and methods 
    to manage its state, commands, and submission to an HPC system.

    Attributes:
        state (int): The current state of the job.
        uuid (str): A unique identifier for the job.
        command (list): The command to be executed by the job.
        hpc_command (list): The full HPC command including submission parameters.
        description (str): A description of the job shown in the progress bar.
        status_current (str): Current Progress: n in n/N shown in the prorgress bar.
        status_total (str): Current Progress: N in n/N shown in the prorgress bar.
        lsf_job_id (int or None): The job ID assigned by the LSF scheduler.
        lsf_job_queue (str or None): The queue to which the job is submitted.
        bsub_error_msg (str or None): Error message from the `bsub` command, if any.

    Constants:
        JOB_NONE (int): Represents a job with no state.
        JOB_SUBMITTED (int): Represents a job that has been submitted.
        JOB_WAITING (int): Represents a job that is waiting to run.
        JOB_RUNNING (int): Represents a job that is currently running.
        JOB_COMPLETED (int): Represents a job that has completed successfully or failed.
        JOB_BSUB_FAILED (int): Represents a job where the job submission failed.
        JOB_TASK_FAILED (int): Represents a job where a task failed.

        BSUB_ARGS_DEFAULT (dict): Default arguments for the `bsub` command.
        BSUB_ARGS_DEFAULT_GPU (dict): Default arguments for GPU-based `bsub` commands.

    """
    state = 0
    uuid = ""
    command = []
    hpc_command = []
    description = ""
    status_current = ""
    status_total = ""
    lsf_job_id = None
    lsf_job_queue = None
    bsub_error_msg = None

    JOB_NONE        = 0x0001
    JOB_SUBMITTED   = 0x0002
    JOB_WAITING     = 0x0004
    JOB_RUNNING     = 0x0008
    JOB_COMPLETED   = 0x0010
    JOB_BSUB_FAILED = 0x1000
    JOB_TASK_FAILED = 0x2000

    BSUB_ARGS_DEFAULT = {
        "-q": "batch",
        "-R": "ui=aiml_batch_training_dy && um=background && osrel>=70",
    }

    BSUB_ARGS_DEFAULT_GPU = {
        "-q": "gpu",
        "-gpu": "num=1:j_exclusive=yes",
        "-R": "osrel>=70 && type=any"        
    }


    def __init__(self, cmd=[], cores=4, bsub_args=None):
        """
        Initializes an HPC_Job instance with the specified command, number of cores, 
        and optional bsub arguments.

        Args:
            cmd (list, optional): The command to be executed on the HPC system. 
                                  e.g. ["python", "script.py", "--arg1", "value1"]
            cores (int, optional): The number of cores to allocate for the job. 
                                   Defaults to 4.
            bsub_args (dict, optional): Additional arguments for the `bsub` command. 
                                        If not provided, defaults to `HPC_Job.BSUB_ARGS_DEFAULT`.
        """
        if bsub_args is None:
            bsub_args = HPC_Job.BSUB_ARGS_DEFAULT.copy()
        bsub_args["-n"] = str(cores) 

        params = [[str(x),str(y)] for x,y in bsub_args.items()]
        params = reduce(lambda x,y: x + y, params)

        cmd = [str(x) for x in cmd]
        hpc_command = ["bsub", "-Is"] + params + cmd

        self.hpc_command = hpc_command
        self.command = cmd
        self.uuid = str(uuid.uuid4())
        self.state = HPC_Job.JOB_WAITING


class DistributeHPC(object):
    """
    DistributeHPC is a class designed to manage and execute distributed High-Performance Computing (HPC) jobs. 
    It provides functionality for job submission, progress tracking. The class is particularly useful for 
    running multiple computational tasks concurrently while monitoring their progress in real-time.
    
    Methods:
        __init__(worker=10):
            Initializes the DistributeHPC instance with the specified number of worker processes.
        set_jobs_hpc(jobs, num_cpu=4):
            Deprecated. Use `set_jobs` instead.
        set_jobs(jobs, num_cpu=4):
            Adds jobs to the job list. Jobs can either be instances of `HPC_Job` or lists of commands.
        start():
            Starts the distributed HPC job processing workflow. Manages job execution using a multiprocessing 
            pool and visualizes progress with progress bars.
    """

    jobs = []
    worker = 10
    closed = None
    
    def __init__(self, worker=10):
        self.worker = worker
        self.closed = Value('i', 0)
        
    def _show_progress_bar(self):
        self.thread = Thread(target=self._update, args=(self.closed, self.jobs, self.worker))
        self.closed.value = 0
        self.thread.daemon = True
        self.thread.start()
        
    def _hide_progress_bar(self):
        self.closed.value = 1
        self.thread.join()   
        self.thread = None

    def _signal_handler(self, sig, frame):
        self._hide_progress_bar()
        sys.exit(0)
        

    @staticmethod
    def _submit_hpc_job(args):
        """
        Execute a job assigned to a working during Pool.imap().
        A worker got assigned to this job and will now be executed.

        (Called by start() during job submission to the processing pool)
        
        Args:
            args (tuple): A tuple containing the job (HPC_Job instance) and the multiprocessing.Queue 
                          for communication between processes.

        Workflow:
            1. Adds the initial job state to the queue.
            2. Executes the job using the `subprocess.Popen` method with the specified HPC command.
            3. Monitors the job's output and error streams to update its state and progress:
                - Detects job submission, waiting, and running states based on specific keywords.
                - Captures job ID and queue information from the output.
                - Tracks progress updates using "STATUS" messages.
                - Handles errors and stack traces, marking the job as failed if necessary.
            4. Writes output and error logs to files in the `lsf_logs` directory.
            5. Updates the job state to indicate completion and adds it to the queue.

        """
        job = args

        # Add initial queue entry
        err_handle = None
        out_handle = None

        try:

            queue.put(job) # Announce the job to the queue
            p = subprocess.Popen(job.hpc_command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, preexec_fn=os.setpgrp, text=True) 

            # Poll results
            while True:

                # Check if the process is still running (Strange, but this is needed to avoid deadlocks)
                if p.poll() is not None:
                    break

                # Check if there is output to read (We need select to check for both stdin and stderr, without blocking)
                readable, _, _ = select.select([p.stdout, p.stderr], [], [], 0.01)

                # Check if there is data to read from the stderr stream
                # Try to parse it, usually bsub uses stderr for job submission status
                if p.stderr in readable:
                    s = p.stderr.readline()
                    if "waiting for dispatch" in s.lower():
                        job.state |= HPC_Job.JOB_WAITING
                        queue.put(job)

                    if "starting on" in s.lower():
                        job.state |= HPC_Job.JOB_RUNNING
                        queue.put(job)

                    if re.match(r"^Job <[0-9]*>.*", s.strip(), re.IGNORECASE):
                        job.state |= HPC_Job.JOB_SUBMITTED
                        job_str = re.sub(r"^[^0-9]*([0-9]*)[^<]*<([^>]*).*", "\\1|\\2", s.lower().strip(), re.IGNORECASE)
                        job.lsf_job_id, job.lsf_job_queue = job_str.split("|")
                        os.makedirs("lsf_logs", exist_ok=True)
                        err_handle = open(f'lsf_logs/lsf_{job.lsf_job_queue}_{job.lsf_job_id}.err', 'a+')
                        out_handle = open(f'lsf_logs/lsf_{job.lsf_job_queue}_{job.lsf_job_id}.out', 'a+')
                        queue.put(job)
                    if err_handle is not None:
                        err_handle.write(s)

                # Check if there is data to read from the stdout stream
                # If there is Traceback somewhere in the output, we assume the job failed
                # If there is a STATUS line, we assume the job is running and its a progress update
                if p.stdout in readable:
                    s = p.stdout.readline()
                    if s.startswith("Traceback"):
                        job.state |= HPC_Job.JOB_TASK_FAILED
                        queue.put(job)
                        p.terminate()

                    if s.startswith("STATUS,"):
                        details = s.split(",")
                        details = [x.strip() for x in details]
                        job.status_current, job.status_total, job.description = details[-3:]
                        queue.put(job)
                    if out_handle is not None:
                        out_handle.write(s)

        except Exception as e:
            job.bsub_error_msg = str(e)
            job.state |= HPC_Job.JOB_BSUB_FAILED

        if err_handle is not None:
            err_handle.close()
        if out_handle is not None:
            out_handle.close()

        # Add queue entry to indicate completion
        job.state |= HPC_Job.JOB_COMPLETED
        queue.put(job)

        p.wait()

        
    @staticmethod
    def _clear_progress_bar(bar):
        """
        Resets and clears the progress bar's state and appearance.

        This is needed if more workers are there then open tasks

        Args:
            bar (tqdm.tqdm): The progress bar instance to be cleared. This is 
            an object from the `tqdm` library.

        """
        bar.n = 0
        bar.total = None
        bar.bar_format = ""
        bar.set_description(Style.RESET_ALL + "")
        bar.update(0)

    @staticmethod
    def _update_main_progress_bar(bar, all_jobs):
        """
        Updates the main progress bar with the current status of HPC jobs.

        Args:
            bar (tqdm.tqdm): The progress bar object to be updated.
            all_jobs (dict): A dictionary containing all HPC jobs, where the keys are job IDs 
                             and the values are HPC_Job objects.

        The function performs the following:
            - Counts the number of jobs in different states (running, completed, task failures, bsub failures).
            - Updates the progress bar to reflect the number of completed jobs.
            - Sets a descriptive text for the progress bar, including the number of running, completed, 
              and failed jobs. If there are bsub failures, their error messages are also included.
        """

        # Extract run-details
        cnt_task_fails   = list(filter(lambda x: x.state & HPC_Job.JOB_TASK_FAILED, all_jobs.values()))
        cnt_bsub_fails   = list(filter(lambda x: x.state & HPC_Job.JOB_BSUB_FAILED, all_jobs.values()))
        tasks_running = list(filter(lambda x: not (x.state & HPC_Job.JOB_COMPLETED), all_jobs.values()))
        tasks_done    = list(filter(lambda x:     (x.state & HPC_Job.JOB_COMPLETED), all_jobs.values()))

        bar.update(len(tasks_done) - bar.n)
        desc = f"{Style.BRIGHT}{Fore.CYAN}RUNNING:{len(tasks_running)}, DONE:{len(tasks_done)}{Style.RESET_ALL}"
        if len(cnt_task_fails):
            desc += f"\033[31m, FAILED:{len(cnt_task_fails)}\033[0m"

        if cnt_bsub_fails:
            err = ", ".join([x.bsub_error_msg for x in cnt_bsub_fails])
            desc += f"\033[31m, FAILED:{err}\033[0m"

        bar.set_description(desc)

    @staticmethod
    def _update_progress_bar(bar, job):
        """
        Updates a progress bar to reflect the current state and progress of a job.
        Args:
            bar (tqdm.tqdm): The progress bar object to be updated.
            job (HPC_Job): The job object containing the current state, progress, and metadata.
        Behavior:
            - Updates the progress bar's total and current progress based on the job's status.
            - Sets the progress bar's description to include the job's state, command, and description.
            - Colors the description text based on the job's state for better visual feedback.
            - Handles special cases such as task failures by appending additional information to the description.
        """

        # Define job states and their corresponding colors
        job_states = {            
            HPC_Job.JOB_SUBMITTED   : ("submitted", Fore.YELLOW),
            HPC_Job.JOB_WAITING     : ("wait", Fore.YELLOW),
            HPC_Job.JOB_RUNNING     : ("run", Fore.GREEN),
            HPC_Job.JOB_COMPLETED   : ("done", Style.RESET_ALL),
            HPC_Job.JOB_BSUB_FAILED : ("fail", Fore.RED),
            HPC_Job.JOB_TASK_FAILED : ("FAIL",Fore.RED),
        }

        # Update the description
        bar.bar_format = "{l_bar}{bar}|{n_fmt}/{total_fmt} [{elapsed}<{remaining}]"

        if job.status_total != "": 
            bar.total = int(job.status_total)
        
        if job.status_current != "":
            bar.update(int(job.status_current) - bar.n)
        
        # Determine the color and state based on the job's state
        state_color = Style.RESET_ALL
        state_str = ""
        for state, (_str, _style) in job_states.items():
            if job.state & state:
                state_color = _style
                state_str = _str

        #description = f"{state_color}{job.lsf_job_id}<{job.lsf_job_queue}>{job.description} [{job.state}]"
        cmd_str = " ".join(job.command)[0:50]

        description = f"{state_color}[{state_str}]: {cmd_str} >> {job.description}"
        if job.state & HPC_Job.JOB_TASK_FAILED:
            description += " [Stacktrace!] "
        description += f"{Style.RESET_ALL}"
        bar.set_description(description)

    @staticmethod
    def _update(closed, jobs, workers):
        """
        Updates the progress of distributed jobs and manages progress bars for workers.
        This function continuously monitors the status of jobs in a distributed HPC 
        (High-Performance Computing) environment. It updates progress bars for individual 
        workers and a main progress bar for all jobs. 
        Args:
            closed (multiprocessing.Value): A shared value indicating whether the 
                processing has been closed. A value of 0 means processing is ongoing, 
                while a non-zero value indicates closure.
            queue (multiprocessing.Queue): A queue containing job updates (HPC_Job object). 
            jobs (list): A list of all jobs.
            workers (int): The number of workers processes handling the jobs.
        Behavior:
            - Creates and manages progress bars for the total job progress and individual workers.
            - Updates progress bars periodically (every 0.2 seconds) to reflect the current 
              state of jobs.
            - Cleans up progress bars for completed jobs.
            - Closes all progress bars once processing is complete.
        """

        all_jobs = {} # Tracking all jobs, received by the queue

        # Create progress bars
        bars  = [tqdm.tqdm(total=len(jobs), leave=True, desc="Total")] # Main progress bar
        bars += [tqdm.tqdm(leave=True) for i in range(workers)]        # Worker progress bars
        bar_assignement = dict([(i+1, None) for i in range(workers)])  # Bar-id to job-uuid mapping, to keep order while rendering

        _last_update = 0
        _lastUUID = None
        
        while closed.value == 0 or not queue.empty():

            # Get queue entry and add it to the status dict
            job = None
            if not queue.empty():
                job = queue.get_nowait()
            if job is None:
                time.sleep(0.01)
                continue
 
            # Log job details
            all_jobs[job.uuid] = job

            # Update only every n-seconds
            if _last_update > (time.time() - 0.5) and _lastUUID == job.uuid:
                continue
            _last_update = time.time()
            _lastUUID = job.uuid

            # Clean up progress bars
            for i in range(workers):
                if bar_assignement[i+1] is not None:
                    uuid = bar_assignement[i+1]
                    if uuid in all_jobs:
                        _job = all_jobs[uuid]
                        if _job.state & HPC_Job.JOB_COMPLETED:
                            bar_assignement[i+1] = None
                            DistributeHPC._clear_progress_bar(bars[i+1])

            # Add this job.uuid if not already assigned
            if job.uuid not in bar_assignement.values():
                for i in range(workers):
                    if bar_assignement[i+1] is None:
                        bar_assignement[i+1] = job.uuid
                        break

            # Update task progress bar
            barid = [x for x,y in bar_assignement.items() if y == job.uuid]
            if len(barid):
                DistributeHPC._update_progress_bar(bars[barid[0]], job)

            # Update main progress bar
            DistributeHPC._update_main_progress_bar(bars[0], all_jobs)

        # Close all progress bars
        DistributeHPC._update_main_progress_bar(bars[0], all_jobs)
        _ = [DistributeHPC._clear_progress_bar(x) for x in bars[1:]]            
        _ = [x.close() for x in bars]    

    # Legacy function 
    def set_jobs_hpc(self, jobs, num_cpu=4):
        print("WARNING: set_jobs_hpc is deprecated. Use set_jobs instead.")
        self.set_jobs(jobs)        

    def set_jobs(self, jobs, num_cpu=4):
        """
        Adds jobs to the job list

        Args:
            jobs (list): A list of jobs to be added. Each job can either be an instance
                         of `HPC_Job` or a [list] of commands.
        """
        for j in jobs:
            if isinstance(j, HPC_Job):
                self.jobs.append(j)
            else:
                _j = HPC_Job(cmd=j, cores=num_cpu)
                self.jobs.append(_j)

        
        
    def start(self):
        """
        Starts the distributed HPC job processing workflow.
        This method initializes and manages the execution of HPC jobs using a multiprocessing pool.
        It also handles progress visualization and ensures proper cleanup of resources.
        Workflow:
        1. Displays a progress bar to track job execution.
        2. Sets up a signal handler for graceful interruption handling.
        3. Creates a multiprocessing pool to process jobs concurrently.
        4. Submits jobs to the pool.
        5. Ensures proper cleanup of the pool and hides the progress bar after execution.
        Note:
            - The number of worker processes is determined by the `self.worker` attribute.
            - Each job is submitted along with the shared queue (`self.queue`).
            - The `_submit_hpc_job` method is used to process individual jobs.
        Raises:
            Any exceptions raised during job submission or processing will propagate.
        """

        # Create a thread to plot progress
        self._show_progress_bar()
        
        # Create a pool to process multiple tasks
        signal.signal(signal.SIGINT, self._signal_handler)
        pool = multiprocessing.Pool(processes=self.worker, initargs=self)
        results = pool.imap(self._submit_hpc_job, [x for x in self.jobs])
        for r in results:
            pass # Not interested in the result

        pool.close()
        time.sleep(0.5)
        pool.terminate()

        # Hide progress bar
        self._hide_progress_bar()

In [None]:
#| hide
#import nbdev; nbdev.nbdev_export('hpc.ipynb')