In [1]:
from omegaconf import OmegaConf
import os
import torch

import hydra
from omegaconf import DictConfig, OmegaConf

# Pytorch lightning imports
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from data.pdb_dataloader import PdbDataModule
from models.flow_module import FlowModule
from experiments import utils as eu
cfg = OmegaConf.load("configs/base.yaml")
_cfg = cfg
_data_cfg = cfg.data
_exp_cfg = cfg.experiment
_datamodule: LightningDataModule = PdbDataModule(_data_cfg)
_datamodule.setup(stage="fit")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
cfg = OmegaConf.load("configs/base.yaml")
_cfg = cfg
_data_cfg = cfg.data
_exp_cfg = cfg.experiment
_datamodule: LightningDataModule = PdbDataModule(_data_cfg)
_datamodule.setup(stage="fit")

# Dataloaders

In [4]:
train_loader = _datamodule.train_dataloader()



In [5]:
bs = next(iter(train_loader))
print(bs.keys())
bs["trans_1"].shape

dict_keys(['aatype', 'res_idx', 'rotmats_1', 'trans_1', 'res_mask', 'csv_idx'])


torch.Size([2, 86, 3])

In [6]:
bs["aatype"]

tensor([[ 3, 19,  9, 16, 19, 18, 11,  3,  4,  2, 18, 16,  7, 13, 15,  7,  7, 10,
         16,  9,  7,  3, 18,  2, 10,  0,  1, 10,  2, 15, 10,  7, 19, 10,  2,  3,
          3,  9, 15, 15, 10,  1,  9, 16,  5,  7, 18,  5,  0,  9, 10, 18,  5,  3,
          3,  2, 13,  7,  7,  0, 15, 16, 19,  9,  2, 15,  3,  2, 15,  4, 10,  2,
         16, 16, 17,  2,  3, 11, 19, 15, 15,  9,  1, 19,  9,  0],
        [11, 19, 13,  6, 19,  8, 19,  1, 14, 11, 11, 10,  0, 19,  6, 14, 11,  7,
         15, 10,  6, 19,  2,  4, 15, 16, 16,  4,  2,  5, 14,  6, 19,  7,  7, 10,
          6, 16, 15, 10,  2, 11,  9, 10, 10,  3,  6,  5,  0,  5, 17, 11,  8, 18,
         10, 19, 15,  2,  9, 15,  8,  3, 16, 19, 10,  5,  4,  8, 13, 16,  4, 15,
          7, 11,  5,  6, 15, 12,  2, 15,  2, 19, 15, 19, 18,  5]])

I'm curious to see if ```res_mask``` are all 1's for all data points and it turns out they do.

In [7]:
from tqdm import tqdm
for bs in tqdm(train_loader):
    res_mask = bs['res_mask']
    assert not (res_mask != 1).any()

# OK, all of res_mask contains all 1's
# From the implementations, res_masks of value 1's don't make any difference, we can simplify some code

100%|██████████| 3938/3938 [00:35<00:00, 109.50it/s]


# Sample condition geodesic paths

## My interpolant

In [8]:
from data.my_interpolant import Interpolant 
interpolant = Interpolant(cfg.interpolant)
interpolant.set_device(bs['res_mask'].device)

Use ```corrupt_batch``` to get trajectories

In [9]:
noisy_batch = interpolant.corrupt_batch(bs)
noisy_batch.keys()

Computing igso3_expansion: 100%|██████████| 1000/1000 [00:32<00:00, 30.41it/s]


dict_keys(['aatype', 'res_idx', 'rotmats_1', 'trans_1', 'res_mask', 'csv_idx', 't', 'trans_t', 'rotmats_t'])

Among them, most useful ones are:
- ```rotmats_t```: shape [B, l, max_num_res, 3, 3]
- ```trans_t```: shape [B, l, max_num_res, 3]
- ```t```: shape [B, l], the time points, created by interpolating 0 and 1
- ```res_mask```: shape [B, l, max_num_res], mask of residues, generally all 1's

