In [19]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [20]:
import logging
import os
import re
from argparse import Namespace
from pathlib import Path
from typing import Dict, Type, Union

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
import yaml
from torch.utils.data import DataLoader

from double_jig_gen.data import (
    ABCDataset,
    get_folkrnn_dataloaders,
    pad_batch,
    get_oneills_dataloaders
)
from double_jig_gen.tokenizers import Tokenizer
from double_jig_gen.models import SimpleRNN, Transformer
from double_jig_gen.utils import get_model_from_checkpoint

logging.basicConfig()
LOGGER = logging.getLogger(__name__)
LOGGER.setLevel("DEBUG")

In [21]:
def _get_most_recent_path(paths):
    """Returns the most recently created path from a list of paths.

    Args:
        paths: a list of paths to check.

    Returns:
        the most recently created path.
    """
    return max(paths, key=os.path.getctime)

In [22]:
expt_ids = list(range(23, 31))
# expt_ids.remove(25)

In [23]:
def get_args(expt_dirpath):    
    checkpoint_dirpath = Path(expt_dirpath, "checkpoints")
    ckpt_paths = [
        path for path in checkpoint_dirpath.iterdir() if str(path).endswith(".ckpt")
    ]
    latest_ckpt_path = _get_most_recent_path(ckpt_paths)

    experiment_args_path = Path(expt_dirpath, "experiment_args.yaml")
    # The yaml file has lowcase trainer in tag:
    # python/name:pytorch_lightning.trainer.trainer._gpus_arg_default
    # so loading fails with SafeLoader, have to use BaseLoader
    # args = pl.core.saving.load_hparams_from_yaml(str(experiment_args_path))
    with open(str(experiment_args_path), 'r') as fh:
        args_dict = yaml.load(fh, Loader=yaml.BaseLoader)
    args_dict['latest_checkpoint'] = latest_ckpt_path
    args_dict["checkpoint_epoch"] = int(
        ''.join(re.findall(r"\d+", str(latest_ckpt_path.name)))
    )
    args = Namespace()
    vars(args).update(args_dict)
    return args

In [24]:
scratch_path = "/disk/scratch_fast"
expt_dirpath = Path(f"{scratch_path}/s0816700/logs/lightning_logs")
expt_dirs = [
    Path(expt_dirpath, f"version_{expt_id}") for expt_id in expt_ids
]
args = {
    str(dirpath): vars(get_args(dirpath)) for dirpath in expt_dirs
}

In [25]:
expt_configs = pd.DataFrame.from_dict(args, orient='index')

In [29]:
cols = [
    "dataset",
    "checkpoint_epoch",
    "early_stopping_patience",
    "val_prop", 
    "val_shuffle",
    "batch_size",
    "model_load_from_checkpoint",
    "latest_checkpoint",
    "seed",
    "model"
]

In [27]:
expt_configs["model"] = expt_configs.latest_checkpoint.apply(
    lambda ckpt_path: SimpleRNN.load_from_checkpoint(checkpoint_path=str(ckpt_path))
)

In [30]:
expt_configs[cols]

Unnamed: 0,dataset,checkpoint_epoch,early_stopping_patience,val_prop,val_shuffle,batch_size,model_load_from_checkpoint,latest_checkpoint,seed,model
/disk/scratch_fast/s0816700/logs/lightning_logs/version_23,oneills,108,100,0.05,False,16,,/disk/scratch_fast/s0816700/logs/lightning_log...,101505917,"SimpleRNN(\n (dropout_layer): Dropout(p=0.5, ..."
/disk/scratch_fast/s0816700/logs/lightning_logs/version_24,oneills,100,100,0.05,False,16,,/disk/scratch_fast/s0816700/logs/lightning_log...,3583301105,"SimpleRNN(\n (dropout_layer): Dropout(p=0.5, ..."
/disk/scratch_fast/s0816700/logs/lightning_logs/version_25,oneills,393,100,1.0,True,16,,/disk/scratch_fast/s0816700/logs/lightning_log...,645574720,"SimpleRNN(\n (dropout_layer): Dropout(p=0.5, ..."
/disk/scratch_fast/s0816700/logs/lightning_logs/version_26,oneills,408,100,1.0,True,16,,/disk/scratch_fast/s0816700/logs/lightning_log...,3513404735,"SimpleRNN(\n (dropout_layer): Dropout(p=0.5, ..."
/disk/scratch_fast/s0816700/logs/lightning_logs/version_27,oneills,91,100,0.2,False,16,/disk/scratch_fast/s0816700/logs/lightning_log...,/disk/scratch_fast/s0816700/logs/lightning_log...,2729551693,"SimpleRNN(\n (dropout_layer): Dropout(p=0.5, ..."
/disk/scratch_fast/s0816700/logs/lightning_logs/version_28,oneills,87,100,0.2,False,16,/disk/scratch_fast/s0816700/logs/lightning_log...,/disk/scratch_fast/s0816700/logs/lightning_log...,515078650,"SimpleRNN(\n (dropout_layer): Dropout(p=0.5, ..."
/disk/scratch_fast/s0816700/logs/lightning_logs/version_29,oneills,275,100,1.0,True,16,/disk/scratch_fast/s0816700/logs/lightning_log...,/disk/scratch_fast/s0816700/logs/lightning_log...,88681166,"SimpleRNN(\n (dropout_layer): Dropout(p=0.5, ..."
/disk/scratch_fast/s0816700/logs/lightning_logs/version_30,oneills,86,100,1.0,True,16,/disk/scratch_fast/s0816700/logs/lightning_log...,/disk/scratch_fast/s0816700/logs/lightning_log...,845257221,"SimpleRNN(\n (dropout_layer): Dropout(p=0.5, ..."


