## Basic Setup

Run the cells below for the basic setup of this notebook.

In [2]:
try:
    from google.colab import drive # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False
    print('No colab environment, assuming local setup.')

if IN_COLAB:
    drive.mount('/content/drive')

    # TODO: Enter the foldername in your Drive where you have saved the unzipped
    # turorials folder, e.g. 'alphafold-decoded/tutorials'
    FOLDERNAME = None
    assert FOLDERNAME is not None, "[!] Enter the foldername."

    # Now that we've mounted your Drive, this ensures that
    # the Python interpreter of the Colab VM can load
    # python files from within it.
    import sys
    sys.path.append('/content/drive/My Drive/{}'.format(FOLDERNAME))
    %cd /content/drive/My\ Drive/$FOLDERNAME
    %pip install py3dmol
    %pip install modelcif

    print('Connected COLAB to Google Drive.')

import os
    
base_folder = 'model'
control_folder = f'{base_folder}/control_values'

assert os.path.isdir(control_folder), 'Folder "control_values" not found, make sure that FOLDERNAME is set correctly.' if IN_COLAB else 'Folder "control_values" not found, make sure that your root folder is set correctly.'

In [3]:
# Run this to download the Openfold Weights
import subprocess

if not os.path.isdir('model/openfold_params') or not os.listdir('model/openfold_params'):
    print('Download Openfold weights...')
    %pip install awscli
    subprocess.call(['bash', 'model/download_openfold_params.sh', 'model'])
else:
    print('Weights folder already exists.')


In [4]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [5]:
from matplotlib import pyplot as plt
import torch
import math

In [6]:
torch.set_grad_enabled(False)

# Model

We are ready. All the parts of Alphafold are implemented, and all that's left to do is to stitch them together into the full model. 

It is described in Algorithm 2 in the paper. Note that we omit the template stack, and we don't do ensembling. Therefore, you can omit the lines 3, 4, 19, 20, and 7-13.

Go to `model.py` and implement the initialization and the forward pass. After that, check your code by running the following cell. You might have to fix some datatype missmatches that fell through earlier checks. If they come up, go back to the methods and make sure that any tensors you create in them are of the right dtype (you can grab the dtype from one of the arguments of the method).

In [7]:
from model.model import Model
from model.control_values.model_checks import c_m, c_z, c_e, f_e, tf_dim, c_s, num_blocks_evoformer, num_blocks_extra_msa
from model.control_values.model_checks import test_module_shape, test_module_method

model = Model(c_m, c_z, c_e, f_e, tf_dim, c_s, num_blocks_extra_msa, num_blocks_evoformer)

test_module_shape(model, 'model', control_folder)

def test_method(*args):
    outputs = model(*args)
    return outputs['final_positions'], outputs['position_mask'], outputs['angles'], outputs['frames']


test_module_method(model, 'model', 'batch', ('final_positions', 'position_mask', 'angles', 'frames'), control_folder, test_method)

## Building the Real Model

Our model passed the check for the toy input. Now it's time to load the model in real size, with real weights, and do a real evaluation pass. For this, it's much faster to use the GPU. If you are in Colab, make sure select a GPU under "Runtime -> Change runtime type" if you haven't done so before.

In [9]:
if torch.cuda.is_available():
    device = 'cuda'
    print('Compatible GPU available.')
elif torch.backends.mps.is_available():
    device = 'mps'
    print('MPS (Metal Performance Shaders) available.')
else:
    device = 'cpu'
    print('No compatible GPU, fallback to CPU.')


