Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
SPDX-License-Identifier: Apache-2.0

# Quick Start: Running OpenFold on AWS Batch

## Table of Contents
0. [Install Dependencies](#0.-Install-Dependencies)
1. [Create Target](#1.-Create-Target)
2. [Submit Sequence Alignment and Folding Jobs](#2.-Submit-Sequence-Alignment-and-Folding-Jobs) 
3. [Download results](#3.-Download-results)
4. [Visualze results](#4.-Visualize-results)

## 0. Install Dependencies

In [None]:
# Import required Python packages

import boto3
from datetime import datetime
from batchfold.batchfold_environment import BatchFoldEnvironment
from batchfold.openfold_job import OpenFoldJob
from batchfold.batchfold_target import BatchFoldTarget
from batchfold.mmseqs2_job import MMseqs2Job
import matplotlib.pyplot as plt
from nbhelpers import nbhelpers

# Create AWS clients
boto_session = boto3.session.Session()
batch_environment = BatchFoldEnvironment(boto_session = boto_session)

S3_BUCKET = batch_environment.default_bucket
print(f" S3 bucket name is {S3_BUCKET}")

## 1. Create Target

In [None]:
target_id = "7VNA"
target = BatchFoldTarget(target_id=target_id, s3_bucket=S3_BUCKET)
target.add_sequence(
    seq_id=target_id,
    seq="GSIPHKENMFKSKHKLDFSLVSMDQRGKHILGYADAELVNMGGYDLVHYDDLAYVASAHQELLKTGASGMIAYRYQKKDGEWQWLQTSSRLVYKNSKPDFVICTHRQLMDEEGHDLLGKR",
    description="Chain A|Ahr homolog spineless|Drosophila melanogaster",
)

## 2. Submit Sequence Alignment and Folding Jobs

In [None]:
job_name = target.target_id + "_MMseqs2Job_" + datetime.now().strftime("%Y%m%d%s")
mmseqs2_job = MMseqs2Job(
        job_name = job_name,
        target_id = target.target_id,
        fasta_s3_uri = target.get_fasta_s3_uri(),
        output_s3_uri = target.get_msas_s3_uri(),
        boto_session = boto_session,
        cpu = 64,
        memory = 500
    )

job_name = target.target_id + "_OpenFoldJob_" + datetime.now().strftime("%Y%m%d%s")
openfold_job = OpenFoldJob(
    job_name = job_name,
    boto_session = boto_session,
    target_id = target.target_id,
    fasta_s3_uri = target.get_fasta_s3_uri(),
    msa_s3_uri = target.get_msas_s3_uri()+"/jackhmmer/",
    output_s3_uri = target.get_predictions_s3_uri() + "/" + job_name,
    use_precomputed_msas = True,
    config_preset = "finetuning_ptm",
    openfold_checkpoint_path = "openfold_params/finetuning_ptm_2.pt",
    save_outputs = True,
    cpu = 4,
    memory = 15, # Why not 16? ECS needs about 1 GB for container services
    gpu = 1
)

mmseqs2_submission = batch_environment.submit_job(mmseqs2_job, job_queue_name="GravitonOnDemandJobQueue")
openfold_submission = batch_environment.submit_job(openfold_job, job_queue_name="G4dnJobQueue", depends_on=[mmseqs2_submission])

Check on job statuses

In [None]:
for job in [mmseqs2_job, openfold_job]:
    print(f"Job {job.describe_job()[0]['jobName']} is in status {job.describe_job()[0]['status']}")

## 3. Download Results

Once the jobs are finished, download the results

In [None]:
target = BatchFoldTarget(target_id=target_id, s3_bucket=S3_BUCKET, boto_session=boto_session)
target.download_predictions(local_path="data")

## 4. Visualize Results

### Plot Alignment Data

In [None]:
plt = nbhelpers.msa_plot(target_id, f"data/{target_id}/msas/mmseqs2")
plt.show()

### Plot Predicted Structure

In [None]:
last_job_name = target.list_job_names(job_type="OpenFold")[-1]
print(last_job_name)
nbhelpers.pdb_plot(pdb_path = f"data/{target_id}/predictions/{last_job_name}", show_sidechains=False).show()