In [None]:
%pip install transformers torch ipywidgets pillow

In [1]:
from ipywidgets import FileUpload, Button, Label, HBox
from IPython.display import display, clear_output
import subprocess
import os
import time
from transformers import AutoFeatureExtractor, ResNetForImageClassification

In [2]:
def initialize_globals():
    global username, remote_host, remote_path, local_model_path, job_id
    
    username = "gburdell3"  # replace with your own username
    remote_host = "login-ice.pace.gatech.edu"  # don't change this!
    remote_path = f"/home/hice1/{username}/cybershuttle_project"
    local_model_path = "./my_local_model"  # don't change this!
    job_id = None

In [3]:
def setup_ui():
    global message, uploader, submit_button
    
    message = Label('‚è≥ Please upload an image file to process')
    uploader = FileUpload(accept='image/*', multiple=False)
    submit_button = Button(description="Process Image", button_style='success')
    
    submit_button.on_click(on_submit_clicked)
    display(message, uploader, submit_button)

In [4]:
def on_submit_clicked(b):
    if not uploader.value:
        message.value = "‚ùó No file uploaded. Please upload an image first."
        return
    
    # Clear output and start processing
    clear_output()
    
    # Get uploaded file data
    file_item = list(uploader.value.items())[0]
    uploaded_filename = file_item[0]
    uploaded_content = file_item[1]['content']
    
    # Display the UI again after clearing
    display(message, uploader, submit_button)
    message.value = "‚è≥ Processing in progress..."
    submit_button.disabled = True
    uploader.disabled = True
    
    print(f"üîÑ Processing file: {uploaded_filename}...")
    
    # Save the uploaded content
    with open('input.png', 'wb') as f:
        f.write(uploaded_content)
    
    print(f"‚úÖ File '{uploaded_filename}' saved as 'input.png'")
    
    # Run each step in sequence
    initialize_globals()
    check_local_model()
    check_remote_project()
    check_remote_env()
    check_remote_model()
    upload_job_files()
    submit_job()
    monitor_job()
    download_results()
    display_results()
    
    # Re-enable UI elements
    message.value = "‚úÖ Processing complete! Upload another file to process again."
    submit_button.disabled = False
    uploader.disabled = False

In [5]:
# Cell 5: Check local model
def check_local_model():
    print(f"üîç Checking if local model folder exists at {local_model_path}...")

    if not os.path.exists(local_model_path):
        print("‚ö° Local model folder not found. Downloading ResNet-50 model...")

        model_name = "microsoft/resnet-50"
        model = ResNetForImageClassification.from_pretrained(model_name)
        feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)

        os.makedirs(local_model_path, exist_ok=True)
        model.save_pretrained(local_model_path)
        feature_extractor.save_pretrained(local_model_path)

        print("‚úÖ Model downloaded and saved locally.")
    else:
        print("‚úÖ Local model folder already exists. Skipping download.")


In [6]:
def check_remote_project():
    print(f"üîç Checking if project folder exists at {remote_path}...")

    check_project = subprocess.run(
        ["ssh", f"{username}@{remote_host}", f"test -d {remote_path}"],
        capture_output=True
    )

    if check_project.returncode != 0:
        print("‚ö° Project folder not found. Creating...")
        subprocess.run(
            ["ssh", f"{username}@{remote_host}", f"mkdir -p {remote_path}"],
            text=True,
            shell=False
        )
        print("‚úÖ Project folder created.")
    else:
        print("‚úÖ Project folder already exists.")

In [7]:
def check_remote_env():
    print(f"üîç Checking if virtual environment exists at {remote_path}/myenv...")

    check_env = subprocess.run(
        ["ssh", f"{username}@{remote_host}", f"test -d {remote_path}/myenv"],
        capture_output=True
    )

    if check_env.returncode != 0:
        print("‚ö° Virtual environment not found. Setting up myenv...")

        setup_env_commands = (
            f"cd {remote_path} && "
            "module load python/3.10 && "
            "python -m venv myenv && "
            "source myenv/bin/activate && "
            "pip install --upgrade pip && "
            "pip install torch torchvision transformers pillow"
        )

        subprocess.run(
            ["ssh", f"{username}@{remote_host}", f"bash -l -c '{setup_env_commands}'"],
            text=True,
            shell=False
        )

        print("‚úÖ Virtual environment created and packages installed.")
    else:
        print("‚úÖ Virtual environment already exists. Skipping setup.")

