# Info

Based on https://github.com/lucidrains/alphafold2 and https://github.com/lucidrains/egnn-pytorch with help from https://github.com/hypnopump.

# Setup

In [1]:
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader
from torch import nn
import torch.nn.functional as F
from einops import rearrange

# Data

In [2]:
import sidechainnet as scn
#from sidechainnet.sequence.utils import VOCAB
from sidechainnet.utils.sequence import ProteinVocabulary as VOCAB # From https://github.com/lucidrains/egnn-pytorch/blob/main/examples/egnn_test.ipynb
VOCAB = VOCAB()
from sidechainnet.structure.build_info import NUM_COORDS_PER_RES

In [3]:
VOCAB

ProteinVocabulary[size=21]

# Models

In [4]:
from alphafold2_pytorch import Alphafold2
import alphafold2_pytorch.constants as constants

from se3_transformer_pytorch import SE3Transformer
from alphafold2_pytorch.utils import *

# Constants

In [5]:
FEATURES = "esm" # one of ["esm", "msa", None]
DEVICE = None # defaults to cuda if available, else cpu
NUM_BATCHES = int(1e5)
GRADIENT_ACCUMULATE_EVERY = 1 #16
LEARNING_RATE = 3e-4
IGNORE_INDEX = -100
THRESHOLD_LENGTH = 250
TO_PDB = False
SAVE_DIR = ""

In [6]:
FEATURES, DEVICE, NUM_BATCHES, GRADIENT_ACCUMULATE_EVERY, LEARNING_RATE, IGNORE_INDEX, THRESHOLD_LENGTH, TO_PDB, SAVE_DIR

('esm', None, 100000, 1, 0.0003, -100, 250, False, '')

# Set device

In [7]:
DEVICE = constants.DEVICE
DISTOGRAM_BUCKETS = constants.DISTOGRAM_BUCKETS

In [8]:
DEVICE, DISTOGRAM_BUCKETS

(device(type='cuda'), 37)

# Set embedder model from esm if appropiate - Load ESM-1b model

In [9]:
if FEATURES == "esm":
    # from pytorch hub (almost 30gb)
    embedd_model, alphabet = torch.hub.load("facebookresearch/esm", "esm1b_t33_650M_UR50S")
    ##  alternatively do
    # import esm # after installing esm
    # model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
    batch_converter = alphabet.get_batch_converter()

Using cache found in /home/mmp/.cache/torch/hub/facebookresearch_esm_master


# AF2 helpers

In [10]:
def cycle(loader, cond = lambda x: True):
    while True:
        for data in loader:
            if not cond(data):
                continue
            yield data

def get_esm_embedd(seq):
    str_seq = "".join([VOCAB.int2char(x) for x in seq.squeeze(0).cpu().numpy()])
    batch_labels, batch_strs, batch_tokens = batch_converter( [(0, str_seq)] )
    with torch.no_grad():
        results = embedd_model(batch_tokens, repr_layers=[33], return_contacts=False)
    return results["representations"][33].to(DEVICE)

https://github.com/jonathanking/sidechainnet#loading-sidechainnet-with-pytorch-dataloaders<br>
`Downloaded SidechainNet to ./sidechainnet_data/sidechainnet_casp12_30.pkl.`

# Get data

In [26]:
data = scn.load(
    casp_version = 12,
    thinning = 30,
    with_pytorch = 'dataloaders',
    batch_size = 1,
    dynamic_batching = False
)

SidechainNet was loaded from ./sidechainnet_data/sidechainnet_casp12_30.pkl.


In [27]:
data.keys()

dict_keys(['train', 'train-eval', 'test', 'valid-10', 'valid-20', 'valid-30', 'valid-40', 'valid-50', 'valid-70', 'valid-90'])

In [28]:
data = iter(data['train'])
data_cond = lambda t: t[1].shape[1] < THRESHOLD_LENGTH
dl = cycle(data, data_cond)

In [29]:
#d_test = next(data)

In [35]:
#dir(d_test)

In [None]:
# d_test "keys": 'angs', 'count', 'crds', 'evos', 'index', 'int_seqs',
#                'msks', 'pids', 'ress', 'secs', 'seq_evo_sec', 'seqs'

# AF2 model

