# AWS-RoseTTAFold

## Introduction

This notebook runs the [RoseTTAFold](https://www.ipd.uw.edu/2021/07/rosettafold-accurate-protein-structure-prediction-accessible-to-all/) algorithm developed by Minkyung Baek et al. and described in (M. Baek et al., Science 
10.1126/science.abj8754 (2021))[https://www.ipd.uw.edu/wp-content/uploads/2021/07/Baek_etal_Science2021_RoseTTAFold.pdf] on AWS. The AWS workflow depends on a Batch compute environment.

![RoseTTAFold Network Architecture](img/RF_workflow.png "RoseTTAFold Network Architecture")

![AWS Achitecture](img/AWS_arch.png "AWS Architecture")

## Environment setup

In [None]:
## Install dependencies
!pip install -r requirements.txt

In [None]:
## Import helper functions at src/rfutils.py
from src import rfutils

## Load additional dependencies
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
import boto3
import glob
import json
import pandas as pd
import sagemaker

In [None]:
pd.set_option('max_colwidth', None)

In [None]:
# Get service clients
session = boto3.session.Session()
sm_session = sagemaker.session.Session()
region = session.region_name
role = sagemaker.get_execution_role()
s3 = boto3.client('s3', region_name=region)
account_id = boto3.client('sts').get_caller_identity().get('Account')

Define an S3 bucket (or use the default SageMaker bucket)

In [None]:
bucket = sm_session.default_bucket()

## Input Protein Sequence

Provide the path to a .fasta file

In [None]:
seq = SeqIO.read("data/T1078.fa", "fasta")

Or, alternatively enter a protein sequence manually

In [None]:
seq = SeqRecord(
    Seq("MKQHKAMIVALIVICITAVVAALVTRKDLCEVHIRTGQTEVAVF"),
    id="YP_025292.1",
    name="HokC",
    description="toxic membrane protein, small",
)

In [None]:
print(f"Protein sequence for analysis is \n{seq}")

## Submit RoseTTAFold Job

Generate Job Name

In [None]:
job_name = rfutils.get_job_name(seq.id)
print(f"Automatically-generated job name is: {job_name}")

Upload fasta file to S3

In [None]:
input_uri = rfutils.upload_fasta_to_s3(seq, bucket, job_name)

Submit job to AWS Batch queue

In [None]:
two_step_response=rfutils.submit_2_step_job(
    bucket=bucket,
    job_name=job_name,
    data_prep_input_file="input.fa",
    data_prep_job_definition="AWS-RoseTTAFold-CPU",
    data_prep_queue="AWS-RoseTTAFold-CPU",
    data_prep_cpu=24,
    data_prep_mem=80,
    predict_job_definition="AWS-RoseTTAFold-GPU",
    predict_queue="AWS-RoseTTAFold-GPU",
    predict_cpu=24,
    predict_mem=80,
    predict_gpu=2
)
jobId = two_step_response[0]["jobId"]

## Check status of RF Job

In [None]:
print(json.dumps(rfutils.get_batch_job_info(two_step_response[0]["jobId"]), indent=4, sort_keys=True))
print(json.dumps(rfutils.get_batch_job_info(two_step_response[1]["jobId"]), indent=4, sort_keys=True))

output=rfutils.get_rf_job_info(hrs_in_past=1)

## Retrieve and analyze MSA results

Pause while the job starts up

In [None]:
rfutils.wait_for_job_start(two_step_response[0]["jobId"])

Get run logs (Run this multiple times to see how the job progresses)

In [None]:
info = rfutils.get_batch_job_info(two_step_response[0]["jobId"])
rfutils.get_batch_logs(info["logStreamName"]).tail()

In [None]:
info = rfutils.get_batch_job_info(two_step_response[0]["jobId"])

if info["status"] == "SUCCEEDED":
    # download MSA file
    print(f"Downloading MSA file from s3://{bucket}/{info['jobName']}/{info['jobName']}.msa0.a3m")
    s3.download_file(bucket, f"{info['jobName']}/{info['jobName']}.msa0.a3m", "data/alignment.msa")
    msa_all = rfutils.parse_a3m("data/alignment.msa")
    rfutils.plot_msa_info(msa_all)
else:
    print(f"{info['jobId']} is in {info['status']} status. Please try again once the job has completed.")

## Retrieve and display structure

In [None]:
rfutils.wait_for_job_start(two_step_response[1]["jobId"])

In [None]:
info = rfutils.get_batch_job_info(two_step_response[1]["jobId"])
rfutils.get_batch_logs(info["logStreamName"]).tail()

In [None]:
info = rfutils.get_batch_job_info(two_step_response[1]["jobId"])

if info["status"] == "SUCCEEDED":
    s3.download_file(bucket, f"{info['jobName']}/{info['jobName']}.e2e.pdb", "data/e2e.pdb")
    color = "lDDT" #@param ["chain", "lDDT", "rainbow"]
    show_sidechains = False
    show_mainchains = False
    rfutils.show_pdb("data/e2e.pdb", show_sidechains, show_mainchains, color, chains=1, vmin=0.5, vmax=0.9).show()
    if color == "lDDT": rfutils.plot_plddt_legend().show()
else:
    print(f"{info['jobId']} is in {info['status']} status. Please try again once the job has completed.")

## Analyze Proteins in Bulk

In [None]:
fasta_files = glob.glob('data/*.fa')
job_ids = []
for file in fasta_files:
    seq = SeqIO.read(file, "fasta")
    job_name = rfutils.get_job_name(seq.id)
    print(f"Automatically-generated job name is: {job_name}")
    input_uri = rfutils.upload_fasta_to_s3(seq, bucket, job_name)
    two_step_response=rfutils.submit_2_step_job(
        bucket=bucket,
        job_name=job_name,
        data_prep_input_file="input.fa",
        data_prep_job_definition="AWS-RoseTTAFold-CPU",
        data_prep_queue="AWS-RoseTTAFold-CPU",
        data_prep_cpu=24,
        data_prep_mem=80,
        predict_job_definition="AWS-RoseTTAFold-GPU",
        predict_queue="AWS-RoseTTAFold-GPU",
        predict_cpu=24,
        predict_mem=80,
        predict_gpu=2
    )

In [None]:
output=rfutils.get_rf_job_info(hrs_in_past=1)