In [8]:
def check_remote_model():
    print(f"üîç Checking if model directory exists at {remote_path}/my_local_model...")

    check_model = subprocess.run(
        ["ssh", f"{username}@{remote_host}", f"test -d {remote_path}/my_local_model"],
        capture_output=True
    )

    if check_model.returncode != 0:
        print("‚ö° Model not found on ICE. Uploading local my_local_model/ folder...")

        subprocess.run(
            ["scp", "-r", local_model_path, f"{username}@{remote_host}:{remote_path}/my_local_model"],
            text=True,
            shell=False
        )

        print("‚úÖ Model uploaded to ICE HPC.")
    else:
        print("‚úÖ Model already exists on ICE. Skipping upload.")

In [9]:
def upload_job_files():
    print("üì§ Uploading job_script.sh to ICE...")

    upload_result = subprocess.run(
        ["scp", "job_script.sh", "run_model.py", "input.png", f"{username}@{remote_host}:{remote_path}/"],
        capture_output=True,
        text=True
    )

    if upload_result.returncode == 0:
        print("‚úÖ job_script.sh uploaded successfully.")
    else:
        print("‚ùó Failed to upload job_script.sh. See error:")
        print(upload_result.stderr)
        raise RuntimeError("Failed to upload job script.")

In [10]:
def submit_job():
    global job_id
    
    print("üöÄ Submitting job...")

    result = subprocess.run(
        ["ssh", f"{username}@{remote_host}", f"cd {remote_path} && sbatch job_script.sh"],
        capture_output=True,
        text=True
    )

    # Print output for debugging
    print("STDOUT:")
    print(result.stdout)

    print("STDERR:")
    print(result.stderr)

    # Check if job submission succeeded
    if "Submitted batch job" in result.stdout.strip():
        job_id = result.stdout.strip().split()[-1]
        print(f"üöÄ Job ID submitted: {job_id}")
    else:
        print("‚ùó sbatch did not submit correctly. Please check error above.")

In [11]:
def monitor_job():
    if not job_id:
        print("‚ùó No job ID available to monitor.")
        return
        
    print("‚è≥ Monitoring job status...")
    while True:
        queue_check = subprocess.run(
            ["ssh", f"{username}@{remote_host}", f"squeue -u {username} | grep {job_id}"],
            capture_output=True,
            text=True
        )
        
        if queue_check.stdout == "":
            print("‚úÖ Job completed!")
            break
        else:
            print("‚è≥ Job still running...waiting 10 seconds...")
            time.sleep(10)


In [12]:
def download_results():
    print("üì• Downloading results...")
    subprocess.run(["scp", f"{username}@{remote_host}:{remote_path}/output.txt", "."])
    subprocess.run(["scp", f"{username}@{remote_host}:{remote_path}/error.log", "."])
    print("‚úÖ Files downloaded.")
    
    print("üßπ Cleaning up remote output files...")
    subprocess.run(["ssh", f"{username}@{remote_host}", f"rm -f {remote_path}/output.txt {remote_path}/error.log"])


In [None]:
def display_results():
    print("üìÑ Prediction Result:")
    try:
        with open('output.txt', 'r') as f:
            result_content = f.read()
        print(result_content)
    except FileNotFoundError:
        print("‚ùó Output file not found. Check for errors in the process.")

In [None]:
setup_ui()

Label(value='‚è≥ Please upload an image file to process')

FileUpload(value=(), accept='image/*', description='Upload')

Button(button_style='success', description='Process Image', style=ButtonStyle())