# Multivariate Time Series Generation - Train the model

This Notebook demonstrates the ability of our Model to generate dependent multi-channel physiological signals

This repository is an artifact for the paper under review "Multivariate Generative Adversarial Networks and their Loss Functions for Synthesis of Multichannel ECGs" submitted to IEEE Pattern Recognition and Machine Intelligence 2020.

The code has been supplied as Jupyter Notebooks and set up to run in Google Colaboratory. The dataset used is open source and freely available from PhysioNet.

## If Using Google Colabs

Mount your drive if you are running this on Colabs

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Directories

In [None]:
import os

os.chdir('/content/drive/My Drive/MV_GAN_Journal/Multivariate_time_series_gen/')

# Make Directories to store Training Results
training_dir = '/content/drive/My Drive/MV_GAN_Journal/Multivariate_time_series_gen/Results/Loss-SenseGAN_NSR'
if not os.path.exists(training_dir):
  os.mkdir(training_dir)

# Directories for different MBD Layers (losses and models) 
# -- MBD not used in this work but can be implemented
minibatch_layer = [0, 3, 5, 8, 10]
for i in minibatch_layer:
  mbd_dir = (training_dir+'/MBD_'+str(i))
  if not os.path.exists(mbd_dir):
    os.mkdir(mbd_dir)
    os.mkdir(mbd_dir+'/gen')
    os.mkdir(mbd_dir+'/disc')

Data and Save Directories

In [None]:
savepath = training_dir
datapath = './Data/'

# Choose which dataset you want here.
#datafile = 'ECG_Arr.pt' # ECG Arrhythmia Dataset
datafile = 'ecg_mit_nsnr.pt' # ECG Normal Sinus Rhythm Dataset

## Import necessary dependencies

In [None]:
import json as js
import pickle

from tqdm import tqdm
import numpy as np

import torch
from torch.autograd.variable import Variable
import torch.autograd as autograd

from model import Generator, Discriminator

from matplotlib import pyplot as plt
import math

In [None]:
!pip install fastdtw

from scipy.spatial.distance import sqeuclidean
from fastdtw import fastdtw



 The R MVDTW package is used here for speed of computation. The Dependent MVDTW is implemented in this Notebook but takes a long time to execute. 

In [None]:
import rpy2.robjects.numpy2ri
from rpy2.robjects.packages import importr
rpy2.robjects.numpy2ri.activate()
import rpy2.robjects as robj

In [None]:
%load_ext rpy2.ipython

  from pandas.core.index import Index as PandasIndex


In [None]:
%%R
install.packages("dtw")

In [None]:
# Set up our R namespaces
R = rpy2.robjects.r
DTW = importr('dtw')

## GPU

In [None]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

if torch.cuda.is_available():
  cuda = True
  print('Using: ' +str(torch.cuda.get_device_name(device)))
else:
  print('Using: CPU')

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

Using: Tesla T4


## Initialise Parameters

In [None]:
"""Initialising Parameters"""
def init_params(data='ecg_mit_nsnr.pt'):
  # Loaded from Preprocessing step
  datafile = data

  # Number of features
  seq_len = 500 
  batch_size = 50 

  # Params for the generator
  hidden_nodes_g = 50
  layers = 2
  tanh_layer = False

  # No. of training rounds per epoch
  D_rounds = 3
  G_rounds = 1
  num_epochs = 50
  learning_rate = 0.0002

  # Loss weight for gradient penalty
  lambda_gp = 10

  # Params for the Discriminator
  minibatch_layer = [0, 3, 5, 8, 10]
  minibatch_normal_init = False
  num_cvs = 2
  cv1_out= 10
  cv1_k = 3
  cv1_s = 1
  p1_k = 3
  p1_s = 2
  cv2_out = 10
  cv2_k = 3
  cv2_s = 1
  p2_k = 3
  p2_s = 2

  # Create Dictionary - for re-use
  params = {
      'data' : datafile,
      'seq_len' : seq_len,
      'batch_size' : batch_size,
      'hidden_nodes_g': hidden_nodes_g,
      'layers':layers,
      'tanh_layer':tanh_layer,
      'D_rounds' : D_rounds,
      'G_rounds': G_rounds,
      'epochs': num_epochs,
      'learning_rate' : learning_rate,
      'lambda_gp' : lambda_gp,
      'minibatch_layer' : minibatch_layer,
      'minibatch_normal_init' : minibatch_normal_init,
      'num_cvs' : num_cvs,
      'cv1_out' : cv1_out,
      'cv1_k' : cv1_k,
      'cv1_s' : cv1_s,
      'p1_k' : p1_k,
      'p1_s' : p1_s,
      'cv2_out' : cv2_out,
      'cv2_k' : cv2_k,
      'cv2_s' : cv2_s,
      'p2_k' : p2_k,
      'p2_s' : p2_s   
  }

  return params

## Function Definitions


*   Load Data
*   Noise Function
*   Gradient Penalty
*   Evaluation
*   Save Params to JSON



In [None]:
def load_data(filename, batch_size):
    mv_data = torch.load(filename)
    if len(mv_data[0,:,0]) == 501:
        mv_data = mv_data[:, :-1, :] 
    data_loader = torch.utils.data.DataLoader(mv_data, batch_size=batch_size)
    num_batches = len(data_loader)
    
    return data_loader, num_batches