In [14]:
model = Alphafold2(
    dim = 128,
    depth = 1,
    heads = 1, # Maybe set even lower?
    dim_head = 16, # Maybe set even lower?
    predict_coords = False,
    num_backbone_atoms = 3,
    structure_module_dim = 8,
    structure_module_depth = 2,
    structure_module_heads = 4,
    structure_module_dim_head = 16,
    structure_module_refinement_iters = 2
).to(DEVICE)

In [15]:
#model

# AF2 optimizer

In [16]:
dispersion_weight = 0.1
criterion = nn.MSELoss()
optim = Adam(model.parameters(), lr = LEARNING_RATE)

# EGNN helpers

Based on: https://github.com/lucidrains/egnn-pytorch/blob/main/examples/egnn_test.ipynb

In [46]:
def encode_whole_protein(seq, true_coords, padding_seq,
                         needed_info = { "cutoffs": [2, 5, 10],
                                          "bond_scales": [0.5, 1, 2]}, free_mem=False):
    """ Encodes a whole protein. In points + vectors. """
    device, precise = true_coords.device, true_coords.type()
    #################
    # encode points #
    #################
    cloud_mask = torch.tensor(scn_cloud_mask(seq[:-padding_seq or None])).bool().to(device)
    flat_mask = rearrange(cloud_mask, 'l c -> (l c)')
    coords_wrap = rearrange(true_coords, '(l c) d -> l c d', c=14)[:-padding_seq or None] 
    # embedd everything

    # position in backbone embedding
    aa_pos = encode_dist( torch.arange(len(seq[:-padding_seq or None]), device=device).float(), scales=needed_info["aa_pos_scales"])
    atom_pos = chain2atoms(aa_pos)[cloud_mask]

    # atom identity embedding
    atom_id_embedds = torch.stack([SUPREME_INFO[k]["atom_id_embedd"] for k in seq[:-padding_seq or None]], 
                                  dim=0)[cloud_mask].to(device)
    # aa embedding
    seq_int = torch.tensor([AAS2NUM[aa] for aa in seq[:-padding_seq or None]], device=device).long()
    aa_id_embedds   = chain2atoms(seq_int, mask=cloud_mask)

    ################
    # encode bonds #
    ################
    bond_info = encode_whole_bonds(x = coords_wrap[cloud_mask],
                                   x_format = "coords",
                                   embedd_info = {},
                                   needed_info = needed_info )
    whole_bond_idxs, whole_bond_enc, bond_embedd_info = bond_info
    #########
    # merge #
    #########

    # concat so that final is [vector_dims, scalar_dims]
    point_n_vectors = 0
    point_n_scalars = 2*len(needed_info["aa_pos_scales"]) + 1 +\
                      2 # the last 2 are to be embedded yet

    whole_point_enc = torch.cat([ atom_pos, # 2n+1
                                  atom_id_embedds.unsqueeze(-1),
                                  aa_id_embedds.unsqueeze(-1) ], dim=-1) # the last 2 are yet to be embedded
    if free_mem:
        del cloud_mask, atom_pos, atom_id_embedds, aa_id_embedds

    # record embedding dimensions
    point_embedd_info = {"point_n_vectors": point_n_vectors,
                         "point_n_scalars": point_n_scalars,}

    embedd_info = {**point_embedd_info, **bond_embedd_info}

    return whole_point_enc, whole_bond_idxs, whole_bond_enc, embedd_info

In [47]:
NEEDED_INFO = {"cutoffs": [], # "15_closest"
               "bond_scales": [1, 2],
               "aa_pos_scales": [2,4,8,16,32,64,128],
               "atom_pos_scales": [1,2,4,8,16,32],
               "dist2ca_norm_scales": [1,2,4],
               "bb_norms_atoms": [0.5], # will encode 3 vectors with this
               # nn-degree connection
               "adj_degree": 2
              }
# get model sizes from encoded protein
#seq, true_coords, angles, padding_seq, mask, id = train_examples_storer[-1] 
#NEEDED_INFO["seq"] = seq[:-padding_seq or None]
#NEEDED_INFO["covalent_bond"] = prot_covalent_bond(seq)

In [48]:
### adjust for egnn: 
#embedd_info["bond_n_scalars"] -= 2*len(NEEDED_INFO["bond_scales"])+1
#embedd_info["bond_n_vectors"] = 0
#embedd_info