We are using the weights from Openfold for our model, as it is also written in PyTorch. However, we made some tweaks to the data without changing the actual weights (like renaming a few modules, swapping the angles so that they are predicted as (cos(phi), sin(phi)) like it's said in the paper, not (sin(phi), cos(phi))). You can look at the tweaks in `utils.py` in the method `load_openfold_weights`. 

If you named all your modules according to the descriptions, the following cell should run cleanly and tell you that all keys in the state dictionary matched successfully.

In [10]:
from model.utils import load_openfold_weights
model = Model()
openfold_weights = load_openfold_weights('model/openfold_params/finetuning_2.pt')

model.load_state_dict(openfold_weights)

As the first step in the forward pass, we'll call `create_features_from_a3m` four times in a row, to get four different input batches (different because of the random selection of cluster centers) four the four different cycles. 

In [12]:
from feature_extraction.feature_extraction import create_features_from_a3m

batch = None

##########################################################################
# TODO: Compute four single batches with create_features_from_a3m.       #
#   Then, stack all the tensors in the list to get the complete batch.   #
##########################################################################

single_cycle_batches = []
for i in range(4):
    single_cycle_batch = create_features_from_a3m('feature_extraction/alignment_tautomerase.a3m')
    single_cycle_batches.append(single_cycle_batch)

batch = {
    key: torch.stack([single_batch[key] for single_batch in single_cycle_batches], dim=-1)
    for key in single_cycle_batches[0].keys()
}

##########################################################################
#               END OF YOUR CODE                                         #
##########################################################################

expected_shapes = {
    'msa_feat': (512, 59, 49, 4),
    'extra_msa_feat': (5120, 59, 25, 4),
    'target_feat': (59, 21, 4),
    'residue_index': (59, 4),
}

shapes = { key: value.shape for key, value in batch.items() }
assert set(expected_shapes.keys()) == set(shapes.keys())
for key, shape in shapes.items():
    assert expected_shapes[key] == shape, f'Shape mismatch for {key}: {shape} vs {expected_shapes[key]}'




For a faster forward pass, we'll map the model and the inputs to the GPU, if a compatible one is available.

In [13]:
model.to(device)
for key, value in batch.items():
    batch[key] = value.to(device)

Now it's time for the forward pass. If you weren't careful with mapping all the tensors you created within the methods to the correct device (which you can grab from any of the input tensors of the methods), you'll have to do some debugging before this runs through.

In [14]:
with torch.no_grad():
    outputs = model(batch)

We wrote a little wrapper around the modelcif python library, to create an mmcif file from our atom positions. If you are interested, you can inspect it in `utils.py`. The following cell uses this method to create `prediction.mmcif`.

In [15]:
from model.utils import to_modelcif
from geometry.residue_constants import restypes

atom_positions = outputs['final_positions'][..., -1]
atom_mask = outputs['position_mask'][..., -1]
seq_inds = batch['target_feat'].cpu()[..., -1].argmax(dim=-1).numpy()
seq = ''.join([restypes[ind] for ind in seq_inds])

cif_str = to_modelcif(atom_positions, atom_mask, seq)

with open('model/prediction.mmcif', 'w') as f:
    f.write(cif_str)

The py3Dmol library allows us to visualize our prediction in Jupyter Notebooks. You can uncomment the second `view = ...` line and comment the first one out, to display the protein prediction alongside the crystal structure from the PDB (if you do so, the little sticks in the background are the second part of the dimer from the PDB).

In [16]:
import py3Dmol

view = py3Dmol.view()
view = py3Dmol.view(query='2OP8')

with open('model/prediction.mmcif', 'r') as f:
    data = f.read()
view.addModel(data, 'mmcif')
view.addModel('2OP8')
view.setStyle({'chain': 'A'}, {'cartoon': {'color':'spectrum'}})
view.zoomTo()
view

If you are working in Colab, you can run the following cell to download the resulting mmcif file. Open the file in a protein view like ChimeraX on your local machine. Load the crystal structure as well. You can use the 'Tools -> Structure Analysis -> Matchmaker' to align your prediction with the template. The root mean square deviation (the error) for the alignment should be really small.

In [14]:
# If you're in Colab, you can run this cell to download your prediction
if IN_COLAB:
    from google.colab import files # type: ignore
    files.download('model/prediction.mmcif')