In [1]:
from omegaconf import OmegaConf
import os
import torch

import hydra
from omegaconf import DictConfig, OmegaConf
from models.together_model import ProteinVAELLMmodel
from data import all_atom
# Pytorch lightning imports
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from data.pdb_dataloader import PdbDataModule
from models.flow_module import FlowModule
from experiments import utils as eu
import openfold.utils.rigid_utils as ru
import torch.nn.functional as F
cfg = OmegaConf.load("configs/base.yaml")
_cfg = cfg
_data_cfg = cfg.data
_exp_cfg = cfg.experiment
_datamodule: LightningDataModule = PdbDataModule(_data_cfg)
_datamodule.setup(stage="fit")

  from .autonotebook import tqdm as notebook_tqdm


# Set stuff up

In [2]:
cfg = OmegaConf.load("configs/base.yaml")
_cfg = cfg
_data_cfg = cfg.data
_exp_cfg = cfg.experiment
training_cfg = _exp_cfg.training
_datamodule: LightningDataModule = PdbDataModule(_data_cfg)
_datamodule.setup(stage="fit")

In [3]:
train_loader = _datamodule.train_dataloader()



In [4]:
from data.my_interpolant import Interpolant 
interpolant = Interpolant(cfg.interpolant)
interpolant.set_device("cuda")

In [5]:
model = ProteinVAELLMmodel(cfg).to("cuda")
model.attach_backward_hooks() # attach backward hooks

In [6]:
optimizer = torch.optim.AdamW(
            params=model.parameters(),
            lr=0.0001)

# Train stuff

In [7]:
batch = next(iter(train_loader))
print(batch["aatype"])
for key, value in batch.items():
    batch[key] = value.to("cuda")
noisy_batch = interpolant.corrupt_batch(batch)

tensor([[ 3, 19,  9, 16, 19, 18, 11,  3,  4,  2, 18, 16,  7, 13, 15,  7,  7, 10,
         16,  9,  7,  3, 18,  2, 10,  0,  1, 10,  2, 15, 10,  7, 19, 10,  2,  3,
          3,  9, 15, 15, 10,  1,  9, 16,  5,  7, 18,  5,  0,  9, 10, 18,  5,  3,
          3,  2, 13,  7,  7,  0, 15, 16, 19,  9,  2, 15,  3,  2, 15,  4, 10,  2,
         16, 16, 17,  2,  3, 11, 19, 15, 15,  9,  1, 19,  9,  0],
        [11, 19, 13,  6, 19,  8, 19,  1, 14, 11, 11, 10,  0, 19,  6, 14, 11,  7,
         15, 10,  6, 19,  2,  4, 15, 16, 16,  4,  2,  5, 14,  6, 19,  7,  7, 10,
          6, 16, 15, 10,  2, 11,  9, 10, 10,  3,  6,  5,  0,  5, 17, 11,  8, 18,
         10, 19, 15,  2,  9, 15,  8,  3, 16, 19, 10,  5,  4,  8, 13, 16,  4, 15,
          7, 11,  5,  6, 15, 12,  2, 15,  2, 19, 15, 19, 18,  5]])


In [8]:
B, l, num_res = noisy_batch['res_mask'].shape
loss_mask = noisy_batch['res_mask'].reshape(B*l, num_res)
        
if training_cfg.min_plddt_mask is not None:
    plddt_mask = noisy_batch['res_plddt'] > training_cfg.min_plddt_mask
    loss_mask *= plddt_mask
        

# Ground truth labels
gt_trans = noisy_batch['trans_t']
gt_rotmats = noisy_batch['rotmats_t'] 

# Model output predictions.

framediff_out = model(noisy_batch)
pred_trans = framediff_out["pred_T"]['pred_trans'].reshape(B, l, 128, 3)
pred_rotmats = framediff_out["pred_T"]['pred_rotmats'].reshape(B, l, 128, 3, 3)
        
# Shift for CausalLM
shifted_pred_trans  = pred_trans[:, :-1, :, :]
shifted_pred_rotmats = pred_rotmats[:, :-1, :, :, :]
shifted_gt_trans = gt_trans[:, 1:, :, :]
shifted_gt_rotmats = gt_rotmats[:, 1:, :, :, :]

# Reshape back to [B*(l-1), 128, *]
flat_shifted_pred_trans = shifted_pred_trans.reshape(B*(l-1), _data_cfg.dataset.max_num_res, 3)
flat_shifted_pred_rotmats = shifted_pred_rotmats.reshape(B*(l-1), _data_cfg.dataset.max_num_res, 3, 3)

flat_shifted_gt_trans = shifted_gt_trans.reshape(B*(l-1), _data_cfg.dataset.max_num_res, 3)
flat_shifted_gt_rotmats = shifted_gt_rotmats.reshape(B*(l-1), _data_cfg.dataset.max_num_res, 3, 3)