In [49]:
def prot_covalent_bond(seq, adj_degree=1, cloud_mask=None):
    """ Returns the idxs of covalent bonds for a protein.
        Inputs 
        * seq: str. Protein sequence in 1-letter AA code.
        * cloud_mask: mask selecting the present atoms.
        Outputs: edge_idxs
    """
    # create or infer cloud_mask
    if cloud_mask is None: 
        cloud_mask = scn_cloud_mask(seq).bool()
    device, precise = cloud_mask.device, cloud_mask.type()
    # get starting poses for every aa
    scaff = torch.zeros_like(cloud_mask)
    scaff[:, 0] = 1
    idxs = scaff[cloud_mask].nonzero().view(-1)
    # get poses + idxs from the dict with GVP_DATA - return all edges
    adj_mat = torch.zeros(idxs.amax()+14, idxs.amax()+14)
    attr_mat = torch.zeros_like(adj_mat)
    for i,idx in enumerate(idxs):
        # bond with next aa
        extra = []
        if i < idxs.shape[0]-1:
            extra = [[2, (idxs[i+1]-idx).item()]]

        bonds = idx + torch.tensor( GVP_DATA[seq[i]]['bonds'] + extra ).long().t() 
        adj_mat[bonds[0], bonds[1]] = 1.

    # convert to undirected
    adj_mat = adj_mat + adj_mat.t()
    # do N_th degree adjacency
    for i in range(adj_degree):
        if i == 0:
            attr_mat += adj_mat
            continue

        adj_mat = (adj_mat @ adj_mat).bool().float() 
        attr_mat[ (adj_mat - attr_mat.bool().float()).bool() ] += i+1

    edge_idxs = attr_mat.nonzero().t().long()
    edge_attrs = attr_mat[edge_idxs[0], edge_idxs[1]]

    return edge_idxs, edge_attrs

# EGNN model

In [None]:
from egnn_pytorch import EGNN_Sparse_Network

In [None]:
# define model
model_egnn = EGNN_Sparse_Network(n_layers=4,
                                 feats_dim=2, pos_dim = 3,
                                 edge_attr_dim = 1, m_dim = 32,
                                 fourier_features = 4,
                                 embedding_nums=[36,20], embedding_dims=[16,16],
                                 edge_embedding_nums=[3], edge_embedding_dims=[2],
                                 update_coors=True, update_feats=True, 
                                 norm_feats=False, norm_coors=False, recalc=False)

In [None]:
model_egnn = model.to(DEVICE)

In [None]:
noise = 1
optimizer = torch.optim.Adam(model_egnn.parameters(), lr=1e-3)

# Training loop

