# Lab1 (deployment): Provision resources for running the Progen2 model on AWS Batch

This notebook will guide you through setting up AWS Batch infrastructure to run Progen2 jobs with parameters. We'll create all necessary resources step by step.

#### Prerequisites
- IAM roles configured with appropriate permissions


## Step 1: Setup and Configuration

First, let's get our AWS account information and set up variables we'll use throughout the notebook.

In [None]:
import boto3
import json
import time
import datetime
from utils.iam_helper import IamHelper

##########################################################

# Get AWS account information
sts_client = boto3.client('sts')
account_id = sts_client.get_caller_identity()['Account']
region = boto3.Session().region_name

# Define S3 bucket and folder names
S3_BUCKET = f'workshop-data-{account_id}'
LAB1_FOLDER = 'lab1-progen'
LAB2_FOLDER = 'lab2-amplify'
LAB3_FOLDER = 'lab3-esmfold'

print(f"Account ID: {account_id}")
print(f"Region: {region}")
print(f"S3 Bucket: {S3_BUCKET}")

##########################################################

# Define URI of the Progen2 image registered in ECR
ECR_IMAGE_URI = f"{account_id}.dkr.ecr.{region}.amazonaws.com/models/progen2:latest"                 
print(f"ECR Image URI: {ECR_IMAGE_URI}")

# Retrieve ARNs of IAM roles required for provisioning Batch resources
iam_helper = IamHelper()
batch_service_role_arn = iam_helper.find_role_arn_by_pattern('BatchServiceRole')
instance_profile_arn = f"arn:aws:iam::{account_id}:instance-profile/EcsInstanceProfile"
job_role_arn = iam_helper.find_role_arn_by_pattern('BatchJobRole')

print()
print(f"BatchServiceRole ARN: {batch_service_role_arn}")
print(f"Instance Profile ARN: {instance_profile_arn}")
print(f"BatchJobRole ARN : {job_role_arn}")


In [None]:
# Create AWS S3 bucket and local data folders for labs 1-3
!aws s3 mb s3://$S3_BUCKET
!mkdir -p data/$LAB1_FOLDER
!mkdir -p data/$LAB2_FOLDER
!mkdir -p data/$LAB3_FOLDER

## Step 2: Get VPC Information

We need to identify the VPC and subnets where our Batch compute environment will run. We'll use the default VPC.

In [None]:
ec2_client = boto3.client('ec2')

# Get default VPC
vpcs = ec2_client.describe_vpcs(Filters=[{'Name': 'isDefault', 'Values': ['true']}])
if not vpcs['Vpcs']:
    raise Exception("No default VPC found. Please create one or specify a custom VPC.")

default_vpc_id = vpcs['Vpcs'][0]['VpcId']
print(f"Default VPC ID: {default_vpc_id}")

# Get subnets in default VPC
subnets = ec2_client.describe_subnets(Filters=[{'Name': 'vpc-id', 'Values': [default_vpc_id]}])
subnet_ids = [subnet['SubnetId'] for subnet in subnets['Subnets']]

print(f"Found {len(subnet_ids)} subnets:")
for subnet_id in subnet_ids:
    print(f"  - {subnet_id}")

if not subnet_ids:
    raise Exception("No subnets found in default VPC")

## Step 3: Create Batch Resources

AWS Batch manages large-scale parallel processing by dynamically scheduling containerized workloads

### Step 3.1: Create Batch Compute Environment

The compute environment defines the compute resources (EC2 instances) that will run Progen2 jobs. We'll create a managed compute environment that automatically scales based on job demand.

In [None]:
batch_client = boto3.client('batch')

# Get default security group for the VPC
security_groups = ec2_client.describe_security_groups(
    Filters=[
        {'Name': 'vpc-id', 'Values': [default_vpc_id]},
        {'Name': 'group-name', 'Values': ['default']}
    ]
)
default_sg_id = security_groups['SecurityGroups'][0]['GroupId']
print(f"Default Security Group ID: {default_sg_id}")

