In [1]:
"""
Train a baseline autoregressive model that uses a causal LM approach to generating
series of angles
"""
import sys
# caution: path[0] is reserved for script path (or '' in REPL)
sys.path.append('/bin/')
sys.path.append('bin')
sys.path.append('bin/')
sys.path.append('/bin')

import os
from pathlib import Path
import json
import argparse
from datetime import datetime
import logging
import multiprocessing
from typing import *

import numpy as np
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.strategies.ddp import DDPStrategy

from transformers import BertConfig

from foldingdiff import datasets, modelling, losses, plotting, utils
from foldingdiff import custom_metrics as cm

from train import ANGLES_DEFINITIONS, build_callbacks, record_args_and_metadata

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributions as dist
import torch.optim as optim
from torch import nn, einsum
from torch.utils.data import DataLoader

#my codes
import transformer
import utilities 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def get_train_valid_test_sets(
    angles_definitions: ANGLES_DEFINITIONS = "canonical-full-angles",
    max_seq_len: int = 512,
    min_seq_len: int = 0,
    seq_trim_strategy: datasets.TRIM_STRATEGIES = "leftalign",
) -> Tuple[
    datasets.AutoregressiveCausalDataset,
    datasets.AutoregressiveCausalDataset,
    datasets.AutoregressiveCausalDataset,
]:
    """
    Get the train/valid/test splits using the autoregressive wrapper on the datsets
    """

    clean_dset_class = {
        "canonical": datasets.CathCanonicalAnglesDataset,
        "canonical-full-angles": datasets.CathCanonicalAnglesOnlyDataset,
        "canonical-minimal-angles": datasets.CathCanonicalMinimalAnglesDataset,
        "cart-coords": datasets.CathCanonicalCoordsDataset,
    }[angles_definitions]
    logging.info(f"Clean dataset class: {clean_dset_class}")

    splits = ["train", "validation", "test"]
    logging.info(f"Creating data splits: {splits}")
    clean_dsets = [
        clean_dset_class(
            split=s,
            pad=max_seq_len,
            min_length=min_seq_len,
            trim_strategy=seq_trim_strategy,
            zero_center=False if angles_definitions == "cart-coords" else True,
        )
        for s in splits
    ]

    # Set the training set mean to the validation set mean
    if len(clean_dsets) > 1 and clean_dsets[0].means is not None:
        logging.info(f"Updating valid/test mean offset to {clean_dsets[0].means}")
        for i in range(1, len(clean_dsets)):
            clean_dsets[i].means = clean_dsets[0].means

    causal_dsets = [
        datasets.AutoregressiveCausalDataset(
            d, dset_key="coords" if angles_definitions == "cart-coords" else "angles"
        )
        for d in clean_dsets
    ]
    for dsname, ds in zip(splits, causal_dsets):
        logging.info(f"{dsname}: {ds}")
    return causal_dsets

def return_dataset(
    ### Well, really only returns a dataset. 
    results_dir: str = "./results",
    angles_definitions: ANGLES_DEFINITIONS = "canonical-full-angles",
    max_seq_len: int = 128,
    min_seq_len: int = 0,
    trim_strategy: datasets.TRIM_STRATEGIES = "randomcrop",
    # Related to model architecture
    seq_len_encoding: modelling.TIME_ENCODING = "gaussian_fourier",  # Embeds the total sequence length
    num_hidden_layers: int = 12,  # Default 12
    hidden_size: int = 384,  # Default 768
    intermediate_size: int = 768,  # Default 3072
    num_heads: int = 12,  # Default 12
    position_embedding_type: Literal[
        "absolute", "relative_key_query", "relative_key"
    ] = "absolute",  # Default absolute
    dropout_p: float = 0.1,
    decoder: modelling.DECODER_HEAD = "mlp",
    # Related to training strategy
    gradient_clip: float = 1.0,
    batch_size: int = 32,
    lr: float = 5e-5,
    l2_norm: float = 0.01,
    loss: modelling.LOSS_KEYS = "smooth_l1",
    min_epochs: Optional[int] = None,
    max_epochs: int = 10000,  # 10000, set to 100 for debug
    early_stop_patience: int = 0,  # Set to 0 to disable early stopping
    lr_scheduler: modelling.LR_SCHEDULE = "LinearWarmup",  # Try LinearWarmup?
    use_swa: bool = False,
):
    """
    Train the model
    """
    func_args = locals()

    ft_key = "coords" if angles_definitions == "cart-coords" else "angles"
    dsets = get_train_valid_test_sets(
        angles_definitions=angles_definitions,
        max_seq_len=max_seq_len,
        min_seq_len=min_seq_len,
        seq_trim_strategy=trim_strategy,
    )
    assert len(dsets) == 3

    # Calculate effective batch size
    # https://pytorch-lightning.readthedocs.io/en/1.4.0/advanced/multi_gpu.html#batch-size
    # Under DDP, effective batch size is batch_size * num_gpus * num_nodes
    effective_batch_size = batch_size
    if torch.cuda.is_available():
        effective_batch_size = int(batch_size / torch.cuda.device_count())
    pl.utilities.rank_zero_info(
        f"Given batch size: {batch_size} --> effective batch size with {torch.cuda.device_count()} GPUs: {effective_batch_size}"
    )

    # Create data loaders
    train_dataloader, valid_dataloader, test_dataloader = [
        DataLoader(
            dataset=ds,
            batch_size=effective_batch_size,
            shuffle=i == 0,  # Shuffle only train loader
            num_workers=multiprocessing.cpu_count(),
            pin_memory=True,
        )
        for i, ds in enumerate(dsets)
    ]

    logging.info(f"Using loss function: {loss}")
    
    return train_dataloader, valid_dataloader, test_dataloader

