In [None]:

import sys
import os
import luigi
import sciluigi as sl
import logging
import yaml
import paramiko
import time

sys.path.append('..')
# everyone needs to be quite 
logging.getLogger().setLevel(logging.WARNING)
logging.getLogger('SMB').setLevel(logging.WARNING)
logging.getLogger('napari').setLevel(logging.WARNING)
logging.getLogger('matplotlib').setLevel(logging.WARNING)
logging.getLogger('in_n_out').setLevel(logging.WARNING)
logging.getLogger('numcodecs').setLevel(logging.WARNING)
logging.getLogger('numba').setLevel(logging.WARNING)
logging.getLogger('luigi').setLevel(logging.WARNING)
logging.getLogger('numexpr').setLevel(logging.WARNING)
logging.getLogger('luigi-interface').setLevel(logging.WARNING)
logging.getLogger('sciluigi-interface').setLevel(logging.WARNING)
logging.getLogger('cellpose').setLevel(logging.WARNING)

from src import Receipt, NASConnection
from src.steps import get_task

In [None]:
os.environ["SSHPASS"] = '*************'

In [None]:
from luigi.contrib.ssh import RemoteTarget
import subprocess

class AngelFISHLuigiTask(sl.Task):
    receipt_path = luigi.Parameter()
    remote_path = luigi.Parameter()
    step_name = luigi.Parameter()
    output_path = luigi.Parameter()

    def out_doneflag(self):
        return sl.TargetInfo(self, RemoteTarget(self.output_path, host='keck.engr.colostate.edu', username='formanj', sshpass=True))
        # return sl.TargetInfo(self, self.output_path)

    def run(self):
        # Load the configuration
        conf = yaml.safe_load(open(str(r'C:\Users\formanj\GitHub\AngelFISH\config_cluster.yml')))
        usr = str(conf['user']['username'])
        pwd = str(conf['user']['password'])
        remote_address = str(conf['user']['remote_address'])
        port = 22


        # Create SSH client
        ssh = paramiko.SSHClient()
        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        ssh.connect(remote_address, port, usr, pwd)

        remote_receipt_path = os.path.basename(self.receipt_path)
        # remote_receipt_path = remote_receipt_path.replace('\\', '/')

        # Submit the SLURM job and capture job ID
        sbatch_command = f'sbatch --parsable run_step.sh {remote_receipt_path} {self.step_name}'
        combined_command = f'cd {self.remote_path}; {sbatch_command}'

        stdin, stdout, stderr = ssh.exec_command(combined_command)
        job_submission_output = stdout.read().decode().strip()
        job_submission_error = stderr.read().decode()

        # Parse job ID
        if job_submission_error:
            raise RuntimeError(f"SLURM job submission failed: {job_submission_error}")

        job_id = job_submission_output.split(';')[0]
        print(f"Submitted SLURM job with ID: {job_id}")

        # Poll until job completes
        import time

        def is_job_active(ssh_client, job_id):
            check_command = f'squeue -j {job_id} -h'
            stdin, stdout, stderr = ssh_client.exec_command(check_command)
            result = stdout.read().decode()
            return bool(result.strip())  # job is active if output is non-empty

        wait_time = 10  # seconds
        max_wait = 3600  # 1 hour timeout
        elapsed = 0

        while is_job_active(ssh, job_id):
            if elapsed >= max_wait:
                raise TimeoutError(f"SLURM job {job_id} did not complete within {max_wait} seconds.")
            print(f"[{job_id}] Still running... waiting {wait_time}s")
            time.sleep(wait_time)
            elapsed += wait_time

        print(f"[{job_id}] Job complete.")

        # Close the SSH connection
        ssh.close()

        time.sleep(1) # this was annoying to discover

        # with self.out_doneflag().open('w') as f:
        #     f.write('complete')

