In [1]:
import time
# science
import numpy as np
import torch
from einops import repeat, rearrange

In [2]:
import joblib
import sidechainnet

In [3]:
from sidechainnet.utils.sequence import ProteinVocabulary as VOCAB
VOCAB = VOCAB()

### Load a protein in SCN format - you can skip this since a joblib file is provided

In [4]:
dataloaders = sidechainnet.load(casp_version=7, with_pytorch="dataloaders")
dataloaders.keys() # ['train', 'train_eval', 'valid-10', ..., 'valid-90', 'test']
# ProteinDataset(casp_version=12, split='train', n_proteins=81454,
#               created='Sep 20, 2020')

SidechainNet was loaded from ./sidechainnet_data/sidechainnet_casp7_30.pkl.


dict_keys(['train', 'train-eval', 'test', 'valid-10', 'valid-20', 'valid-30', 'valid-40', 'valid-50', 'valid-70', 'valid-90'])

In [5]:
min_len = 700
for batch in dataloaders['train']:
    real_seqs = [''.join([VOCAB.int2char(aa) for aa in seq]) for seq in batch.int_seqs.numpy()]
    print("seq len", len(real_seqs[0]))
    try:
        for i in range(len(batch.int_seqs.numpy())):
            # get variables
            seq     = real_seqs[i]
            int_seq = batch.int_seqs[i]
            angles  = batch.angs[i]
            # get padding
            padding_angles = (torch.abs(angles).sum(dim=-1) == 0).long().sum()
            padding_seq    = (np.array([x for x in seq]) == "_").sum()
            # only accept sequences with right dimensions and no missing coords
            # if padding_seq == padding_angles:
            # print("paddings_match")
            # print("len coords", list(batch.crds[i].shape)[0]//3, "vs int_seq", len(int_seq))
            if list(batch.crds[i].shape)[0]//14 == len(int_seq):
                if len(seq) > min_len and padding_seq == padding_angles:
                    print("stopping at sequence of length", len(seq))
                    print(len(seq), angles.shape, padding_seq == padding_angles == list(batch.crds[i].shape)[0]//3)
                    print("paddings: ", padding_seq, padding_angles)
                    raise StopIteration
                else:
                    print("found a seq of length:", len(seq), "but below the threshold:", min_len)
    except StopIteration:
        break

seq len 205
found a seq of length: 205 but below the threshold: 700
found a seq of length: 205 but below the threshold: 700
found a seq of length: 205 but below the threshold: 700
found a seq of length: 205 but below the threshold: 700
found a seq of length: 205 but below the threshold: 700
found a seq of length: 205 but below the threshold: 700
found a seq of length: 205 but below the threshold: 700
found a seq of length: 205 but below the threshold: 700
found a seq of length: 205 but below the threshold: 700
found a seq of length: 205 but below the threshold: 700
found a seq of length: 205 but below the threshold: 700
found a seq of length: 205 but below the threshold: 700
found a seq of length: 205 but below the threshold: 700
found a seq of length: 205 but below the threshold: 700
found a seq of length: 205 but below the threshold: 700
found a seq of length: 205 but below the threshold: 700
found a seq of length: 205 but below the threshold: 700
found a seq of length: 205 but below

found a seq of length: 35 but below the threshold: 700
found a seq of length: 35 but below the threshold: 700
found a seq of length: 35 but below the threshold: 700
found a seq of length: 35 but below the threshold: 700
found a seq of length: 35 but below the threshold: 700
found a seq of length: 35 but below the threshold: 700
found a seq of length: 35 but below the threshold: 700
found a seq of length: 35 but below the threshold: 700
found a seq of length: 35 but below the threshold: 700
found a seq of length: 35 but below the threshold: 700
found a seq of length: 35 but below the threshold: 700
found a seq of length: 35 but below the threshold: 700
found a seq of length: 35 but below the threshold: 700
found a seq of length: 35 but below the threshold: 700
found a seq of length: 35 but below the threshold: 700
found a seq of length: 35 but below the threshold: 700
found a seq of length: 35 but below the threshold: 700
found a seq of length: 35 but below the threshold: 700
found a se

found a seq of length: 66 but below the threshold: 700
found a seq of length: 66 but below the threshold: 700
found a seq of length: 66 but below the threshold: 700
found a seq of length: 66 but below the threshold: 700
found a seq of length: 66 but below the threshold: 700
found a seq of length: 66 but below the threshold: 700
found a seq of length: 66 but below the threshold: 700
found a seq of length: 66 but below the threshold: 700
found a seq of length: 66 but below the threshold: 700
found a seq of length: 66 but below the threshold: 700
found a seq of length: 66 but below the threshold: 700
found a seq of length: 66 but below the threshold: 700
found a seq of length: 66 but below the threshold: 700
found a seq of length: 66 but below the threshold: 700
found a seq of length: 66 but below the threshold: 700
found a seq of length: 66 but below the threshold: 700
found a seq of length: 66 but below the threshold: 700
found a seq of length: 66 but below the threshold: 700
found a se

found a seq of length: 50 but below the threshold: 700
found a seq of length: 50 but below the threshold: 700
found a seq of length: 50 but below the threshold: 700
found a seq of length: 50 but below the threshold: 700
found a seq of length: 50 but below the threshold: 700
found a seq of length: 50 but below the threshold: 700
found a seq of length: 50 but below the threshold: 700
found a seq of length: 50 but below the threshold: 700
found a seq of length: 50 but below the threshold: 700
found a seq of length: 50 but below the threshold: 700
found a seq of length: 50 but below the threshold: 700
found a seq of length: 50 but below the threshold: 700
found a seq of length: 50 but below the threshold: 700
found a seq of length: 50 but below the threshold: 700
found a seq of length: 50 but below the threshold: 700
found a seq of length: 50 but below the threshold: 700
found a seq of length: 50 but below the threshold: 700
found a seq of length: 50 but below the threshold: 700
found a se

### Load joblib file

In [6]:
# joblib.dump({"seq": seq, "int_seq": int_seq, "angles": angles,
#              "id": batch.pids[i], "true_coords": batch.crds[i]}, "experiments/727_aas_seq_and_angles.joblib")
info = joblib.load("experiments/112_aas_seq_and_angles.joblib")
seq, int_seq, angles, id_, true_coords = info["seq"], info["int_seq"], info["angles"], info["id"], info["true_coords"]

padding_angles = (torch.abs(angles).sum(dim=-1) == 0).long().sum()
padding_seq    = (np.array([x for x in seq]) == "_").sum()

### Load algo

In [7]:
from massive_pnerf import *

In [8]:
# measure time to featurize
# %timeit build_scaffolds(seq[:-padding_seq], angles[:-padding_seq])

In [9]:
# featurize
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
scaffolds = build_scaffolds_from_scn_angles(seq[:-padding_seq], angles[:-padding_seq].to(device))

In [10]:
scaffolds["bond_mask"].shape

torch.Size([107, 14])

In [11]:
%%timeit
# convert coords - fold
coords, mask = proto_fold(seq[:-padding_seq], **scaffolds, device=device)
coords_flat  = rearrange(coords, 'l c d -> (l c) d') 

10.5 ms ± 190 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


#### Profiling

In [12]:
%load_ext snakeviz
%snakeviz proto_fold(seq[:-padding_seq], **scaffolds, device=device)

 
*** Profile stats marshalled to file '/var/folders/lh/zgndpx8x755_lcsq48lp_5t40000gn/T/tmpvci7z88c'. 
Embedding SnakeViz in this document...


#### Display

In [32]:
id_, true_coords.shape, len(int_seq[:-padding_angles-1]), seq

('2F2H_d2f2hf1',
 torch.Size([1568, 3]),
 106,
 'NTLLALGNNDQRPDYVWHEGTAFHLFNLQDGHEAVCEVPAADGSVIFTLKAARTGNTITVTGAGEAKNWTLCLRNVVKVNGLQDGSQAESEQGLVVKPQGNALTITL_____')

In [33]:
coords_flat.shape

torch.Size([1498, 3])

In [34]:
sb = sidechainnet.StructureBuilder(int_seq, crd=coords_flat) 
sb.to_3Dmol()

<py3Dmol.view at 0x1a2a666550>

In [35]:
# base structure with current coords
sb = sidechainnet.StructureBuilder(int_seq[:-padding_angles], crd=true_coords[:-14*padding_angles]) # coords_flat

# scn custom nerf
# sb = sidechainnet.StructureBuilder(int_seq[:-padding_angles], angles[:-padding_seq])

# put structure coords in wrapper
sb._initialize_coordinates_and_PdbCreator()
true_coords = sb.coords
# true_coords = rearrange(scn_struct_coords, '(l c) d -> l c d', c=14)

sb.to_3Dmol()

<py3Dmol.view at 0x1a2a024bd0>

### Diagnose rotation matrix

In [36]:
# Standard import
import matplotlib.pyplot as plt
# Import 3D Axes 
from mpl_toolkits.mplot3d import axes3d

In [37]:
%matplotlib notebook

#### True backbone

In [38]:
# print init of true chain to compare
# Set up Figure and 3D Axes 
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
start_res = (torch.cat([true_coords[:3, :],
                        true_coords[14:14+3],
                        true_coords[28:28+3],
                        true_coords[42:42+3]], dim=0) - true_coords[0, :]).numpy()

ax.plot(start_res[:3, 0], start_res[:3, 1], start_res[:3, 2],  "-o", label="first res")
ax.plot(start_res[2:, 0], start_res[2:, 1], start_res[2:, 2],  "-o", label="first res")
# ax.plot(first_res[:, 0], first_res[:, 1], first_res[:, 2],  "r-o", label="second aa")
# ax.plot(destin_first[:, 0], destin_first[:, 1], destin_first[:, 2],  "g-o", label="rotated second aa")
plt.legend()
plt.xlabel("x axis")
plt.ylabel("y axis")
plt.show()

<IPython.core.display.Javascript object>

### True and predicted backbone

In [39]:
# Set up Figure and 3D Axes 
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# Plot using Axes notation and standard function calls
prev_res = coords[0, :3].numpy()
destin_first = rearrange(coords[:4, :3], 'l c d -> (l c) d').numpy()


ax.plot([1,0,0],[0,1,0],[0,0,0], 'y-o', label="algo_init_b")
ax.plot(destin_first[2:, 0], destin_first[2:, 1], destin_first[2:, 2],  "g-o", label="predicted_rotated_2nd_aa")
ax.plot(prev_res[:, 0], prev_res[:, 1], prev_res[:, 2],  "c-o", label="predicted_first_aa")

ax.plot(start_res[2:, 0], start_res[2:, 1], start_res[2:, 2],  "-o", label="true_chain_cont")
ax.plot(start_res[:3, 0], start_res[:3, 1], start_res[:3, 2],  "-o", label="true_first_res")
# ax.plot(first_res[:, 0], first_res[:, 1], first_res[:, 2],  "r-o", label="second aa")

plt.legend()
plt.xlabel("x axis")
plt.ylabel("y axis")
plt.show()

<IPython.core.display.Javascript object>

### Rotate and show superimposed to confirm

In [40]:
def kabsch_torch(X, Y):
    """ Kabsch alignment of X into Y. 
        Assumes X,Y are both (Dims x N_points). See below for wrapper.
    """
    #  center X and Y to the origin
    X_ = X - X.mean(dim=-1, keepdim=True)
    Y_ = Y - Y.mean(dim=-1, keepdim=True)
    # calculate convariance matrix (for each prot in the batch)
    C = torch.matmul(X_, Y_.t())
    # Optimal rotation matrix via SVD - warning! W must be transposed
    V, S, W = torch.svd(C.detach())
    # determinant sign for direction correction
    d = (torch.det(V) * torch.det(W)) < 0.0
    if d:
        S[-1]    = S[-1] * (-1)
        V[:, -1] = V[:, -1] * (-1)
    # Create Rotation matrix U
    U = torch.matmul(V, W.t())
    # calculate rotations
    X_ = torch.matmul(X_.t(), U).t()
    # return centered and aligned
    return X_, Y_


In [41]:
flat_mask = rearrange(scaffolds["cloud_mask"], 'l d -> (l d)')
coords_aligned, labels_aligned = kabsch_torch(coords_flat[flat_mask].t(),
                                              true_coords[flat_mask].t())
                                              # true_coords[:-14*padding_angles][flat_mask].t())
# create coord scaffolds
scaff_coords_aligned = torch.zeros(coords_flat.shape).float()
scaff_labels_aligned = torch.zeros(coords_flat.shape).float()
# fill
scaff_coords_aligned[flat_mask] = coords_aligned.t()
scaff_labels_aligned[flat_mask] = labels_aligned.t()
# replace vars
coords_aligned, labels_aligned = scaff_coords_aligned, scaff_labels_aligned

In [42]:
# Set up Figure and 3D Axes 
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# Plot using Axes notation and standard function calls
start_res    = torch.cat([labels_aligned[:3, :],
                        labels_aligned[14:14+3],
                        labels_aligned[28:28+3],
                        labels_aligned[42:42+3]], dim=0).numpy()
destin_first = torch.cat([coords_aligned[:3, :],
                        coords_aligned[14:14+3],
                        coords_aligned[28:28+3],
                        coords_aligned[42:42+3]], dim=0).numpy()

destin_first -= (destin_first[:1] - start_res[:1])

ax.plot(start_res[2:, 0], start_res[2:, 1], start_res[2:, 2],  "b-o", label="true_chain_cont")
ax.plot(start_res[:3, 0], start_res[:3, 1], start_res[:3, 2],  "k-o", label="true_first_res")
# ax.plot(first_res[:, 0], first_res[:, 1], first_res[:, 2],  "r-o", label="second aa")

ax.plot(destin_first[2:, 0], destin_first[2:, 1], destin_first[2:, 2],  "g-o", label="predicted_rotated_2nd_aa")
ax.plot(destin_first[:3, 0], destin_first[:3, 1], destin_first[:3, 2],  "c-o", label="predicted_first_aa")

plt.legend()
plt.xlabel("x axis")
plt.ylabel("y axis")
plt.show()

<IPython.core.display.Javascript object>

In [43]:
boxel_pred_coords = rearrange(coords_aligned, '(l c) d -> l c d', c=14).numpy()
boxel_true_coords = rearrange(labels_aligned, '(l c) d -> l c d', c=14).numpy()

In [44]:
for i in range(15):
    print(i, "( n-ca, ca-c, c-n, c-o)", np.linalg.norm(boxel_pred_coords[i, :4] - \
                                                  np.concatenate([boxel_pred_coords[i, 1:3], # CA , C
                                                                  boxel_pred_coords[i+1, :1], # N+1
                                                                  boxel_pred_coords[i, 2:3]], # C
                                                                  axis=0), axis=-1),
                                   np.linalg.norm(boxel_true_coords[i, :4] - \
                                                  np.concatenate([boxel_true_coords[i, 1:3],
                                                                  boxel_true_coords[i+1, :1],
                                                                  boxel_pred_coords[i, 2:3]], axis=0), axis=-1))

0 ( n-ca, ca-c, c-n, c-o) [1.4664935 1.5241193 1.3289379 1.2289996] [1.4729016 1.5119934 1.330413  1.1655873]
1 ( n-ca, ca-c, c-n, c-o) [1.4664928 1.5241185 1.3289372 1.2290006] [1.4648343 1.5186538 1.3174567 1.2755295]
2 ( n-ca, ca-c, c-n, c-o) [1.4664928 1.5241193 1.3289375 1.2290004] [1.4272082 1.5563269 1.3316144 1.2152689]
3 ( n-ca, ca-c, c-n, c-o) [1.4664932 1.524119  1.3289373 1.2290002] [1.4446503 1.5176473 1.3326255 1.1848733]
4 ( n-ca, ca-c, c-n, c-o) [1.4664936 1.5241188 1.3289372 1.2289994] [1.456298  1.522775  1.3344283 1.2961861]
5 ( n-ca, ca-c, c-n, c-o) [1.466493  1.5241199 1.3289365 1.2290007] [1.4675089 1.5232855 1.3250848 1.1746181]
6 ( n-ca, ca-c, c-n, c-o) [1.4664929 1.5241193 1.3289369 1.2290004] [1.4661055 1.5011121 1.335454  1.2651452]
7 ( n-ca, ca-c, c-n, c-o) [1.4664925 1.524121  1.3289374 1.229001 ] [1.481046  1.5653684 1.3391701 1.2023404]
8 ( n-ca, ca-c, c-n, c-o) [1.4664922 1.5241201 1.3289378 1.2289997] [1.4622474 1.5181217 1.3367028 1.2872869]
9 ( n-ca, 

In [45]:
backbones = []
for i in range(len(boxel_true_coords)-1):
    backbones.append(np.linalg.norm(boxel_true_coords[i, :4] - \
                                    np.concatenate([boxel_true_coords[i, 1:3],
                                                    boxel_true_coords[i+1, :1],
                                                    boxel_pred_coords[i, 2:3]], axis=0), axis=-1))
    
# for (=O) it's a normal distro so picking the mean although high variabiity
print("mean:", np.vstack(backbones).mean(axis=0))
print("median:", np.median(np.vstack(backbones), axis=0))
print("std:", np.vstack(backbones).std(axis=0))

mean: [1.4562691 1.5263689 1.3298723 1.230395 ]
median: [1.4569814 1.523855  1.3296897 1.2336347]
std: [0.01438655 0.01655221 0.00736604 0.04156507]


### Check error

In [46]:
def rmsd_torch(X, Y):
    """ Assumes x,y are both (B x D x N). See below for wrapper. """
    return torch.sqrt( torch.mean((X - Y)**2, axis=(-1, -2)) )

print("RMSD is:", rmsd_torch(coords_aligned, labels_aligned))

RMSD is: tensor(0.1095)


### Save oriented to manually diagnose

In [47]:
# save predicted
sb = sidechainnet.StructureBuilder(int_seq, crd=coords_aligned) 
sb.to_pdb("preds/predicted_112.pdb")
sb = sidechainnet.StructureBuilder(int_seq, crd=labels_aligned) 
sb.to_pdb("preds/labels_112_scn_nerf.pdb")
# go here: https://molstar.org/viewer/
# load chains and use superimposition tool

In [48]:
# sb = sidechainnet.StructureBuilder(int_seq, crd=coords_flat) 
# sb.to_pdb("preds/predicted_112.pdb")