In [None]:
def noise(batch_size, features):
    noise_vec = torch.randn(2, batch_size, features).to(device)
    
    return noise_vec

In [None]:
def compute_gradient_penalty(D, real_samples, fake_samples):
    """Calculates the gradient penalty loss for WGAN GP"""
    # Random weight term for interpolation between real and fake samples
    alpha = Tensor(np.random.random((real_samples.size(0), 1, 1, 1)))
    # Get random interpolation between real and fake samples
    interpolates = (alpha * real_samples + ((1 - alpha) * fake_samples)).requires_grad_(True)
    d_interpolates = D(interpolates)
    fake = Variable(Tensor(real_samples.shape[0], 2).fill_(1.0), requires_grad=False)
    # Get gradient w.r.t. interpolates
    gradients = autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=fake,
        create_graph=True,
        retain_graph=True,
        only_inputs=True,
    )[0]
    gradients = gradients.view(gradients.size(0), -1)
    gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return gradient_penalty

In [None]:
def save_params(params, filename):
    json = js.dumps(params)
    f = open(filename+'/parameters.json','w')
    f.write(json)
    f.close()

### Load Model

*Load and Initialise*

In [None]:
def load_model(params, minibatch_layer):
    
    generator = Generator(params['seq_len'], params['batch_size'], hidden_dim = params['hidden_nodes_g'], tanh_output = params['tanh_layer']).to(device)
    discriminator = Discriminator(params['seq_len'], in_channels=2,
                          cv1_k=3, cv1_s=1, p1_k=3, p1_s=1,
                          cv2_k=3, cv2_s=1, p2_k=3, p2_s=2,
                          cv3_k=3, cv3_s=2, p3_k=3, p3_s=2,
                          cv4_k=5, cv4_s=2, p4_k=5, p4_s=2, 
                          minibatch_layer = minibatch_layer, minibatch_init = params['minibatch_normal_init']).to(device)

    generator.train()
    discriminator.train()
    
    # Optimizers
    g_optimizer = torch.optim.RMSprop(generator.parameters(), lr = params['learning_rate'])
    d_optimizer = torch.optim.RMSprop(discriminator.parameters(), lr = params['learning_rate'])
    
    return generator, discriminator, g_optimizer, d_optimizer

## DTW Functions

###R_DTW Package
This package is provided and implemented here only for speed of computation.

In [None]:
def R_DTW(real, fake):
  dtw_dist = []

  for i in range(len(real)):
    #for j in range(len(fake)):
      X = real[i,:,:,0].detach().cpu().numpy()
      Y = fake[i,:,:,0].detach().cpu().numpy()

      template = X.transpose()
      rt,ct = template.shape
      query = Y.transpose()
      rq,cq = query.shape

      #converting numpy matrices to R matrices
      templateR=R.matrix(template,nrow=rt,ncol=ct)
      queryR=R.matrix(query,nrow=rq,ncol=cq)

      # Calculate the alignment vector and corresponding distance
      alignment = R.dtw(templateR,queryR,keep=True, step_pattern=R.rabinerJuangStepPattern(4,"c"),open_begin=True,open_end=True)

      dist = alignment.rx('distance')[0][0]
      dtw_dist.append(dist)

  return dtw_dist

def FAST_DTW(real, fake):
  dtw_dist = []
  for i in range(50):
    #for j in range(len(fake)):
    x = real[i,:,:,0]
    y = fake[i,:,:,0]

    d0, _ = fastdtw(x[0].view(1,-1).detach().cpu().numpy(), y[0].view(1,-1).detach().cpu().numpy(), dist=sqeuclidean)
    d1, _ = fastdtw(x[1].view(1,-1).detach().cpu().numpy(), y[1].view(1,-1).detach().cpu().numpy(), dist=sqeuclidean)
    dtw_dist.append([d0,d1])

  return dtw_dist

### DTW Programatically

In [None]:
def distance_matrix(Q, C):
    matrix = np.ones((len(C), len(Q)))
    for i in range(len(C)):
        for j in range(len(Q)):
            matrix[i,j] = (Q[j] - C[i])**2
    distances = np.asmatrix(matrix)
    
    return distances


## Plot the Distance Cost Plot
def distance_cost_plot(distances):
    im = plt.imshow(distances, interpolation='nearest', cmap='Reds') 
    plt.gca().invert_yaxis()
    plt.xlabel("X")
    plt.ylabel("Y")
    plt.grid()
    plt.colorbar();

def accumulated_costs(Q,C, distances):
    accumulated_cost = np.zeros((len(C), len(Q)))
    accumulated_cost[0,0] = distances[0,0]
    
    # First Row Only
    for i in range(1, len(Q)):
        accumulated_cost[0,i] = distances[0,i] + accumulated_cost[0, i-1]
    # First Column Only
    for i in range(1, len(C)):
        accumulated_cost[i,0] = distances[i, 0] + accumulated_cost[i-1, 0]
    # All other Elements
    for i in range(1, len(C)):
        for j in range(1, len(Q)):
            accumulated_cost[i, j] = min(accumulated_cost[i-1, j-1], accumulated_cost[i-1, j], accumulated_cost[i, j-1]) + distances[i, j]
    
    return accumulated_cost