In [None]:
class Upload_Task(sl.Task):
    output_path = luigi.Parameter()
    receipt_path = luigi.Parameter()
    remote_path = luigi.Parameter()

    def out_doneflag(self):
        return sl.TargetInfo(self, RemoteTarget(self.output_path, host='keck.engr.colostate.edu', username='formanj', sshpass=True))
        # return sl.TargetInfo(self, self.output_path)
    
    def run(self):
        conf = yaml.safe_load(open(str(r'C:\Users\formanj\GitHub\AngelFISH\config_cluster.yml')))
        usr = str(conf['user']['username'])
        pwd = str(conf['user']['password'])
        remote_address = str(conf['user']['remote_address'])
        port = 22

        # Create SSH client
        ssh = paramiko.SSHClient()
        ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())
        ssh.connect(remote_address, port, usr, pwd)

        # Remote path to file with all directories
        remote_receipt_path = os.path.join(self.remote_path, os.path.basename(self.receipt_path))
        remote_receipt_path = remote_receipt_path.replace('\\', '/')

        # Transfer the file
        sftp = ssh.open_sftp()
        sftp.put(self.receipt_path, remote_receipt_path)
        sftp.close()

        time.sleep(1) # this was annoying to discover

        # with self.out_doneflag().open('w') as f:
        #     f.write('complete')

In [None]:
class AngelFISHWorkflow(sl.WorkflowTask):
    receipt_path = luigi.Parameter() # this location will get change along the execution of the workflow

    def workflow(self):
        # upload 
        task_refs = []
        cluster_path = '/home/formanj/Github/AngelFISH/cluster'
        cluster_path = cluster_path.replace('\\', '/')
        receipt_filename = os.path.basename(self.receipt_path).replace('\\', '/')
        remote_receipt_path = os.path.join(cluster_path, receipt_filename).replace('\\', '/')
        # status_dir = r'C:\Users\formanj\GitHub\AngelFISH\cluster\status'
        
        step_task = self.new_task(
                            'upload_task',
                            Upload_Task,
                            output_path=remote_receipt_path, # os.path.join(status_dir, 'upload_receipt.done'),
                            receipt_path=self.receipt_path,
                            remote_path=cluster_path
                        )
        previous_task = step_task
        task_refs.append(step_task)

        receipt = Receipt(path=self.receipt_path)
        step_order = receipt['step_order']
        name = os.path.basename(receipt['meta_arguments']['nas_location'])
        database_loc = os.path.dirname(cluster_path)
        database_loc = os.path.join(database_loc, 'database')
        remote_local_location = os.path.join(database_loc, name).replace('\\', '/')
        remote_analysis_dir = os.path.join(remote_local_location, receipt['meta_arguments']['analysis_name'])
        remote_status_dir = os.path.join(remote_analysis_dir, 'status')

        for step_name in step_order:
            path = os.path.join(remote_status_dir, f'step_{step_name}.txt').replace('\\', '/')
            step_task = self.new_task(
                                step_name,
                                AngelFISHLuigiTask,
                                receipt_path=remote_receipt_path,
                                step_name=step_name,
                                output_path=path, # os.path.join(status_dir, f'{step_name}.done'),
                                remote_path=cluster_path
                            )

            # Add dependency chain
            if previous_task is not None:
                step_task.in_upstream = previous_task.out_doneflag
            previous_task = step_task
            task_refs.append(step_task)
        return task_refs

In [None]:
luigi.build(AngelFISHWorkflow(receipt_path=r'C:\Users\formanj\GitHub\AngelFISH\examples\new_pipeline.json').workflow(), local_scheduler=True)

In [None]:
AngelFISHWorkflow(receipt_path=r'C:\Users\formanj\GitHub\AngelFISH\examples\new_pipeline.json').workflow()[0].output()[0].exists()

In [None]:
RemoteTarget(r'/home/formanj/Github/AngelFISH/database/SS004_10min_100nM_7/default_name/status/step_segment_nuc.txt', host='keck.engr.colostate.edu', username='formanj', sshpass=True).exists()

In [None]:
!sshpass -e ssh formanj@keck.engr.colostate.edu echo "hello"