Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: Apache-2.0

# Using the AWS Batch Architecture for Protein Folding

This notebook allows you to predict multiple protein sequences from the CAMEO data set between 2022-04-08 and 2022-07-02

## Table of Contents
0. [Install Dependencies](#0.-Install-Dependencies)
1. [Get target list](#1.-Get-target-list)
2. [Run MSA generation and folding jobs](#2.-Run-MSA-generation-and-folding-jobs) 
3. [Download results](#3.-Download-results)
4. [Visualze results](#4.-Visualize-results)
5. [Compare result to experimental structure](#5.-Compare-result-to-experimental-structure)

## 0. Install Dependencies

In [None]:
%pip install -U git+https://github.com/brianloyal/tmscoring.git

In [None]:
# Import required Python packages

import boto3
from datetime import datetime
from batchfold.batchfold_environment import BatchFoldEnvironment
from batchfold.jackhmmer_job import JackhmmerJob
from batchfold.openfold_job import OpenFoldJob
from batchfold.alphafold2_job import AlphaFold2Job
from batchfold.batchfold_target import BatchFoldTarget
from batchfold.mmseqs2_job import MMseqs2Job
import matplotlib.pyplot as plt
from nbhelpers import nbhelpers
import os

# Get client information
boto_session = boto3.session.Session()
batch_environment = BatchFoldEnvironment(boto_session = boto_session)

S3_BUCKET = batch_environment.default_bucket
print(f" S3 bucket name is {S3_BUCKET}")

## 1. Get target list

In [None]:
pdb_list = ["7Q4L_A", "7DUV_A", "7PP2_A", "7OIO_A", "7T9X_A"]
nbhelpers.get_pdb_data(pdb_list)

## 2. Run MSA generation and folding jobs

In [None]:
from pathlib import Path
data_dir = "data/fasta"

for file in [ file for file in os.listdir(data_dir) if Path(file).suffix in [".fa", ".fasta"] ]:
    print(os.path.join(data_dir, file))
    target_id = file.split(".")[0]
    target = BatchFoldTarget(target_id=target_id, s3_bucket=S3_BUCKET, boto_session=boto_session)
    target.add_fasta(os.path.join(data_dir, file))
    target.upload_fasta()

    jackhmmer_job_name = target.target_id + "_JackhmmerJob_" + datetime.now().strftime("%Y%m%d%s")
    jackhmmer_job = JackhmmerJob(
            job_name = jackhmmer_job_name,
            target_id = target.target_id,
            fasta_s3_uri = target.get_fasta_s3_uri(),
            output_s3_uri = target.get_msas_s3_uri(),
            boto_session = boto_session,
            cpu = 16,
            memory = 32
        )

    alphafold2_job_name = target.target_id + "_AlphaFold2Job_" + datetime.now().strftime("%Y%m%d%s")
    alphafold2_job = AlphaFold2Job(
        job_name = alphafold2_job_name,
        boto_session = boto_session,
        target_id = target.target_id,
        fasta_s3_uri = target.get_fasta_s3_uri(),
        msa_s3_uri = target.get_msas_s3_uri()+"/jackhmmer",
        output_s3_uri = target.get_predictions_s3_uri() + "/" + alphafold2_job_name,
        max_template_date = "2022-01-01",
        use_precomputed_msas = True,
        model_preset = "monomer_ptm",    
        benchmark = True,
        cpu = 4,
        memory = 16,
        gpu = 1
    )

    jackhmmer_submission = batch_environment.submit_job(jackhmmer_job, job_queue_name="GravitonSpotJobQueue")
    alphafold2_submission = batch_environment.submit_job(alphafold2_job, job_queue_name="G4dnJobQueue", depends_on=[jackhmmer_submission])

    mmseqs2_job_name = target.target_id + "_MMseqs2Job_" + datetime.now().strftime("%Y%m%d%s")
    mmseqs2_job = MMseqs2Job(
            job_name = mmseqs2_job_name,
            target_id = target.target_id,
            fasta_s3_uri = target.get_fasta_s3_uri(),
            output_s3_uri = target.get_msas_s3_uri(),
            boto_session = boto_session,
            cpu = 64,
            memory = 500
        )

    openfold_job_name = target.target_id + "_OpenFoldJob_" + datetime.now().strftime("%Y%m%d%s")
    openfold_job = OpenFoldJob(
        job_name = openfold_job_name,
        boto_session = boto_session,
        target_id = target.target_id,
        fasta_s3_uri = target.get_fasta_s3_uri(),
        msa_s3_uri = target.get_msas_s3_uri()+"/jackhmmer/",
        output_s3_uri = target.get_predictions_s3_uri() + "/" + openfold_job_name,
        max_template_date = "2022-01-01",
        use_precomputed_msas = True,
        config_preset = "finetuning_ptm",
        openfold_checkpoint_path = "openfold_params/finetuning_ptm_2.pt",
        save_outputs = True,
        cpu = 4,
        memory = 16,
        gpu = 1
    )

    mmseqs2_submission = batch_environment.submit_job(mmseqs2_job, job_queue_name="GravitonSpotJobQueue")
    openfold_submission = batch_environment.submit_job(openfold_job, job_queue_name="G4dnJobQueue", depends_on=[mmseqs2_submission])

Once the jobs are finished, download the results

## 3. Download results

In [None]:
target = BatchFoldTarget(target_id="7OIO_A", s3_bucket=S3_BUCKET, boto_session=boto_session)
target.download_all(local_path="data")

## 4. Visualize results

### Alignment Data

In [None]:
sto_path = f"data/{target.target_id}/msas/jackhmmer"
plt = nbhelpers.msa_plot(id, sto_path)
plt.show()

### Structure Data

In [None]:
last_job_name = target.list_job_names(job_type="OpenFold")[0]
nbhelpers.pdb_plot(pdb_path = f"data/{target.target_id}/predictions/{last_job_name}", show_sidechains=False).show()

## 5. Compare result to experimental structure

### Align OpenFold to experimental structure

In [None]:
import tmscoring
import py3Dmol

last_openfold_job_name = target.list_job_names(job_type="OpenFold")[0]

openfold_alignment = tmscoring.TMscoring(
    f"data/{target.target_id}/predictions/{last_job_name}/{target.target_id}_finetuning_ptm_relaxed.pdb",
    f"data/pdb/{target.target_id}.pdb"
    )

# Find the optimal alignment
openfold_alignment.optimise()

# Get the TM score:
tm_score = openfold_alignment.tmscore(**openfold_alignment.get_current_values())
print(f"TM score is {tm_score}")

# Save the aligned files:
openfold_alignment.write(outputfile='openfold_alignment.pdb', appended=True)

with open("openfold_alignment.pdb") as f:
    aligned_pdb = f.read()

view = py3Dmol.view(width=800, height=600)
view.addModels(aligned_pdb)

for n, chain, color in zip(range(2), ["A","B","C",], ["red", "blue", "green"]):
    view.setStyle({"chain": chain}, {"cartoon": {"color": color}}).zoomTo()
view.show()


### Align AlphaFold to experimental structure

In [None]:
last_alphafold_job_name = target.list_job_names(job_type="AlphaFold")[0]

alphafold_alignment = tmscoring.TMscoring(
    f"data/{target.target_id}/predictions/{last_alphafold_job_name}/ranked_0.pdb",
    f"data/pdb/{target.target_id}.pdb"
    )

# Find the optimal alignment
alphafold_alignment.optimise()

# Get the TM score:
tm_score = alphafold_alignment.tmscore(**alphafold_alignment.get_current_values())
print(f"TM score is {tm_score}")

# Save the aligned files:
alphafold_alignment.write(outputfile='alphafold_alignment.pdb', appended=True)

with open("alphafold_alignment.pdb") as f:
    aligned_pdb = f.read()

view = py3Dmol.view(width=800, height=600)
view.addModels(aligned_pdb)

for n, chain, color in zip(range(2), ["A","B","C",], ["red", "blue", "green"]):
    view.setStyle({"chain": chain}, {"cartoon": {"color": color}}).zoomTo()
view.show()


### Align OpenFold to AlphaFold

In [None]:
last_alphafold_job_name = target.list_job_names(job_type="AlphaFold")[0]

oa_alignment = tmscoring.TMscoring(
    f"data/{target.target_id}/predictions/{last_job_name}/{target.target_id}_finetuning_ptm_relaxed.pdb",
    f"data/{target.target_id}/predictions/{last_alphafold_job_name}/ranked_0.pdb",
    )

# Find the optimal alignment
oa_alignment.optimise()

# Get the TM score:
tm_score = oa_alignment.tmscore(**oa_alignment.get_current_values())
print(f"TM score is {tm_score}")

# Save the aligned files:
oa_alignment.write(outputfile='oa_alignment.pdb', appended=True)

with open("oa_alignment.pdb") as f:
    aligned_pdb = f.read()

view = py3Dmol.view(width=800, height=600)
view.addModels(aligned_pdb)

for n, chain, color in zip(range(2), ["A","B","C",], ["red", "blue", "green"]):
    view.setStyle({"chain": chain}, {"cartoon": {"color": color}}).zoomTo()
view.show()
