# AWS-RoseTTAFold

## I. Introduction

This notebook runs the [RoseTTAFold](https://www.ipd.uw.edu/2021/07/rosettafold-accurate-protein-structure-prediction-accessible-to-all/) algorithm developed by Minkyung Baek et al. and described in [M. Baek et al., Science 
10.1126/science.abj8754 2021](https://www.ipd.uw.edu/wp-content/uploads/2021/07/Baek_etal_Science2021_RoseTTAFold.pdf) on AWS.

<img src="img/RF_workflow.png" alt="RoseTTAFold Workflow" width="800px" />

The AWS workflow depends on a Batch compute environment.

<img src="img/AWS-RoseTTAFold-arch.png" alt="AWS-RoseTTAFold Architecture" width="800px" />

## II. Environment setup

In [None]:
## Install dependencies
%pip install -q -q -r requirements.txt

In [None]:
## Import helper functions at rfutils/rfutils.py
from rfutils import rfutils

## Load additional dependencies
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import boto3
import glob
import json
import pandas as pd
import sagemaker

pd.set_option("max_colwidth", None)

# Get service clients
session = boto3.session.Session()
sm_session = sagemaker.session.Session()
region = session.region_name
role = sagemaker.get_execution_role()
s3 = boto3.client("s3", region_name=region)
account_id = boto3.client("sts").get_caller_identity().get("Account")

bucket = sm_session.default_bucket()
print(f"S3 bucket name is {bucket}")

## III. Input Protein Sequence

Enter a protein sequence manually

In [None]:
seq = SeqRecord(
    Seq("MKQHKAMIVALIVICITAVVAALVTRKDLCEVHIRTGQTEVAVF"),
    id="YP_025292.1",
    name="HokC",
    description="toxic membrane protein, small",
)

Or provide the path to a fasta file

In [None]:
seq = SeqIO.read("data/T1078.fa", "fasta")

In [None]:
print(f"Protein sequence for analysis is \n{seq}")

## IV. Submit RoseTTAFold Jobs

### Generate Job Name

In [None]:
job_name = rfutils.create_job_name(seq.id)
print(f"Automatically-generated job name is: {job_name}")

### Upload fasta file to S3

In [None]:
input_uri = rfutils.upload_fasta_to_s3(seq, bucket, job_name)

### Submit jobs to AWS Batch queues

Get the names of the AWS Batch resources deployed in your account.

In [None]:
batch_resources = rfutils.get_rosettafold_batch_resources(region=region)
batch_resources

In [None]:
two_step_response = rfutils.submit_2_step_job(
    bucket=bucket,
    job_name=job_name,
    data_prep_input_file="input.fa",
    data_prep_job_definition=batch_resources["dataPrepJobDefinition"][0],
    data_prep_queue=batch_resources["dataPrepJobQueue"][0],
    data_prep_cpu=16,
    data_prep_mem=60,
    predict_job_definition=batch_resources["predictJobDefinition"][0],
    predict_queue=batch_resources["predictJobQueue"][0],
    predict_cpu=24,
    predict_mem=90,
    predict_gpu=1,
)
data_prep_jobId = two_step_response[0]["jobId"]
predict_jobId = two_step_response[1]["jobId"]

## V. Check Status of Data Prep and Prediction Jobs

In [None]:
rfutils.get_rf_job_info(
    cpu_queue=batch_resources["dataPrepJobQueue"][0],
    gpu_queue=batch_resources["predictJobQueue"][0],
    hrs_in_past=3,
)

## VI. View Data Prep Results

Pause while the data prep job starts up

In [None]:
rfutils.wait_for_job_start(data_prep_jobId)

Get logs for data prep job (Run this multiple times to see how the job progresses)

In [None]:
data_prep_logStreamName = rfutils.get_batch_job_info(data_prep_jobId)["logStreamName"]
rfutils.get_batch_logs(data_prep_logStreamName).tail(n=5)

Retrieve and Display Multiple Sequence Alignment (MSA) Results

In [None]:
rfutils.display_msa(data_prep_jobId, bucket)

## VII. View Prediction Results

Pause while the predict job starts up

In [None]:
rfutils.wait_for_job_start(predict_jobId)

Get logs for prediction job (Run this multiple times to see how the job progresses)

In [None]:
data_prep_logStreamName = rfutils.get_batch_job_info(data_prep_jobId)["logStreamName"]
rfutils.get_batch_logs(data_prep_logStreamName).tail(n=5)

## VIII. View Job Metrics

In [None]:
metrics = rfutils.get_rf_job_metrics(job_name, bucket, region)

print(f'Number of sequences in MSA: {metrics["DATA_PREP"]["MSA_COUNT"]}')
print(f'Number of templates: {metrics["DATA_PREP"]["TEMPLATE_COUNT"]}')
print(f'MSA duration (sec): {metrics["DATA_PREP"]["MSA_DURATION"]}')
print(f'SS duration (sec): {metrics["DATA_PREP"]["SS_DURATION"]}')
print(f'Template search duration (sec): {metrics["DATA_PREP"]["TEMPLATE_DURATION"]}')
print(
    f'Total data prep duration (sec): {metrics["DATA_PREP"]["TOTAL_DATA_PREP_DURATION"]}'
)
print(f'Total predict duration (sec): {metrics["PREDICT"]["TOTAL_PREDICT_DURATION"]}')

## IX. Retrieve and Display Predicted Structure

In [None]:
rfutils.display_structure(predict_jobId, bucket, vmin=0.5, vmax=0.9)