In [None]:
for _ in range(NUM_BATCHES):
    for _ in range(GRADIENT_ACCUMULATE_EVERY):
        
        ### Stage 1
        
        batch = next(dl)
        seq, coords, mask = batch.seqs, batch.crds, batch.msks
        mask = mask.bool() # Needs to be set to bool

        b, l, _ = seq.shape

        # prepare data and mask labels
        seq, coords, mask = seq.argmax(dim = -1).to(DEVICE), coords.to(DEVICE), mask.to(DEVICE)
        # coords = rearrange(coords, 'b (l c) d -> b l c d', l = l) # no need to rearrange for now
        # mask the atoms and backbone positions for each residue

        # sequence embedding (msa / esm / attn / or nothing)
        msa, embedds = None, None

        # get embedds
        if FEATURES == "esm":
            #embedds = get_esm_embedd(seq)
            embedds = get_esm_embedd(seq).unsqueeze(0)
            msa_mask = None
            #msa_mask = torch.ones_like(embedds).bool()
            #msa_mask = torch.ones_like(embedds[..., -1]).bool()
        # get msa here
        elif FEATURES == "msa":
            pass
        # no embeddings
        else:
            pass

        # predict - out is (batch, L * 3, 3)

        refined = model(
            seq,
            msa = msa,
            embedds = embedds,
            mask = mask,
            msa_mask = msa_mask
            )
        
        ### Stage 2
        
        distance_pred = refined # is this correct?
        
        
        # prepare mask for backbone coordinates

        assert model.num_backbone_atoms > 1, 'must constitute to at least 3 atomic coordinates for backbone'

        N_mask, CA_mask, C_mask = scn_backbone_mask(seq, boolean = True, n_aa = model.num_backbone_atoms)

        cloud_mask = scn_cloud_mask(seq, boolean=True)
        flat_cloud_mask = rearrange(cloud_mask, 'b l c -> b (l c)')
        chain_mask = (mask.unsqueeze(-1) * cloud_mask)
        flat_chain_mask = rearrange(chain_mask, 'b l c -> b (l c)')

        bb_mask = rearrange(chain_mask[:, :, :model.num_backbone_atoms], 'b l c -> b (l c)')
        bb_mask_crossed = rearrange(bb_mask, 'b i -> b i ()') * rearrange(bb_mask, 'b j -> b () j')

        # structural refinement

        if model.predict_real_value_distances:
            distances, distance_std = distance_pred.unbind(dim = -1)
            weights = (1 / (1 + distance_std)) # could also do a distance_std.sigmoid() here
        else:
            distances, weights = center_distogram_torch(distance_pred)

        weights.masked_fill_(bb_mask_crossed, 0.)

        coords_3d, _ = MDScaling(distances, 
            weights = weights,
            iters = model.mds_iters,
            fix_mirror = True,
            N_mask = N_mask,
            CA_mask = CA_mask,
            C_mask = C_mask
        )
        coords = rearrange(coords_3d, 'b c n -> b n c')
        
        ### Stage 3
        
        # See below for code from EGNN loop:
        ## encode as needed
        #encoded = encode_whole_protein(seq, true_coords, padding_seq, needed_info=NEEDED_INFO, free_mem=True)
        #x, edge_index, edge_attrs, embedd_info = encoded
        ## add position coords
        #cloud_mask = scn_cloud_mask(seq)
        #if padding_seq:
        #    cloud_mask[-padding_seq:] = 0.
        #cloud_mask = cloud_mask.bool()
        #flat_cloud_mask = rearrange(cloud_mask, 'l c -> (l c)')
        #x = torch.cat([true_coords[flat_cloud_mask], x ], dim=-1)
        
        
        # We need as shown in line:
        # seq, true_coords, angles, padding_seq, mask, pid = get_prot(dataloader_=dataloaders_,
        
        # From batch "keys":
        # 'angs', 'count', 'crds', 'evos', 'index', 'int_seqs', 'msks', 'pids', 'ress', 'secs', 'seq_evo_sec', 'seqs'
        
        # seq: take from above
        true_coords = coords
        angles = batch.angs.to(DEVICE)
        # padding_seq ?
        mask = batch.msks.to(DEVICE)
        pid = batch.pids #.to(DEVICE)
        
        # encode as needed
        masked_coords = true_coords + noise * torch.randn_like(true_coords) # (*2-1)
        encoded = encode_whole_protein(seq, true_coords, padding_seq, needed_info=NEEDED_INFO, free_mem=True)
        x, edge_index, edge_attrs, embedd_info = encoded
        
        # add position coords - better mask accounting for missing atoms
        cloud_mask_naive = scn_cloud_mask(seq).bool()
        cloud_mask = scn_cloud_mask(seq, coords=true_coords).bool()
        if padding_seq:
            cloud_mask[-padding_seq:] = 0.
        # cloud is all points, chain is all for which we have labels
        chain_mask = mask.unsqueeze(-1) * cloud_mask
        flat_chain_mask = rearrange(chain_mask, 'l c -> (l c)')
        flat_cloud_mask = rearrange(cloud_mask, 'l c -> (l c)')
        # slice useless norm and vector embeddings
        masked_coords = masked_coords[flat_cloud_mask]

        #############
        # MASK EDGES AND NODES ACCOUNTING FOR SCN MISSING ATOMS
        #############
        # NODES
        x = torch.cat([masked_coords, x[:, -2:][cloud_mask[cloud_mask_naive]] ], dim=-1)
        # EDGES: delete all edges with masked-out atoms

        # pick all current indexes and turn them to 1.
        to_mask_edges = torch.zeros(edge_index.amax()+1, edge_index.amax()+1).to(edge_index.device)
        to_mask_edges[edge_index[0], edge_index[1]] = 1.
        # delete erased bonds
        masked_out_atoms = (-1*(cloud_mask[cloud_mask_naive].float() - 1)).bool()
        to_mask_edges[masked_out_atoms] *= 0.
        to_mask_edges = to_mask_edges * to_mask_edges.t()
        # get mask for the edge_attrs
        attr_mask = to_mask_edges[edge_index[0], edge_index[1]].bool()
        edge_attrs = edge_attrs[attr_mask, :]
        # delete unwanted rows and cols
        wanted = to_mask_edges.sum(dim=-1).bool()
        edge_index = (to_mask_edges[wanted, :][:, wanted]).nonzero().t()
        #############
        # continue
        #############
        edge_attrs = edge_attrs[:, -1:]
        batch = torch.tensor([0 for i in range(x.shape[0])], device=device).long()

        if torch.amax(edge_index) >= x.shape[0]:
            print("wtf, breaking, debug, index out of bounds")
            break

        # predict
        preds = model.forward(x, edge_index, batch=batch, edge_attr=edge_attrs,
                              recalc_edge=None, verbose = False)

        # MEASURE ERROR - format pred and target
        target_coords = true_coords[flat_cloud_mask].clone()
        pred_coords   = preds[:, :3]
        base_coords   = x[:, :3]

        # option 2: loss is RMSD on reconstructed coords  // align - sometimes svc fails - idk why
        try:
            pred_aligned, target_aligned = kabsch_torch(pred_coords.t(), target_coords.t()) # (3, N)

            loss = ( (pred_aligned.t() - target_aligned.t())[flat_chain_mask[flat_cloud_mask]]**2 ).mean() 
        except:
            pred_aligned, target_aligned = None, None
            print("svd failed convergence, ep:", ep)
            loss = ( (pred_coords - target_coords)[flat_chain_mask[flat_cloud_mask]]**2 ).mean()
        # measure error
        loss_base = ((base_coords - target_coords)**2).mean() 
        # not aligned: # loss = ((pred_coords - target_coords)**2).mean()**0.5 

        # back pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # records / prints
        iteration += 1
        epoch_losses.append( loss.item() )
        baseline_losses.append( loss_base.item() )

        n_print = 10
        if iteration % n_print == 1:
            tic = time.time()
            print("BATCH: {0} / {1}, loss: {2}, baseline_loss: {3}, time: {4}".format(iteration, n_per_iter,
                                                                                      np.mean(epoch_losses[-n_print:]),
                                                                                      baseline_losses[-1],
                                                                                      tic-tac))
            tac = time.time()
            if iteration % n_per_iter == 1:
                print("---------------------------------")