# Timestep used for normalization.
t = noisy_batch['t'].reshape(B, l, 1)[:, 1:, :].reshape(-1,1) # We throw away the first time points
norm_scale = 1 - torch.min(
    t[..., None], torch.tensor(training_cfg.t_normalize_clip))
        
        

# Backbone atom loss
gt_bb_atoms = all_atom.to_atom37(flat_shifted_gt_trans, flat_shifted_gt_rotmats)[:, :, :3] 
pred_bb_atoms = all_atom.to_atom37(flat_shifted_pred_trans, flat_shifted_pred_rotmats)[:, :, :3]

gt_bb_atoms *= training_cfg.bb_atom_scale / norm_scale[..., None]
pred_bb_atoms *= training_cfg.bb_atom_scale / norm_scale[..., None]
                
        
loss_denom = torch.sum(loss_mask, dim=-1, dtype=torch.float).mean() * 3 # Added a mean here, this doesn'y matter since our mask is all 1's
bb_atom_loss = torch.sum(
    (gt_bb_atoms - pred_bb_atoms) ** 2,
    dim=(-1, -2, -3)
) / loss_denom

# Pairwise distance loss
num_batch = gt_bb_atoms.shape[0]
gt_flat_atoms = gt_bb_atoms.reshape([num_batch, num_res*3, 3])
gt_pair_dists = torch.linalg.norm(
    gt_flat_atoms[:, :, None, :] - gt_flat_atoms[:, None, :, :], dim=-1)
pred_flat_atoms = pred_bb_atoms.reshape([num_batch, num_res*3, 3])
pred_pair_dists = torch.linalg.norm(
    pred_flat_atoms[:, :, None, :] - pred_flat_atoms[:, None, :, :], dim=-1)

flat_loss_mask = torch.tile(loss_mask[:, :, None], (1, 1, 3))[B:,:,:] # change the shape because we shifted tokens, all entries of loss masks are 1 so don't matter, we throw away B tokens
flat_loss_mask = flat_loss_mask.reshape([num_batch, num_res*3])
flat_res_mask = torch.tile(loss_mask[:, :, None], (1, 1, 3))[B:,:,:] # change the shape because we shifted tokens, all entries of loss masks are 1 so don't matter
flat_res_mask = flat_res_mask.reshape([num_batch, num_res*3])

gt_pair_dists = gt_pair_dists * flat_loss_mask[..., None]
pred_pair_dists = pred_pair_dists * flat_loss_mask[..., None]
pair_dist_mask = flat_loss_mask[..., None] * flat_res_mask[:, None, :]

dist_mat_loss = torch.sum(
    (gt_pair_dists - pred_pair_dists)**2 * pair_dist_mask,
    dim=(1, 2))
dist_mat_loss /= (torch.sum(pair_dist_mask, dim=(1, 2)) - num_res)

auxiliary_loss = (bb_atom_loss + dist_mat_loss) * (
    t[:, 0]> training_cfg.aux_loss_t_pass
)
auxiliary_loss *= _exp_cfg.training.aux_loss_weight
        
auxiliary_loss = auxiliary_loss.mean()
kl_div = (1 + 2 * framediff_out["vae_log_sigma"] - framediff_out["vae_mu"].pow(2) - framediff_out["vae_log_sigma"].exp().pow(2))[:, :-1, :] # Throw away the last mu and sigma because we're not using it to predict
kl_div = - 0.5 * kl_div.sum(dim=-1).mean()
mse_loss = F.mse_loss(flat_shifted_pred_trans, flat_shifted_gt_trans) + F.mse_loss(flat_shifted_pred_rotmats, flat_shifted_gt_rotmats)

Parameters with NaNs: []
Parameters without NaNs: ['framediff_model.node_embedder.linear.weight', 'framediff_model.node_embedder.linear.bias', 'framediff_model.edge_embedder.linear_s_p.weight', 'framediff_model.edge_embedder.linear_s_p.bias', 'framediff_model.edge_embedder.linear_relpos.weight', 'framediff_model.edge_embedder.linear_relpos.bias', 'framediff_model.edge_embedder.edge_embedder.0.weight', 'framediff_model.edge_embedder.edge_embedder.0.bias', 'framediff_model.edge_embedder.edge_embedder.2.weight', 'framediff_model.edge_embedder.edge_embedder.2.bias', 'framediff_model.edge_embedder.edge_embedder.4.weight', 'framediff_model.edge_embedder.edge_embedder.4.bias', 'framediff_model.edge_embedder.edge_embedder.5.weight', 'framediff_model.edge_embedder.edge_embedder.5.bias', 'framediff_model.trunk.ipa_0.head_weights', 'framediff_model.trunk.ipa_0.linear_q.weight', 'framediff_model.trunk.ipa_0.linear_q.bias', 'framediff_model.trunk.ipa_0.linear_kv.weight', 'framediff_model.trunk.ipa_



In [None]:
loss = mse_loss + auxiliary_loss + kl_div
print(loss)
optimizer.zero_grad()
loss.backward()
optimizer.step()