In [1]:
from omegaconf import OmegaConf
import os
import torch

import hydra
from omegaconf import DictConfig, OmegaConf
from models.together_model import ProteinVAELLMmodel, ProteinVAELLM_FrameDiff_first_model
from data import all_atom
# Pytorch lightning imports
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from data.pdb_dataloader import PdbDataModule
from models.flow_module import FlowModule
from experiments import utils as eu
import openfold.utils.rigid_utils as ru
import torch.nn.functional as F
cfg = OmegaConf.load("configs/base.yaml")
_cfg = cfg
_data_cfg = cfg.data
_exp_cfg = cfg.experiment
_datamodule: LightningDataModule = PdbDataModule(_data_cfg)
_datamodule.setup(stage="fit")
torch.autograd.set_detect_anomaly(True)

  from .autonotebook import tqdm as notebook_tqdm


<torch.autograd.anomaly_mode.set_detect_anomaly at 0x148b9e11ebc0>

# Set stuff up

In [2]:
cfg = OmegaConf.load("configs/base.yaml")
_cfg = cfg
_data_cfg = cfg.data
_exp_cfg = cfg.experiment
training_cfg = _exp_cfg.training
_datamodule: LightningDataModule = PdbDataModule(_data_cfg)
_datamodule.setup(stage="fit")

In [3]:
train_loader = _datamodule.train_dataloader()



In [4]:
from data.my_interpolant import Interpolant 
interpolant = Interpolant(cfg.interpolant)
interpolant.set_device("cuda")

In [5]:
model = ProteinVAELLM_FrameDiff_first_model(cfg).to("cuda")
model.attach_backward_hooks() # attach backward hooks

In [6]:
optimizer = torch.optim.AdamW(
            params=model.parameters(),
            lr=0.0001)

In [7]:
model

