Before doing anything below, be sure to install the conda environment: `conda env create -f tvae.yml`, then add the environment to your jupyter configuration `python -m ipykernel install --user --name=tvae`. When launching the notebook, be sure to select the `tvae` environment.

# Model, training, and device configurations

In [1]:
from train import *
# Example model arguments
model_config = { 
    "name": "tvae",
    "z_dim": 32, 
    "h_dim": 256,
    "rnn_dim": 256,
    "num_layers": 1,
    "in_state_dim": 28, 
    "in_action_dim": 28, 
    "out_state_dim": 14, 
    "out_action_dim": 14
}

# Example training arguments
train_config = { 
    "batch_size": 128,
    "num_epochs_til_val": 10, 
    "learning_rate": 0.0002,
    "num_epochs": 300,
    "clip": 10, 
    'comet_ml_key': '',
    "project_name": 'test_jupyter_notebook'
}  

# Device configurations
device = torch.device('cuda:1')

# Dataset configurations and loading
Before completing this step, you must save the input and output data arrays in `.npz` format. The dimensions of the input data array should be [`num_samples`, `sequence_length`, `in_state_dim`] and the output array should be [`num_samples`, `sequence_length`, `out_state_dim`]. Use `np.savez()` to save the arrays to the correct format.

In [2]:
# Set paths to the npz arrays with your desired inputs and outputs
root_data_dir = '/media/storage/andrew/data/autism_dataset/10_10/both_in_forecast_res/'
data_in_path = os.path.join(root_data_dir, 'data_in.npz')
data_out_path = os.path.join(root_data_dir, 'data_out.npz')

# val_prop controls the proportion of the validation set relative to all data
data_config = { 
    'name': 'mouse_v1',
    'in_file': data_in_path,
    'out_file': data_out_path,
}   
# build dataset and train / val split
dataset = load_dataset(data_config)
data_loader = DataLoader(
                dataset,
                batch_size=train_config['batch_size'],
                shuffle=True
)

In [3]:
# This will check and store the most recent checkpointed model in your project directory
train_config = checkpoint_handler(train_config)

# Model instantiation and preparation

In [4]:
model = TVAE(model_config)
model = model.to(device)
model.prepare_stage(train_config)

# Run training
If you set the `comet_ml_key` in the first step correctly, than logging information will be at the URL provided below. Checkpoints will be saved to `./checkpoints/<YOUR PROJECT NAME>/epoch #`. Checkpoints get saved every time validation is run which is set in `train_config[num_epochs_til_val`.

In [5]:
train(model, data_loader, train_config, device)

COMET ERROR: The given API key  is invalid, please check it against the dashboard. Your experiment would not be logged 


-=-=-= EPOCH 40 OF 300 =-=-=-


  1%|          | 36/5066 [00:05<12:05,  6.94it/s] 


KeyboardInterrupt: 

# Loading model checkpoints
There is not currently code for selecting which model checkpoint has the lowest validation loss - that is on the todo list. For now, you should look at comet ML graphs and select the one with the lowest validation NLL. Once you have done that all you need to do to load in a model is the following:

In [None]:
checkpoint_path = "<YOUR CHECKPOINT PATH HERE>"

model.load_state_dict(torch.load(checkpoint_path))

# Generating Reconstructions and Embeddings
This step assumes that you have chosen which model checkpoint you want to use to generate embeddings and reconstructions. This method will generate `num_reconstruction` samples from the posterior distribution predicted by the encoder, and then select the sample with the lowest negative log likelihood as the reconstruction of the corresponding original. Each embedding in embeddings is the mean of the posterior distribution used to generate the samples.

In [None]:
from reconstruct import reconstruct

num_reconstructions = 10 
reconstructions, originals, embeddings = reconstruct(model, data_loader, device, num_reconstructions)

# Plotting reconstructions against originals
After running the code below, you can check the gif located at the path you specified to see if the reconstruction matches the original.

In [None]:
from util.plotting.plot_seq import plot_reconstruction

idx = 3 # arbitrary test reconstruction
plot_reconstruction(originals[idx], reconstructions[idx], path='./gifs/test.gif') # Good idea to store your gifs in their own folder