In [12]:
trn, vld, tst = get_oneills_dataloaders(
    "/disk/scratch_fast/s0816700/data/oneills/oneills_reformat.abc",
    "/disk/scratch_fast/s0816700/data/folk-rnn/data_v3_vocabulary.txt",
    batch_size=16,
    num_workers=4,
    pin_memory=True,
)

In [13]:
lightning_trainer = pl.Trainer(
    gpus='7,',
)

GPU available: True, used: True
INFO:lightning:GPU available: True, used: True
TPU available: False, using: 0 TPU cores
INFO:lightning:TPU available: False, using: 0 TPU cores
CUDA_VISIBLE_DEVICES: [7]
INFO:lightning:CUDA_VISIBLE_DEVICES: [7]


In [14]:
model = SimpleRNN(
    rnn_type="LSTM",
    ntoken=106,
    ninp=256,
    nhid=512,
    nlayers=3,
    model_batch_size=16,
    dropout=.6,
    embedding_padding_idx=0,
)

In [15]:
def get_avg_loss_per_token(model, dataloader, device):
    model = model.to(device)
#     print(dict(model.named_parameters()))
    model.eval()
    loss_total = 0
    with torch.no_grad():
        for padded_batch, seq_lens in dataloader:
            padded_batch = padded_batch.to(device)
            outputs = model(padded_batch, seq_lens)
            loss_total += model.loss(outputs, padded_batch).item()
    return loss_total / len(tst)

In [16]:
get_avg_loss_per_token(model, tst, 'cuda')

4.6627185344696045

In [36]:
test_res = {}
for log_dir, log_data in expt_configs.iterrows():
    model = log_data.model
    test_res[log_dir] = get_avg_loss_per_token(model, tst, 'cuda')
    
#     model.to('cuda')
#     print({name: data.shape for name, data in model.named_parameters()})
    # WEIRD CUDA ERRORS...
#     test_res[ckpt_path] = lightning_trainer.test(
#         model,
#         test_dataloaders=tst,
#         ckpt_path=ckpt_path,
#     )
#     test_res[ckpt_path] = get_avg_loss_per_token(model, tst, device='cuda')

In [40]:
expt_configs["test_loss"] = pd.Series(test_res)

In [43]:
cols = [
    "test_loss",
    "checkpoint_epoch",
    "val_prop", 
    "val_shuffle",
    "batch_size",
    "model_load_from_checkpoint",
    "seed",
]
expt_configs[cols]

Unnamed: 0,test_loss,checkpoint_epoch,val_prop,val_shuffle,batch_size,model_load_from_checkpoint,seed
/disk/scratch_fast/s0816700/logs/lightning_logs/version_23,0.537336,108,0.05,False,16,,101505917
/disk/scratch_fast/s0816700/logs/lightning_logs/version_24,0.688422,100,0.05,False,16,,3583301105
/disk/scratch_fast/s0816700/logs/lightning_logs/version_25,0.630157,393,1.0,True,16,,645574720
/disk/scratch_fast/s0816700/logs/lightning_logs/version_26,0.509551,408,1.0,True,16,,3513404735
/disk/scratch_fast/s0816700/logs/lightning_logs/version_27,0.503171,91,0.2,False,16,/disk/scratch_fast/s0816700/logs/lightning_log...,2729551693
/disk/scratch_fast/s0816700/logs/lightning_logs/version_28,0.42524,87,0.2,False,16,/disk/scratch_fast/s0816700/logs/lightning_log...,515078650
/disk/scratch_fast/s0816700/logs/lightning_logs/version_29,0.48026,275,1.0,True,16,/disk/scratch_fast/s0816700/logs/lightning_log...,88681166
/disk/scratch_fast/s0816700/logs/lightning_logs/version_30,0.293839,86,1.0,True,16,/disk/scratch_fast/s0816700/logs/lightning_log...,845257221