def backtrack(Q, C, accumulated_cost, plotting=True):
    path = [[len(Q)-1, len(C)-1]]
    i = len(C)-1
    j = len(Q)-1
    while i>0 and j>0:
        if i==0:
            j = j - 1
        elif j==0:
            i = i - 1
        else:
            if accumulated_cost[i-1, j] == min(accumulated_cost[i-1, j-1], accumulated_cost[i-1, j], accumulated_cost[i, j-1]):
                i = i - 1
            elif accumulated_cost[i, j-1] == min(accumulated_cost[i-1, j-1], accumulated_cost[i-1, j], accumulated_cost[i, j-1]):
                j = j-1
            else:
                i = i - 1
                j= j- 1
        path.append([j, i])
    path.append([0,0])

    path_x = [point[0] for point in path]
    path_y = [point[1] for point in path]

  
    if plotting == True:
        distance_cost_plot(accumulated_cost)
        plt.plot(path_x, path_y)

    return path

def path_cost(Q, C, accumulated_cost, distances):
    path = [[len(Q)-1, len(C)-1]]
    cost = 0
    i = len(C)-1
    j = len(Q)-1
    while i>0 and j>0:
        if i==0:
            j = j - 1
        elif j==0:
            i = i - 1
        else:
            if accumulated_cost[i-1, j] == min(accumulated_cost[i-1, j-1], accumulated_cost[i-1, j], accumulated_cost[i, j-1]):
                i = i - 1
            elif accumulated_cost[i, j-1] == min(accumulated_cost[i-1, j-1], accumulated_cost[i-1, j], accumulated_cost[i, j-1]):
                j = j-1
            else:
                i = i - 1
                j= j- 1
        path.append([j, i])
    path.append([0,0])

    for [C, Q] in path:
        cost = cost +distances[Q, C]

    return(path, cost)

def distance_DTWd(Q, C):
    matrix = np.ones((len(C[0]), len(Q[0])))
    for i in range(len(C[0])):
        for j in range(len(Q[0])):
            d = 0
            for M in range(len(Q)):
                d += ((Q[M][j] - C[M][i])**2)
            matrix[i,j] = d
  
    distances = np.asmatrix(matrix)
    return distances

def DTW_i(Q, C):
    c = 0
    p = []
    for i in range(len(Q)):
        distance =  distance_matrix(Q[i], C[i])
        acc_costs = accumulated_costs(Q[i],C[i],distance)
        path = backtrack(Q[i],C[i], acc_costs, plotting=False)
        paths, cost = path_cost(Q[i], C[i], acc_costs, distance)
        c += cost
    return(c)

def DTW_d(Q, C):
    c = []
    p = []
    for i in range(len(Q)):
        distance = distance_DTWd(Q,C)
        acc_costs = accumulated_costs(Q[i],C[i],distance)
        path = backtrack(Q[i],C[i], acc_costs, plotting=False)
        paths, cost = path_cost(Q[i], C[i], acc_costs, distance)
        c.append(cost)

    return(np.min(c))

In [None]:
def evaluate_dtw_sample(gen, real):
    d=[]
    real = real[:,:,:,0]
    gen = gen[:,:,:,0]
    j=0
    sample = real[j].permute(1,0)
    gen_data = gen[j].permute(1,0)
    # Compute DTW_d
    d.append(DTW_d(gen_data.detach().cpu().numpy(), sample.detach().cpu().numpy()))
    # Option for DTW_i
    #d.append(DTW_i(gen_data.detach().cpu().numpy(), sample.detach().cpu().numpy()))
    D = np.mean(d)
    
    return D

### SOFTDTW

Can choose to use this function in DTWGAN if you want

In [None]:
import numpy as np
import torch
from numba import jit
from torch.autograd import Function

@jit(nopython = True)
def compute_softdtw(D, gamma):
  B = D.shape[0]
  N = D.shape[1]
  M = D.shape[2]
  R = np.ones((B, N + 2, M + 2)) * np.inf
  R[:, 0, 0] = 0
  for k in range(B):
    for j in range(1, M + 1):
      for i in range(1, N + 1):
        r0 = -R[k, i - 1, j - 1] / gamma
        r1 = -R[k, i - 1, j] / gamma
        r2 = -R[k, i, j - 1] / gamma
        rmax = max(max(r0, r1), r2)
        rsum = np.exp(r0 - rmax) + np.exp(r1 - rmax) + np.exp(r2 - rmax)
        softmin = - gamma * (np.log(rsum) + rmax)
        R[k, i, j] = D[k, i - 1, j - 1] + softmin
  return R

@jit(nopython = True)
def compute_softdtw_backward(D_, R, gamma):
  B = D_.shape[0]
  N = D_.shape[1]
  M = D_.shape[2]
  D = np.zeros((B, N + 2, M + 2))
  E = np.zeros((B, N + 2, M + 2))
  D[:, 1:N + 1, 1:M + 1] = D_
  E[:, -1, -1] = 1
  R[:, : , -1] = -np.inf
  R[:, -1, :] = -np.inf
  R[:, -1, -1] = R[:, -2, -2]
  for k in range(B):
    for j in range(M, 0, -1):
      for i in range(N, 0, -1):
        a0 = (R[k, i + 1, j] - R[k, i, j] - D[k, i + 1, j]) / gamma
        b0 = (R[k, i, j + 1] - R[k, i, j] - D[k, i, j + 1]) / gamma
        c0 = (R[k, i + 1, j + 1] - R[k, i, j] - D[k, i + 1, j + 1]) / gamma
        a = np.exp(a0)
        b = np.exp(b0)
        c = np.exp(c0)
        E[k, i, j] = E[k, i + 1, j] * a + E[k, i, j + 1] * b + E[k, i + 1, j + 1] * c
  return E[:, 1:N + 1, 1:M + 1]

