In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd
import torch
import os
import sys
sys.path.append("../")

In [2]:
atlas_data_dir = "/data/cb/scratch/datasets/atlas"
atlas_dynamicslabels_dir = "/data/cb/scratch/datasets/atlas_dynamics_labels"
config_dir = "../trained/configs"

In [92]:
import yaml
from pathlib import Path

# Load data config
with open(config_dir+"/data/atlas_config.yaml", "r") as file:
    dataconfig = yaml.safe_load(file)
    
# Load train config
with open(config_dir+"/model/dynaprot_simple.yaml", "r") as file:
    modelconfig = yaml.safe_load(file)
    
modelconfig["data_config"] = dataconfig
    
from DynaProt.data.datasets import DynaProtDataset, OpenFoldBatchCollator

dataset = DynaProtDataset(dataconfig)
dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=10,
        collate_fn=OpenFoldBatchCollator(),
        num_workers=12,
        shuffle=False,
    )

batch_prots = next(iter(dataloader))

for k in batch_prots.keys():
    # print(k,batch_prots[k])
    print(f"{k}: {batch_prots[k].shape}")

aatype: torch.Size([10, 512, 21])
residue_index: torch.Size([10, 512])
all_atom_positions: torch.Size([10, 512, 37, 3])
all_atom_mask: torch.Size([10, 512, 37])
resi_pad_mask: torch.Size([10, 512])
dynamics_means: torch.Size([10, 512, 3])
dynamics_covars: torch.Size([10, 512, 3, 3])
frames: torch.Size([10, 512, 4, 4])


In [93]:
from DynaProt.model.DynaProt import DynaProt

model = DynaProt(modelconfig)
model

DynaProt(
  (sequence_embedding): Embedding(21, 128)
  (ipa_blocks): ModuleList(
    (0-7): 8 x InvariantPointAttention(
      (linear_q): Linear(in_features=128, out_features=64, bias=True)
      (linear_kv): Linear(in_features=128, out_features=128, bias=True)
      (linear_q_points): Linear(in_features=128, out_features=48, bias=True)
      (linear_kv_points): Linear(in_features=128, out_features=144, bias=True)
      (linear_b): Linear(in_features=128, out_features=4, bias=True)
      (linear_out): Linear(in_features=704, out_features=128, bias=True)
      (softmax): Softmax(dim=-1)
      (softplus): Softplus(beta=1.0, threshold=20.0)
    )
  )
  (mean_predictor): Linear(in_features=128, out_features=3, bias=True)
  (covars_predictor): Linear(in_features=128, out_features=6, bias=True)
)

In [94]:
model.training_step(batch_prots,0)

RuntimeError: The size of tensor a (512) must match the size of tensor b (3) at non-singleton dimension 2

In [89]:
import torch.nn.functional as F
b = 10
n = 2
preds = torch.repeat_interleave(torch.arange(6)+1,n).reshape(-1,n).T.unsqueeze(0).repeat(b,1,1)
# print(preds,preds.shape)


L = torch.zeros(b,n,3,3)

i = 0
for c in range(3):
    for r in range(c,3):
        L[:,:,r,c] = preds[:,:,i] 
        if r == c:
            L[:,:,r,c] = F.softplus(L[:,:,r,c])
        i+=1

covars = L @ L.transpose(2,3)
print(covars)
# # print(L)
# print(covars)

tensor([[[[ 1.7247,  2.6265,  3.9398],
          [ 2.6265, 20.1455, 26.0907],
          [ 3.9398, 26.0907, 70.0297]],

         [[ 1.7247,  2.6265,  3.9398],
          [ 2.6265, 20.1455, 26.0907],
          [ 3.9398, 26.0907, 70.0297]]],


        [[[ 1.7247,  2.6265,  3.9398],
          [ 2.6265, 20.1455, 26.0907],
          [ 3.9398, 26.0907, 70.0297]],

         [[ 1.7247,  2.6265,  3.9398],
          [ 2.6265, 20.1455, 26.0907],
          [ 3.9398, 26.0907, 70.0297]]],


        [[[ 1.7247,  2.6265,  3.9398],
          [ 2.6265, 20.1455, 26.0907],
          [ 3.9398, 26.0907, 70.0297]],

         [[ 1.7247,  2.6265,  3.9398],
          [ 2.6265, 20.1455, 26.0907],
          [ 3.9398, 26.0907, 70.0297]]],


        [[[ 1.7247,  2.6265,  3.9398],
          [ 2.6265, 20.1455, 26.0907],
          [ 3.9398, 26.0907, 70.0297]],

         [[ 1.7247,  2.6265,  3.9398],
          [ 2.6265, 20.1455, 26.0907],
          [ 3.9398, 26.0907, 70.0297]]],


        [[[ 1.7247,  2.6265,  3.9398],
 