In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import numpy as np
import torch
import torch.nn as nn

from utils.util import print_size, training_loss, calc_diffusion_hyperparams
from utils.util import get_mask_mnr, get_mask_bm, get_mask_rm

from imputers.SSSDS4Imputer import SSSDS4Imputer


CUDA extension for cauchy multiplication not found. Install by going to extensions/cauchy/ and running `python setup.py install`. This should speed up end-to-end training by 10-50%
Falling back on slow Cauchy kernel. Install at least one of pykeops or the CUDA extension for efficiency.


In [4]:
global trainset_config
global diffusion_hyperparams
global model_config

config = {   
    "diffusion_config":{
        "T": 200,
        "beta_0": 0.0001,
        "beta_T": 0.02
    },
    "wavenet_config": {
        "in_channels": 14, 
        "out_channels":14,
        "num_res_layers": 36,
        "res_channels": 256, 
        "skip_channels": 256,
        "diffusion_step_embed_dim_in": 128,
        "diffusion_step_embed_dim_mid": 512,
        "diffusion_step_embed_dim_out": 512,
        "s4_lmax": 100,
        "s4_d_state":64,
        "s4_dropout":0.0,
        "s4_bidirectional":1,
        "s4_layernorm":1
    },
    "train_config": {
        "output_directory": "./results/mujoco",
        "ckpt_iter": -1,
        "iters_per_ckpt": 100,
        "iters_per_logging": 100,
        "n_iters": 150000,
        "learning_rate": 2e-4,
        "only_generate_missing": 1,
        "use_model": 2,
        "masking": "rm",
        "missing_k": 90
    },
    "trainset_config":{
        "train_data_path": "/home/hanyuji/data/mujoco_dataset/train_mujoco.npy",
        "test_data_path": "/home/hanyuji/data/mujoco_dataset/test_mujoco.npy",
        "segment_length":100,
        "sampling_rate": 100
    },
    "gen_config":{
        "output_directory": "./results/mujoco",
        "ckpt_path": "./results/mujoco/"
    }
}

train_config = config["train_config"]  # training parameters
trainset_config = config["trainset_config"]  # to load trainset
model_config = config['wavenet_config']
diffusion_hyperparams = calc_diffusion_hyperparams(**config["diffusion_config"])  # dictionary of all diffusion hyperparameters

In [5]:
# train(**train_config)

output_directory = train_config['output_directory']
ckpt_iter = train_config['ckpt_iter']
iters_per_ckpt = train_config['iters_per_ckpt']
iters_per_logging = train_config['iters_per_logging']
n_iters = train_config['n_iters']
learning_rate = train_config['learning_rate']
only_generate_missing = train_config['only_generate_missing']
masking = train_config['masking']
missing_k = train_config['missing_k']


"""
Train Diffusion Models

Parameters:
output_directory (str):         save model checkpoints to this path
ckpt_iter (int or 'max'):       the pretrained checkpoint to be loaded; 
                                automatically selects the maximum iteration if 'max' is selected
data_path (str):                path to dataset, numpy array.
n_iters (int):                  number of iterations to train
iters_per_ckpt (int):           number of iterations to save checkpoint, 
                                default is 10k, for models with residual_channel=64 this number can be larger
iters_per_logging (int):        number of iterations to save training log and compute validation loss, default is 100
learning_rate (float):          learning rate

use_model (int):                0:DiffWave. 1:SSSDSA. 2:SSSDS4.
only_generate_missing (int):    0:all sample diffusion.  1:only apply diffusion to missing portions of the signal
masking(str):                   'mnr': missing not at random, 'bm': blackout missing, 'rm': random missing
missing_k (int):                k missing time steps for each feature across the sample length.
"""