class _SoftDTW(Function):
  @staticmethod
  def forward(ctx, D, gamma):
    dev = D.device
    dtype = D.dtype
    gamma = torch.Tensor([gamma]).to(dev).type(dtype) # dtype fixed
    D_ = D.detach().cpu().numpy()
    g_ = gamma.item()
    R = torch.Tensor(compute_softdtw(D_, g_)).to(dev).type(dtype)
    ctx.save_for_backward(D, R, gamma)
    return R[:, -2, -2]

  @staticmethod
  def backward(ctx, grad_output):
    dev = grad_output.device
    dtype = grad_output.dtype
    D, R, gamma = ctx.saved_tensors
    D_ = D.detach().cpu().numpy()
    R_ = R.detach().cpu().numpy()
    g_ = gamma.item()
    E = torch.Tensor(compute_softdtw_backward(D_, R_, g_)).to(dev).type(dtype)
    return grad_output.view(-1, 1, 1).expand_as(E) * E, None

class SoftDTW(torch.nn.Module):
  def __init__(self, gamma=1.0, normalize=False):
    super(SoftDTW, self).__init__()
    self.normalize = normalize
    self.gamma=gamma
    self.func_dtw = _SoftDTW.apply

  def calc_distance_matrix(self, x, y):
    n = x.size(1)
    m = y.size(1)
    d = x.size(2)
    x = x.unsqueeze(2).expand(-1, n, m, d)
    y = y.unsqueeze(1).expand(-1, n, m, d)
    dist = torch.pow(x - y, 2).sum(3)
    return dist

  def forward(self, x, y):
    assert len(x.shape) == len(y.shape)
    squeeze = False
    if len(x.shape) < 3:
      x = x.unsqueeze(0)
      y = y.unsqueeze(0)
      squeeze = True
    if self.normalize:
      D_xy = self.calc_distance_matrix(x, y)
      out_xy = self.func_dtw(D_xy, self.gamma)
      D_xx = self.calc_distance_matrix(x, x)
      out_xx = self.func_dtw(D_xx, self.gamma)
      D_yy = self.calc_distance_matrix(y, y)
      out_yy = self.func_dtw(D_yy, self.gamma)
      result = out_xy - 1/2 * (out_xx + out_yy) # distance
    else:
      D_xy = self.calc_distance_matrix(x, y)
      out_xy = self.func_dtw(D_xy, self.gamma)
      result = out_xy # discrepancy
    return result.squeeze(0) if squeeze else result

## GAN Functions

### LSGAN

Standard Least Squares GAN

In [None]:
# !!! Minimizes MSE instead of BCE
adversarial_loss = torch.nn.MSELoss()

def train_LSgan(filename, params, savepath):
    # Load data
    data_loader, num_batches = load_data(filename, params['batch_size'])
    # Save Parameters
    save_params(params, savepath)
    
    # Iterative through the MBD Layers
    for mbd in params['minibatch_layer']:
        print("MBD_Layer: "+str(mbd))
        G_losses = []
        D_losses = []
        
        # Load model for this MBD layer
        generator, discriminator, g_optimizer, d_optimizer = load_model(params, int(mbd))
        for n in tqdm(range(params['epochs'])):
            
            for n_batch, sample_data in enumerate(data_loader):
                if len(sample_data[:,0,0]) < params['batch_size']:
                    break
                else:
                    # Adversarial GT
                    valid = Variable(Tensor(sample_data.size(0), 2).fill_(1.0), requires_grad=False)
                    fake = Variable(Tensor(sample_data.size(0), 2).fill_(0.0), requires_grad=False)

                    # -----------------
                    #  Train Generator
                    # -----------------
                    g_optimizer.zero_grad()
                    h_g = generator.init_hidden()
                    
                    # Sample noise as generator input
                    noise_sample = Variable(noise(len(sample_data), params['seq_len'])).to(device)

                    # Generate a batch of fake data
                    fake_data = generator.forward(noise_sample,h_g)#.detach()
                    fake_data = fake_data.view(params['batch_size'], -1, params['seq_len'], 1)

                    # Loss measures generator's ability to fool the discriminator
                    g_loss = adversarial_loss(discriminator(fake_data), valid)

                    g_loss.backward()
                    g_optimizer.step()  
                    
                    # ---------------------
                    #  Train Discriminator
                    # ---------------------
                    d_optimizer.zero_grad()

                    # Generate a batch of real data   
                    real_data = Variable(sample_data.float()).to(device)  
                    real_data = real_data.view(params['batch_size'], -1, params['seq_len'], 1)

                    # Measure discriminator's ability to classify real from generated samples
                    real_loss = adversarial_loss(discriminator(real_data), valid)
                    fake_loss = adversarial_loss(discriminator(fake_data.detach()), fake)

                    d_loss = (0.5 * (real_loss + fake_loss))

                    d_loss.backward()
                    d_optimizer.step()


                if n_batch == (num_batches - 2):
                    G_losses.append(g_loss.item())
                    D_losses.append(d_loss.item())
                  
                    torch.save(generator.state_dict(), savepath+'/MBD_' +str(mbd)+ '/gen/generator_state_'+str(n)+'.pt')
                    torch.save(discriminator.state_dict(), savepath+'/MBD_' +str(mbd)+ '/disc/discriminator_state_'+str(n)+'.pt')
                                   
        # Dumping the errors for each training epoch.
        with open(savepath+'/MBD_'+str(mbd)+'/generator_losses.txt', 'wb') as fp:
            pickle.dump(G_losses, fp)
        with open(savepath+'/MBD_' +str(mbd)+ '/discriminator_losses.txt', 'wb') as fp:
            pickle.dump(D_losses, fp) 
                

