In [2]:
"""
Train a baseline autoregressive model that uses a causal LM approach to generating
series of angles
"""
import sys
# caution: path[0] is reserved for script path (or '' in REPL)
sys.path.append('/bin/')
sys.path.append('bin')

import os
from pathlib import Path
import json
import argparse
from datetime import datetime
import logging
import multiprocessing
from typing import *

import numpy as np
import torch
from torch.utils.data import DataLoader
import pytorch_lightning as pl
from pytorch_lightning.strategies.ddp import DDPStrategy

from transformers import BertConfig

from foldingdiff import datasets, modelling, losses, plotting, utils
from foldingdiff import custom_metrics as cm

from train import ANGLES_DEFINITIONS, build_callbacks, record_args_and_metadata

### Meat of the Code, Datasets

In [3]:
def get_train_valid_test_sets(
    angles_definitions: ANGLES_DEFINITIONS = "canonical-full-angles",
    max_seq_len: int = 128,
    min_seq_len: int = 0,
    seq_trim_strategy: datasets.TRIM_STRATEGIES = "leftalign",
) -> Tuple[
    datasets.AutoregressiveCausalDataset,
    datasets.AutoregressiveCausalDataset,
    datasets.AutoregressiveCausalDataset,
]:
    """
    Get the train/valid/test splits using the autoregressive wrapper on the datsets
    """

    clean_dset_class = {
        "canonical": datasets.CathCanonicalAnglesDataset,
        "canonical-full-angles": datasets.CathCanonicalAnglesOnlyDataset,
        "canonical-minimal-angles": datasets.CathCanonicalMinimalAnglesDataset,
        "cart-coords": datasets.CathCanonicalCoordsDataset,
    }[angles_definitions]
    logging.info(f"Clean dataset class: {clean_dset_class}")

    splits = ["train", "validation", "test"]
    logging.info(f"Creating data splits: {splits}")
    clean_dsets = [
        clean_dset_class(
            split=s,
            pad=max_seq_len,
            min_length=min_seq_len,
            trim_strategy=seq_trim_strategy,
            zero_center=False if angles_definitions == "cart-coords" else True,
        )
        for s in splits
    ]

    # Set the training set mean to the validation set mean
    if len(clean_dsets) > 1 and clean_dsets[0].means is not None:
        logging.info(f"Updating valid/test mean offset to {clean_dsets[0].means}")
        for i in range(1, len(clean_dsets)):
            clean_dsets[i].means = clean_dsets[0].means

    causal_dsets = [
        datasets.AutoregressiveCausalDataset(
            d, dset_key="coords" if angles_definitions == "cart-coords" else "angles"
        )
        for d in clean_dsets
    ]
    for dsname, ds in zip(splits, causal_dsets):
        logging.info(f"{dsname}: {ds}")
    return causal_dsets

def train(
    ### Well, really only returns a dataset. 
    results_dir: str = "./results",
    angles_definitions: ANGLES_DEFINITIONS = "canonical-full-angles",
    max_seq_len: int = 128,
    min_seq_len: int = 40,
    trim_strategy: datasets.TRIM_STRATEGIES = "randomcrop",
    # Related to model architecture
    seq_len_encoding: modelling.TIME_ENCODING = "gaussian_fourier",  # Embeds the total sequence length
    num_hidden_layers: int = 12,  # Default 12
    hidden_size: int = 384,  # Default 768
    intermediate_size: int = 768,  # Default 3072
    num_heads: int = 12,  # Default 12
    position_embedding_type: Literal[
        "absolute", "relative_key_query", "relative_key"
    ] = "absolute",  # Default absolute
    dropout_p: float = 0.1,
    decoder: modelling.DECODER_HEAD = "mlp",
    # Related to training strategy
    gradient_clip: float = 1.0,
    batch_size: int = 256,
    lr: float = 5e-5,
    l2_norm: float = 0.01,
    loss: modelling.LOSS_KEYS = "smooth_l1",
    min_epochs: Optional[int] = None,
    max_epochs: int = 10000,  # 10000, set to 100 for debug
    early_stop_patience: int = 0,  # Set to 0 to disable early stopping
    lr_scheduler: modelling.LR_SCHEDULE = "LinearWarmup",  # Try LinearWarmup?
    use_swa: bool = False,
):
    """
    Train the model
    """
    func_args = locals()

    ft_key = "coords" if angles_definitions == "cart-coords" else "angles"
    dsets = get_train_valid_test_sets(
        angles_definitions=angles_definitions,
        max_seq_len=max_seq_len,
        min_seq_len=min_seq_len,
        seq_trim_strategy=trim_strategy,
    )
    assert len(dsets) == 3

    # Calculate effective batch size
    # https://pytorch-lightning.readthedocs.io/en/1.4.0/advanced/multi_gpu.html#batch-size
    # Under DDP, effective batch size is batch_size * num_gpus * num_nodes
    effective_batch_size = batch_size
    if torch.cuda.is_available():
        effective_batch_size = int(batch_size / torch.cuda.device_count())
    pl.utilities.rank_zero_info(
        f"Given batch size: {batch_size} --> effective batch size with {torch.cuda.device_count()} GPUs: {effective_batch_size}"
    )

    # Create data loaders
    train_dataloader, valid_dataloader, test_dataloader = [
        DataLoader(
            dataset=ds,
            batch_size=effective_batch_size,
            shuffle=i == 0,  # Shuffle only train loader
            num_workers=multiprocessing.cpu_count(),
            pin_memory=True,
        )
        for i, ds in enumerate(dsets)
    ]

    logging.info(f"Using loss function: {loss}")
    
    return train_dataloader, valid_dataloader, test_dataloader