In [None]:
        ### Old stuff:

#        # build SC container. set SC points to CA and optionally place carbonyl O
#        proto_sidechain = sidechain_container(refined, n_aa=batch,
#                                              cloud_mask=cloud_mask, place_oxygen=False)
#
#        # rotate / align
#        coords_aligned, labels_aligned = Kabsch(refined, coords[flat_cloud_mask])
#
#        # atom mask
#
#        cloud_mask = scn_cloud_mask(seq, boolean = False)
#        flat_cloud_mask = rearrange(cloud_mask, 'b l c -> b (l c)')
#
#        # chain_mask is all atoms that will be backpropped thru -> existing + trainable
#
#        chain_mask = (mask * cloud_mask)[cloud_mask]
#        flat_chain_mask = rearrange(chain_mask, 'b l c -> b (l c)')
#
#        # save pdb files for visualization
#
#        if TO_PDB:
#            # idx from batch to save prot and label
#            idx = 0
#            coords2pdb(seq[idx, :, 0], coords_aligned[idx], cloud_mask, prefix=SAVE_DIR, name="pred.pdb")
#            coords2pdb(seq[idx, :, 0], labels_aligned[idx], cloud_mask, prefix=SAVE_DIR, name="label.pdb")
#
#        # loss - RMSE + distogram_dispersion
#        loss = torch.sqrt(criterion(coords_aligned[flat_chain_mask], labels_aligned[flat_chain_mask])) + \
#                          dispersion_weight * torch.norm( (1/weights)-1 )
#
#        loss.backward()
#    print('loss:', loss.item())
#
#    optim.step()
#    optim.zero_grad()

# End