### LSGAN-DTW 
Implemented with R multivariate DTW package. You can replace `R_DTW(G(z), x)` function with `evaluate_dtw_sample(G(z), x)` for our adapted DTW method but it is much slower.

In [None]:
# !!! Minimizes MSE instead of BCE
adversarial_loss = torch.nn.MSELoss()

def train_RDTWGAN(filename, params, savepath):
    # Load data
    data_loader, num_batches = load_data(filename, params['batch_size'])
    # Save Parameters
    save_params(params, savepath)
    
    # Iterative through the MBD Layers
    for mbd in params['minibatch_layer']:
        print("MBD_Layer: "+str(mbd))
        print("DL_Length = " +str(len((data_loader))))
        G_losses = []
        D_losses = []
        
        # Load model for this MBD layer
        generator, discriminator, g_optimizer, d_optimizer = load_model(params, int(mbd))
        for n in tqdm(range(params['epochs'])):
            for n_batch, sample_data in enumerate(data_loader):
                if len(sample_data[:,0,0]) < params['batch_size']:
                    break
                else:
                    # Adversarial GT
                    valid = Variable(Tensor(sample_data.size(0), 2).fill_(1.0), requires_grad=False)
                    fake = Variable(Tensor(sample_data.size(0), 2).fill_(0.0), requires_grad=False)

                    # -----------------
                    #  Train Generator
                    # -----------------
                    g_optimizer.zero_grad()
                    h_g = generator.init_hidden()
                    
                    # Sample noise as generator input
                    noise_sample = Variable(noise(len(sample_data), params['seq_len'])).to(device)

                    # Generate a batch of fake data
                    fake_data = generator.forward(noise_sample,h_g)#.detach()
                    fake_data = fake_data.view(params['batch_size'], -1, params['seq_len'], 1)

                    # Loss measures generator's ability to fool the discriminator
                    g_loss = adversarial_loss(discriminator(fake_data), valid)

                    g_loss.backward()
                    g_optimizer.step()  
                    
                    # ---------------------
                    #  Train Discriminator
                    # ---------------------
                    d_optimizer.zero_grad()

                    # Generate a batch of real data   
                    real_data = Variable(sample_data.float()).to(device)  
                    real_data = real_data.view(params['batch_size'], -1, params['seq_len'], 1)

                    # Measure discriminator's ability to classify real from generated samples
                    real_loss = adversarial_loss(discriminator(real_data), valid)
                    fake_loss = adversarial_loss(discriminator(fake_data.detach()), fake)
                    dtw = R_DTW(fake_data, real_data)
                    dtw_loss = 1 - (1.0/math.log(dtw))

                    #dtw_loss = torch.tensor(dtw_loss, dtype=torch.float64)

                    d_loss = (0.5 * (real_loss + fake_loss) + dtw_loss)
 
                    d_loss.backward()
                    d_optimizer.step()

                    if n_batch % 50 == 0:
                      print("Batch: " +str(n_batch)+ "/"+str(len(data_loader)))
                      print("D_loss: "+str(d_loss))


                if n_batch == (num_batches - 2):
                    G_losses.append(g_loss.item())
                    D_losses.append(d_loss.item())
                  
                    torch.save(generator.state_dict(), savepath+'/MBD_' +str(mbd)+ '/gen/generator_state_'+str(n)+'.pt')
                    torch.save(discriminator.state_dict(), savepath+'/MBD_' +str(mbd)+ '/disc/discriminator_state_'+str(n)+'.pt')
                                   
        # Dumping the errors for each training epoch.
        with open(savepath+'/MBD_'+str(mbd)+'/generator_losses.txt', 'wb') as fp:
            pickle.dump(G_losses, fp)
        with open(savepath+'/MBD_' +str(mbd)+ '/discriminator_losses.txt', 'wb') as fp:
            pickle.dump(D_losses, fp) 
                

###DTWGAN

Can use whatever DTW package here. For the paper we use our `evaluate_dtw_sample`, here `SoftDTW` is implmented as it is differentiable.

In [None]:
from timeit import default_timer as timer
import math
import chainer.functions as F

In [None]:
# !!! Minimizes DTW distance
criterion = SoftDTW(gamma=1.0, normalize=True) # just like nn.MSELoss()