In [4]:
train_dataloader, valid_dataloader, test_dataloader = train()

Given batch size: 256 --> effective batch size with 1 GPUs: 256


In [5]:
type(train_dataloader)

torch.utils.data.dataloader.DataLoader

In [6]:
for batch_idx, data in enumerate(train_dataloader):
    #print(data.keys())
    print(batch_idx)
    print(data['angles'].size())

0
torch.Size([150, 128, 6])


In [7]:
a = data['angles']
b = data['coords']
c = data['lengths']
d = data['attn_mask']
e = data['position_ids']
f = data['lengths']
g = data['causal_attn_mask']
h = data['causal_idx']

In [8]:
print(a.size())
print(b.size())
print(c.size())
print(d.size())
print(e.size())
print(f.size())
print(g.size())
print(h.size())

torch.Size([250, 128, 6])
torch.Size([250, 128, 3])
torch.Size([250])
torch.Size([250, 128])
torch.Size([250, 128])
torch.Size([250])
torch.Size([250, 128])
torch.Size([250])


In [9]:
print(a[3, :, :])
print(c)

tensor([[-2.4451e-01,  1.7895e-01, -4.4647e-01, -7.4402e-02, -1.5677e-02,
         -1.6079e-02],
        [ 4.4061e-01,  2.4549e+00,  1.7343e-01,  2.9227e-02,  4.9996e-04,
          1.8555e-02],
        [ 1.0225e-01,  2.2345e+00, -1.2736e-02,  2.6463e-02,  1.4782e-04,
         -2.0862e-03],
        [ 4.1460e-01, -6.1419e-01,  8.6740e-02,  4.7227e-02, -7.2865e-03,
         -6.5846e-03],
        [ 1.1934e-01, -9.1851e-01,  2.0334e-01,  5.4105e-02,  1.2675e-02,
         -1.2477e-02],
        [-5.0638e-01, -2.2511e-01,  1.4172e-01,  7.1030e-02,  3.1419e-03,
         -3.2218e-02],
        [-6.0244e-01, -1.1064e-01, -1.6993e-02,  1.4916e-02, -6.6385e-03,
          2.2501e-02],
        [ 4.3340e-01, -9.0748e-01,  1.3629e-01,  2.9383e-02,  1.3731e-03,
          2.1537e-02],
        [ 4.1791e-01, -9.3883e-01, -2.8774e-02, -1.5001e-03,  1.8919e-03,
          1.1503e-02],
        [ 4.1861e-01, -7.9223e-01,  7.5116e-03,  3.3424e-03, -1.6312e-02,
          2.5062e-02],
        [ 4.4650e-01, -9.9393e

In [10]:
print(c)

tensor([ 88, 128, 128, 128, 128, 128, 128,  80, 113,  83, 128, 128, 124, 128,
        128, 128, 115,  85, 128, 128, 126,  81,  66, 121, 121,  89,  88, 128,
        128,  65, 128, 128, 128, 128, 128,  50,  82, 128, 116, 128, 128, 128,
        128,  60, 128, 128,  86,  71, 128,  99, 128, 128, 128, 128, 128, 128,
        128, 120,  99, 103,  97, 128,  87, 128,  62, 128, 110, 128, 128, 127,
        106,  87, 125,  74,  86, 104, 128, 114,  96, 128,  59, 128,  78,  76,
         67, 121,  53, 128,  84, 128,  65, 116,  76,  94, 128, 105,  92,  91,
        106,  86,  65, 124, 128,  77, 128,  67,  61, 102,  70, 119, 128, 128,
        128, 128, 107, 104, 100,  80, 128,  81, 127, 110,  46, 128, 128, 123,
        116, 128, 128, 128, 108, 128, 128, 127,  83, 128, 128, 128, 128, 128,
         69, 106, 128,  56, 120,  95,  94, 106, 128, 128, 128, 110, 128, 128,
        128, 128,  89, 128, 128, 128,  67, 128, 128, 127, 104, 128, 128, 113,
        125,  90,  90, 128, 102, 128, 128, 128, 128,  86, 128, 1

In [11]:
"""
Train a baseline autoregressive model that uses a causal LM approach to generating
series of angles
"""

def build_parser() -> argparse.ArgumentParser:
    """
    Build CLI parser
    """
    parser = argparse.ArgumentParser(
        usage=__doc__,
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )
    parser.add_argument(
        "config", nargs="?", default="", type=str, help="json of params"
    )
    parser.add_argument(
        "-o",
        "--outdir",
        type=str,
        default=os.path.join(os.getcwd(), "results"),
        help="Directory to write model training outputs",
    )
    return parser


def main():
    """
    Run the training script based on params in the given json file
    """
    parser = build_parser()
    args = parser.parse_args()

    # Load in parameters and run training loop
    config_args = {}
    if args.config:
        with open(args.config, "r") as f:
            config_args = json.load(f)

    config_args = utils.update_dict_nonnull(config_args, {"results_dir": args.outdir})

    train(**config_args)
    
'''
if __name__ == "__main__":
    curr_time = datetime.now().strftime("%y%m%d_%H%M%S")
    logging.basicConfig(
        level=logging.INFO,
        handlers=[
            logging.FileHandler(f"training_{curr_time}.log"),
            logging.StreamHandler(),
        ],
    )

    main()
'''

'\nif __name__ == "__main__":\n    curr_time = datetime.now().strftime("%y%m%d_%H%M%S")\n    logging.basicConfig(\n        level=logging.INFO,\n        handlers=[\n            logging.FileHandler(f"training_{curr_time}.log"),\n            logging.StreamHandler(),\n        ],\n    )\n\n    main()\n'