## Basic Setup

Run the cells below for the basic setup of this notebook.

In [1]:
try:
    from google.colab import drive # type: ignore
    IN_COLAB = True
except:
    IN_COLAB = False
    print('No colab environment, assuming local setup.')

if IN_COLAB:
    drive.mount('/content/drive')

    # TODO: Enter the foldername in your Drive where you have saved the unzipped
    # turorials folder, e.g. 'alphafold-decoded/tutorials'
    FOLDERNAME = None
    assert FOLDERNAME is not None, "[!] Enter the foldername."

    # Now that we've mounted your Drive, this ensures that
    # the Python interpreter of the Colab VM can load
    # python files from within it.
    import sys
    sys.path.append('/content/drive/My Drive/{}'.format(FOLDERNAME))
    %cd /content/drive/My\ Drive/$FOLDERNAME

    print('Connected COLAB to Google Drive.')

import os
    
base_folder = 'structure_module'
control_folder = f'{base_folder}/control_values'

assert os.path.isdir(control_folder), 'Folder "control_values" not found, make sure that FOLDERNAME is set correctly.' if IN_COLAB else 'Folder "control_values" not found, make sure that your root folder is set correctly.'

In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
from matplotlib import pyplot as plt
import torch
import math

In [4]:
torch.set_grad_enabled(False)

# Structure Module

The Structure Module is the final part of Alphafold. It takes the single representation from the Evoformer and directly predicts the 3D positions of each heavy atom. So far, the model used very little geometric information. The Evoformer uses mostly grid-based, column-wise or row-wise operations on its inputs. The Structure Module is different in this regard. 

## Invariant Point Attention

Invariant point attention is the core, geometric mechanism in the structure module. The idea is the following: The structure module runs multiple iterations per pass of the full model, updating its current guess on the backbone transforms in each iteration. It starts with a so-called 'Black-Hole-Initialization', where all backbones start at position 0 with identity transforms. The invariant point attention module uses the transforms from the last pass. It samples its query and key points in the local coordinate frames of each amino acid. That means, it's baked into the attention mechanism, that attention is increased for residues that are close to each other (according to the latest guess of the residue's positions). 

The module is desribed in Algorithm 22. Take a first look at it and start by implementing the `__init__` method and `prepare_qkv` in `ipa.py`. Check your code by running the following cell.

In [5]:
from structure_module.control_values.structure_module_checks import test_module_shape, test_module_method
from structure_module.control_values.structure_module_checks import c_s, c_z, n_qp, n_pv, N_head, c
from structure_module.ipa import InvariantPointAttention

ipa = InvariantPointAttention(c_s, c_z, n_qp, n_pv, N_head, c)

test_module_shape(ipa, 'ipa', control_folder)

test_module_method(ipa, 'ipa_prep', 's', ('q', 'k', 'v', 'qp', 'kp', 'vp'), control_folder, lambda x: ipa.prepare_qkv(x))

Next, we will go through the computation of the attention scores, which is line 5 to line 7 in Algorithm 22. This is where the magic happens. Don't look to much at wc and wl here, they are just picked in a way that, given normal initialization, the three sources of attention - q/k, bias, and qp/kp - contribute about equally to the attention scores, and gamma is a learnable parameter that can adjust this weighting.

The interesting part is the contribution of the attention points. `qp` and `kp`, as far as we implemented them, just arise from embedding the single representation. In machine learning, we often assume these values as distributed following a standard normal distribution. This is reinforced by the layer normalization we put in at several stages in the model. This means, `qp` and `kp` are in a spherical, normal distribution around the coordinate origin. 

But the attention contribution from `qp` and `kp` doesn't arise from this position. Instead, they are warped through the backbone transforms `T` before their difference is computed. This warping through the backbone transforms means interpreting the query and key points as local coordinates in the backbone transforms. The results are the global positions of the key and query points. 

**In this sense, Alphafold calculates the key and query points by adding an offset to the backbone positions, as inferred by the model so far.**

