In [None]:
%%bash
jupyter nbconvert models.ipynb --to script

In [None]:
import torch
from torch import nn
from torch.autograd import Variable
import pickle

import losswise

from prettytable import PrettyTable
from tqdm import tqdm
import numpy as np
import os
import sys
import random

from datasets import BurstDataset, ShuffledBatchSequentialSampler, FakeBurstDataset
from prep_dataset import BurstDatasetStandardizer
from models import Encoder, Decoder
from train_functions import *
from eval_functions import plot_autoencoding, autoencode

In [None]:
# todo: implement cuda stuff

### Initialize and prep datasets

In [None]:
torch.manual_seed(1)
np.random.seed(1)
np.random.seed(1)

In [None]:
MIN_BURST_SECS, MAX_BURST_SECS = 0.1, 3
MIN_EPISODE_MINS, MAX_EPISODE_MINS = 10, None

PAD_LENGTH = 3*200 # corresponds to 5 sec long burst
BATCH_BY_LEN = True

DATA_DIR = '/home/alice-eeg/NFS/script_output/describe_bs/'
MAX_NUM_PATIENTS = 40
dataset = BurstDataset(min_burst_secs=MIN_BURST_SECS, max_burst_secs=MAX_BURST_SECS, 
                       min_episode_mins=MIN_EPISODE_MINS, max_episode_mins=MAX_EPISODE_MINS, 
                       sort_len=False)
dataset.init_dataset(DATA_DIR, PAD_LENGTH, max_num_patients=MAX_NUM_PATIENTS)
train_split, dev_split = 0.6, 0.2
train_dataset, dev_dataset, test_dataset = dataset.split(train_split, dev_split, split_sort_len=BATCH_BY_LEN)

standardizer = BurstDatasetStandardizer()
standardizer.fit_transform(train_dataset)
standardizer.transform(dev_dataset)
standardizer.transform(test_dataset)

In [None]:
len(train_dataset)

### Initialize model

In [None]:
HIDDEN_SIZE = 100
INPUT_SIZE = 1 # This CANNOT be changed! 
BIDIRECTIONAL = True
NUM_LAYERS = 1
EXTRA_INPUT_DIM = False
encoder = Encoder(INPUT_SIZE, HIDDEN_SIZE, bidirectional=BIDIRECTIONAL, num_layers=NUM_LAYERS)
decoder = Decoder(HIDDEN_SIZE, INPUT_SIZE, extra_input_dim=EXTRA_INPUT_DIM, encoder_bidirectional=BIDIRECTIONAL, 
                  num_layers=NUM_LAYERS)
if torch.cuda.is_available():
    encoder = encoder.cuda()
    decoder = decoder.cuda()

### Define training params

In [None]:
BATCH_SIZE = 30 
NUM_EPOCHS = 100 # Normally use 50, but can stop early at 20
LR = 1e-3
WEIGHT_DECAY = 1e-4
TEACHER_FORCING_SLOPE = 0.001
TRAIN_REVERSED = True

SAVE_DIR = None

# Run training

In [None]:
# batch_sampler = ShuffledBatchSequentialSampler(dataset, batch_size=40, drop_last=False)
# data_loader = torch.utils.data.DataLoader(dataset, batch_sampler=batch_sampler)
# for batch in data_loader:
#     break

In [None]:
USE_LOSSWISE = True

In [None]:
params_dict = {
    # dataset filtering
    'min_burst_secs':MIN_BURST_SECS, 'max_burst_secs':MAX_BURST_SECS, 
    'min_episode_mins':MIN_EPISODE_MINS, 'max_episode_mins':MAX_EPISODE_MINS, 
    # dataset size
    'max num patients':MAX_NUM_PATIENTS, 'len(train_data)':len(train_dataset), 
    'train split':train_split, 'dev split':dev_split, 
    # dataset padding
    'pad length': PAD_LENGTH, 
    # model params
    'hidden size': HIDDEN_SIZE, 'bidirectional':BIDIRECTIONAL, 'num layers':NUM_LAYERS, 
    'extra input dim':EXTRA_INPUT_DIM, 
    # training params
    'batch by length': BATCH_BY_LEN, 'batch size':BATCH_SIZE, 'num epochs': NUM_EPOCHS, 
    'learning rate': LR, 'weight decay': WEIGHT_DECAY, 
    'teacher forcing slope': TEACHER_FORCING_SLOPE, 'train reversed':TRAIN_REVERSED, 
    'save dir': SAVE_DIR}
if SAVE_DIR is not None:
    pickle.dump(params_dict, open(os.path.join(SAVE_DIR, "params_dict.pkl"), "w"))

In [None]:
if USE_LOSSWISE:
    losswise.set_api_key('W2TAMB3SZ') # api_key for "coma-eeg"
    session = losswise.Session(tag='Run with num layers 3', max_iter=NUM_EPOCHS,
                               params=params_dict)
    losswise_graph = session.graph('loss', kind='min')
else:
    losswise_graph = None

In [None]:
train_model(train_dataset, dev_dataset, test_dataset, encoder, decoder, SAVE_DIR,
            num_epochs=NUM_EPOCHS, 
            batch_size=BATCH_SIZE, lr=LR, weight_decay=WEIGHT_DECAY, 
            teacher_forcing_slope=TEACHER_FORCING_SLOPE, train_reversed=TRAIN_REVERSED, batch_by_len=BATCH_BY_LEN, 
            losswise_graph=losswise_graph, params_dict=params_dict)
session.done()

## Plot the autoencoding

In [None]:
sample = train_dataset[8]
#mse = plot_autoencoding(sample, encoder, decoder, toss_encoder_output=False, reverse=TRAIN_REVERSED)