def train_LSoftDTWgan(filename, params, savepath):
    # Load data
    data_loader, num_batches = load_data(filename, params['batch_size'])
    # Save Parameters
    save_params(params, savepath)
    
    # Iterative through the MBD Layers
    for mbd in params['minibatch_layer']:
        print("MBD_Layer: "+str(mbd))
        G_losses = []
        D_losses = []
        
        # Load model for this MBD layer
        generator, discriminator, g_optimizer, d_optimizer = load_model(params, int(mbd))
        for n in tqdm(range(params['epochs'])):
            for n_batch, sample_data in enumerate(data_loader):
                if len(sample_data[:,0,0]) < params['batch_size']:
                    break
                else:
                    # Adversarial GT
                    valid = Variable(Tensor(sample_data.size(0), 2).fill_(1.0), requires_grad=False)
                    fake = Variable(Tensor(sample_data.size(0), 2).fill_(0.0), requires_grad=False)
                    
                    # -----------------
                    #  Train Generator
                    # -----------------
                    g_optimizer.zero_grad()
                    h_g = generator.init_hidden()
                    
                    # Sample noise as generator input
                    noise_sample = Variable(noise(len(sample_data), params['seq_len'])).to(device)

                    # Generate a batch of fake data
                    fake_data = generator.forward(noise_sample,h_g)#.detach()
                    fake_data = fake_data.view(params['batch_size'], -1, params['seq_len'], 1)

                    g_loss = 0.5*criterion(discriminator(fake_data), valid)
                    g_loss.backward()
                    g_optimizer.step()  
                    
                    # ---------------------
                    #  Train Discriminator
                    # ---------------------
                    d_optimizer.zero_grad()

                    # Generate a batch of real data   
                    real_data = Variable(sample_data.float()).to(device)  
                    real_data = real_data.view(params['batch_size'], -1, params['seq_len'], 1)
  
                    real_loss = criterion(discriminator(real_data), valid)
                    fake_loss = criterion(discriminator(fake_data.detach()), fake)


                    if n_batch % 50 == 0:
                      print("Batch: " +str(n_batch)+ "/"+str(len(data_loader)))
                    
                    d_loss = (0.5 * (real_loss + fake_loss))
                    #d_loss = torch.tensor(d_loss, requires_grad=True).cuda()

                    d_loss.backward()
                    d_optimizer.step()


                if n_batch == (num_batches - 2):
                    print(g_loss, d_loss)
                    G_losses.append(g_loss.item())
                    D_losses.append(d_loss.item())
                  
                    torch.save(generator.state_dict(), savepath+'/MBD_' +str(mbd)+ '/gen/generator_state_'+str(n)+'.pt')
                    torch.save(discriminator.state_dict(), savepath+'/MBD_' +str(mbd)+ '/disc/discriminator_state_'+str(n)+'.pt')
                                   
        # Dumping the errors for each training epoch.
        with open(savepath+'/MBD_'+str(mbd)+'/generator_losses.txt', 'wb') as fp:
            pickle.dump(G_losses, fp)
        with open(savepath+'/MBD_' +str(mbd)+ '/discriminator_losses.txt', 'wb') as fp:
            pickle.dump(D_losses, fp) 
                

### LS-DTWGAN

In [None]:
# !!! Minimizes MSE instead of BCE
adversarial_loss = torch.nn.MSELoss()

def train_LSDTWgan(filename, params, savepath):
    # Load data
    data_loader, num_batches = load_data(filename, params['batch_size'])
    # Save Parameters
    save_params(params, savepath)
    
    # Iterative through the MBD Layers
    for mbd in params['minibatch_layer']:
        print("MBD_Layer: "+str(mbd))
        print("DL_Length = " +str(len((data_loader))))
        G_losses = []
        D_losses = []
        
        # Load model for this MBD layer
        generator, discriminator, g_optimizer, d_optimizer = load_model(params, int(mbd))
        for n in tqdm(range(params['epochs'])):
            for n_batch, sample_data in enumerate(data_loader):
                if len(sample_data[:,0,0]) < params['batch_size']:
                    break
                else:
                    # Adversarial GT
                    valid = Variable(Tensor(sample_data.size(0), 2).fill_(1.0), requires_grad=False)
                    fake = Variable(Tensor(sample_data.size(0), 2).fill_(0.0), requires_grad=False)

                    # -----------------
                    #  Train Generator
                    # -----------------
                    g_optimizer.zero_grad()
                    h_g = generator.init_hidden()
                    
                    # Sample noise as generator input
                    noise_sample = Variable(noise(len(sample_data), params['seq_len'])).to(device)

                    # Generate a batch of fake data
                    fake_data = generator.forward(noise_sample,h_g)#.detach()
                    fake_data = fake_data.view(params['batch_size'], -1, params['seq_len'], 1)

                    # Loss measures generator's ability to fool the discriminator
                    g_loss = adversarial_loss(discriminator(fake_data), valid)

                    g_loss.backward()
                    g_optimizer.step()  
                    
                    # ---------------------
                    #  Train Discriminator
                    # ---------------------
                    d_optimizer.zero_grad()

                    # Generate a batch of real data   
                    real_data = Variable(sample_data.float()).to(device)  
                    real_data = real_data.view(params['batch_size'], -1, params['seq_len'], 1)

                    # Measure discriminator's ability to classify real from generated samples
                    real_loss = adversarial_loss(discriminator(real_data), valid)
                    fake_loss = adversarial_loss(discriminator(fake_data.detach()), fake)
                    dtw_loss = evaluate_dtw_sample(fake_data, real_data)
                    d_loss = (0.5 * (real_loss + fake_loss)) + (0.0001 * dtw_loss)
                    #d_loss = (0.3 * (real_loss + fake_loss + math.log(dtw_loss)))
 
                    d_loss.backward()
                    d_optimizer.step()

                    if n_batch % 50 == 0:
                      print("Batch: " +str(n_batch)+ "/"+str(len(data_loader)))


                if n_batch == (num_batches - 2):
                    G_losses.append(g_loss.item())
                    D_losses.append(d_loss.item())
                  
                    torch.save(generator.state_dict(), savepath+'/MBD_' +str(mbd)+ '/gen/generator_state_'+str(n)+'.pt')
                    torch.save(discriminator.state_dict(), savepath+'/MBD_' +str(mbd)+ '/disc/discriminator_state_'+str(n)+'.pt')
                                   
        # Dumping the errors for each training epoch.
        with open(savepath+'/MBD_'+str(mbd)+'/generator_losses.txt', 'wb') as fp:
            pickle.dump(G_losses, fp)
        with open(savepath+'/MBD_' +str(mbd)+ '/discriminator_losses.txt', 'wb') as fp:
            pickle.dump(D_losses, fp) 
                