In [3]:
train_dataloader, valid_dataloader, test_dataloader = return_dataset()

Given batch size: 32 --> effective batch size with 1 GPUs: 32


In [4]:
GLOBAL_NUM_BINS = 600
GLOBAL_BATCH_SIZE = 157

In [5]:
device = torch.device("cuda:0")
torch.cuda.empty_cache()

model = transformer.UnconditionalTransformer(seq_len=768, hidden_size=512, num_bins=GLOBAL_NUM_BINS, 
                             dropout=0.05, dnlayers=4, batch_size=GLOBAL_BATCH_SIZE, 
                             ffn_hidden_size=1024, num_heads=1, qk_depth=128, 
                             v_depth=128, pseudolikelihood=True, device=device).to(device)

In [6]:
for batch_idx, data_dict in enumerate(train_dataloader):
    print(batch_idx)
    
    k = data_dict['lengths']
    print(k.size())

0
torch.Size([32])
1
torch.Size([32])
2
torch.Size([32])
3
torch.Size([32])
4
torch.Size([27])


In [34]:
k.size()

torch.Size([157, 256, 6])

In [25]:
loss_over_time = []
model.train()
optimizer = optim.Adam(model.parameters(), lr=0.0001) 
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
    optimizer, T_max = 10, eta_min=0, last_epoch= -1, verbose=False)

def train_loop(epoch):
    train_loss = 0
    for batch_idx, data_dict in enumerate(train_dataloader):
        
        data = data_dict['angles'].to(device)
        data = torch.flatten(data, start_dim=1)
        
        target = utilities.custom_bucketize(data, GLOBAL_NUM_BINS).to(device)
        
        optimizer.zero_grad()
        
        softmax_X_pred, attn_mats = model(decoder_input=data, return_attention=True) 
        #softmax_X_pred = model(decoder_input=target)#, return_attention=True) 
        
        loss = model.loss(X=softmax_X_pred, Y=target) + 1e-9*torch.sum(torch.abs(attn_mats))
            
        loss.backward()
        train_loss += loss.item()
        
        #clip grad norm
        torch.nn.utils.clip_grad_value_(model.parameters(), 1)
        
        optimizer.step()
        
        if batch_idx % 50 == 0:#args.log_interval == 0: #by default, args.log_interval = 10
            print('Train Epoch:', epoch, '[{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                batch_idx, #index of the batch we are on 
                len(train_dataloader), #how many batches are in the data loader
                100. * batch_idx / len(train_dataloader), #progress percentage
                loss.item() / GLOBAL_BATCH_SIZE#,  #hardcoded batch size
                ))
            loss_over_time.append(loss.item()/len(train_dataloader) )
            torch.set_printoptions(threshold=10_000)
            
import time
start = time.time()

for epoch in range(0, 20, 1):
    train_loop(epoch)

end = time.time()
elapsed_time = end - start



KeyboardInterrupt: 