# Create compute environment
compute_env_name = 'progen2-batch-compute-env'
try:
    response = batch_client.create_compute_environment(
        computeEnvironmentName=compute_env_name,
        type='MANAGED',
        state='ENABLED',
        computeResources={
            'type': 'EC2',
            'minvCpus': 0,
            'maxvCpus': 50,
            'desiredvCpus': 0,
            'instanceTypes': ['c6i.xlarge'],
            'subnets': subnet_ids,
            'securityGroupIds': [default_sg_id],  
            'instanceRole': instance_profile_arn,
            'tags': {
                'Name': 'BatchComputeEnvironment',
                'Purpose': 'Inference'
            }
        },
        serviceRole=batch_service_role_arn
    )
    
    print(f"Created compute environment: {response['computeEnvironmentName']}")
    print(f"   ARN: {response['computeEnvironmentArn']}")
    
except batch_client.exceptions.ClientException as e:
    if 'already exists' in str(e):
        print(f"Compute environment {compute_env_name} already exists")
    else:
        raise e

compute_env_arn = f"arn:aws:batch:{region}:{account_id}:compute-environment/{compute_env_name}"

In [None]:
print("Waiting for compute environment to be ready...")
# Use describe_compute_environments to check status
while True:
    response = batch_client.describe_compute_environments(
        computeEnvironments=[compute_env_name]
    )
    
    if response['computeEnvironments']:
        status = response['computeEnvironments'][0]['status']
        print(f"Compute environment status: {status}")
        
        if status == 'VALID':
            break
        elif status in ['INVALID', 'DELETED']:
            raise Exception(f"Compute environment failed with status: {status}")
    
    time.sleep(15)

print("Compute environment is ready")

### Step 3.2: Create Batch Job Queue

When you submit an AWS Batch job, you submit it to a particular job queue, where the job resides until it's scheduled onto a compute environment.
The job queue connects jobs to compute environments. Jobs submitted to this queue will run on the compute environment we just created.

In [None]:
job_queue_name = 'progen2-batch-job-queue'

try:
    response = batch_client.create_job_queue(
        jobQueueName=job_queue_name,
        state='ENABLED',
        priority=1,
        computeEnvironmentOrder=[
            {
                'order': 1,
                'computeEnvironment': compute_env_name
            }
        ]
    )
    
    print(f"Created job queue: {response['jobQueueName']}")
    print(f"   ARN: {response['jobQueueArn']}")
    
except batch_client.exceptions.ClientException as e:
    if 'already exists' in str(e):
        print(f"Job queue {job_queue_name} already exists")
    else:
        raise e

In [None]:
print("Waiting for job queue to be ready...")
# Use describe_job_queues to check status
while True:
    response = batch_client.describe_job_queues(
        jobQueues=[job_queue_name]
    )
    
    if response['jobQueues']:
        status = response['jobQueues'][0]['state']
        print(f"Job queue status: {status}")
        
        if status == 'ENABLED':
            break
        elif status == 'DISABLED':
            raise Exception("Job queue is disabled")
    
    time.sleep(15)

print("Job queue is ready")

### Step 3.3: Create Batch Job Definition

The job definition serves as a blueprint that defines the execution parameters for Progen2 jobs, specifying:

- Container image: The Docker image containing the Progen2 model and dependencies
- Resource allocation: vCPU and memory requirements for protein sequence generation workloads
- Parameter configuration: Placeholders for dynamic inputs like sequence length, temperature, and sampling parameters
- Environment setup: Variables for model paths, AWS credentials, and runtime configurations

This standardized definition ensures consistent job execution across the AWS Batch environment while allowing flexibility for 
different protein engineering tasks.

Note the `Ref::` syntax for parameters - these will be replaced with actual values when you submit jobs.

In [None]:
job_definition_name = 'progen2-job-definition'

try:
    response = batch_client.register_job_definition(
        jobDefinitionName=job_definition_name,
        type='container',
        containerProperties={
            'image': ECR_IMAGE_URI,
            'vcpus': 2,
            'memory': 4096,
            'command': [
                'Ref::hfModelId',
                'Ref::s3InputParamsPath',
                'Ref::batchId',
                'Ref::batchSize',
                'Ref::batchNumber',
                'Ref::s3OutputPath'
            ],
            'environment': [
                {'name': 'AWS_DEFAULT_REGION', 'value': region},
                {'name': 'S3_BUCKET', 'value': S3_BUCKET}
            ],
            'jobRoleArn': job_role_arn
        },
        retryStrategy={'attempts': 2},
        timeout={'attemptDurationSeconds': 3600}  # 1 hour timeout
    )
    
    print(f"Created job definition: {response['jobDefinitionName']}")
    print(f"   ARN: {response['jobDefinitionArn']}")
    print(f"   Revision: {response['revision']}")
    
