# Info

Based on https://github.com/lucidrains/alphafold2 and https://github.com/lucidrains/egnn-pytorch with help from https://github.com/hypnopump.

# Setup

In [1]:
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F
from einops import rearrange

In [2]:
import os

# Data

In [3]:
import sidechainnet as scn
#from sidechainnet.sequence.utils import VOCAB
from sidechainnet.utils.sequence import ProteinVocabulary as VOCAB # From https://github.com/lucidrains/egnn-pytorch/blob/main/examples/egnn_test.ipynb
VOCAB = VOCAB()
from sidechainnet.structure.build_info import NUM_COORDS_PER_RES

In [4]:
scn.__version__

'v0.4.0'

In [5]:
VOCAB

ProteinVocabulary[size=21]

# Models

In [6]:
from alphafold2_pytorch import Alphafold2
import alphafold2_pytorch.constants as constants

from se3_transformer_pytorch import SE3Transformer
from alphafold2_pytorch.utils import *

# Constants

In [7]:
FEATURES = "esm" # one of ["esm", "msa", None]
DEVICE = None#  torch.device("cuda" if torch.cuda.is_available() else "cpu") # defaults to cuda if available, else cpu
NUM_BATCHES = int(1e5)
GRADIENT_ACCUMULATE_EVERY = 1 #16
LEARNING_RATE = 3e-4
IGNORE_INDEX = -100
THRESHOLD_LENGTH = 250
TO_PDB = False
SAVE_DIR = ""

In [8]:
FEATURES, DEVICE, NUM_BATCHES, GRADIENT_ACCUMULATE_EVERY, LEARNING_RATE, IGNORE_INDEX, THRESHOLD_LENGTH, TO_PDB, SAVE_DIR

('esm', None, 100000, 1, 0.0003, -100, 250, False, '')

# Set device

In [9]:
DEVICE = constants.DEVICE
DISTOGRAM_BUCKETS = constants.DISTOGRAM_BUCKETS

In [10]:
DEVICE, DISTOGRAM_BUCKETS

(device(type='cuda'), 37)

# Set embedder model from esm if appropiate - Load ESM-1b model

In [11]:
if FEATURES == "esm":
    # from pytorch hub (almost 30gb)
    embedd_model, alphabet = torch.hub.load("facebookresearch/esm", "esm1b_t33_650M_UR50S")
    ##  alternatively do
    # import esm # after installing esm
    # model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
    batch_converter = alphabet.get_batch_converter()

Using cache found in /home/mmp/.cache/torch/hub/facebookresearch_esm_master


# AF2 helpers

In [12]:
def cycle(loader, cond = lambda x: True):
    while True:
        for data in loader:
            if not cond(data):
                continue
            yield data

https://github.com/jonathanking/sidechainnet#loading-sidechainnet-with-pytorch-dataloaders<br>
`Downloaded SidechainNet to ./sidechainnet_data/sidechainnet_casp12_30.pkl.`

# Get data

In [13]:
data = scn.load(
    casp_version = 12,
    thinning = 30,
    with_pytorch = 'dataloaders',
    batch_size = 1,
    dynamic_batching = False
)

SidechainNet was loaded from ./sidechainnet_data/sidechainnet_casp12_30.pkl.


In [14]:
data.keys()

dict_keys(['train', 'train-eval', 'test', 'valid-10', 'valid-20', 'valid-30', 'valid-40', 'valid-50', 'valid-70', 'valid-90'])

In [15]:
data = iter(data['train'])
data_cond = lambda t: t[1].shape[1] < THRESHOLD_LENGTH
dl = cycle(data, data_cond)

In [16]:
#d_test = next(data)

In [17]:
#dir(d_test)

In [18]:
# d_test "keys": 'angs', 'count', 'crds', 'evos', 'index', 'int_seqs',
#                'msks', 'pids', 'ress', 'secs', 'seq_evo_sec', 'seqs'

# AF2 model

In [19]:
model = Alphafold2(
    dim = 128,
    depth = 1,
    heads = 1, # Maybe set even lower?
    dim_head = 16, # Maybe set even lower?
    predict_coords = False,
    num_backbone_atoms = 3,
).to(DEVICE)

In [20]:
#model

