In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import torch
import os
import sys
sys.path.append("../")
from dynaprot.evaluation.visualizer import plot_3d_gaussian_ellipsoids,plot_3d_gaussian_comparison
from openfold.utils.rigid_utils import  Rigid
from tqdm import tqdm
import plotly.express as px

data_dir = "/data/cb/scratch/datasets/atlas_dynamics_labels"
config_dir = "../configs"


In [2]:
device = 7

In [3]:
import yaml
from pathlib import Path

with open(config_dir+"/data/atlas_config.yaml", "r") as file:
    dataconfig = yaml.safe_load(file)
    
with open(config_dir+"/model/dynaprot_simple.yaml", "r") as file:
    modelconfig = yaml.safe_load(file)
    
modelconfig["data_config"] = dataconfig
    
print(modelconfig)

from dynaprot.data.datasets import DynaProtDataset, OpenFoldBatchCollator

dataset = DynaProtDataset(dataconfig, split="test")
print(len(dataset))
dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        collate_fn=OpenFoldBatchCollator(),
        num_workers=12,
        shuffle=False,
    )


{'model_params': {'num_ipa_blocks': 8, 'd_model': 128}, 'train_params': {'precision': 32, 'batch_size': 20, 'epochs': 10000, 'learning_rate': 0.0001, 'grad_clip_norm': 1.0, 'accelerator': 'gpu', 'strategy': 'ddp', 'num_devices': [3, 4, 5, 6, 7], 'num_nodes': 1, 'project': 'openprot/dynamics', 'neptune_api_key': 'INSERT YOUR API TOKEN HERE', 'tags': ['dynaprot', 'debugging', 'dropout'], 'log_model_checkpoints': True}, 'eval_params': {'loss_weights': {'resi_gaussians': {'mse_means': 0.0, 'mse_covs': 0.0, 'kldiv': 0.0, 'eigen_penalty': None, 'cond_penalty': None, 'frob_norm': 0.0, 'log_frob_norm': 0.0, 'affine_invariant_dist': 0.0, 'bures_dist': 1.0}, 'resi_rmsf': None, 'resi_rmsd': None, 'resi_rg': None}}, 'checkpoint_path': '', 'logs': '/path/to/logs', 'results': '/path/to/results', 'data_config': {'repo_dir': '/data/cb/mihirb14/projects/DynaProt', 'data_dir': '/data/cb/scratch/datasets/atlas_dynamics_labels_calpha', 'protein_chains_path': '/data/cb/mihirb14/projects/dynaprot/dynaprot/d

In [15]:
from dynaprot.model.architecture import DynaProt

# model = DynaProt.load_from_checkpoint("../.neptune/DYNAMICS-126/DYNAMICS-126/checkpoints/step13112.ckpt", cfg=modelconfig).to(device)
# model = DynaProt.load_from_checkpoint("../.neptune/DYNAMICS-126/DYNAMICS-126/checkpoints/step=174798.ckpt", cfg=modelconfig).to(device).eval()
# model = DynaProt.load_from_checkpoint("../.neptune/DYNAMICS-134/DYNAMICS-134/checkpoints/step=30000.ckpt", cfg=modelconfig).to(device).eval()
# model = DynaProt.load_from_checkpoint("../.neptune/DYNAMICS-135/DYNAMICS-135/checkpoints/step=4017.ckpt", cfg=modelconfig).to(device).eval()
# model = DynaProt.load_from_checkpoint("../.neptune/DYNAMICS-135/DYNAMICS-135/checkpoints/step=31000.ckpt", cfg=modelconfig).to(device).eval()
# model = DynaProt.load_from_checkpoint("../.neptune/DYNAMICS-154/DYNAMICS-154/checkpoints/step=26143.ckpt", cfg=modelconfig).to(device).eval()     # Best to date (rmwd 1.04) with centroids, bures distance loss, atlas no replicates?
model = DynaProt.load_from_checkpoint("../.neptune/DYNAMICS-167/DYNAMICS-167/checkpoints/epoch=2400-step=31213.ckpt", cfg=modelconfig).to(device).eval()                 # Best to date with centroids, bures distance loss, atlas with replicates
# model = DynaProt.load_from_checkpoint("../.neptune/DYNAMICS-150/DYNAMICS-150/checkpoints/step=47060.ckpt",cfg=modelconfig).to(device).eval()        # Centroids, MSE Loss, atlas no replicates

model

DynaProt(
  (sequence_embedding): Embedding(21, 128)
  (ipa_blocks): ModuleList(
    (0-7): 8 x InvariantPointAttention(
      (linear_q): Linear(in_features=128, out_features=64, bias=True)
      (linear_kv): Linear(in_features=128, out_features=128, bias=True)
      (linear_q_points): Linear(in_features=128, out_features=48, bias=True)
      (linear_kv_points): Linear(in_features=128, out_features=144, bias=True)
      (linear_b): Linear(in_features=128, out_features=4, bias=True)
      (linear_out): Linear(in_features=704, out_features=128, bias=True)
      (softmax): Softmax(dim=-1)
      (softplus): Softplus(beta=1, threshold=20)
    )
  )
  (dropout): Dropout(p=0.2, inplace=False)
  (covars_predictor): Linear(in_features=128, out_features=6, bias=True)
  (loss): DynaProtLoss()
)

## Test RMWD variance contribution (bures distance) 

In [16]:
from dynaprot.evaluation import metrics

rmwds = []
for prot in tqdm(dataloader):
    pred = model(prot["aatype"].argmax(dim=-1).to(device), Rigid.from_tensor_4x4(prot["frames"].to(device)), prot["resi_pad_mask"].to(device))
    mask = prot["resi_pad_mask"].bool().to(device)
    true_covars = prot["dynamics_covars"].to(device).float()[mask]
    predicted_covars =  pred["covars"][mask]

    rmwds.append(torch.sqrt(metrics.bures_distance(predicted_covars,true_covars)).item())


  0%|          | 0/82 [00:00<?, ?it/s]

100%|██████████| 82/82 [00:06<00:00, 13.01it/s]


In [None]:
px.box(rmwds)

: 

In [13]:
import mdtraj

# def compute_gaussians_per_residue(traj):
#     num_residues = traj.topology.n_residues
#     means = np.zeros((num_residues, 3))       # Shape (n_residues, 3) for (x, y, z)
#     covariances = np.zeros((num_residues, 3,3))   # Shape (n_residues, 3) for (x, y, z)

#     for i, residue in enumerate(traj.topology.residues):

#         atom_indices = [atom.index for atom in residue.atoms]
        
#         # Extract xyz coordinates for all atoms in the residue across all frames
#         # scale nanometers to angstroms (x10)
#         xyz = np.mean(traj.xyz[:, atom_indices, :],axis=1) # shape (n_frames, 3)  frames by residue i's position (mean pos of atoms) 

#         # Compute mean and variance across all frames for each atom
#         means[i] = np.mean(xyz, axis=0)  # shape (1, 3)

#         centered_data = xyz - means[i]
        
#         covariances[i] = centered_data.T @ centered_data /(centered_data.shape[0] - 1)  # shape (3, 3) 
        
#     return torch.from_numpy(covariances)

def compute_gaussians_per_residue(traj, calpha = True):
    num_residues = traj.topology.n_residues
    means = np.zeros((num_residues, 3),dtype=np.float64)       # Shape (n_residues, 3) for (x, y, z)
    covariances = np.zeros((num_residues, 3,3),dtype=np.float64)   # Shape (n_residues, 3) for (x, y, z)
    residuecoords = []

    for i, residue in enumerate(traj.topology.residues):
        use_calpha = calpha
        if use_calpha:
            ca_atom = [atom.index for atom in residue.atoms if atom.name == 'CA']
            if ca_atom:
                xyz = traj.xyz[:, ca_atom[0], :].astype(np.float64) * 10           # shape (T, 3)
            else:
                use_calpha = False  # calpha wasnt found
        
        if not use_calpha:
            atom_indices = [atom.index for atom in residue.atoms]   # Extract xyz coordinates for all atoms in the residue across all frames
            # scale nanometers to angstroms (x10)
            xyz = np.mean(traj.xyz[:, atom_indices, :].astype(np.float64),axis=1) * 10 # shape (T, 3)  frames by residue i's position (mean pos of atoms) 

        mean_xyz = np.mean(xyz, axis=0).astype(np.float64)  # shape (1, 3)
        centered_xyz = xyz-mean_xyz # shape (T,3)
        residuecoords.append(centered_xyz)

        # Compute mean and variance across all frames for each residue
        means[i] = mean_xyz
        covariances[i] = (centered_xyz.T @ centered_xyz /(centered_xyz.shape[0] - 1)).astype(np.float64)  # shape (3, 3) 
    
    return  torch.from_numpy(covariances)

        
path = "/data/cb/scratch/datasets/alphaflow_ensembles/alphaflow_md_templates_base_202402"
ref_path = "/data/cb/scratch/datasets/atlas/"
af_rmwds = []

for prot in tqdm(np.load("../dynaprot/data/preprocessing/protein_lists/atlas_chains_test.npy")):
    traj = mdtraj.load_pdb(os.path.join(path,f"{prot}.pdb"))
    # ref = mdtraj.load_pdb(os.path.join(ref_path,prot,f"{prot}.pdb"))     
    # traj.superpose(ref) 
    traj.superpose(traj,frame=0)
    af_covs = compute_gaussians_per_residue(traj)
    gt = torch.load(f"/data/cb/scratch/datasets/atlas_dynamics_labels_replicates/{prot}/{prot}.pt")
    gt_covs = covars = gt["dynamics_covars"]
    af_rmwds.append(torch.sqrt(metrics.bures_distance(af_covs,gt_covs)).item())



100%|██████████| 82/82 [13:22<00:00,  9.79s/it]


In [14]:
px.box(af_rmwds)