The distance of these key and query points is subtracted from the attention scores, meaning that pairs far apart from each other contribute little to the update, while close pairs contribute strongly.

Implement the method `compute_attention_scores` and check your implementation by running the following cell.

In [6]:
from structure_module.control_values.structure_module_checks import test_module_method
from structure_module.control_values.structure_module_checks import c_s, c_z, n_qp, n_pv, N_head, c
from structure_module.ipa import InvariantPointAttention

ipa = InvariantPointAttention(c_s, c_z, n_qp, n_pv, N_head, c)

test_module_method(ipa, 'ipa_att_scores', ('q', 'k', 'qp', 'kp', 'z', 'T'), 'att', control_folder, lambda *x: ipa.compute_attention_scores(*x))

The attention scores are used to measure the contribution of three different features to the output: The pair representation `z`, the value vectors `v` and the value points `vp`. The direct computation of these outputs is straightforward, but it's good practice for your einsum strings. 

Look at line 10 in Algorithm 22 carefully: The value points are mapped through the transform before the attention weighing, that mapped back afterwards. Let's assume first the attention scores are close to a one-hot vector, meaning for fixed indices $i$ and $h$, for one index $j$ the score is 1 while the other ones are 0. Then, this line simplifies to $T_i^{-1} \circ T_j \circ \vec{v}_j^{hp}$. That means the value vector is sampled around the backbone transform of residue $j$, and for the update, the coordinates of this point are calculated with respect to the backbone transform of residue $i$. If the attention scores are not one-hot, the value vectors are sampled around several of the backbone transforms and averaged before being localized to the transform of residue $i$.

Implement the method `compute_outputs` and check your implementation by running the following cell.

In [7]:
from structure_module.control_values.structure_module_checks import test_module_method
from structure_module.control_values.structure_module_checks import c_s, c_z, n_qp, n_pv, N_head, c
from structure_module.ipa import InvariantPointAttention

ipa = InvariantPointAttention(c_s, c_z, n_qp, n_pv, N_head, c)

test_module_method(ipa, 'ipa_att_outputs', ('att_scores', 'z', 'v', 'vp', 'T'), ('v_out', 'vp_out', 'vp_outnorm', 'pairwise_out'), control_folder, lambda *x: ipa.compute_outputs(*x))

Now, we've got all the complicated parts together. Assemble them in the forward method of `InvariantPointAttention` to finalize the module. Then, check your code with the following cell.

In [8]:
from structure_module.control_values.structure_module_checks import test_module_forward
from structure_module.control_values.structure_module_checks import c_s, c_z, n_qp, n_pv, N_head, c
from structure_module.ipa import InvariantPointAttention

ipa = InvariantPointAttention(c_s, c_z, n_qp, n_pv, N_head, c)

test_module_forward(ipa, 'ipa', ('s', 'z', 'T'), 'out', control_folder)

## Structure Module

With IPA, the hard part of the Structure Module is already done. The rest is mostly about stitching together a few modules. 

We'll start with `StructureModuleTransition`, which are lines 8-9 in Algorithm 20. Implement the initialization and forward pass, then check your method with the following cell. 

In [9]:
from structure_module.control_values.structure_module_checks import test_module_forward, test_module_shape, c_s
from structure_module.structure_module import StructureModuleTransition

transition = StructureModuleTransition(c_s)

test_module_shape(transition, 'sm_transition', control_folder)

test_module_forward(transition, 'sm_transition', 's', 's_out', control_folder)

Next up is `BackboneUpdate`. All it's doing is embedding the single representation into a 6-value vector, which is split into three values for a quaternion (padded with 1 to reach four values) and three values for a translation. The quaternion is normalized and converted into a rotation matrix. We already implemented this conversion in the geometry section.

Implement the initialization and forward pass for `BackboneUpdate`, then check your implementation by running the following cell.

In [10]:
from structure_module.control_values.structure_module_checks import test_module_forward, test_module_shape, c_s
from structure_module.structure_module import BackboneUpdate

