In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import torch
import os
import sys
sys.path.append("../")
from dynaprot.evaluation.visualizer import plot_3d_gaussian_ellipsoids,plot_3d_gaussian_comparison
from openfold.utils.rigid_utils import  Rigid
from tqdm import tqdm
import plotly.express as px

data_dir = "/data/cb/scratch/datasets/atlas_dynamics_labels"
config_dir = "../configs"


In [2]:
device = 7

In [3]:
import yaml
from pathlib import Path

with open(config_dir+"/data/atlas_config.yaml", "r") as file:
    dataconfig = yaml.safe_load(file)
    
with open(config_dir+"/model/dynaprot_simple.yaml", "r") as file:
    modelconfig = yaml.safe_load(file)
    
modelconfig["data_config"] = dataconfig
    
print(modelconfig)

from dynaprot.data.datasets import DynaProtDataset, OpenFoldBatchCollator

dataset = DynaProtDataset(dataconfig, split="test")
print(len(dataset))
dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=1,
        collate_fn=OpenFoldBatchCollator(),
        num_workers=12,
        shuffle=False,
    )


{'model_params': {'num_ipa_blocks': 8, 'd_model': 128}, 'train_params': {'precision': 32, 'batch_size': 20, 'epochs': 10000, 'learning_rate': 0.0001, 'grad_clip_norm': 1.0, 'accelerator': 'gpu', 'strategy': 'ddp', 'num_devices': [0, 1, 2, 3, 4], 'num_nodes': 1, 'project': 'openprot/dynamics', 'neptune_api_key': 'INSERT YOUR API TOKEN HERE', 'tags': ['dynaprot', 'debugging', 'dropout'], 'log_model_checkpoints': True}, 'eval_params': {'loss_weights': {'resi_gaussians': {'mse_means': 0.0, 'mse_covs': 0.0, 'kldiv': 0.0, 'eigen_penalty': 0.0, 'cond_penalty': 0.0, 'frob_norm': 0.0, 'log_frob_norm': 0.0, 'affine_invariant_dist': 0.0, 'bures_dist': 1.0}, 'resi_rmsf': None, 'resi_rmsd': None, 'resi_rg': None}}, 'checkpoint_path': '', 'logs': '/path/to/logs', 'results': '/path/to/results', 'data_config': {'repo_dir': '/data/cb/mihirb14/projects/DynaProt', 'data_dir': '/data/cb/scratch/datasets/atlas_dynamics_labels', 'protein_chains_path': '/data/cb/mihirb14/projects/dynaprot/dynaprot/data/prepr

In [11]:
from dynaprot.model.architecture import DynaProt

# model = DynaProt.load_from_checkpoint("../.neptune/DYNAMICS-126/DYNAMICS-126/checkpoints/step13112.ckpt", cfg=modelconfig).to(device)
# model = DynaProt.load_from_checkpoint("../.neptune/DYNAMICS-126/DYNAMICS-126/checkpoints/step=174798.ckpt", cfg=modelconfig).to(device).eval()
# model = DynaProt.load_from_checkpoint("../.neptune/DYNAMICS-134/DYNAMICS-134/checkpoints/step=30000.ckpt", cfg=modelconfig).to(device).eval()
model = DynaProt.load_from_checkpoint("../.neptune/DYNAMICS-135/DYNAMICS-135/checkpoints/step=4017.ckpt", cfg=modelconfig).to(device).eval()

model

DynaProt(
  (sequence_embedding): Embedding(21, 128)
  (ipa_blocks): ModuleList(
    (0-7): 8 x InvariantPointAttention(
      (linear_q): Linear(in_features=128, out_features=64, bias=True)
      (linear_kv): Linear(in_features=128, out_features=128, bias=True)
      (linear_q_points): Linear(in_features=128, out_features=48, bias=True)
      (linear_kv_points): Linear(in_features=128, out_features=144, bias=True)
      (linear_b): Linear(in_features=128, out_features=4, bias=True)
      (linear_out): Linear(in_features=704, out_features=128, bias=True)
      (softmax): Softmax(dim=-1)
      (softplus): Softplus(beta=1, threshold=20)
    )
  )
  (dropout): Dropout(p=0.2, inplace=False)
  (covars_predictor): Linear(in_features=128, out_features=6, bias=True)
  (loss): DynaProtLoss()
)

## Test RMWD variance contribution (bures distance) 

In [12]:
from dynaprot.evaluation import metrics

rmwds = []
for prot in tqdm(dataloader):
    pred = model(prot["aatype"].argmax(dim=-1).to(device), Rigid.from_tensor_4x4(prot["frames"].to(device)), prot["resi_pad_mask"].to(device))
    mask = prot["resi_pad_mask"].bool().to(device)
    true_covars = prot["dynamics_covars"].to(device).float()[mask]
    predicted_covars =  pred["covars"][mask]

    rmwds.append(torch.sqrt(metrics.bures_distance(predicted_covars,true_covars)).item())


100%|██████████| 82/82 [00:06<00:00, 12.12it/s]


In [13]:
px.box(rmwds)