# AF2 optimizer

In [21]:
dispersion_weight = 0.1
criterion = nn.MSELoss()
optim = Adam(model.parameters(), lr = LEARNING_RATE)

# EGNN helpers

Based on: https://github.com/lucidrains/egnn-pytorch/blob/main/examples/egnn_test.ipynb

In [22]:
NEEDED_INFO = {"cutoffs": [], # "15_closest"
               "bond_scales": [1, 2],
               "aa_pos_scales": [2,4,8,16,32,64,128],
               "atom_pos_scales": [1,2,4,8,16,32],
               "dist2ca_norm_scales": [1,2,4],
               "bb_norms_atoms": [0.5], # will encode 3 vectors with this
               # nn-degree connection
               "adj_degree": 2
              }
# get model sizes from encoded protein
#seq, true_coords, angles, padding_seq, mask, id = train_examples_storer[-1] 
#NEEDED_INFO["seq"] = seq[:-padding_seq or None]
#NEEDED_INFO["covalent_bond"] = prot_covalent_bond(seq)

In [23]:
### adjust for egnn: 
#embedd_info["bond_n_scalars"] -= 2*len(NEEDED_INFO["bond_scales"])+1
#embedd_info["bond_n_vectors"] = 0
#embedd_info

# EGNN model

In [24]:
#from egnn_pytorch.egnn_pytorch import EGNN_Sparse_Network

In [25]:
os.sys.path.append('/home/mmp/projects/EleutherAI/egnn_pytorch_git')

In [26]:
!ls /home/mmp/projects/EleutherAI/egnn_pytorch_git