"\nTrain Diffusion Models\n\nParameters:\noutput_directory (str):         save model checkpoints to this path\nckpt_iter (int or 'max'):       the pretrained checkpoint to be loaded; \n                                automatically selects the maximum iteration if 'max' is selected\ndata_path (str):                path to dataset, numpy array.\nn_iters (int):                  number of iterations to train\niters_per_ckpt (int):           number of iterations to save checkpoint, \n                                default is 10k, for models with residual_channel=64 this number can be larger\niters_per_logging (int):        number of iterations to save training log and compute validation loss, default is 100\nlearning_rate (float):          learning rate\n\nuse_model (int):                0:DiffWave. 1:SSSDSA. 2:SSSDS4.\nonly_generate_missing (int):    0:all sample diffusion.  1:only apply diffusion to missing portions of the signal\nmasking(str):                   'mnr': missing not at ran

In [6]:
# map diffusion hyperparameters to gpu
for key in diffusion_hyperparams:
    if key != "T":
        diffusion_hyperparams[key] = diffusion_hyperparams[key].cuda()

# predefine model
net = SSSDS4Imputer(**model_config).cuda()
print_size(net)

# define optimizer
optimizer = torch.optim.Adam(net.parameters(), lr=learning_rate)


SSSDS4Imputer Parameters: 48.371726M


In [7]:
### data loading and reshaping ###

training_data = np.load(trainset_config['train_data_path'])
print(training_data.shape)
training_data = np.split(training_data, 160, 0)
training_data = np.array(training_data)
print(training_data.shape)

training_data = torch.from_numpy(training_data).float().cuda()
print('Data loaded')


(8000, 100, 14)
(160, 50, 100, 14)
Data loaded


In [8]:
# training
n_iter = ckpt_iter + 1
while n_iter < n_iters + 1:
    for batch in training_data:

        if masking == 'rm':
            transposed_mask = get_mask_rm(batch[0], missing_k)
        elif masking == 'mnr':
            transposed_mask = get_mask_mnr(batch[0], missing_k)
        elif masking == 'bm':
            transposed_mask = get_mask_bm(batch[0], missing_k)

        mask = transposed_mask.permute(1, 0)
        mask = mask.repeat(batch.size()[0], 1, 1).float().cuda()
        loss_mask = ~mask.bool()
        batch = batch.permute(0, 2, 1)

        assert batch.size() == mask.size() == loss_mask.size()

        # back-propagation
        optimizer.zero_grad()
        X = batch, batch, mask, loss_mask
        loss = training_loss(net, nn.MSELoss(), X, diffusion_hyperparams,
                                only_generate_missing=only_generate_missing)

        loss.backward()
        optimizer.step()

        if n_iter % iters_per_logging == 0:
            print("iteration: {} \tloss: {}".format(n_iter, loss.item()))

        # save checkpoint
        if n_iter > 0 and n_iter % iters_per_ckpt == 0:
            checkpoint_name = '{}.pkl'.format(n_iter)
            torch.save({'model_state_dict': net.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict()},
                        os.path.join(output_directory, checkpoint_name))
            print('model at iteration %s is saved' % n_iter)

        n_iter += 1


iteration: 0 	loss: 0.9102503061294556
iteration: 100 	loss: 0.48290690779685974
model at iteration 100 is saved
iteration: 200 	loss: 0.2387528270483017
model at iteration 200 is saved
iteration: 300 	loss: 0.11651907861232758
model at iteration 300 is saved
iteration: 400 	loss: 0.05717376619577408
model at iteration 400 is saved
iteration: 500 	loss: 0.04107207432389259
model at iteration 500 is saved
iteration: 600 	loss: 0.03194050118327141
model at iteration 600 is saved
iteration: 700 	loss: 0.028316810727119446
model at iteration 700 is saved
iteration: 800 	loss: 0.046355199068784714
model at iteration 800 is saved
iteration: 900 	loss: 0.02934068627655506
model at iteration 900 is saved
iteration: 1000 	loss: 0.026989130303263664
model at iteration 1000 is saved
iteration: 1100 	loss: 0.022299429401755333
model at iteration 1100 is saved
iteration: 1200 	loss: 0.021887024864554405
model at iteration 1200 is saved
iteration: 1300 	loss: 0.02610797993838787
model at iteration 1

KeyboardInterrupt: 