In [1]:
import os, sys
import shutil
import json
import logging
from pathlib import Path
import multiprocessing
import argparse
import functools
from datetime import datetime
from typing import *

import numpy as np
from matplotlib import pyplot as plt

import torch
import torch.nn as nn
from torch.utils.data import Dataset, Subset
from torch.utils.data.dataloader import DataLoader
import torch.nn.functional as F
from einops import rearrange

import pytorch_lightning as pl
from pytorch_lightning.strategies.ddp import DDPStrategy

from transformers import BertConfig

from foldingdiff import datasets
from foldingdiff import modelling
from foldingdiff import losses
from foldingdiff import beta_schedules
from foldingdiff import plotting
from foldingdiff import utils
from foldingdiff import custom_metrics as cm

assert torch.cuda.is_available(), "Requires CUDA to train"
# reproducibility
torch.manual_seed(6489)
# torch.use_deterministic_algorithms(True)
torch.backends.cudnn.benchmark = False

# Define some typing literals
ANGLES_DEFINITIONS = Literal[
    "canonical", "canonical-full-angles", "canonical-minimal-angles", "cart-coords"
]

  from .autonotebook import tqdm as notebook_tqdm


In [32]:
angles_definitions = "canonical-full-angles"
dataset_key = "cath"
max_seq_len = 128
min_seq_len = 60
seq_trim_strategy = "leftalign"
toy = False
train_only = False

clean_dset_class = {
        "canonical": datasets.CathCanonicalAnglesDataset,
        "canonical-full-angles": datasets.CathCanonicalAnglesOnlyDataset,
        "canonical-minimal-angles": datasets.CathCanonicalMinimalAnglesDataset,
        "cart-coords": datasets.CathCanonicalCoordsDataset,
}[angles_definitions] # select one clean dataset class
logging.info(f"Clean dataset class: {clean_dset_class}")

splits = ["train"] if train_only else ["train", "validation", "test"]
logging.info(f"Creating data splits: {splits}")
clean_dsets = [
    clean_dset_class(
        pdbs=dataset_key,
        split=s,
        pad=max_seq_len,
        min_length=min_seq_len,
        trim_strategy=seq_trim_strategy,
        zero_center=False if angles_definitions == "cart-coords" else True,
        toy=toy,
    )
    for s in splits
]
train_dataset, valid_dataset, test_dataset = clean_dsets

In [33]:
len(train_dataset), len(valid_dataset), len(test_dataset)

(22780, 2847, 2849)

In [34]:
def process_dataset(ds):
    data_list = [ds[i] for i in range(len(ds))]
    df = pd.DataFrame(data_list)
    df['lengths'] = df['lengths'].apply(lambda x: x.item())
    df = df.sort_values(by='lengths')
    filtered_df = df[(df['lengths'] >= 60) & (df['lengths'] <= 128)]
    return filtered_df

In [35]:
filtered_train_ds = process_dataset(train_dataset)
filtered_val_ds = process_dataset(valid_dataset)
filtered_test_ds = process_dataset(test_dataset)

In [36]:
filtered_train_ds.head()

Unnamed: 0,angles,coords,attn_mask,position_ids,lengths
13945,"[[tensor(0.), tensor(2.9529), tensor(0.0445), ...","[[tensor(22.5130), tensor(72.0890), tensor(63....","[tensor(1.), tensor(1.), tensor(1.), tensor(1....","[tensor(0), tensor(1), tensor(2), tensor(3), t...",60
15539,"[[tensor(0.), tensor(2.6073), tensor(0.0185), ...","[[tensor(7.4850), tensor(15.6570), tensor(73.4...","[tensor(1.), tensor(1.), tensor(1.), tensor(1....","[tensor(0), tensor(1), tensor(2), tensor(3), t...",60
6425,"[[tensor(0.), tensor(1.6765), tensor(0.0135), ...","[[tensor(8.8570), tensor(-1.3570), tensor(11.3...","[tensor(1.), tensor(1.), tensor(1.), tensor(1....","[tensor(0), tensor(1), tensor(2), tensor(3), t...",60
11673,"[[tensor(0.), tensor(2.3735), tensor(0.0204), ...","[[tensor(105.3070), tensor(10.6050), tensor(64...","[tensor(1.), tensor(1.), tensor(1.), tensor(1....","[tensor(0), tensor(1), tensor(2), tensor(3), t...",60
6442,"[[tensor(0.), tensor(1.1725), tensor(-0.0099),...","[[tensor(30.4300), tensor(1.1830), tensor(2.79...","[tensor(1.), tensor(1.), tensor(1.), tensor(1....","[tensor(0), tensor(1), tensor(2), tensor(3), t...",60


