### 0. Install dependencies

In [228]:
%pip install -q --upgrade pip
%pip install -q --upgrade sagemaker boto3 awscli boto3 ipywidgets

Note: you may need to restart the kernel to use updated packages.
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
aiobotocore 2.7.0 requires botocore<1.31.65,>=1.31.16, but you have botocore 1.34.81 which is incompatible.[0m[31m
[0mNote: you may need to restart the kernel to use updated packages.


In [27]:
import json
from pathlib import Path
import os
from time import strftime
from functools import partial
import importlib

import utilities as u

import boto3
import sagemaker
from sagemaker.experiments.run import Run
from sagemaker.inputs import TrainingInput
from sagemaker.pytorch import PyTorch

In [2]:
sagemaker.__version__

'2.214.0'

In [17]:
boto_session = boto3.session.Session()
sagemaker_session = sagemaker.session.Session(boto_session)

REGION_NAME = sagemaker_session.boto_region_name
S3_BUCKET = sagemaker_session.default_bucket()

EXPERIMENT_NAME = "hyenaDNA-pretraining-v2"

SAGEMAKER_EXECUTION_ROLE = sagemaker.session.get_execution_role(sagemaker_session)
print(f"Assumed SageMaker role is {SAGEMAKER_EXECUTION_ROLE}")

INFO:botocore.credentials:Found credentials from IAM Role: BaseNotebookInstanceEc2InstanceRole


Assumed SageMaker role is arn:aws:iam::111918798052:role/Admin


### 1. Read the data from AWS HealthOmics

In [4]:
seq_store_id = "4308389581"
!aws omics get-sequence-store --id {seq_store_id} > /tmp/seq-store.json

In [12]:
seq_store_info = json.loads(Path("/tmp/seq-store.json").read_text())
s3_uri = seq_store_info["s3Access"]["s3Uri"]
s3_arn = seq_store_info["s3Access"]["s3AccessPointArn"]
key_arn = seq_store_info["sseConfig"]["keyArn"]
s3_uri, s3_arn, key_arn

('s3://111918798052-4308389-m7r4grkrg7nkpmf5swnjwf1iqsdieuse1b-s3alias/111918798052/sequenceStore/4308389581/',
 'arn:aws:s3:us-east-1:559620149354:accesspoint/111918798052-4308389581',
 'arn:aws:kms:us-east-1:559620149354:key/ef42c6a8-5692-4a6c-9a66-a2d1058a9a41')

For this notebook to access the objects in the above S3 access point, `s3uri`, you must add a policy
to this notebook's execution role (`SAGEMAKER_EXECUTION_ROLE`). The output of the following cell is the policy that
you should attach to this role:

In [15]:
print(json.dumps({
    "Version": "2012-10-17",
    "Statement": [
        { 
            "Sid": "S3DirectAccess",
            "Effect": "Allow",
            "Action": [ 
                "s3:GetObject",
                "s3:ListBucket"
            ], 
            "Resource": "*",
            "Condition": {
                "StringEquals": {
                    "s3:DataAccessPointArn": s3_arn
                } 
            }
        },
        { 
            "Sid": "DefaultSequenceStoreKMSDecrypt",
            "Effect": "Allow",
            "Action": "kms:Decrypt",
            "Resource": key_arn
        }
    ] 
}, indent=2))

{
  "Version": "2012-10-17",
  "Statement": [
    {
      "Sid": "S3DirectAccess",
      "Effect": "Allow",
      "Action": [
        "s3:GetObject",
        "s3:ListBucket"
      ],
      "Resource": "*",
      "Condition": {
        "StringEquals": {
          "s3:DataAccessPointArn": "arn:aws:s3:us-east-1:559620149354:accesspoint/111918798052-4308389581"
        }
      }
    },
    {
      "Sid": "DefaultSequenceStoreKMSDecrypt",
      "Effect": "Allow",
      "Action": "kms:Decrypt",
      "Resource": "arn:aws:kms:us-east-1:559620149354:key/ef42c6a8-5692-4a6c-9a66-a2d1058a9a41"
    }
  ]
}


#### 1.1 Load the data from S3 to the local disc

In [54]:
data_dir = Path("/tmp/data/mouse")
!aws s3 cp --recursive {s3_uri}readSet/ {str(data_dir)}

download: s3://111918798052-4308389-m7r4grkrg7nkpmf5swnjwf1iqsdieuse1b-s3alias/111918798052/sequenceStore/4308389581/readSet/2677481661/chr8.fq.gz to ../../../tmp/data/mouse/2677481661/chr8.fq.gz
download: s3://111918798052-4308389-m7r4grkrg7nkpmf5swnjwf1iqsdieuse1b-s3alias/111918798052/sequenceStore/4308389581/readSet/2063114856/chr16.fq.gz to ../../../tmp/data/mouse/2063114856/chr16.fq.gz
download: s3://111918798052-4308389-m7r4grkrg7nkpmf5swnjwf1iqsdieuse1b-s3alias/111918798052/sequenceStore/4308389581/readSet/2043193651/chr17.fq.gz to ../../../tmp/data/mouse/2043193651/chr17.fq.gz
download: s3://111918798052-4308389-m7r4grkrg7nkpmf5swnjwf1iqsdieuse1b-s3alias/111918798052/sequenceStore/4308389581/readSet/1040759226/chr12.fq.gz to ../../../tmp/data/mouse/1040759226/chr12.fq.gz
download: s3://111918798052-4308389-m7r4grkrg7nkpmf5swnjwf1iqsdieuse1b-s3alias/111918798052/sequenceStore/4308389581/readSet/4019662486/chr5.fq.gz to ../../../tmp/data/mouse/4019662486/chr5.fq.gz
download: s3:/

#### 1.2 Uncompress the files

In [55]:
fastq_files = u.convert_directory(data_dir, suffix=".fq.gz",
                                  convertor=partial(u.gunzip_file,
                                                    suffix=".gz"),
                                  delete_orig_file=True)

/tmp/data/mouse/1040759226/chr12.fq.gz -> /tmp/data/mouse/1040759226/chr12.fq
Deleted /tmp/data/mouse/1040759226/chr12.fq.gz
/tmp/data/mouse/2043193651/chr17.fq.gz -> /tmp/data/mouse/2043193651/chr17.fq
Deleted /tmp/data/mouse/2043193651/chr17.fq.gz
/tmp/data/mouse/2677481661/chr8.fq.gz -> /tmp/data/mouse/2677481661/chr8.fq
Deleted /tmp/data/mouse/2677481661/chr8.fq.gz
/tmp/data/mouse/2063114856/chr16.fq.gz -> /tmp/data/mouse/2063114856/chr16.fq
Deleted /tmp/data/mouse/2063114856/chr16.fq.gz
/tmp/data/mouse/4019662486/chr5.fq.gz -> /tmp/data/mouse/4019662486/chr5.fq
Deleted /tmp/data/mouse/4019662486/chr5.fq.gz
/tmp/data/mouse/4164737562/chr9.fq.gz -> /tmp/data/mouse/4164737562/chr9.fq
Deleted /tmp/data/mouse/4164737562/chr9.fq.gz
/tmp/data/mouse/4291333584/chr10.fq.gz -> /tmp/data/mouse/4291333584/chr10.fq
Deleted /tmp/data/mouse/4291333584/chr10.fq.gz
/tmp/data/mouse/4399289471/chr1.fq.gz -> /tmp/data/mouse/4399289471/chr1.fq
Deleted /tmp/data/mouse/4399289471/chr1.fq.gz
/tmp/data/mo

#### 1.3 Convert each FASTQ into an equivalent FASTA

In [56]:
importlib.reload(u)
fasta_files = u.convert_directory(data_dir, suffix=".fq",
                                  convertor=u.convert_fastq_to_fasta,
                                  delete_orig_file=True)

/tmp/data/mouse/1040759226/chr12.fq -> /tmp/data/mouse/1040759226/chr12.fa
Deleted /tmp/data/mouse/1040759226/chr12.fq
/tmp/data/mouse/2043193651/chr17.fq -> /tmp/data/mouse/2043193651/chr17.fa
Deleted /tmp/data/mouse/2043193651/chr17.fq
/tmp/data/mouse/2677481661/chr8.fq -> /tmp/data/mouse/2677481661/chr8.fa
Deleted /tmp/data/mouse/2677481661/chr8.fq
/tmp/data/mouse/2063114856/chr16.fq -> /tmp/data/mouse/2063114856/chr16.fa
Deleted /tmp/data/mouse/2063114856/chr16.fq
/tmp/data/mouse/4019662486/chr5.fq -> /tmp/data/mouse/4019662486/chr5.fa
Deleted /tmp/data/mouse/4019662486/chr5.fq
/tmp/data/mouse/4164737562/chr9.fq -> /tmp/data/mouse/4164737562/chr9.fa
Deleted /tmp/data/mouse/4164737562/chr9.fq
/tmp/data/mouse/4291333584/chr10.fq -> /tmp/data/mouse/4291333584/chr10.fa
Deleted /tmp/data/mouse/4291333584/chr10.fq
/tmp/data/mouse/4399289471/chr1.fq -> /tmp/data/mouse/4399289471/chr1.fa
Deleted /tmp/data/mouse/4399289471/chr1.fq
/tmp/data/mouse/5719009676/chr4.fq -> /tmp/data/mouse/571900

#### 1.4 Re-jigger the directory hierachy to match what HyenaDNA needs

In [57]:
# Remove the readSet dir segments
for child in data_dir.rglob("**/*"):
    if child.is_file():
        path = str(child).split("/")
        new_path = path[:-2] + path[-1:]
        target = child.rename(Path("/".join(new_path)))
        print(f"Moved {child} to {target}")

Moved /tmp/data/mouse/1040759226/chr12.fa to /tmp/data/mouse/chr12.fa
Moved /tmp/data/mouse/2043193651/chr17.fa to /tmp/data/mouse/chr17.fa
Moved /tmp/data/mouse/2677481661/chr8.fa to /tmp/data/mouse/chr8.fa
Moved /tmp/data/mouse/2063114856/chr16.fa to /tmp/data/mouse/chr16.fa
Moved /tmp/data/mouse/4019662486/chr5.fa to /tmp/data/mouse/chr5.fa
Moved /tmp/data/mouse/4164737562/chr9.fa to /tmp/data/mouse/chr9.fa
Moved /tmp/data/mouse/4291333584/chr10.fa to /tmp/data/mouse/chr10.fa
Moved /tmp/data/mouse/4399289471/chr1.fa to /tmp/data/mouse/chr1.fa
Moved /tmp/data/mouse/5719009676/chr4.fa to /tmp/data/mouse/chr4.fa
Moved /tmp/data/mouse/4552296085/chr11.fa to /tmp/data/mouse/chr11.fa
Moved /tmp/data/mouse/4815912818/chrX.fa to /tmp/data/mouse/chrX.fa
Moved /tmp/data/mouse/5628991692/chr19.fa to /tmp/data/mouse/chr19.fa
Moved /tmp/data/mouse/6611795136/chr18.fa to /tmp/data/mouse/chr18.fa
Moved /tmp/data/mouse/7039739896/chr13.fa to /tmp/data/mouse/chr13.fa
Moved /tmp/data/mouse/6849506425

### 2. Training



### 2.1 Define the training container 

In [253]:
pytorch_image_uri = f"763104351884.dkr.ecr.{REGION_NAME}.amazonaws.com/pytorch-training:2.2.0-gpu-py310-cu121-ubuntu20.04-sagemaker"
pytorch_image_uri

'763104351884.dkr.ecr.us-east-1.amazonaws.com/pytorch-training:2.2.0-gpu-py310-cu121-ubuntu20.04-sagemaker'

#### 2.2 Define the training job parameters

In [401]:
MODEL_ID = 'LongSafari/hyenadna-small-32k-seqlen-hf'
TRAINING_JOB_NAME = 'hyenaDNA-pretraining'

# Additional training parameters
hyperparameters = {
    "species" : "mouse",
    "epochs": 150,
    "model_checkpoint": MODEL_ID,
    "max_length": 32_000,
    "batch_size": 4, 
    "learning_rate": 6e-4,
    "weight_decay" : 0.1,
    "log_level" : "INFO",
    "log_interval" : 100
}


#### 2.3 Define Metrics to track


In [460]:
metric_definitions = [
    {"Name": "epoch", "Regex": "Epoch: ([0-9.]*)"},
    {"Name": "step", "Regex": "Step: ([0-9.]*)"},
    {"Name": "train_loss", "Regex": "Train Loss: ([0-9.e-]*)"},
    {"Name": "train_perplexity", "Regex": "Train Perplexity: ([0-9.e-]*)"},
    {"Name": "eval_loss", "Regex": "Eval Average Loss: ([0-9.e-]*)"},
    {"Name": "eval_perplexity", "Regex": "Eval Perplexity: ([0-9.e-]*)"}
]

#### 2.4 Define the tensorboard configurations to track the training results

In [403]:
from sagemaker.debugger import TensorBoardOutputConfig

LOG_DIR="/opt/ml/output/tensorboard"

output_path = os.path.join(
    "s3://", S3_BUCKET, "sagemaker-output", "training", TRAINING_JOB_NAME
)

tensorboard_output_config = TensorBoardOutputConfig(
    s3_output_path=os.path.join(output_path, 'tensorboard'),
    container_local_output_path=LOG_DIR
)

#### 2.4 Define Estimator

In [404]:
hyenaDNA_estimator = PyTorch(
    base_job_name=TRAINING_JOB_NAME,
    entry_point="train_hf_accelerate.py",
    source_dir="scripts/",
    instance_type="ml.g5.12xlarge",
    instance_count=1,
    image_uri=pytorch_image_uri,
    role=SAGEMAKER_EXECUTION_ROLE,
    hyperparameters=hyperparameters,
    metric_definitions=metric_definitions,
    sagemaker_session=sagemaker_session,
    distribution={"torch_distributed": {"enabled": True}},
    tags=[{"Key": "project", "Value": "genomics-model-pretraining"}],
    keep_alive_period_in_seconds=1800,
    tensorboard_output_config=tensorboard_output_config,
)


#### 2.5 Start Training with Distributed Data Parallel

In [405]:
with Run(
    experiment_name=EXPERIMENT_NAME,
    sagemaker_session=sagemaker_session,
) as run:
    hyenaDNA_estimator.fit(
        {
            "data": TrainingInput(
                s3_data=data_uri, input_mode="File"
            ),
        },
        wait=False,
    )


INFO:sagemaker:Creating training-job with name: hyenaDNA-pretraining-2024-04-11-01-11-51-388


In [461]:
training_job_name = hyenaDNA_estimator.latest_training_job.name
training_job_name

'hyenaDNA-pretraining-2024-04-11-01-11-51-388'

### 5. Training Results 

* In our training process we had pushed the training resulsts to Tensorboard. You can see them using SageMaker tensorboad application. Execuate following cell to get link to the the tensorboard application

In [462]:
from sagemaker.interactive_apps.tensorboard import TensorBoardApp

user_profile = "shamika"

with open("/opt/ml/metadata/resource-metadata.json", "r") as f:
    app_metadata = json.loads(f.read())
    sm_user_profile_name = app_metadata["SpaceName"]
    sm_domain_id = app_metadata["DomainId"]

tb_app = TensorBoardApp(REGION_NAME)
tb_app.get_app_url(
    training_job_name=training_job_name,
    create_presigned_domain_url=True,           
    domain_id=sm_domain_id,                 
    user_profile_name=user_profile, 
    open_in_default_web_browser=False,
    optional_create_presigned_url_kwargs={} 
)



'https://studio-d-xgpxwyumgsdh.studio.us-east-1.sagemaker.aws/auth?token=eyJhbGciOiJIUzI1NiJ9.eyJmYXNDcmVkZW50aWFscyI6IkFZQURlSGFJOVVoZE5hTktFWXpoc0FVUjFNSUFYd0FCQUJWaGQzTXRZM0o1Y0hSdkxYQjFZbXhwWXkxclpYa0FSRUV3TjJWMWJYTlRjR0phZERWb2VIVkNhbEIyVDNWUGFuVmhSRVpGUmxWQlVXaG1hWFl2ZFVWaVNEbDJTR1pQVkZWTVFrbGxXa3RUYjJWWU9UaFpNQ3RwVVQwOUFBRUFCMkYzY3kxcmJYTUFTMkZ5YmpwaGQzTTZhMjF6T25WekxXVmhjM1F0TVRvNU9EQXpOVEk0TWpZeE1UVTZhMlY1TDJFNE9UZ3labVU0TFRFM056Y3ROR0kwWmkwNE9UVTVMV00yTldNeE9XSXhZak14TUFDNEFRSUJBSGdubFhwQlJKL2g4bXY4ek1wY292U0FSWk1INi8vbTBCY2lxRVYwbnNNREVRRVc3Y1VIdlpxUm5DSVhuL2ZGZDFGRkFBQUFmakI4QmdrcWhraUc5dzBCQndhZ2J6QnRBZ0VBTUdnR0NTcUdTSWIzRFFFSEFUQWVCZ2xnaGtnQlpRTUVBUzR3RVFRTUNZRlRlVmF6ZnVwSlBCczJBZ0VRZ0R2RTBmVzZGamNZMmw1SDUwT1ZJVVBDMnJEeEVROFNQV0pSa0pldko2OVJMSFVmaFFHTVdEbkZMdklPOFpVSFJGZ2VZUzB5ZC9uakk4bzlPQUlBQUFBQURBQUFFQUFBQUFBQUFBQUFBQUFBQUFDYlgxdVpwUzJzejBsMGZab3c4L1pyLy8vLy93QUFBQUVBQUFBQUFBQUFBQUFBQUFFQUFBUW44a2dPT1RHOEdRVnV6cGdNVHhGYVl1VEZ5cWpuWHFlTksrNEdzR1AyUjJ1Ym53QUNkenUrM0owOU

### 6. Deploy trained model to an realtime endpoint

In [321]:
from sagemaker.pytorch.model import PyTorchModel
from sagemaker.serializers import JSONSerializer
from sagemaker.estimator import Estimator

#training_job_name = "hyenaDNA-pretraining-2024-04-06-06-23-26-412"
attached_estimator = Estimator.attach(training_job_name)

model_data = attached_estimator.model_data
model_data


2024-04-06 17:23:11 Starting - Starting the training job
2024-04-06 17:23:11 Pending - Preparing the instances for training
2024-04-06 17:23:11 Downloading - Downloading the training image
2024-04-06 17:23:11 Training - Training image download completed. Training in progress.
2024-04-06 17:23:11 Uploading - Uploading generated training model
2024-04-06 17:23:11 Completed - Instances not retained as a result of warmpool resource limits being exceeded


's3://sagemaker-us-east-1-111918798052/hyenaDNA-pretraining-2024-04-06-06-23-26-412/output/model.tar.gz'

In [473]:
# Deploy the model to create a real-time endpoint
endpoint_name = 'hyenaDNA-mouse-pretrained-ep'  
pytorch_deployment_uri = f"763104351884.dkr.ecr.{REGION}.amazonaws.com/pytorch-inference:2.2.0-gpu-py310-cu118-ubuntu20.04-sagemaker"

hyenaDNAModel = PyTorchModel(
    model_data=model_data,
    role=SAGEMAKER_EXECUTION_ROLE,
    image_uri=pytorch_deployment_uri,
    entry_point="inference.py",
    source_dir="scripts/",
    sagemaker_session=sagemaker_session,
    name=endpoint_name,
    env = {
        'MMS_MAX_REQUEST_SIZE': '2000000000',
        'MMS_MAX_RESPONSE_SIZE': '2000000000',
        'MMS_DEFAULT_RESPONSE_TIMEOUT': '900',
        'TS_MAX_RESPONSE_SIZE':'2000000000',
        'TS_MAX_REQUEST_SIZE':'2000000000',
    }
)

In [479]:
real_time_endpoint_name = "hyenaDNA-mouse-pretrained-real-ep-v8"
env = {
    'SAGEMAKER_MODEL_SERVER_TIMEOUT':'7200', 
    'TS_MAX_RESPONSE_SIZE':'2000000000',
    'TS_MAX_REQUEST_SIZE':'2000000000',
    'MMS_MAX_RESPONSE_SIZE':'2000000000',
    'MMS_MAX_REQUEST_SIZE':'2000000000'
}

# deploy the endpoint endpoint
realtime_predictor = hyenaDNAModel.deploy(
    initial_instance_count=1,
    instance_type="ml.g5.8xlarge",
    endpoint_name=real_time_endpoint_name,
    env=env,
)

INFO:sagemaker:Repacking model artifact (s3://sagemaker-us-east-1-111918798052/hyenaDNA-pretraining-2024-04-06-06-23-26-412/output/model.tar.gz), script artifact (scripts/), and dependencies ([]) into single tar.gz file located at s3://sagemaker-us-east-1-111918798052/hyenaDNA-mouse-pretrained-ep-v8/model.tar.gz. This may take some time depending on model size...
INFO:sagemaker:Creating model with name: hyenaDNA-mouse-pretrained-ep-v8
INFO:sagemaker:Creating endpoint-config with name hyenaDNA-mouse-pretrained-real-ep-v8
INFO:sagemaker:Creating endpoint with name hyenaDNA-mouse-pretrained-real-ep-v8


-----------!

### 7. Test the realtime endpoint



In [None]:
import json

data = [sample_genome_data[0]
realtime_predictor.serializer = JSONSerializer()
realtime_predictor.deserializer = JSONDeserializer()
realtime_predictor.predict(data=data)

#### 7. Cleanup



In [None]:
realtime_predictor.delete_endpoint()