bb_update = BackboneUpdate(c_s)

test_module_shape(bb_update, 'bb_update', control_folder)

test_module_forward(bb_update, 'bb_update', 's', 'T_out', control_folder)

Next, we will compute the prediction of the side-chain torsion angles. These are lines 11-14 in Algorithm 20. We'll start by implementing one layer of this so called AngleResNet, which means one of the lines 12 and 13 (the ResNet has two layers). Implement the initialization and forward pass for `AngleResNetLayer`. Check your code with the following cell.

In [11]:
from structure_module.control_values.structure_module_checks import test_module_forward, test_module_shape, c
from structure_module.structure_module import AngleResNetLayer

resnet_layer = AngleResNetLayer(c)

test_module_shape(resnet_layer, 'resnet_layer', control_folder)

test_module_forward(resnet_layer, 'resnet_layer', 'a', 'a_out', control_folder)

The AngleResNet combines two of these layers with additional input and output layers. The output layer predicts the torsion angles in the format of unnormalized (cos(phi), sin(phi)) pairs. These are mapped back to the unit circle by normalization and are used directly in this form, without actually computing phi. Implement the initialization and forward pass for `AngleResNet`, then check your code by running the following cell.

In [12]:
from structure_module.control_values.structure_module_checks import test_module_forward, test_module_shape, c_s, c
from structure_module.structure_module import AngleResNet

angle_resnet = AngleResNet(c_s, c)

test_module_shape(angle_resnet, 'angle_resnet', control_folder)

test_module_forward(angle_resnet, 'angle_resnet', ('s', 's_initial'), 'alpha', control_folder)

We've got all the parts for the Structure Module. Now, we put them all together. Start by implementing the `__init__` method of the Structure Module and check your code with the following cell.

In [13]:
from structure_module.control_values.structure_module_checks import test_module_shape, c_s, c_z, c, n_layer
from structure_module.structure_module import StructureModule

sm = StructureModule(c_s, c_z, n_layer, c)

test_module_shape(sm, 'structure_module', control_folder)

Now, we'll implement `process_outputs`. It has two tasks: First, it calls `compute_all_atom_coordinates` to compute the heavy-atom positions from the backbone transforms and the torsion angles. Second, it selects the pseudo-beta positions from all atom positions. These are used by the recycling embedder for the next iteration of the network. They are the positions of the C-beta atoms (for each amino acid except glycine), or the C-alpha atoms (for glycine, which doesn't have a C-beta atom). Implement the method and check your implementation by running the following cell. 

You don't need to support batched use (this makes selections easier, and we didn't enforce batched support in `compute_all_atom_coordinates`). If you want to implement it and check your implementation, you can remove the 'include_batched=False' flag.

In [14]:
from structure_module.control_values.structure_module_checks import test_module_method, c_s, c_z, c, n_layer
from structure_module.structure_module import StructureModule

sm = StructureModule(c_s, c_z, n_layer, c)

test_module_method(sm, 'sm_process_outputs', ('T', 'alpha', 'F'), ('pos', 'pos_mask', 'pseudo_beta'), control_folder, lambda *x: sm.process_outputs(*x), include_batched=False)

As the very last step for the Structure Module, we'll implement the forward pass, which chains together all the modules we implemented earlier. Implement `forward` and check your implementation with the following cell.

In [15]:
from structure_module.control_values.structure_module_checks import test_module_forward, c_s, c_z, c, n_layer
from structure_module.structure_module import StructureModule

sm = StructureModule(c_s, c_z, n_layer, c)

def check(*args):
    output = sm(*args)
    return output['angles'], output['frames'], output['final_positions'], output['position_mask'], output['pseudo_beta_positions']


test_module_method(sm, 'structure_module', ('s', 'z', 'F'), ('angles', 'frames', 'final_positions', 'position_mask', 'pseudo_beta_positions'), control_folder, check)

## Conclusion