except Exception as e:
    print(f"Error creating job definition: {e}")

job_definition_arn = f"arn:aws:batch:{region}:{account_id}:job-definition/{job_definition_name}:1"

## Step 4: Verify ECR Image Registration
Verify that the Progen2 Docker image is fully registered and available in ECR before proceeding to inference

In [None]:
def wait_for_ecr_image_availability(image_uri, check_interval=30, max_attempts=20):
    """
    Wait for ECR image to be fully registered and available for use
    """
    
    # Parse the ECR image URI to extract repository and tag
    try:
        # Extract repository name and tag from URI
        # Format: account.dkr.ecr.region.amazonaws.com/repository:tag
        uri_parts = image_uri.split('/')
        repository_name = '/'.join(uri_parts[1:]).split(':')[0]  # Extract repo name
        image_tag = image_uri.split(':')[-1] if ':' in image_uri else 'latest'
        
        print(f"Checking ECR image availability...")
        print(f"Repository: {repository_name}")
        print(f"Tag: {image_tag}")
        print(f"Full URI: {image_uri}")
        print()
        
    except Exception as e:
        raise Exception(f"Failed to parse ECR image URI: {image_uri}. Error: {e}")
    
    ecr_client = boto3.client('ecr')
    
    for attempt in range(1, max_attempts + 1):
        try:
            print(f"Attempt {attempt}/{max_attempts}: Checking image availability...")
            
            # Check if the repository exists
            try:
                repo_response = ecr_client.describe_repositories(
                    repositoryNames=[repository_name]
                )
                print(f" Repository '{repository_name}' exists")
                
            except ecr_client.exceptions.RepositoryNotFoundException:
                print(f" Repository '{repository_name}' not found")
                if attempt == max_attempts:
                    raise Exception(f"Repository '{repository_name}' does not exist in ECR")
                print(f"Waiting {check_interval} seconds before retry...")
                time.sleep(check_interval)
                continue
            
            # Check if the specific image tag exists
            try:
                images_response = ecr_client.describe_images(
                    repositoryName=repository_name,
                    imageIds=[{'imageTag': image_tag}]
                )
                
                if images_response['imageDetails']:
                    image_details = images_response['imageDetails'][0]
                    
                    # Check image size and pushed date
                    image_size_mb = image_details.get('imageSizeInBytes', 0) / (1024 * 1024)
                    pushed_at = image_details.get('imagePushedAt', 'Unknown')
                    
                    print(f" Image tag '{image_tag}' found")
                    print(f" Image size: {image_size_mb:.2f} MB")
                    print(f" Pushed at: {pushed_at}")
                    
                    # Verify image is not empty/corrupted
                    if image_size_mb < 1:  # Less than 1MB might indicate an issue
                        print(f" Warning: Image size is unusually small ({image_size_mb:.2f} MB)")
                        print("This might indicate an incomplete or corrupted image")
                    
                    print(f"\n ECR image is fully registered and ready for use!")
                    print(f"Image URI: {image_uri}")
                    return True
                    
                else:
                    print(f" Image tag '{image_tag}' not found in repository")
                    
            except ecr_client.exceptions.ImageNotFoundException:
                print(f" Image with tag '{image_tag}' not found")
            
            except Exception as e:
                print(f" Error checking image: {e}")
            
            if attempt < max_attempts:
                print(f"Waiting {check_interval} seconds before retry...")
                time.sleep(check_interval)
                print()
            
        except Exception as e:
            print(f"✗ Error during attempt {attempt}: {e}")
            if attempt == max_attempts:
                raise Exception(f"Failed to verify ECR image after {max_attempts} attempts: {e}")
            
            print(f"Waiting {check_interval} seconds before retry...")
            time.sleep(check_interval)
            print()
    
    raise Exception(f"ECR image verification timed out after {max_attempts} attempts")

# Wait for the Progen2 ECR image to be available
print("Verifying Progen2 Docker image registration in ECR...")
print("This step ensures the image is fully pushed and available before running Batch jobs.")
print()

try:
    wait_for_ecr_image_availability(ECR_IMAGE_URI)
    print("\n All deployment steps completed successfully!")
    print("Your AWS Batch environment is ready for Progen2 inference jobs.")
    print(f"You can now proceed to run the inference notebook: lab1-progen-on-batch-inference.ipynb")
    
except Exception as e:
    print(f"\n ECR image verification failed: {e}")