In [10]:
noisy_batch['rotmats_t'].shape, noisy_batch['trans_t'].shape, noisy_batch['t'].shape, noisy_batch['res_mask'].shape

(torch.Size([2, 16, 128, 3, 3]),
 torch.Size([2, 16, 128, 3]),
 torch.Size([2, 16]),
 torch.Size([2, 16, 128]))

We convert 3-by-3 rotation matrices to 4d unit quaternions.

Then, we convert the quaternions to 3d and concatenate with the 3d translation vector to form a 6d vector for each residue.

In [11]:
import openfold.utils.rigid_utils as ru
quats = ru.rot_to_quat(noisy_batch['rotmats_t'])
print("shape of quaternions: {}".format(quats.shape))

# Step 1: Check if the first element in the last dimension is negative
condition = quats[..., 0] < 0
# Step 2: Conditionally negate the entire sub-tensor
quats[condition] *= -1
# Step 3: Discard the first element from the last dimension
quats = quats[..., 1:] 

print("Processed quaternion shape: {}".format(quats.shape))

concat_quats_trans = torch.cat((quats, noisy_batch['trans_t']), dim=-1)
print("Concatenated with translation vectors: {}".format(concat_quats_trans.shape))

shape of quaternions: torch.Size([2, 16, 128, 4])
Processed quaternion shape: torch.Size([2, 16, 128, 3])
Concatenated with translation vectors: torch.Size([2, 16, 128, 6])


We concatenate the vectors of each residue of a protein together to form a token for the protein backbone.
- shape [B, l, 6*max_num_res]

In [12]:
backbone = concat_quats_trans.reshape(*concat_quats_trans.shape[:-2], -1)
backbone.shape

torch.Size([2, 16, 768])

Our model

In [13]:
from models.vaellm_model import VAE_GPT2
from transformers import GPT2Model
gpt2 = GPT2Model.from_pretrained('gpt2')
vae_gpt2 = VAE_GPT2(base_model=gpt2, emb_dim=768, z_dim=768)

In [14]:
llm_out = vae_gpt2(backbone)
llm_out.keys()

NaN in model parameters: []


dict_keys(['z_sampled', 'mu', 'log_sigma'])

```z_sampled``` is the representation for the protein backbone. 
- shape [B, l, 6 * max_num_res]

In [15]:
llm_out["z_sampled"].shape

torch.Size([2, 16, 768])

In [16]:
llm_out["z_sampled"]