In [37]:
def save_dataset(df, split:str):
    angles = torch.stack(df['angles'].tolist())
    lengths = torch.tensor(df['lengths'].tolist())
    torch.save(angles, f"protein_angles_{split}.pt")
    torch.save(lengths, f"protein_lengths_{split}.pt")

In [38]:
save_dataset(filtered_train_ds, "train")
save_dataset(filtered_val_ds, "val")
save_dataset(filtered_test_ds, "test")

In [39]:
class CustomDataset(Dataset):
    def __init__(self, data_path, lengths_path):
        self.data = torch.load(data_path)
        self.lengths = torch.load(lengths_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.lengths[idx]

train_ds = CustomDataset("protein_angles_train.pt", "protein_lengths_train.pt")
val_ds = CustomDataset("protein_angles_val.pt", "protein_lengths_val.pt")
test_ds = CustomDataset("protein_angles_test.pt", "protein_lengths_test.pt")

  self.data = torch.load(data_path)
  self.lengths = torch.load(lengths_path)


In [42]:
train_ds[0][0].shape, train_ds[0][1]

(torch.Size([128, 6]), tensor(60))

In [66]:
import torch
from torch.utils.data import DataLoader, Dataset, Sampler
class LengthBasedSampler(Sampler):
    def __init__(self, lengths, batch_size):
        self.lengths = lengths
        self.batch_size = batch_size
        
        # Group indices by length
        self.indices_by_length = {}
        for idx, length in enumerate(lengths):
            if length.item() not in self.indices_by_length:
                self.indices_by_length[length.item()] = []
            self.indices_by_length[length.item()].append(idx)

        # Create a list of batches where all sequences in each batch have the same length
        self.batches = []
        for length, indices in self.indices_by_length.items():
            for i in range(0, len(indices), batch_size):
                self.batches.append(indices[i:i + batch_size])
        
    def __iter__(self):
        # Shuffle the batches for randomness
        return iter(self.batches)
    
    def __len__(self):
        return len(self.batches)

In [105]:
def verify(lengths_path, batch_size, ds, df):
    sampler = LengthBasedSampler(torch.load(lengths_path), batch_size)
    loader = DataLoader(ds, batch_sampler=sampler)
    lengths = {key: len(value) for key, value in sampler.indices_by_length.items()}
    original_lengths = df['lengths'].value_counts().to_dict()
    assert lengths == original_lengths, "Samples don't match"
    for batch in loader:
        assert len(torch.unique(batch[1])) == 1, "Non unique length in the batch"
        assert len(batch[0]) == len(batch[1])
    print("Verification successful!")

In [107]:
verify(lengths_path="protein_lengths_train.pt", batch_size=32, ds=train_ds, df=filtered_train_ds)
verify(lengths_path="protein_lengths_val.pt", batch_size=32, ds=val_ds, df=filtered_val_ds)
verify(lengths_path="protein_lengths_test.pt", batch_size=32, ds=test_ds, df=filtered_test_ds)

  sampler = LengthBasedSampler(torch.load(lengths_path), batch_size)


Verification successful!
Verification successful!
Verification successful!


In [109]:
sampler = LengthBasedSampler(torch.load("protein_lengths_train.pt"), 32)
loader = DataLoader(train_ds, batch_sampler=sampler)
batch = next(iter(loader))

  sampler = LengthBasedSampler(torch.load("protein_lengths_train.pt"), 32)


In [110]:
data, lengths = batch

In [112]:
data.shape

torch.Size([32, 128, 6])

In [114]:
lengths

tensor([60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60,
        60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60, 60])