ProteinVAELLM_FrameDiff_first_model(
  (framediff_model): FlowModel(
    (node_embedder): NodeEmbedder(
      (linear): Linear(in_features=256, out_features=128, bias=True)
    )
    (edge_embedder): EdgeEmbedder(
      (linear_s_p): Linear(in_features=128, out_features=64, bias=True)
      (linear_relpos): Linear(in_features=64, out_features=64, bias=True)
      (edge_embedder): Sequential(
        (0): Linear(in_features=236, out_features=64, bias=True)
        (1): ReLU()
        (2): Linear(in_features=64, out_features=64, bias=True)
        (3): ReLU()
        (4): Linear(in_features=64, out_features=64, bias=True)
        (5): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      )
    )
    (trunk): ModuleDict(
      (ipa_0): InvariantPointAttention(
        (linear_q): Linear(in_features=128, out_features=1024, bias=True)
        (linear_kv): Linear(in_features=128, out_features=2048, bias=True)
        (linear_q_points): Linear(in_features=128, out_features=192, bias=True

# Train stuff

In [8]:
batch = next(iter(train_loader))
# print(batch["aatype"])
for key, value in batch.items():
    batch[key] = value.to("cuda")
noisy_batch = interpolant.corrupt_batch(batch, pad=False)
B, l, _ = noisy_batch['res_mask'].shape
num_res = noisy_batch["aatype"].shape[1]
print(f"Batch size: {B}, number of time points: {l}, number of residues: {num_res}")

Batch size: 2, number of time points: 16, number of residues: 86


In [10]:
loss_mask = noisy_batch['res_mask'].reshape(B*l, num_res)
        
if training_cfg.min_plddt_mask is not None:
    plddt_mask = noisy_batch['res_plddt'] > training_cfg.min_plddt_mask
    loss_mask *= plddt_mask
        

# Ground truth labels
gt_trans = noisy_batch['trans_t']
gt_rotmats = noisy_batch['rotmats_t']

# Model output predictions.

framediff_out = model(noisy_batch)
pred_trans = framediff_out["pred_T"]['pred_trans']
pred_rotmats = framediff_out["pred_T"]['pred_rotmats']
        
# Shift for CausalLM
shifted_pred_trans  = pred_trans[:, :-1, :, :]
shifted_pred_rotmats = pred_rotmats[:, :-1, :, :, :]
shifted_gt_trans = gt_trans[:, 1:, :, :]
shifted_gt_rotmats = gt_rotmats[:, 1:, :, :, :]

# Reshape back to [B*(l-1), 128, *]
flat_shifted_pred_trans = shifted_pred_trans.reshape(B*(l-1), num_res, 3)
flat_shifted_pred_rotmats = shifted_pred_rotmats.reshape(B*(l-1), num_res, 3, 3)

flat_shifted_gt_trans = shifted_gt_trans.reshape(B*(l-1), num_res, 3)
flat_shifted_gt_rotmats = shifted_gt_rotmats.reshape(B*(l-1), num_res, 3, 3)

# Timestep used for normalization.
t = noisy_batch['t'].reshape(B, l, 1)[:, 1:, :].reshape(-1,1) # We throw away the first time points
norm_scale = 1 - torch.min(
    t[..., None], torch.tensor(training_cfg.t_normalize_clip))
        
        

# Backbone atom loss
gt_bb_atoms = all_atom.to_atom37(flat_shifted_gt_trans, flat_shifted_gt_rotmats)[:, :, :3] 
pred_bb_atoms = all_atom.to_atom37(flat_shifted_pred_trans, flat_shifted_pred_rotmats)[:, :, :3]

gt_bb_atoms *= training_cfg.bb_atom_scale / norm_scale[..., None]
pred_bb_atoms *= training_cfg.bb_atom_scale / norm_scale[..., None]
                
        
loss_denom = torch.sum(loss_mask, dim=-1, dtype=torch.float).mean() * 3 # Added a mean here, this doesn'y matter since our mask is all 1's
bb_atom_loss = torch.sum(
    (gt_bb_atoms - pred_bb_atoms) ** 2,
    dim=(-1, -2, -3)
) / loss_denom

# Pairwise distance loss
num_batch = gt_bb_atoms.shape[0]
gt_flat_atoms = gt_bb_atoms.reshape([num_batch, num_res*3, 3])
gt_pair_dists = torch.linalg.norm(
    gt_flat_atoms[:, :, None, :] - gt_flat_atoms[:, None, :, :], dim=-1)
pred_flat_atoms = pred_bb_atoms.reshape([num_batch, num_res*3, 3])
pred_pair_dists = torch.linalg.norm(
    pred_flat_atoms[:, :, None, :] - pred_flat_atoms[:, None, :, :], dim=-1)

flat_loss_mask = torch.tile(loss_mask[:, :, None], (1, 1, 3))[B:,:,:] # change the shape because we shifted tokens, all entries of loss masks are 1 so don't matter, we throw away B tokens
flat_loss_mask = flat_loss_mask.reshape([num_batch, num_res*3])
flat_res_mask = torch.tile(loss_mask[:, :, None], (1, 1, 3))[B:,:,:] # change the shape because we shifted tokens, all entries of loss masks are 1 so don't matter
flat_res_mask = flat_res_mask.reshape([num_batch, num_res*3])

gt_pair_dists = gt_pair_dists * flat_loss_mask[..., None]
pred_pair_dists = pred_pair_dists * flat_loss_mask[..., None]
pair_dist_mask = flat_loss_mask[..., None] * flat_res_mask[:, None, :]

dist_mat_loss = torch.sum(
    (gt_pair_dists - pred_pair_dists)**2 * pair_dist_mask,
    dim=(1, 2))
dist_mat_loss /= (torch.sum(pair_dist_mask, dim=(1, 2)) - num_res)

auxiliary_loss = (bb_atom_loss + dist_mat_loss) * (
    t[:, 0]> training_cfg.aux_loss_t_pass
)
auxiliary_loss *= _exp_cfg.training.aux_loss_weight
        
auxiliary_loss = auxiliary_loss.mean()
kl_div = (1 + 2 * framediff_out["vae_log_sigma"] - framediff_out["vae_mu"].pow(2) - framediff_out["vae_log_sigma"].exp().pow(2))[:, :-1, :] # Throw away the last mu and sigma because we're not using it to predict
kl_div = - 0.5 * kl_div.sum(dim=-1).mean()
mse_loss = F.mse_loss(flat_shifted_pred_trans, flat_shifted_gt_trans) + F.mse_loss(flat_shifted_pred_rotmats, flat_shifted_gt_rotmats)

Eigvals: tensor([[[-0.3333, -0.3333, -0.3333,  1.0000],
         [-0.3333, -0.3333, -0.3333,  1.0000],
         [-0.3333, -0.3333, -0.3333,  1.0000],
         ...,
         [-0.3333, -0.3333, -0.3333,  1.0000],
         [-0.3333, -0.3333, -0.3333,  1.0000],
         [-0.3333, -0.3333, -0.3333,  1.0000]],

        [[-0.3333, -0.3333, -0.3333,  1.0000],
         [-0.3333, -0.3333, -0.3333,  1.0000],
         [-0.3333, -0.3333, -0.3333,  1.0000],
         ...,
         [-0.3333, -0.3333, -0.3333,  1.0000],
         [-0.3333, -0.3333, -0.3333,  1.0000],
         [-0.3333, -0.3333, -0.3333,  1.0000]],

        [[-0.3333, -0.3333, -0.3333,  1.0000],
         [-0.3333, -0.3333, -0.3333,  1.0000],
         [-0.3333, -0.3333, -0.3333,  1.0000],
         ...,
         [-0.3333, -0.3333, -0.3333,  1.0000],
         [-0.3333, -0.3333, -0.3333,  1.0000],
         [-0.3333, -0.3333, -0.3333,  1.0000]],

        ...,

        [[-0.3333, -0.3333, -0.3333,  1.0000],
         [-0.3333, -0.3333, -0.3333,



Eigvals: tensor([[[[-0.3333, -0.3333, -0.3333,  1.0000],
          [-0.3333, -0.3333, -0.3333,  1.0000],
          [-0.3333, -0.3333, -0.3333,  1.0000],
          ...,
          [-0.3333, -0.3333, -0.3333,  1.0000],
          [-0.3333, -0.3333, -0.3333,  1.0000],
          [-0.3333, -0.3333, -0.3333,  1.0000]],

         [[-0.3333, -0.3333, -0.3333,  1.0000],
          [-0.3333, -0.3333, -0.3333,  1.0000],
          [-0.3333, -0.3333, -0.3333,  1.0000],
          ...,
          [-0.3333, -0.3333, -0.3333,  1.0000],
          [-0.3333, -0.3333, -0.3333,  1.0000],
          [-0.3333, -0.3333, -0.3333,  1.0000]],

         [[-0.3333, -0.3333, -0.3333,  1.0000],
          [-0.3333, -0.3333, -0.3333,  1.0000],
          [-0.3333, -0.3333, -0.3333,  1.0000],
          ...,
          [-0.3333, -0.3333, -0.3333,  1.0000],
          [-0.3333, -0.3333, -0.3333,  1.0000],
          [-0.3333, -0.3333, -0.3333,  1.0000]],

         ...,

         [[-0.3333, -0.3333, -0.3333,  1.0000],
          [-0



In [10]:
loss = mse_loss + auxiliary_loss + kl_div
print(loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()

tensor(3333295.2500, device='cuda:0', grad_fn=<AddBackward0>)
No NaNs detected in any gradient inputs of Linear
No NaNs detected in any gradient outputs of Linear
No NaNs detected in any gradient inputs of Linear
No NaNs detected in any gradient outputs of Linear
No NaNs detected in any gradient inputs of Linear
No NaNs detected in any gradient outputs of Linear
No NaNs detected in any gradient inputs of LayerNorm
No NaNs detected in any gradient outputs of LayerNorm
No NaNs detected in any gradient inputs of Dropout
No NaNs detected in any gradient outputs of Dropout
No NaNs detected in any gradient inputs of Conv1D
No NaNs detected in any gradient outputs of Conv1D
No NaNs detected in any gradient inputs of NewGELUActivation
No NaNs detected in any gradient outputs of NewGELUActivation
No NaNs detected in any gradient inputs of Conv1D
No NaNs detected in any gradient outputs of Conv1D
No NaNs detected in any gradient inputs of GPT2MLP
No NaNs detected in any gradient outputs of GPT2M

  File "/gpfs/gibbs/project/dijk/sh2748/conda_envs/fm/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/gpfs/gibbs/project/dijk/sh2748/conda_envs/fm/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/gpfs/gibbs/project/dijk/sh2748/conda_envs/fm/lib/python3.10/site-packages/ipykernel_launcher.py", line 17, in <module>
    app.launch_new_instance()
  File "/gpfs/gibbs/project/dijk/sh2748/conda_envs/fm/lib/python3.10/site-packages/traitlets/config/application.py", line 1043, in launch_instance
    app.start()
  File "/gpfs/gibbs/project/dijk/sh2748/conda_envs/fm/lib/python3.10/site-packages/ipykernel/kernelapp.py", line 736, in start
    self.io_loop.start()
  File "/gpfs/gibbs/project/dijk/sh2748/conda_envs/fm/lib/python3.10/site-packages/tornado/platform/asyncio.py", line 195, in start
    self.asyncio_loop.run_forever()
  File "/gpfs/gibbs/project/dijk/sh2748/conda_envs/fm/lib/python3.10/

RuntimeError: Function 'LinalgEighBackward0' returned nan values in its 0th output.