### LS-GAN

Loss Sense GAN with L2 distance

In [None]:
# -------- Init tensor --------
adversarial_loss = torch.nn.MSELoss()

slope=0.0
LeakyReLU = torch.nn.LeakyReLU(slope).to(device)
lamb = 2e-4
l2dist = torch.nn.PairwiseDistance(2)


def get_direct_gradient_penalty(netD, x, gamma, cuda):
    if cuda:
        x = x.cuda()

    x = autograd.Variable(x, requires_grad=True)
    output = netD(x)
    gradOutput = torch.ones(output.size()).cuda() if cuda else torch.ones(output.size())
    
    gradient = torch.autograd.grad(outputs=output, inputs=x, grad_outputs=gradOutput, create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradientPenalty = (gradient.norm(2, dim=1)).mean() * gamma
    
    return gradientPenalty

def LS_GAN(filename, params, savepath):
    # Load data
    data_loader, num_batches = load_data(filename, params['batch_size'])
    # Save Parameters
    save_params(params, savepath)
    
    # Iterative through the MBD Layers
    for mbd in params['minibatch_layer']:
        print("MBD_Layer: "+str(mbd))
        G_losses = []
        D_losses = []
        
        # Load model for this MBD layer
        generator, discriminator, g_optimizer, d_optimizer = load_model(params, int(mbd))
        for n in tqdm(range(params['epochs'])):
            
            for n_batch, sample_data in enumerate(data_loader):
                if len(sample_data[:,0,0]) < params['batch_size']:
                    break
                else:
                    # Adversarial GT
                    valid = Variable(Tensor(sample_data.size(0), 2).fill_(1.0), requires_grad=False)
                    fake = Variable(Tensor(sample_data.size(0), 2).fill_(0.0), requires_grad=False)
                    # ---------------------
                    #  Train Discriminator
                    # ---------------------
                    # Update D network
                    for p in discriminator.parameters():
                        p.requires_grad = True 

                    d_optimizer.zero_grad()
                    h_g = generator.init_hidden()
                    # Generate a batch of real data   
                    real_data = Variable(sample_data.float()).to(device)  
                    real_data = real_data.view(params['batch_size'], -1, params['seq_len'], 1)

                    # Sample noise as generator input
                    noise_sample = Variable(noise(len(sample_data), params['seq_len'])).to(device)

                    # Generate a batch of fake data
                    fake_data = generator.forward(noise_sample,h_g)#.detach()
                    fake_data = fake_data.view(params['batch_size'], -1, params['seq_len'], 1)

                    # Loss R for real
                    LossR = discriminator(real_data)
                    # Loss F for fake.
                    LossF = discriminator(fake_data)
            
                    pdist = l2dist(real_data[:,:,:,0].view(params['batch_size'], -1, 2), fake_data[:,:,:,0].view(params['batch_size'], -1, 2)).mul(lamb).to(device)
 
                    # Loss for D.
                    d_loss = LeakyReLU(LossR - LossF + pdist).mean()
                    d_loss.backward()

                    #gp = get_direct_gradient_penalty(discriminator, real_data, 10, True)
                    #gp.backward()

                    # Gradient of D.
                    gradD = real_data.grad

                    d_optimizer.step() 

                    # -----------------
                    #  Train Generator
                    # -----------------
                    # Update G network, freeze D.
                    for p in discriminator.parameters():
                        p.requires_grad = False 

                    g_optimizer.zero_grad()
                    h_g = generator.init_hidden()

                    # Sample noise as generator input
                    noise_sample = Variable(noise(len(sample_data), params['seq_len'])).to(device)

                    # Generate a batch of fake data
                    fake_data = generator.forward(noise_sample,h_g)
                    fake_data = fake_data.view(params['batch_size'], -1, params['seq_len'], 1)

                    # Loss F for fake.
                    LossF = discriminator(fake_data)
                    # Loss measures generator's ability to fool the discriminator
                    g_loss = LossF.mean()

                    g_loss.backward()
                    g_optimizer.step() 


                if n_batch == (num_batches - 2):
                    G_losses.append(g_loss.item())
                    D_losses.append(d_loss.item())
                  
                    torch.save(generator.state_dict(), savepath+'/MBD_' +str(mbd)+ '/gen/generator_state_'+str(n)+'.pt')
                    torch.save(discriminator.state_dict(), savepath+'/MBD_' +str(mbd)+ '/disc/discriminator_state_'+str(n)+'.pt')
                                   
        # Dumping the errors for each training epoch.
        with open(savepath+'/MBD_'+str(mbd)+'/generator_losses.txt', 'wb') as fp:
            pickle.dump(G_losses, fp)
        with open(savepath+'/MBD_' +str(mbd)+ '/discriminator_losses.txt', 'wb') as fp:
            pickle.dump(D_losses, fp)

### LS-GAN-DTW

Loss-Sense GAN with DTW distance 

In [None]:
# -------- Init tensor --------
adversarial_loss = torch.nn.MSELoss()

slope=0.0
LeakyReLU = torch.nn.LeakyReLU(slope).to(device)
lamb = 2e-4
l2dist = torch.nn.PairwiseDistance(2)


def get_direct_gradient_penalty(netD, x, gamma, cuda):
    if cuda:
        x = x.cuda()

    x = autograd.Variable(x, requires_grad=True)
    output = netD(x)
    gradOutput = torch.ones(output.size()).cuda() if cuda else torch.ones(output.size())
    
    gradient = torch.autograd.grad(outputs=output, inputs=x, grad_outputs=gradOutput, create_graph=True, retain_graph=True, only_inputs=True)[0]
    gradientPenalty = (gradient.norm(2, dim=1)).mean() * gamma
    
    return gradientPenalty

def LS_GAN(filename, params, savepath):
    # Load data
    data_loader, num_batches = load_data(filename, params['batch_size'])
    # Save Parameters
    save_params(params, savepath)
    
    # Iterative through the MBD Layers
    for mbd in params['minibatch_layer']:
        print("MBD_Layer: "+str(mbd))
        G_losses = []
        D_losses = []
        
        # Load model for this MBD layer
        generator, discriminator, g_optimizer, d_optimizer = load_model(params, int(mbd))
        for n in tqdm(range(params['epochs'])):
            
            for n_batch, sample_data in enumerate(data_loader):
                if len(sample_data[:,0,0]) < params['batch_size']:
                    break
                else:
                    # Adversarial GT
                    valid = Variable(Tensor(sample_data.size(0), 2).fill_(1.0), requires_grad=False)
                    fake = Variable(Tensor(sample_data.size(0), 2).fill_(0.0), requires_grad=False)
                    # ---------------------
                    #  Train Discriminator
                    # ---------------------
                    # Update D network
                    for p in discriminator.parameters():
                        p.requires_grad = True 

                    d_optimizer.zero_grad()
                    h_g = generator.init_hidden()
                    # Generate a batch of real data   
                    real_data = Variable(sample_data.float()).to(device)  
                    real_data = real_data.view(params['batch_size'], -1, params['seq_len'], 1)

                    # Sample noise as generator input
                    noise_sample = Variable(noise(len(sample_data), params['seq_len'])).to(device)

                    # Generate a batch of fake data
                    fake_data = generator.forward(noise_sample,h_g)
                    fake_data = fake_data.view(params['batch_size'], -1, params['seq_len'], 1)

                    # Loss R for real
                    LossR = discriminator(real_data)
                    # Loss F for fake.
                    LossF = discriminator(fake_data)

                    #x = torch.FloatTensor(R_DTW(real_data, fake_data)).mul(lamb).to(device) #R_DTW(fake_data, real_data)
                    #pdist = torch.zeros(50, 2).to(device)
                    #pdist[:,0] = x
                    #pdist[:,1] = x

                    pdist = R_DTW(fake_data, real_data)
             
                    # Loss for D.
                    d_loss = LeakyReLU(LossR - LossF + pdist).mean()
                    d_loss.backward()

                    #gp = get_direct_gradient_penalty(discriminator, real_data, 10, True)
                    #gp.backward()

                    # Gradient of D.
                    gradD = real_data.grad

                    d_optimizer.step() 

                    # -----------------
                    #  Train Generator
                    # -----------------
                    # Update G network, freeze D.
                    for p in discriminator.parameters():
                        p.requires_grad = False 

                    g_optimizer.zero_grad()
                    h_g = generator.init_hidden()

                    # Sample noise as generator input
                    noise_sample = Variable(noise(len(sample_data), params['seq_len'])).to(device)

                    # Generate a batch of fake data
                    fake_data = generator.forward(noise_sample,h_g)
                    fake_data = fake_data.view(params['batch_size'], -1, params['seq_len'], 1)

                    # Loss F for fake.
                    LossF = discriminator(fake_data)
                    # Loss measures generator's ability to fool the discriminator
                    g_loss = LossF.mean()

                    g_loss.backward()
                    g_optimizer.step() 


                if n_batch == (num_batches - 2):
                    G_losses.append(g_loss.item())
                    D_losses.append(d_loss.item())
                  
                    torch.save(generator.state_dict(), savepath+'/MBD_' +str(mbd)+ '/gen/generator_state_'+str(n)+'.pt')
                    torch.save(discriminator.state_dict(), savepath+'/MBD_' +str(mbd)+ '/disc/discriminator_state_'+str(n)+'.pt')
                                   
        # Dumping the errors for each training epoch.
        with open(savepath+'/MBD_'+str(mbd)+'/generator_losses.txt', 'wb') as fp:
            pickle.dump(G_losses, fp)
        with open(savepath+'/MBD_' +str(mbd)+ '/discriminator_losses.txt', 'wb') as fp:
            pickle.dump(D_losses, fp)

## Main

In [None]:
# Initialise Parameters
parameters = init_params(datafile)
# Choose GAN function to use
LS_GAN(filename=datapath+datafile, params=parameters, savepath=savepath)