tensor([[[-1.6961e+00, -1.2123e+00, -5.7993e+00,  ...,  3.0064e+00,
           3.6463e+02, -1.0220e+01],
         [-4.1796e+00, -1.0645e+00, -3.6344e+00,  ...,  2.8437e+00,
           2.1208e+02, -3.0281e+00],
         [-5.5020e+00, -1.0201e+00, -4.7652e+00,  ...,  6.5506e+00,
           1.0661e+02,  3.3202e+00],
         ...,
         [-9.7807e-02, -7.8852e-01, -5.5312e+00,  ...,  2.2035e+00,
           5.5355e+01, -3.3696e+00],
         [-3.5365e+00, -8.2009e-01, -3.8279e+00,  ...,  1.6535e+00,
           1.5321e+02, -6.3814e+00],
         [-2.6931e+00, -8.7471e-01, -2.8582e+00,  ...,  1.6319e+00,
           4.9120e+01, -3.4532e+00]],

        [[-5.3179e+00, -1.9424e+00, -7.7379e+00,  ...,  5.8520e+00,
           4.2114e+02,  4.9155e+00],
         [-3.5607e+00, -1.9420e+00, -5.2748e+00,  ...,  5.7461e+00,
          -8.3747e+02, -6.4710e+00],
         [-4.3106e+00, -1.9539e+00, -5.3806e+00,  ...,  6.7742e+00,
           1.8226e+03,  2.5803e+00],
         ...,
         [-3.9085e+00, -1

We unfold the tensor from backbone level back to residue level
- shape [B, l, max_num_res, 6]

In [17]:
B, l, _ = llm_out["z_sampled"].shape
reshaped_z = llm_out["z_sampled"].reshape(B, l, 128, -1)
reshaped_z.shape

torch.Size([2, 16, 128, 6])

We interpret the 6d vector for each residue back to 3d quaternion and 3d translation.

We append a 1 to each quaternion to create 4d quaternions and normalize to form unit quaternions.

In [18]:
import torch.nn.functional as F
recovered_reduced_quats = reshaped_z[:, :, :, :3]
recovered_trans = reshaped_z[:, :, :, 3:]
print(f"Recovered 3D quaternions x,y,z of (1,x,y,z):{recovered_reduced_quats.shape}")
print(f"Recovered translation vectors: {recovered_trans.shape}")
print("-------")

ones = torch.ones(B, l, 128, 1)
recovered_quats = torch.cat((ones, recovered_reduced_quats), dim=-1)
recovered_quats = F.normalize(recovered_quats, p=2, dim=-1)
print(f"Recovered 4D quaternions: {recovered_quats.shape}")
print("-------")

recovered_rots = ru.quat_to_rot(recovered_quats)
print(f"Recovered rotation matrices: {recovered_rots.shape}")
print(f"Recovered translation vectors: {recovered_trans.shape}")

Recovered 3D quaternions x,y,z of (1,x,y,z):torch.Size([2, 16, 128, 3])
Recovered translation vectors: torch.Size([2, 16, 128, 3])
-------
Recovered 4D quaternions: torch.Size([2, 16, 128, 4])
-------
Recovered rotation matrices: torch.Size([2, 16, 128, 3, 3])
Recovered translation vectors: torch.Size([2, 16, 128, 3])


We would want to store them in a dictionary to make FrameDiff happy

In [19]:
noisy_batch['processed_rotmats_t'] = recovered_rots
noisy_batch['processed_trans_t'] = recovered_trans

noisy_batch['processed_rotmats_t'].shape, noisy_batch['rotmats_t'].shape, noisy_batch['processed_trans_t'].shape, noisy_batch['trans_t'].shape, noisy_batch['t'].shape

(torch.Size([2, 16, 128, 3, 3]),
 torch.Size([2, 16, 128, 3, 3]),
 torch.Size([2, 16, 128, 3]),
 torch.Size([2, 16, 128, 3]),
 torch.Size([2, 16]))

FrameDiff work separately on each protein backbone. They expect a shape of [B, N, *], where * is 3,3 or 3 or other shapes for different variables.

We make use of this. We collapse the minibatch and sequence dimension into one to trick FrameDiff into thinking this is the minibatch dimension. 

We can do this because our LLM-VAE already processed temporal dynamics and all we need to do now is to integrate spatially within each protein backbone.

Important elements:
- ```processed_rotmats_t```: shape [B * l, N, 3, 3]
- ```processed_trans_t```: shape [B * l, N, 3]
- ```t```: shape [B * l, 1]. Note FrameDiff expects one time point for each protein. This is also true for us.
- ```res_mask```: shape [B * l, max_num_res]

In [20]:
B, l, N, _, _ = noisy_batch['rotmats_t'].shape
noisy_batch['processed_rotmats_t'] = noisy_batch['processed_rotmats_t'].reshape(B*l, N, 3, 3)
noisy_batch['processed_trans_t'] = noisy_batch['processed_trans_t'].reshape(B*l, N, 3)
noisy_batch['t'] = noisy_batch['t'].reshape(B*l, 1)
noisy_batch['res_mask'] = noisy_batch['res_mask'].reshape(B*l, -1)

noisy_batch['processed_rotmats_t'].shape, noisy_batch['processed_trans_t'].shape, noisy_batch['t'].shape, noisy_batch['res_mask'].shape

(torch.Size([32, 128, 3, 3]),
 torch.Size([32, 128, 3]),
 torch.Size([32, 1]),
 torch.Size([32, 128]))

Now that we processed everything, FrameDiff is happy.

In [21]:
from models.dev_flow_model import FlowModel
model = FlowModel(cfg.model)

In [22]:
framediff_out = model(noisy_batch)
framediff_out.keys()

dict_keys(['pred_trans', 'pred_rotmats'])

The outputs are
- ```pred_rotmats```: shape [B * l, max_num_res, 3, 3]
- ```pred_trans```: shape [B * l, max_num_res, 3]

In [23]:
framediff_out['pred_rotmats'].shape, framediff_out['pred_trans'].shape

(torch.Size([32, 128, 3, 3]), torch.Size([32, 128, 3]))

## Now we would want to compute loss.

In [24]:
noisy_batch["processed_rotmats_t"].shape, framediff_out['pred_rotmats'].shape, noisy_batch["processed_trans_t"].shape, framediff_out['pred_trans'].shape

(torch.Size([32, 128, 3, 3]),
 torch.Size([32, 128, 3, 3]),
 torch.Size([32, 128, 3]),
 torch.Size([32, 128, 3]))

In [25]:
_exp_cfg = cfg.experiment
training_cfg = _exp_cfg.training
training_cfg

{'min_plddt_mask': None, 'loss': 'auxiliary_loss', 'bb_atom_scale': 0.1, 'trans_scale': 0.1, 'translation_loss_weight': 2.0, 't_normalize_clip': 0.9, 'rotation_loss_weights': 1.0, 'aux_loss_weight': 1.0, 'aux_loss_t_pass': 0.25}

In [26]:
loss_mask = noisy_batch['res_mask']

if training_cfg.min_plddt_mask is not None:
    plddt_mask = noisy_batch['res_plddt'] > training_cfg.min_plddt_mask
    loss_mask *= plddt_mask

num_batch, num_res = loss_mask.shape

In [27]:
loss_mask.shape

torch.Size([32, 128])

We throw away the first time point (t = 0) here. This is because later we need to shift tokens because that's what we do in CausalLM tasks.

In [28]:
# Timestep used for normalization.
print(f"Shape of time points: {noisy_batch['t'].shape}")
t = noisy_batch['t'].reshape(B, l, 1)[:, 1:, :].reshape(-1,1)
print(f"Processed time step shape: {t.shape}")
norm_scale = 1 - torch.min(
    t[..., None], torch.tensor(training_cfg.t_normalize_clip))

t.shape, norm_scale.shape

Shape of time points: torch.Size([32, 1])
Processed time step shape: torch.Size([30, 1])


(torch.Size([30, 1]), torch.Size([30, 1, 1]))

We have the model outputs, the shapes are [B*l, max_num_res, *].

We have the sampled ground truths trajectories, the shapes are [B, l, max_num_res, *].

For CausalLM tasks, we need to shift the predicted and ground truth tokens so we first unfold them to [B, l, max_num_res, *].

In [29]:
# These are the model outputs
pred_trans = framediff_out['pred_trans'].reshape(B, l, 128, 3)
pred_rotmats = framediff_out['pred_rotmats'].reshape(B, l, 128, 3, 3)

# These are the sampled ground truths.
gt_trans = noisy_batch['trans_t']
gt_rotmats = noisy_batch['rotmats_t']

pred_rotmats.shape, gt_rotmats.shape, pred_trans.shape, gt_trans.shape

(torch.Size([2, 16, 128, 3, 3]),
 torch.Size([2, 16, 128, 3, 3]),
 torch.Size([2, 16, 128, 3]),
 torch.Size([2, 16, 128, 3]))

Now, we shift the prediction and ground truth.

In [30]:
shifted_pred_trans  = pred_trans[:, :-1, :, :]
shifted_pred_rotmats = pred_rotmats[:, :-1, :, :, :]
shifted_gt_trans = gt_trans[:, 1:, :, :]
shifted_gt_rotmats = gt_rotmats[:, 1:, :, :, :]

shifted_pred_rotmats.shape, shifted_gt_rotmats.shape, shifted_pred_trans.shape, shifted_gt_trans.shape

(torch.Size([2, 15, 128, 3, 3]),
 torch.Size([2, 15, 128, 3, 3]),
 torch.Size([2, 15, 128, 3]),
 torch.Size([2, 15, 128, 3]))

Now, we collapse the first two dimensions to make ```all_atom``` happy.

In [31]:
flat_shifted_pred_trans = shifted_pred_trans.reshape(B*(l-1), 128, 3)
flat_shifted_pred_rotmats = shifted_pred_rotmats.reshape(B*(l-1), 128, 3, 3)

flat_shifted_gt_trans = shifted_gt_trans.reshape(B*(l-1), 128, 3)
flat_shifted_gt_rotmats = shifted_gt_rotmats.reshape(B*(l-1), 128, 3, 3)

flat_shifted_pred_rotmats.shape, flat_shifted_gt_rotmats.shape, flat_shifted_pred_trans.shape, flat_shifted_gt_trans.shape

(torch.Size([30, 128, 3, 3]),
 torch.Size([30, 128, 3, 3]),
 torch.Size([30, 128, 3]),
 torch.Size([30, 128, 3]))

In [32]:
from data import all_atom
gt_bb_atoms = all_atom.to_atom37(flat_shifted_gt_trans, flat_shifted_gt_rotmats)[:, :, :3] 
pred_bb_atoms = all_atom.to_atom37(flat_shifted_pred_trans, flat_shifted_pred_rotmats)[:, :, :3]

gt_bb_atoms *= training_cfg.bb_atom_scale / norm_scale[..., None]
pred_bb_atoms *= training_cfg.bb_atom_scale / norm_scale[..., None]

In [33]:
gt_bb_atoms.shape, pred_bb_atoms.shape

(torch.Size([30, 128, 3, 3]), torch.Size([30, 128, 3, 3]))

In [34]:
loss_denom = torch.sum(loss_mask, dim=-1, dtype=torch.float).mean() * 3
bb_atom_loss = torch.sum(
    (gt_bb_atoms - pred_bb_atoms) ** 2,
    dim=(-1, -2, -3)
) / loss_denom

bb_atom_loss.shape

torch.Size([30])

In [35]:
num_batch = gt_bb_atoms.shape[0]
# Pairwise distance loss
gt_flat_atoms = gt_bb_atoms.reshape([num_batch, num_res*3, 3])
gt_pair_dists = torch.linalg.norm(
    gt_flat_atoms[:, :, None, :] - gt_flat_atoms[:, None, :, :], dim=-1)
pred_flat_atoms = pred_bb_atoms.reshape([num_batch, num_res*3, 3])
pred_pair_dists = torch.linalg.norm(
    pred_flat_atoms[:, :, None, :] - pred_flat_atoms[:, None, :, :], dim=-1)

print(gt_pair_dists.shape, pred_pair_dists.shape)

torch.Size([30, 384, 384]) torch.Size([30, 384, 384])


In [36]:
flat_loss_mask = torch.tile(loss_mask[:, :, None], (1, 1, 3))[B:,:,:] # change the shape because we shifted tokens, all entries of loss masks are 1 so don't matter, we throw away B tokens
flat_loss_mask = flat_loss_mask.reshape([num_batch, num_res*3])
flat_res_mask = torch.tile(loss_mask[:, :, None], (1, 1, 3))[B:,:,:] # change the shape because we shifted tokens, all entries of loss masks are 1 so don't matter
flat_res_mask = flat_res_mask.reshape([num_batch, num_res*3])

In [37]:
flat_loss_mask.shape, flat_res_mask.shape

(torch.Size([30, 384]), torch.Size([30, 384]))

In [38]:
gt_pair_dists = gt_pair_dists * flat_loss_mask[..., None]
pred_pair_dists = pred_pair_dists * flat_loss_mask[..., None]
pair_dist_mask = flat_loss_mask[..., None] * flat_res_mask[:, None, :]

In [39]:
dist_mat_loss = torch.sum(
    (gt_pair_dists - pred_pair_dists)**2 * pair_dist_mask,
    dim=(1, 2))
dist_mat_loss /= (torch.sum(pair_dist_mask, dim=(1, 2)) - num_res)

auxiliary_loss = (bb_atom_loss + dist_mat_loss) * (
    t[:, 0]> training_cfg.aux_loss_t_pass
)
auxiliary_loss *= _exp_cfg.training.aux_loss_weight

In [40]:
auxiliary_loss

tensor([0.0000e+00, 0.0000e+00, 0.0000e+00, 1.8200e+02, 2.6897e+02, 2.9347e+02,
        2.8871e+02, 2.5675e+02, 4.1631e+02, 6.7745e+02, 1.1647e+03, 2.1560e+03,
        4.0746e+03, 5.8246e+03, 7.0865e+03, 0.0000e+00, 0.0000e+00, 0.0000e+00,
        1.0443e+04, 1.2287e+04, 2.4929e+04, 1.6558e+04, 1.3572e+04, 1.7375e+04,
        3.1194e+04, 3.6437e+04, 1.0511e+05, 3.7607e+05, 3.2303e+05, 2.4006e+05],
       grad_fn=<MulBackward0>)

In [41]:
kl_div = (1 + 2 * llm_out["log_sigma"] - llm_out["mu"].pow(2) - llm_out["log_sigma"].exp().pow(2))[:, :-1, :] # Throw away the last mu and sigma because we're not using it to predict
kl_div = - 0.5 * kl_div.sum(dim=-1).mean()

In [42]:
mse_loss = F.mse_loss(flat_shifted_pred_trans, flat_shifted_gt_trans) + F.mse_loss(flat_shifted_pred_rotmats, flat_shifted_gt_rotmats)

In [43]:
loss = mse_loss+auxiliary_loss.sum()+kl_div

In [44]:
loss

tensor(12136346., grad_fn=<AddBackward0>)

In [45]:
optimizer = torch.optim.AdamW(model.parameters())

In [46]:
optimizer.zero_grad()
loss.backward()
optimizer.step()

In [47]:
from models.my_flow_module import FlowModule

_datamodule: LightningDataModule = PdbDataModule(_data_cfg)
_model: LightningModule = FlowModule(_cfg)

In [48]:
# logger = WandbLogger(
#                 _exp_cfg.wandb,
#             )
import GPUtil
devices = GPUtil.getAvailable(order='memory', limit = 8)[:_exp_cfg.num_devices]
trainer = Trainer(
    **_exp_cfg.trainer,
    #callbacks=callbacks,
    # logger=logger,
    use_distributed_sampler=False,
    enable_progress_bar=True,
    enable_model_summary=True,
    devices=devices,
)
trainer.fit(
    model=self._model,
    datamodule=self._datamodule,
    ckpt_path=self._exp_cfg.warm_start
)

  rank_zero_warn(


MisconfigurationException: `Trainer(strategy='ddp')` is not compatible with an interactive environment. Run your code as a script, or choose one of the compatible strategies: `Fabric(strategy='dp'|'ddp_notebook')`. In case you are spawning processes yourself, make sure to include the Trainer creation inside the worker function.

# Inference code

In [49]:
def sample(
            self,
            num_batch,
            num_res,
            model,
        ):
        res_mask = torch.ones(num_batch, num_res, device=self._device)

        # Set-up initial prior samples
        trans_0 = _centered_gaussian(
            num_batch, num_res, self._device) * du.NM_TO_ANG_SCALE
        rotmats_0 = _uniform_so3(num_batch, num_res, self._device)
        batch = {
            'res_mask': res_mask,
        }

        # Set-up time
        ts = torch.linspace(
            self._cfg.min_t, 1.0, self._sample_cfg.num_timesteps)
        t_1 = ts[0]

        prot_traj = [(trans_0, rotmats_0)]
        clean_traj = []
        for t_2 in ts[1:]:

            # Run model.
            trans_t_1, rotmats_t_1 = prot_traj[-1]
            batch['trans_t'] = trans_t_1
            batch['rotmats_t'] = rotmats_t_1
            t = torch.ones((num_batch, 1), device=self._device) * t_1
            batch['t'] = t
            with torch.no_grad():
                model_out = model(batch)

            # Process model output.
            pred_trans_1 = model_out['pred_trans']
            pred_rotmats_1 = model_out['pred_rotmats']
            clean_traj.append(
                (pred_trans_1.detach().cpu(), pred_rotmats_1.detach().cpu())
            )
            if self._cfg.self_condition:
                batch['trans_sc'] = pred_trans_1

            # Take reverse step
            d_t = t_2 - t_1
            trans_t_2 = self._trans_euler_step(
                d_t, t_1, pred_trans_1, trans_t_1)
            rotmats_t_2 = self._rots_euler_step(
                d_t, t_1, pred_rotmats_1, rotmats_t_1)
            prot_traj.append((trans_t_2, rotmats_t_2))
            t_1 = t_2

        # We only integrated to min_t, so need to make a final step
        t_1 = ts[-1]
        trans_t_1, rotmats_t_1 = prot_traj[-1]
        batch['trans_t'] = trans_t_1
        batch['rotmats_t'] = rotmats_t_1
        batch['t'] = torch.ones((num_batch, 1), device=self._device) * t_1
        with torch.no_grad():
            model_out = model(batch)
        pred_trans_1 = model_out['pred_trans']
        pred_rotmats_1 = model_out['pred_rotmats']
        clean_traj.append(
            (pred_trans_1.detach().cpu(), pred_rotmats_1.detach().cpu())
        )
        prot_traj.append((pred_trans_1, pred_rotmats_1))

        # Convert trajectories to atom37.
        atom37_traj = all_atom.transrot_to_atom37(prot_traj, res_mask)
        clean_atom37_traj = all_atom.transrot_to_atom37(clean_traj, res_mask)
        return atom37_traj, clean_atom37_traj, clean_traj


In [50]:
from omegaconf import OmegaConf
import os
import torch

import hydra
from omegaconf import DictConfig, OmegaConf

# Pytorch lightning imports
from pytorch_lightning import LightningDataModule, LightningModule, Trainer
from pytorch_lightning.loggers.wandb import WandbLogger
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint

from data.pdb_dataloader import PdbDataModule
from models.flow_module import FlowModule
from experiments import utils as eu
cfg = OmegaConf.load("configs/base.yaml")
_cfg = cfg
_data_cfg = cfg.data
_exp_cfg = cfg.experiment
_datamodule: LightningDataModule = PdbDataModule(_data_cfg)
_datamodule.setup(stage="fit")

In [51]:
from data import utils as du
from scipy.spatial.transform import Rotation
def _centered_gaussian(num_batch, num_res, device):
    noise = torch.randn(num_batch, num_res, 3, device=device)
    return noise - torch.mean(noise, dim=-2, keepdims=True)

def _uniform_so3(num_batch, num_res, device):
    return torch.tensor(
        Rotation.random(num_batch*num_res).as_matrix(),
        device=device,
        dtype=torch.float32,
    ).reshape(num_batch, num_res, 3, 3)

num_batch=2
num_res=20
_device = "cuda"
res_mask = torch.ones(num_batch, num_res, device=_device)
trans_0 = _centered_gaussian(
            num_batch, num_res, _device) * du.NM_TO_ANG_SCALE
rotmats_0 = _uniform_so3(num_batch, num_res, _device)
batch = {
    'res_mask': res_mask,
}

In [52]:
trans_0 = trans_0[:, :, None,:]

In [53]:
trans_0.shape # [B, N, 3] -> [B, N, 1, 3]

torch.Size([2, 20, 1, 3])

In [54]:
rotmats_0 = rotmats_0[:,:,None,:, :]

In [55]:
rotmats_0.shape # [B, N, 3, 3] -> [B, N, 1, 3, 3]

torch.Size([2, 20, 1, 3, 3])

In [56]:
batch = {
            'res_mask': res_mask,
        }
t = torch.linspace(0,1,16)[None, :].repeat(num_batch,1).to("cuda")
t.shape

torch.Size([2, 16])

In [57]:
import torch.nn.functional as F
def _pad_trans(trans_t):
    '''
            Pad rotmats_t from [B,N,l,3], 
                where N is the actual number of residues (and N \leq max_num_res) 
            to [B, max_num_res,l,3] with all 0's
    '''
    trans_t_padded = F.pad(trans_t, (0, 0, 0, 0, 0, 128 - trans_t.shape[1]), "constant", 0)
    return trans_t_padded

def _pad_rotmats(rotmats_t):
    '''
        Pad rotmats_t from [B,N,l,3,3], 
                where N is the actual number of residues (and N \leq max_num_res) 
            to [B, max_num_res,l,3,3] with all 0's
    '''
    rotmats_t_padded = F.pad(rotmats_t, (0, 0, 0, 0, 0, 0, 0, 128 - rotmats_t.shape[1]), "constant", 0)
    return rotmats_t_padded

def _pad_res_mask(self, res_mask):
    '''
            Pad rotmats from [B,N],
                where N is the actual number of residues (and N \leq max_num_res) 
            to [B, max_num_res] with all 1's

            Note: pad with 1's not 0's
    '''
    res_mask_padded = F.pad(res_mask, (0, 128 - res_mask.shape[1]), mode='constant', value=1)
    return res_mask_padded

In [58]:
rotmats_0 = _pad_rotmats(rotmats_0)
rotmats_0.shape

torch.Size([2, 128, 1, 3, 3])

In [59]:
trans_0 = _pad_trans(trans_0)
trans_0.shape

torch.Size([2, 128, 1, 3])

In [60]:
rotmats_0 = rotmats_0.permute(0, 2, 1, 3, 4)
rotmats_0.shape

torch.Size([2, 1, 128, 3, 3])

In [61]:
trans_0 = trans_0.permute(0, 2, 1, 3)
trans_0.shape

torch.Size([2, 1, 128, 3])

In [62]:
batch['res_mask'] = batch['res_mask'][:,None,:].repeat(1, 16, 1) # replace 16

In [63]:
batch["trans_t"] = trans_0 
batch["rotmats_t"] = rotmats_0
batch['t'] = t

In [64]:
batch["trans_t"].shape, batch["rotmats_t"].shape

(torch.Size([2, 1, 128, 3]), torch.Size([2, 1, 128, 3, 3]))

In [65]:
from models.together_model import ProteinVAELLMmodel
model = ProteinVAELLMmodel(cfg).to("cuda")

In [66]:
with torch.no_grad():   
    out = model.generate(batch)

NaN in model parameters: []


In [67]:
out.keys()

dict_keys(['pred_trans', 'pred_rotmats'])

In [68]:
B = num_batch
l = 16
N = num_res

In [69]:
out["pred_trans"] = out["pred_trans"].reshape(B, l, N, 3)
out["pred_rotmats"] = out["pred_rotmats"].reshape(B, l, N, 3, 3)

In [70]:
out["pred_trans"].shape, out["pred_rotmats"].shape

(torch.Size([2, 16, 20, 3]), torch.Size([2, 16, 20, 3, 3]))

In [71]:
protein_trajectory = []
for i in range(out["pred_trans"].shape[1]):
    protein_trajectory.append(
        (out["pred_trans"][:, i, :, :].detach().cpu(), out["pred_rotmats"][:, i, :, :, :].detach().cpu())
    )

In [72]:
from data import all_atom
atom37_traj = all_atom.transrot_to_atom37(protein_trajectory, res_mask)

In [73]:
atom37_traj[0].shape

torch.Size([2, 20, 37, 3])

In [74]:
from data.my_interpolant import Interpolant 
interpolant = Interpolant(cfg.interpolant)
interpolant.set_device("cuda")
out_sample = interpolant.sample(num_batch=2, num_res=20, model=model)[0][-1].numpy()

NaN in model parameters: []


In [75]:
out_sample.shape

(2, 20, 37, 3)

In [76]:
from transformers import GPT2Model
gpt2 = GPT2Model.from_pretrained('gpt2')

In [77]:
gpt2

GPT2Model(
  (wte): Embedding(50257, 768)
  (wpe): Embedding(1024, 768)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-11): 12 x GPT2Block(
      (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
)