denoise_sparse.py  [0m[01;34megnn_pytorch_git[0m/  LICENSE    setup.cfg  [01;34mtests[0m/
[01;35megnn.png[0m           [01;34mexamples[0m/          README.md  setup.py


In [27]:
from egnn_pytorch_git.egnn_pytorch_git import EGNN_Sparse_Network

In [28]:
# define model
model_egnn = EGNN_Sparse_Network(n_layers=4,
                                 feats_dim=2,
                                 pos_dim = 3,
                                 edge_attr_dim = 1,
                                 m_dim = 32,
                                 fourier_features = 4,
                                 embedding_nums=[36,20],
                                 embedding_dims=[16,16],
                                 edge_embedding_nums=[3],
                                 edge_embedding_dims=[2],
                                 update_coors=True,
                                 update_feats=True, 
                                 norm_feats=False,
                                 norm_coors=False,
                                 recalc=False)

In [29]:
#??EGNN_Sparse_Network

In [30]:
model_egnn = model.to(DEVICE)

In [31]:
optimizer = torch.optim.Adam( list(model_egnn.parameters()) + \
                              list(model.parameters()), lr=1e-3)

  super(Adam, self).__init__(params, defaults)


# Import from geometric-vector-perceptron

In [32]:
os.sys.path.append('/home/mmp/projects/EleutherAI/geometric-vector-perceptron')
os.sys.path.append('/home/mmp/projects/EleutherAI/geometric-vector-perceptron/examples')

In [33]:
!ls /home/mmp/projects/EleutherAI/geometric-vector-perceptron/

[0m[01;35mdiagram.png[0m  [01;34mgeometric_vector_perceptron[0m/  README.md  setup.py
[01;34mexamples[0m/    LICENSE                       setup.cfg  [01;34mtests[0m/


In [34]:
from examples.data_utils import encode_whole_protein, prot_covalent_bond, encode_whole_bonds 

# Training loop

In [35]:
import pdb

In [36]:
#from alphafold2_pytorch.utils import get_esm_embedd

In [37]:
def get_esm_embedd_TEST(seq, embedd_model, batch_converter, embedd_type="per_tok"):
    """ Returns the ESM embeddings for a protein. 
        Inputs: 
        * seq: (L,) tensor of ints (in sidechainnet int-char convention)
        * embedd_model: ESM model (see train_end2end.py for an example)
        * batch_converter: ESM batch converter (see train_end2end.py for an example)
        * embedd_type: one of ["mean", "per_tok"]. 
                       "per_tok" is recommended if working with sequences.
    """
    str_seq = "".join([VOCAB._int2char[x]for x in seq.cpu().numpy()])
    print(f"len(str_seq): {len(str_seq)}")
    batch_labels, batch_strs, batch_tokens = batch_converter( [(0, str_seq)] )
    print(f"len(batch_labels): {len(batch_labels)}")
    print(f"len(batch_strs): {len(batch_strs)}")
    with torch.no_grad():
        results = embedd_model(batch_tokens, repr_layers=[33], return_contacts=False)
        print(f"results.keys(): {results.keys()}")
        print(f"results['representations'][33].shape: {results['representations'][33].shape}")
    # index 0 is for start token. so take from 1 one
    print(f"len(str_seq) + 1: {len(str_seq) + 1}")
    #pdb.set_trace()
    token_reps = results['representations'][33][:, 1 : len(str_seq) + 1].to(seq.device)
    if embedd_type == "mean":
        token_reps = token_reps.mean(dim=0)
    print(f"token_reps.shape: {token_reps.shape}")
    return token_reps

In [38]:
epoch_losses = []
baseline_losses = []

for _ in range(NUM_BATCHES):
    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        
        ### Stage 1
        
        batch = next(dl)
        seq, coords, mask = batch.seqs, batch.crds, batch.msks
        mask = mask.bool() # Needs to be set to bool

        b, l, _ = seq.shape

        # prepare data and mask labels
        seq, coords, mask = seq.argmax(dim = -1).to(DEVICE), coords.to(DEVICE), mask.to(DEVICE)

        # sequence embedding (msa / esm / attn / or nothing)
        msa, embedds = None, None

        # get embedds
        if FEATURES == "esm":
            print(f"seq.shape: {seq.shape}, type(embedd_model) {type(embedd_model)}, type(batch_converter) {type(batch_converter)}")
            #embedds = get_esm_embedd(seq, embedd_model, batch_converter, embedd_type="per_tok").unsqueeze(0)
            embedds = get_esm_embedd_TEST(seq[0], embedd_model, batch_converter, embedd_type="per_tok").unsqueeze(0)
            print(f"embedds.shape: {embedds.shape}")
            #pdb.set_trace()
            msa_mask = None
            #msa_mask = torch.ones_like(embedds).bool()
            #msa_mask = torch.ones_like(embedds[..., -1]).bool()
        # get msa here
        elif FEATURES == "msa":
            pass
        # no embeddings
        else:
            pass

        # predict - out is (batch, L * 3, 3)

        distogram = model(
            seq,
            msa = msa,
            embedds = embedds,
            mask = mask,
            msa_mask = msa_mask
            )
        
        ### Stage 2 - get 3d structure
        
        assert model.num_backbone_atoms > 1, 'must constitute to at least 3 atomic coordinates for backbone'


        N_mask, CA_mask, C_mask = scn_backbone_mask(seq, boolean = True, n_aa = model.num_backbone_atoms)

        cloud_mask = scn_cloud_mask(seq, boolean = True)
        chain_mask = (mask.unsqueeze(-1) * cloud_mask)
        
        bb_flat_mask = rearrange(chain_mask[..., :model.num_backbone_atoms], 'b l c -> b (l c)')
        bb_flat_mask_crossed = rearrange(bb_flat_mask, 'b i -> b i ()') * rearrange(bb_flat_mask, 'b j -> b () j')

        # structural refinement

        if model.predict_real_value_distances:
            distances, distance_std = distogram.unbind(dim = -1)
            weights = (1 / (1 + distance_std)) # could also do a distance_std.sigmoid() here
        else:
            distances, weights = center_distogram_torch(distogram)

        # set unwanted atoms to weight=0 (like C-beta in glycine)
        weights.masked_fill_( torch.logical_not(bb_flat_mask_crossed), 0.)

        coords_3d, _ = MDScaling(distances, 
            weights = weights,
            iters = model.mds_iters,
            fix_mirror = True,
            N_mask = N_mask,
            CA_mask = CA_mask,
            C_mask = C_mask
        )
        
        pred_coords = rearrange(coords_3d, 'b c n -> b n c')
        
        #Â add sidechain
        pred_coords = sidechain_container(pred_coords, n_aa = model.num_backbone_atoms, cloud_mask=cloud_mask)
        pred_coords = rearrange(pred_coords, 'b n l d -> b (n l) d')
        
        ### Stage 3 - refinement        
        
        # We need as shown in line:
        # seq, true_coords, angles, padding_seq, mask, pid = get_prot(dataloader_=dataloaders_,
        
        # From batch "keys":
        # 'angs', 'count', 'crds', 'evos', 'index', 'int_seqs', 'msks', 'pids', 'ress', 'secs', 'seq_evo_sec', 'seqs'
        
        # Assume batch dim is 1
        true_coords = coords # get the coords label from above
        angles = batch.angs.to(DEVICE)
        seq_str = ''.join([VOCAB.int2char(aa.item()) for aa in seq[0]])
        padding_seq = (seq[0] == 20).sum()
        mask = batch.msks.to(DEVICE)
        pid = batch.pids #.to(DEVICE)
        
        # encode as needed
        
        print(f"len(seq_str): {len(seq_str)}, pred_coords.shape: {pred_coords.shape}, angles.shape: {angles.shape}, padding_seq: {padding_seq}")
        encoded = encode_whole_protein(seq_str, true_coords[0], angles[0], padding_seq, needed_info=NEEDED_INFO, free_mem=True)
        x, edge_index, edge_attrs, embedd_info = encoded
        
        # add position coords - better mask accounting for missing atoms
        cloud_mask_naive = scn_cloud_mask(seq_str).bool()
        cloud_mask = scn_cloud_mask(seq_str, coords=true_coords).bool()
        if padding_seq:
            cloud_mask[-padding_seq:] = 0.
        # cloud is all points, chain is all for which we have labels
        chain_mask = mask.unsqueeze(-1) * cloud_mask
        flat_chain_mask = rearrange(chain_mask, 'l c -> (l c)')
        flat_cloud_mask = rearrange(cloud_mask, 'l c -> (l c)')
        # slice useless norm and vector embeddings
        masked_coords = true_coords[flat_cloud_mask]

        #############
        # MASK EDGES AND NODES ACCOUNTING FOR SCN MISSING ATOMS
        #############
        # NODES
        x = torch.cat([masked_coords, x[:, -2:][cloud_mask[cloud_mask_naive]] ], dim=-1)
        # EDGES: delete all edges with masked-out atoms

        # pick all current indexes and turn them to 1.
        to_mask_edges = torch.zeros(edge_index.amax()+1, edge_index.amax()+1).to(edge_index.device)
        to_mask_edges[edge_index[0], edge_index[1]] = 1.
        # delete erased bonds
        masked_out_atoms = (-1*(cloud_mask[cloud_mask_naive].float() - 1)).bool()
        to_mask_edges[masked_out_atoms] *= 0.
        to_mask_edges = to_mask_edges * to_mask_edges.t()
        # get mask for the edge_attrs
        attr_mask = to_mask_edges[edge_index[0], edge_index[1]].bool()
        edge_attrs = edge_attrs[attr_mask, :]
        # delete unwanted rows and cols
        wanted = to_mask_edges.sum(dim=-1).bool()
        edge_index = (to_mask_edges[wanted, :][:, wanted]).nonzero().t()
        #############
        # continue
        #############
        edge_attrs = edge_attrs[:, -1:]
        batch = torch.tensor([0 for i in range(x.shape[0])], device=DEVICE).long()

        if torch.amax(edge_index) >= x.shape[0]:
            print("wtf, breaking, debug, index out of bounds")
            break

        # predict
        preds = model.forward(x, edge_index, batch=batch, edge_attr=edge_attrs,
                              recalc_edge=None, verbose = False)

        # MEASURE ERROR - format pred and target
        target_coords = true_coords[flat_cloud_mask].clone()
        bae_coords    = x[:, :3]
        pred_coords   = preds[:, :3]

        # option 2: loss is RMSD on reconstructed coords  // align - sometimes svc fails - idk why
        try:
            pred_aligned, target_aligned = kabsch_torch(pred_coords.t(), target_coords.t()) # (3, N)
            base_aligned, _              = kabsch_torch(base_coords.t(), target_coords.t()) # (3, N)

            loss = ( (pred_aligned.t() - target_aligned.t())[flat_chain_mask[flat_cloud_mask]]**2 ).mean() 
            loss_base = ( (base_aligned.t() - target_aligned.t())[flat_chain_mask[flat_cloud_mask]]**2 ).mean()
        except:
            base_aligned, pred_aligned, target_aligned = None, None, None
            print("svd failed convergence, ep:", ep)
            loss = ( (pred_coords - target_coords)[flat_chain_mask[flat_cloud_mask]]**2 ).mean()
            loss_base = ( (base_coords - target_coords)[flat_chain_mask[flat_cloud_mask]]**2 ).mean()

        # back pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # records / prints
        iteration += 1
        epoch_losses.append( loss.item() )
        baseline_losses.append( loss_base.item() )

        n_print = 10
        if iteration % n_print == 1:
            tic = time.time()
            print("BATCH: {0} / {1}, loss: {2}, baseline_loss: {3}, time: {4}".format(iteration, n_per_iter,
                                                                                      np.mean(epoch_losses[-n_print:]),
                                                                                      np.mean(baseline_losses[-n_print:]),
                                                                                      tic-tac))
            tac = time.time()
            if iteration % n_per_iter == 1:
                print("---------------------------------")

seq.shape: torch.Size([1, 185]), type(embedd_model) <class 'esm.model.ProteinBertModel'>, type(batch_converter) <class 'esm.data.BatchConverter'>
len(str_seq): 185
len(batch_labels): 1
len(batch_strs): 1
results.keys(): dict_keys(['logits', 'representations'])
results['representations'][33].shape: torch.Size([1, 187, 1280])
len(str_seq) + 1: 186
token_reps.shape: torch.Size([1, 185, 1280])
embedds.shape: torch.Size([1, 1, 185, 1280])
it: 0, stress tensor([256715.1250], device='cuda:0', grad_fn=<MulBackward0>)
it: 1, stress tensor([307274.5625], device='cuda:0', grad_fn=<MulBackward0>)
breaking at iteration 1 with stress tensor([97526.3828], device='cuda:0', grad_fn=<DivBackward0>)
Corrected mirror idxs: tensor([0])
len(seq_str): 185, pred_coords.shape: torch.Size([1, 2590, 3]), angles.shape: torch.Size([1, 185, 12]), padding_seq: 0


KeyError: 'seq'

In [None]:
%debug

In [51]:
633/14

45.214285714285715

In [36]:
??Alphafold2

In [34]:
type(model)

alphafold2_pytorch.alphafold2.Alphafold2

In [33]:
model.num_backbone_atoms

1

In [None]:
        ### Old stuff:

#        # build SC container. set SC points to CA and optionally place carbonyl O
#        proto_sidechain = sidechain_container(refined, n_aa=batch,
#                                              cloud_mask=cloud_mask, place_oxygen=False)
#
#        # rotate / align
#        coords_aligned, labels_aligned = Kabsch(refined, coords[flat_cloud_mask])
#
#        # atom mask
#
#        cloud_mask = scn_cloud_mask(seq, boolean = False)
#        flat_cloud_mask = rearrange(cloud_mask, 'b l c -> b (l c)')
#
#        # chain_mask is all atoms that will be backpropped thru -> existing + trainable
#
#        chain_mask = (mask * cloud_mask)[cloud_mask]
#        flat_chain_mask = rearrange(chain_mask, 'b l c -> b (l c)')
#
#        # save pdb files for visualization
#
#        if TO_PDB:
#            # idx from batch to save prot and label
#            idx = 0
#            coords2pdb(seq[idx, :, 0], coords_aligned[idx], cloud_mask, prefix=SAVE_DIR, name="pred.pdb")
#            coords2pdb(seq[idx, :, 0], labels_aligned[idx], cloud_mask, prefix=SAVE_DIR, name="label.pdb")
#
#        # loss - RMSE + distogram_dispersion
#        loss = torch.sqrt(criterion(coords_aligned[flat_chain_mask], labels_aligned[flat_chain_mask])) + \
#                          dispersion_weight * torch.norm( (1/weights)-1 )
#
#        loss.backward()
#    print('loss:', loss.item())
#
#    optim.step()
#    optim.zero_grad()

# End