In [None]:
# %load ../batchedRNN/model/Data.py
import torch
import numpy as np
from torch.utils import data
 
class DataLoader(object):
    def __init__(self, xs, ys, batch_size, pad_with_last_sample=True, shuffle=True):
        self.batch_size = batch_size
        self.current_ind = 0
        if pad_with_last_sample:
            num_padding = (batch_size - (len(xs) % batch_size)) % batch_size
            x_padding = np.repeat(xs[-1:], num_padding, axis=0)
            y_padding = np.repeat(ys[-1:], num_padding, axis=0)
            xs = np.concatenate([xs, x_padding], axis=0)
            ys = np.concatenate([ys, y_padding], axis=0)
        self.size = len(xs)
        self.num_batch = int(self.size // self.batch_size)
        self.xs = xs
        self.ys = ys
        if shuffle:
            self.shuffle()

    def get_iterator(self):
        self.current_ind = 0

        def _wrapper():
            while self.current_ind < self.num_batch:
                start_ind = self.batch_size * self.current_ind
                end_ind = min(self.size, self.batch_size * (self.current_ind + 1))
                x_i = np.transpose(self.xs[start_ind: end_ind, ...], (1,0,3,2))
                y_i = np.transpose(self.ys[start_ind: end_ind, :,:,0], (1,0,2))
                yield (x_i, y_i)
                self.current_ind += 1

        return _wrapper()


    def shuffle(self):
        permutation = np.random.permutation(self.size)
        self.xs, self.ys = self.xs[permutation], self.ys[permutation]

class DataLoaderWithTime(object):
    def __init__(self, xs, ys, tx, ty, batch_size, pad_with_last_sample=True, shuffle=True):
        self.batch_size = batch_size
        self.current_ind = 0
        if pad_with_last_sample:
            num_padding = (batch_size - (len(xs) % batch_size)) % batch_size
            x_padding = np.repeat(xs[-1:], num_padding, axis=0)
            y_padding = np.repeat(ys[-1:], num_padding, axis=0)
            tx_padding = np.repeat(tx[-1:], num_padding, axis=0)
            ty_padding = np.repeat(ty[-1:], num_padding, axis=0)
            xs = np.concatenate([xs, x_padding], axis=0)
            ys = np.concatenate([ys, y_padding], axis=0)
            tx = np.concatenate([tx, tx_padding], axis=0)
            ty = np.concatenate([ty, ty_padding], axis=0)
        self.size = len(xs)
        self.num_batch = int(self.size // self.batch_size)
        self.xs = xs
        self.ys = ys
        self.tx = tx
        self.ty = ty
        if shuffle:
            self.shuffle()

    def get_iterator(self):
        self.current_ind = 0

        def _wrapper():
            while self.current_ind < self.num_batch:
                start_ind = self.batch_size * self.current_ind
                end_ind = min(self.size, self.batch_size * (self.current_ind + 1))
                x_i = np.transpose(self.xs[start_ind: end_ind, ...], (1,0,3,2))
                y_i = np.transpose(self.ys[start_ind: end_ind, :,:,0], (1,0,2))
                tx_i = np.transpose(self.tx[start_ind: end_ind, ...], (1,0))
                ty_i = np.transpose(self.ty[start_ind: end_ind, ...], (1,0))
                yield (x_i, y_i, tx_i, ty_i)
                self.current_ind += 1

        return _wrapper()

    def shuffle(self):
        permutation = np.random.permutation(self.size)
        self.xs, self.ys, self.tx, self.ty = self.xs[permutation], self.ys[permutation], self.tx[permutation], self.ty[permutation]


In [None]:
# %load ../batchedRNN/utils.py
import logging, sys
import torch
import h5py
import os
import numpy as np
import torch.utils.data as torchUtils
import torch.optim as optim
from functools import partial
import torch.nn as nn
import json
from shutil import copy2, copyfile, copytree
import argparse
import matplotlib
# matplotlib.use('Agg')
import matplotlib.pyplot as plt

# parser = argparse.ArgumentParser(description='Batched Sequence to Sequence')
# parser.add_argument('--h_dim', type=int, default=256)
# parser.add_argument("--z_dim", type=int, default=128)
# parser.add_argument('--no_cuda', action='store_true', default=False,
#                                         help='disables CUDA training')
# parser.add_argument("--no_attn", action="store_true", default=True, help="Do not use AttnDecoder")
# parser.add_argument("--n_epochs", type=int, default=200)
# parser.add_argument("--batch_size", type=int, default= 64)
# parser.add_argument("--n_layers", type=int, default=2)
# parser.add_argument("--initial_lr", type=float, default=1e-4)
# parser.add_argument("--lr_decay_every", type=int, default=10)
# parser.add_argument("--lr_decay_factor", type=float, default=.10)
# parser.add_argument("--lr_decay_beginning", type=int, default=20)
# parser.add_argument("--print_every", type=int, default = 200)
# parser.add_argument("--criterion", type=str, default="L1Loss")
# parser.add_argument("--save_freq", type=int, default=10)
# parser.add_argument("--down_sample", type=float, default=0.0, help="Keep this fraction of the training data")
# # parser.add_argument("--data_dir", type=str, default="./data/reformattedTraffic/")
# parser.add_argument("--model", type=str, default="sketch-rnn")
# parser.add_argument("--lambda_l1", type=float, default=0)
# parser.add_argument("--lambda_l2", type=float, default=5e-4)
# parser.add_argument("--no_schedule_sampling", action="store_true", default=False)
# parser.add_argument("--scheduling_start", type=float, default=1.0)
# parser.add_argument("--scheduling_end", type=float, default=0.0)
# parser.add_argument("--tries", type=int, default=12)
# parser.add_argument("--kld_warmup_until", type=int, default=5)
# parser.add_argument("--kld_weight_max", type=float, default=0.10)
# parser.add_argument("--no_shuffle_after_epoch", action="store_true", default=False)
# parser.add_argument("--clip", type=int, default=10)
# parser.add_argument("--dataset", type=str, default="traffic")
# parser.add_argument("--predictOnTest", action="store_true", default=True)
# parser.add_argument("--encoder_input_dropout", type=float, default=0.5)
# parser.add_argument("--encoder_layer_dropout", type=float, default=0.5)
# parser.add_argument("--decoder_input_dropout", type=float, default=0.5)
# parser.add_argument("--decoder_layer_dropout", type=float, default=0.5)
# parser.add_argument("--noEarlyStopping", action="store_true", default=False)
# parser.add_argument("--earlyStoppingPatients", type=int, default=3)
# parser.add_argument("--earlyStoppingMinDelta", type=float, default=0.0001)
# parser.add_argument("--bidirectionalEncoder", type=bool, default=True)
# parser.add_argument("--local", action="store_true", default=False)
# parser.add_argument("--debugDataset", action="store_true", default=False)
# parser.add_argument("--encoder_h_dim", type=int, default=256)
# parser.add_argument("--decoder_h_dim", type=int, default=512)
# parser.add_argument("--num_mixtures", type=int, default=20)
# args = parser.parse_args()
logging.basicConfig(stream=sys.stderr,level=logging.DEBUG)

def plotLosses(trainLosses, valLosses, trainKLDLosses=None, valKLDLosses=None):
    torch.save(trainLosses, args.save_dir+"plot_train_recon_losses")
    torch.save(valLosses, args.save_dir+"plot_val_recon_losses")
    if trainKLDLosses and valKLDLosses:
        torch.save(trainKLDLosses, args.save_dir+"plot_train_KLD_losses")
        torch.save(valKLDLosses, args.save_dir+"plot_val_KLD_losses")
    plt.rcParams.update({'font.size': 8})
    fig, ax1 = plt.subplots()
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel(args.criterion, color="r")
    ax1.tick_params('y', colors='r')
    ax1.plot(np.arange(1, len(trainLosses)+1), trainLosses, "r--", label="train reconstruction loss")
    ax1.plot(np.arange(1, len(valLosses)+1), valLosses, color="red", label="validation reconstruction loss")
    ax1.legend(loc="upper left")
    ax1.grid()
    plt.title("Losses for {}".format(args.model))
    plt.savefig(args.save_dir + "train_val_loss_plot.png")

def getSaveDir():
    if args.local:
        saveDir = '../save/local/models/model0/'
    else:
        saveDir = '../save/models/model0/'
    while os.path.isdir(saveDir):
        numStart = saveDir.rfind("model")+5
        numEnd = saveDir.rfind("/")
        saveDir = saveDir[:numStart] + str(int(saveDir[numStart:numEnd])+1) + "/"
    os.mkdir(saveDir)
    return saveDir

def saveUsefulData():
    argsFile = args.save_dir + "args.txt"
    with open(argsFile, "w") as f:
        f.write(json.dumps(vars(args)))
    copy2("./train.py", args.save_dir+"train.py")
    copy2("./utils.py", args.save_dir+"utils.py")
    copy2("./gridSearchOptimize.py", args.save_dir+"gridsearchOptimize.py")
    copytree("./model", args.save_dir+"model/")

def getTrafficDataset(dataDir, category):
    f = np.load(os.path.join(dataDir, category + '.npz'))
    my_dataset = torchUtils.TensorDataset(torch.Tensor(f["inputs"]),torch.Tensor(f["targets"])) # create your datset
    scaler = getScaler(f["inputs"])
    sequence_len = f['inputs'].shape[1]
    x_dim = f['inputs'].shape[2]
    channels = f["inputs"].shape[3]
    return my_dataset, scaler, sequence_len, sequence_len, x_dim, channels

def getHumanDataset(dataDir, category):
    f = h5py.File(os.path.join(dataDir, category+".h5"), "r")
    my_dataset = torchUtils.TensorDataset(torch.Tensor(f["input2d"]), torch.Tensor(f["target2d"]))
    scaler = getScaler(f["input2d"])
    input_sequence_len = f["input2d"].shape[1]
    target_sequence_len = f["target2d"].shape[1]
    x_dim = f["input2d"].shape[2]
    channels = f["input2d"].shape[3]
    return my_dataset, scaler, input_sequence_len, target_sequence_len, x_dim, channels

def getLoaderAndScaler(dataDir, category):
    logging.info("Getting {} loader".format(category))
    if args.dataset == "traffic":
        my_dataset, scaler, input_sequence_len, target_sequence_len, x_dim, channels = getTrafficDataset(dataDir, category)
    else:
        my_dataset, scaler, input_sequence_len, target_sequence_len, x_dim, channels = getHumanDataset(dataDir, category)
    shf = False
    if category == "train":
        shf = True
    loader = torchUtils.DataLoader(
        my_dataset,
        batch_size=args.batch_size,
        shuffle=shf,
        num_workers=0,
        pin_memory=False,
        drop_last=True
        )
    return loader, scaler, input_sequence_len, target_sequence_len, x_dim, channels # create your dataloader

def getDataLoaders(dataDir, debug=False):
    loaders = {}
    logging.info("Getting data from {}".format(dataDir))
    if debug:
        categories = ["test"]
        scalerSet = "test"
    else:
        categories = ["train", "val", "test"]
        scalerSet = "train"
    for category in categories:
        loader, scaler, input_sequence_len, target_sequence_len, x_dim, channels = getLoaderAndScaler(dataDir, category)
        if category == scalerSet:
            loaders["scaler"] = scaler
            loaders["input_sequence_len"] = input_sequence_len
            loaders["target_sequence_len"] = target_sequence_len
            loaders["x_dim"] = x_dim
            loaders["channels"] = channels
        loaders[category] = loader
    return loaders

class StandardScaler:
    """
    Standard the input
    """

    def __init__(self, mean0, std0, mean1, std1):
        self.mean0 = mean0
        self.std0 = std0
        self.mean1 = mean1
        self.std1 = std1

    def transform(self, data):
        mean = torch.zeros(data.size())
        mean[...,0] = self.mean0
        mean[...,1] = self.mean1
        std = torch.ones(data.size())
        std[...,0] = self.std0
        std[...,1] = self.std1
        return torch.div(torch.sub(data,mean),std)

class StandardScalerTraffic(StandardScaler):
    def __init__(self, mean0, std0):
        super(StandardScalerTraffic, self).__init__(mean0, std0, 0.0, 1.0)

    def inverse_transform(self, data):
        """
        Inverse transform is applied to output and target.
        These are only the speeds, so only use the first 
        """
        mean = torch.ones(data.size()) * self.mean0
        std = torch.ones(data.size()) * self.std0
        if args.cuda:
            mean = mean.cuda()
            std = std.cuda()
        transformed = torch.add(torch.mul(data, std), mean)
        del mean, std
        return transformed.permute(1,0,2)

    def transformBatchForEpoch(self, batch):
        x = self.transform(batch[0]).permute(1,0,3,2)
        y = self.transform(batch[1])[...,0].permute(1,0,2)
        if args.cuda:
            return x.cuda(), y.cuda()
        return x, y

class StandardScalerHuman(StandardScaler):
    """docstring for StandardScalerHuman"""
    def __init__(self, mean0, std0, mean1, std1):
        super(StandardScalerHuman, self).__init__(mean0, std0, mean1, std1)

    def inverse_transform(self, data):
        """
        applied to output and target
        """
        transed = self.restoreDim(data)
        mean = torch.zeros(transed.size())
        std = torch.ones(transed.size())
        if args.cuda:
            mean = mean.cuda()
            std = std.cuda()
        mean[...,0] = self.mean0
        mean[...,1] = self.mean1
        std[...,0] = self.std0
        std[...,1] = self.std1
        transformed =  torch.add(torch.mul(transed, std), mean)
        del mean, std
        return transformed.permute(1,0,3,2)

    def restoreDim(self, data):
        l1, l2 = torch.split(data, int(data.size(2) / 2), 2)
        return torch.cat((l1.unsqueeze(3), l2.unsqueeze(3)), dim=3)

    def removeDim(self, data):
        layer0, layer1 = torch.split(data, 1, dim=3)
        return torch.cat((layer0.squeeze(3), layer1.squeeze(3)), dim=2)

    def transformBatchForEpoch(self, batch):
        x = self.transform(batch[0]).permute(1,0,3,2)
        y = self.transform(batch[1])
        wideY = self.removeDim(y).permute(1,0,2)
        if args.cuda:
            return x.cuda(), wideY.cuda()
        return x, wideY

def getScaler(trainX):
    mean0 = np.mean(trainX[...,0])
    std0 = np.std(trainX[...,0])
    mean1 = np.mean(trainX[...,1])
    std1 = np.std(trainX[...,1])
    if args.dataset == "traffic":
        return StandardScalerTraffic(mean0, std0)
    elif args.dataset == "human":
        return StandardScalerHuman(mean0, std0, mean1, std1)
    else:
        assert False, "bad dataset"

def getReconLoss(output, target, scaler):
    output = scaler.inverse_transform(output)
    target = scaler.inverse_transform(target)
    assert output.size() == target.size(), "output size: {}, target size: {}".format(output.size(), target.size())
    if args.criterion == "RMSE":
        criterion = nn.MSELoss()
        return torch.sqrt(criterion(output, target))
    elif args.criterion == "L1Loss":
        criterion = nn.L1Loss()
        return criterion(output, target)
    else:
        assert False, "bad loss function"

def getKLDWeight(epoch):
    # kldLossWeight = args.kld_weight_max * min((epoch / (args.kld_warmup_until)), 1.0)
    kldLossWeight = args.kld_weight_max
    return kldLossWeight

def kld_gauss(mean_1, std_1, mean_2, std_2):
    """Using std to compute KLD"""

    kld_element = (2 * torch.log(std_2) - 2 * torch.log(std_1) +
                   (std_1.pow(2) + (mean_1 - mean_2).pow(2)) /
                   std_2.pow(2) - 1)
    return 0.5 * torch.sum(kld_element)

def sketchRNNKLD(latentMean, latentStd):
    m2 = torch.zeros_like(latentMean)
    s2 = torch.ones_like(latentStd)
    return kld_gauss(latentMean, latentStd, m2, s2)

def getLoss(model, output, target, scaler, epoch):
    if args.model == "rnn":
        reconLoss = getReconLoss(output, target, scaler)
        return reconLoss, 0
    else:
        latentMean, latentStd, z, predOut, predMeanOut, predStdOut = output
        reconLoss = getReconLoss(predOut, target, scaler)
        kldLoss = sketchRNNKLD(latentMean, latentStd)
        return reconLoss, kldLoss

def saveModel(modelWeights, epoch):
    fn = args.save_dir+'{}_state_dict_'.format(args.model)+str(epoch)+'.pth'
    torch.save(modelWeights, fn)
    logging.info('Saved model to '+fn)

class EarlyStoppingObject(object):
    """docstring for EarlyStoppingObject"""
    def __init__(self):
        super(EarlyStoppingObject, self).__init__()
        self.bestLoss = None
        self.bestEpoch = None
        self.counter = 0
        self.epochCounter = 0

    def checkStop(self, previousLoss):
        self.epochCounter += 1
        if not args.noEarlyStopping:
            if self.bestLoss is not None and previousLoss + args.earlyStoppingMinDelta >= self.bestLoss:
                self.counter += 1
                if self.counter >= args.earlyStoppingPatients:
                    logging.info("Stopping Early, haven't beaten best loss {:.4f} @ Epoch {} in {} epochs".format(
                        self.bestLoss,
                        self.bestEpoch,
                        args.earlyStoppingPatients))
                    return True
            else:
                self.bestLoss = previousLoss
                self.bestEpoch = self.epochCounter
                self.counter = 0
                return False

        else:
            return False


In [None]:
# %load ../batchedRNN/model/RoseSeq2Seq
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import Parameter
from torch.autograd import Variable
import numpy as np

class EncoderRNN(nn.Module):
    def __init__(self, input_size, hidden_size, n_layers=2, bidirectional=True, args=None):
        super(EncoderRNN, self).__init__()
        self.args = args
        self.input_size = input_size
        self.n_layers = n_layers
        self.hidden_size = hidden_size
        self.embedding = nn.Linear(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size * self.args.channels, hidden_size, n_layers, dropout=self.args.encoder_layer_dropout, bidirectional=bidirectional)
        self.input_dropout = nn.Dropout(p=self.args.encoder_input_dropout)
    def forward(self, input, hidden):
        embedded = self.embedding(input)
        embedded = self.input_dropout(embedded)
        embedded = torch.unsqueeze(embedded, 0)
        embedded = embedded.view(1, self.args.batch_size, -1)
        output, hidden = self.gru(embedded, hidden)
        return output, hidden

    def initHidden(self):
        if self.args.bidirectionalEncoder:
            directions = 2
        else:
            directions = 1
        result = Variable(torch.zeros(self.n_layers * directions, self.args.batch_size, self.hidden_size))
        if self.args.cuda:
            return result.cuda()
        else:
            return result

class DecoderRNN(nn.Module):
    def __init__(self, hidden_size, output_size, n_layers=2, args=None):
        super(DecoderRNN, self).__init__()
        self.args = args

        self.n_layers = n_layers
        self.hidden_size = hidden_size

        self.embedding = nn.Linear(output_size, hidden_size)
        self.input_dropout = nn.Dropout(p=self.args.decoder_input_dropout)
        if self.args.bidirectionalEncoder:
            directions = 2
        else:
            directions = 1
        # encoder hidden is (layers * directions, batch, hidden_size)
        # converted to (layers, batch, hidden_size * directions)
        self.gru = nn.GRU(hidden_size, directions * hidden_size, n_layers, dropout=self.args.decoder_layer_dropout)
        # GRU output (seq_len, batch, directions * hidden_size)
        self.out = nn.Linear(directions * hidden_size, output_size)

    def forward(self, input, hidden):
        embedded = self.embedding(input)
        embedded = self.input_dropout(embedded)
        embedded = F.relu(embedded)
        embedded = torch.unsqueeze(embedded, 0)
        output, hidden = self.gru(embedded, hidden)
        output = self.out(output.squeeze(0))
        #print("decoder output", output[10,31])
        return output, hidden

class Seq2Seq(nn.Module):
    def __init__(self, args):
        super(Seq2Seq, self).__init__()
        self.args = args

        self.enc = EncoderRNN(self.args.x_dim, self.args.h_dim, n_layers=self.args.n_layers, bidirectional=args.bidirectionalEncoder, args=args)

        self.dec = DecoderRNN(self.args.h_dim, self.args.output_dim, n_layers=self.args.n_layers, args=args)

        self.use_schedule_sampling = args.use_schedule_sampling
        self.scheduling_start = args.scheduling_start
        self.scheduling_end = args.scheduling_end

    def _cat_directions(self, h):
        """ If the encoder is bidirectional, do the following transformation.
            (#directions * #layers, #batch, hidden_size) -> (#layers, #batch, #directions * hidden_size)
        """
        h = torch.cat([h[0:h.size(0):2], h[1:h.size(0):2]], 2)
        return h

    def parameters(self):
        return list(self.enc.parameters()) + list(self.dec.parameters())

    def scheduleSample(self, epoch):
        eps = max(self.args.scheduling_start - 
            (self.args.scheduling_start - self.args.scheduling_end)* epoch / self.args.n_epochs,
            self.args.scheduling_end)
        return np.random.binomial(1, eps)

    def forward(self, x, target, epoch):
        encoder_hidden = self.enc.initHidden()
        hs = []
        for t in range(self.args.input_sequence_len):
            encoder_output, encoder_hidden = self.enc(x[t].squeeze(), encoder_hidden)
            hs += [encoder_output]
        if self.args.bidirectionalEncoder:
            decoder_hidden = self._cat_directions(encoder_hidden)
        else:
            decoder_hidden = encoder_hidden
        # Prepare for Decoder
        inp = Variable(torch.zeros(self.args.batch_size, self.args.output_dim))
        if self.args.cuda:
            inp = inp.cuda()
        ys = []
        if not self.training:
            sample=0
        else:
            sample = self.scheduleSample(epoch)
        # Decode
        for t in range(self.args.target_sequence_len):
            decoder_output, decoder_hidden = self.dec(inp, decoder_hidden)
            if sample:
                inp = target[t-1]
            else:
                inp = decoder_output
            ys += [decoder_output]
        return torch.cat([torch.unsqueeze(y, dim=0) for y in ys])


In [None]:
# %load ../batchedRNN/model/SketchRNN.py
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F

class SketchRNNEncoder(nn.Module):
    def __init__(self, args):
        super(SketchRNNEncoder, self).__init__()
        self.args = args
        if self.args.bidirectionalEncoder:
            self.directions = 2
        else:
            self.directions = 1
        # bidirectional lstm:
        self.lstm = nn.LSTM(self.args.x_dim * self.args.channels, self.args.encoder_h_dim, \
            self.args.n_layers, dropout=self.args.encoder_layer_dropout, bidirectional=self.args.bidirectionalEncoder)
        # create mu and sigma from lstm's last output:
        self.fc_mu = nn.Linear(self.args.n_layers * self.directions * self.args.encoder_h_dim, self.args.z_dim)
        self.fc_sigma = nn.Linear(self.args.n_layers * self.directions * self.args.encoder_h_dim, self.args.z_dim)
        

    def forward(self, input, hidden_cell=None):
        if hidden_cell is None:
            hidden_cell = self.init_hidden_cell()
        _, (hidden, cell) = self.lstm(input, hidden_cell)
        # convert hidden size from (n_layers * directions, batch_size, h_dim)
        #                       to (batch_size, n_layers * directions * h_dim)
        hiddenLayers = torch.split(hidden, 1, 0)
        if self.directions == 2 and self.args.n_layers == 2:
            assert len(hiddenLayers) == 4
        hidden_cat = torch.cat([h.squeeze(0) for h in hiddenLayers], 1)
        mu = self.fc_mu(hidden_cat)
        sigma_hat = self.fc_sigma(hidden_cat)
        sigma = torch.exp(sigma_hat / 2)
        z_size = mu.size()
        if self.args.cuda:
            N = Variable(torch.normal(torch.zeros(z_size),torch.ones(z_size)).cuda())
        else:
            N = Variable(torch.normal(torch.zeros(z_size),torch.ones(z_size)))
        z = mu + sigma*N
        return z, mu, sigma_hat



    def init_hidden_cell(self):
        hidden = Variable(torch.zeros(self.directions * self.args.n_layers, self.args.batch_size, self.args.encoder_h_dim))
        cell = Variable(torch.zeros(self.directions * self.args.n_layers, self.args.batch_size, self.args.encoder_h_dim))
        if self.args.cuda:
            return (hidden.cuda(), cell.cuda())
        else:
            return (hidden, cell)

class SketchRNNDecoder(nn.Module):
    def __init__(self, args):
        super(SketchRNNDecoder, self).__init__()
        self.args = args
        # to init hidden and cell from z:
        self.fc_hc = nn.Linear(self.args.z_dim, 2 * self.args.n_layers * self.args.decoder_h_dim)
        # unidirectional lstm:
        self.lstm = nn.LSTM(self.args.z_dim + self.args.output_dim, self.args.decoder_h_dim, self.args.n_layers, dropout=self.args.decoder_layer_dropout)
        self.muLayer = nn.Linear(self.args.decoder_h_dim, self.args.output_dim * self.args.n_gaussians)
        self.sigmaLayer = nn.Linear(self.args.decoder_h_dim, self.args.output_dim * self.args.n_gaussians)
        self.piLayer = nn.Linear(self.args.decoder_h_dim, self.args.output_dim * self.args.n_gaussians)

    def forward(self, inputs, z, hidden_cell=None):
        if hidden_cell is None:
            layers = torch.split(torch.tanh(self.fc_hc(z)),self.args.decoder_h_dim,1)
            hidden = torch.stack(layers[:int(len(layers) / 2)], dim=0)
            cell = torch.stack(layers[int(len(layers) / 2): ], dim=0)
            hidden_cell = (hidden.contiguous(), cell.contiguous())
        outputs,(hidden,cell) = self.lstm(inputs, hidden_cell)
        # outputs size: (seq_len, batch, num_directions * hidden_size)
        # hidden size: (num_layers * num_directions, batch, hidden_size)
        # cell size: (num_layers * num_directions, batch, hidden_size)
        mu = self.muLayer(outputs).view(-1, self.args.batch_size, self.args.output_dim, self.args.n_gaussians)
        sigma = self.sigmaLayer(outputs).view(-1, self.args.batch_size, self.args.output_dim, self.args.n_gaussians)
        pi = self.piLayer(outputs).view(-1, self.args.batch_size, self.args.output_dim, self.args.n_gaussians)
        pi = F.softmax(pi, 3)
        sigma = torch.exp(sigma)
        return (pi, mu, sigma), (hidden, cell)

class SketchRNN(nn.Module):
    def __init__(self, args):
        super(SketchRNN, self).__init__()
        self.args = args
        if self.args.cuda:
            self.encoder = SketchRNNEncoder(args).cuda()
            self.decoder = SketchRNNDecoder(args).cuda()
        else:
            self.encoder = SketchRNNEncoder(args)
            self.decoder = SketchRNNDecoder(args)

    def scheduleSample(self, epoch):
        eps = max(self.args.scheduling_start - 
            (self.args.scheduling_start - self.args.scheduling_end)* epoch / self.args.self.args.n_epochs,
            self.args.scheduling_end)
        return np.random.binomial(1, eps)

    def generatePred(self, pi, mu, sigma):
        if self.args.cuda:
            N = torch.randn(pi.size()).cuda()
            #N = torch.normal(torch.zeros(pi.size()),torch.ones(pi.size()))
        else:
            N = torch.randn(pi.size())
            #N = torch.normal(torch.zeros(pi.size()),torch.ones(pi.size()))
        clusterPredictions = mu + sigma * N
        weightedClusterPredictions = clusterPredictions * pi
        pred = torch.sum(weightedClusterPredictions, dim=3)
        return pred

    def allSteps(self, target, z):
        sos = self.getStartOfSequence()
        batch_init = torch.cat([sos, target[:-1,...]], 0)
        z_stack = torch.stack([z]*(self.args.target_sequence_len))
        inp = torch.cat([batch_init, z_stack], 2)
        (pi, mu, sigma), (hidden, cell) = self.decoder(inp, z)
        return (pi, mu, sigma)

    def oneStepAtATime(self, z):
        sos = self.getStartOfSequence()
        inp = torch.cat([sos, z.unsqueeze(0)], 2)
        piList, muList, sigmaList = [], [], []
        for timeStep in range(self.args.target_sequence_len):
            (pi, mu, sigma), (hidden, cell) = self.decoder(inp, z)
            pred = self.generatePred(pi, mu, sigma)
            inp = torch.cat([pred, z.unsqueeze(0)], 2)
            piList.append(pi.detach())
            muList.append(mu.detach())
            sigmaList.append(sigma.detach())
        Pi = torch.cat(piList, 0)
        Mu = torch.cat(muList, 0)
        Sigma = torch.cat(sigmaList, 0)
        return (Pi, Mu, Sigma)

    def getStartOfSequence(self):
        if self.args.cuda:
            return Variable(torch.zeros(1, self.args.batch_size, self.args.output_dim).cuda())
        else:
            return Variable(torch.zeros(1, self.args.batch_size, self.args.output_dim))

    def doEncoding(self,batch):
        # convert input from [input_sequence_len, batch_size, channels, x_dim]
        #                 to [input_sequence_len, batch_size, channels * x_dim]
        embedded = batch.contiguous().view(-1, self.args.batch_size, self.args.x_dim * self.args.channels)
        z, mu, sigma_hat = self.encoder(embedded)
        return z, mu, sigma_hat

    def forward(self, batch, target, epoch):
        z, latentMean, latentStd = self.doEncoding(batch)
        if self.training:
            (Pi, Mu, Sigma) = self.allSteps(target, z)
        else:
            (Pi, Mu, Sigma) = self.oneStepAtATime(z)
        return Pi, Mu, Sigma, latentMean, latentStd


In [None]:
import ast

class Bunch(object):
    def __init__(self, adict):
        self.__dict__.update(adict)
    def __str__(self):
        out = ""
        for k, v in self.__dict__.items():
            out += "{}: {}, ".format(k, v)
        return out
import json

In [None]:
class PostProcess(object):
    def __init__(self, modelPath,args, chooseModel="rnn", dataDict=None):
        self.args = args
        self.modelPath = modelPath
        if dataDict:
            self.dataDict = dataDict
        else:
            self.dataDict = GetDataLoader("../data/traffic/trafficWithTime/")
        if chooseModel=="rnn":
            self.model = Seq2Seq(self.args)
        elif chooseModel=="sketch-rnn":
            self.model = SketchRNN(self.args)
        self.model.eval()

    def getReconLoss(self, output, target):
        output = self.dataDict["scaler"].inverse_transform(output)
        target = self.dataDict["scaler"].inverse_transform(target)
        assert output.size() == target.size(), "output size: {}, target size: {}".format(output.size(), target.size())
        outputs = {}
        
        if args.criterion == "RMSE":
            mse = nn.MSELoss()
            loss = torch.sqrt(mse(output, target)).item()
        elif args.criterion == "L1Loss":
            l1loss = nn.L1Loss()
            loss = l1loss(output, target).item()
        return loss, output, target

    def sketchRNNKLD(self, latentMean, latentStd, trainingMode, epoch):
        LKL = -0.5*torch.sum(1+latentStd-latentMean**2-torch.exp(latentStd))\
                /float(args.z_dim*args.batch_size)
        if trainingMode:
            # update eta for LKL:
            eta_step = 1-(1-args.eta_min)*args.R**epoch
            if args.cuda:
                KL_min = torch.Tensor([args.KL_min]).cuda()
            else:
                KL_min = torch.Tensor([args.KL_min])
            return eta_step * torch.max(LKL, KL_min)
        else:
            return LKL

    def getLoss(self, output, target, epoch):
        if args.model == "rnn":
            reconLoss, pred, target = self.getReconLoss(output, target)
            kldLoss = 0
            addtlDict = {}
        else:
            Pi, Mu, Sigma, latentMean, latentStd = output
            reconLoss = self.sketchRNNReconLoss(target, Pi, Mu, Sigma)
            kldLoss = self.sketchRNNKLD(latentMean, latentStd, self.model.training, epoch)
            pred = self.dataDict["scaler"].inverse_transform(self.model.generatePred(Pi, Mu, Sigma))
            target = self.dataDict["scaler"].inverse_transform(target)
            addtlDict = {
                "Pi" : Pi,
                "Mu" : Mu,
                "Sigma": Sigma,
                "latentMean" : latentMean,
                "latentStd" : latentStd
            }
        outputDict = {
            "reconLoss" : reconLoss,
            "kldLoss" : kldLoss,
            "pred" : pred,
            "target" : target,
            "additional" : addtlDict
        }
        return outputDict

    def sketchRNNReconLoss(self, target, Pi, Mu, Sigma):
        stackedTarget = torch.stack([target] * Mu.size(3), dim=3)
        m = torch.distributions.Normal(loc=Mu, scale=Sigma)
        loss = torch.exp(m.log_prob(stackedTarget))
        loss = torch.sum(loss * Pi, dim=3)
        loss= -torch.log(loss)
        return loss.mean()

    def getEpochLoss(self, dataset, epoch):
        datas = []
        preds = []
        targets = []
        Pi = []
        Mu = []
        Sigma = []
        epoch_recon_val_loss = 0.0
        epoch_kld_val_loss = 0.0
        nValBatches = 0
        with torch.no_grad():
            for batchIDX, (inputData, target) in enumerate(map(self.dataDict["scaler"].transformBatchForEpoch, self.dataDict[dataset])):
                nValBatches += 1
                output = self.model(inputData, target, epoch)
                lossOutputs = self.getLoss(output, target, epoch)
                epoch_recon_val_loss += lossOutputs["reconLoss"]
                epoch_kld_val_loss += lossOutputs["kldLoss"]
                preds.append(lossOutputs["pred"].cpu().detach().numpy())
                targets.append(lossOutputs["target"].cpu().detach().numpy())
                datas.append(self.dataDict["scaler"].inverse_transform(inputData[:,:,0,:]).cpu().detach().numpy())
                if args.model == "sketch-rnn":
                    Pi.append(lossOutputs["additional"]["Pi"])
                    Mu.append(lossOutputs["additional"]["Mu"])
                    Sigma.append(lossOutputs["additional"]["Sigma"])
                
        datas = np.concatenate(datas, axis=0)
        preds = np.concatenate(preds, axis=0)
        targets = np.concatenate(targets, axis=0)
        if args.model == "sketch-rnn":
            Pi = np.concatenate(Pi, axis=1)
            Mu = np.concatenate(Mu, axis=1)
            Sigma = np.concatenate(Sigma, axis=1)
        retVals = {
            "reconLoss" : epoch_recon_val_loss / nValBatches,
            "kldLoss" : epoch_kld_val_loss / nValBatches,
            "preds" : preds,
            "targets" : targets,
            "datas" : datas,
            "Pi" : Pi,
            "Mu" : Mu,
            "Sigma" : Sigma
        }
        return retVals

    def prep(self, stateDictFile, dataset, epoch):
        desired_state_dict = torch.load(self.modelPath +stateDictFile, map_location=lambda storage, loc: storage)
        self.model.load_state_dict(desired_state_dict)
        self.model.eval()
        assert dataset in ["train", "val", "test"]
        

    def getLossAtEpoch(self, stateDictFile, dataset, epoch= -1):
        self.prep(stateDictFile, dataset, epoch)
        retVals = self.getEpochLoss(dataset, epoch)
        return retVals

In [None]:
def plotTrainValCurve(trainLosses, valLosses, trainKLDLosses=None, valKLDLosses=None):
    plot_every = 1
    plt.rcParams.update({'font.size': 8})
    plt.figure()
    fig, ax1 = plt.subplots()
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel(args.criterion, color="r")
    ax1.tick_params('y', colors='r')
    ax1.plot(np.arange(1, len(trainLosses)+1)*plot_every, trainLosses, "r--", label="train reconstruction loss")
    ax1.plot(np.arange(1, len(valLosses)+1)*plot_every, valLosses, color="red", label="validation reconstruction loss")
    ax1.legend(loc="upper left")
    ax1.grid()
    if trainKLDLosses:
        ax2 = ax1.twinx()
        ax2.set_ylabel("KLD Loss", color="b")
        ax2.tick_params('y', colors='b')
        ax2.plot(np.arange(1, len(trainKLDLosses)+1)*plot_every, trainKLDLosses, "b--", label="train KLD loss")
        ax2.plot(np.arange(1, len(valKLDLosses)+1)*plot_every, valKLDLosses, color="blue", label="val KLD loss")
        ax2.legend(loc="upper right")
        ax2.grid()
    plt.title("Losses for {}".format(args.model))

In [None]:
def plotNHours(means, stds, targets, datas, dataset, dataMean, dataStd, targetTimes, N=24):
    assert False, "Need to fix"
    instance = np.random.randint(targets.shape[1])
    sensor = np.random.randint(targets.shape[2])
    sequenceTrueMean = []
    sequenceTrueStd = []
    sequenceSampleMean = []
    sequenceSampleStd = []
    sequenceTarget = []
    sequenceTimes = []
    shouldMask = []
    maskindex = []
    lastTime = None
    for tStep in range(N):
        realIndex = instance + 12 * tStep
        if realIndex >= means.shape[1]:
            break
        if lastTime and inMinutes(targetTimes[realIndex, -1] - lastTime) > 5:
            shouldMask += [True]
        else:
            shouldMask += [False]
        lastTime = targetTimes[realIndex, -1]
        maskindex += [len(sequenceTrueMean)]
        m = means[:, realIndex, sensor]
        std = stds[:, realIndex, sensor]
        predSamples, sampleMean, sampleStd = getScaledSamples(m, std, dataMean, dataStd)
        sequenceTrueMean += list(m)
        sequenceTrueStd += list(std)
        sequenceSampleMean += list(sampleMean)
        sequenceSampleStd += list(sampleStd)
        sequenceTarget += list(targets[:, realIndex, sensor])
        sequenceTimes += list(targetTimes[realIndex])
        
    #f, ax = plt.subplots(2, sharex=True)
    #f.subplots_adjust(hspace=.5)
    """
    maskedSampleMean = ma.array(sequenceSampleMean)
    maskedTarget = ma.array(sequenceTarget)
    print(maskindex)
    print(shouldMask)
    print(maskedSampleMean.shape)
    for idx, should in zip(maskindex, shouldMask):
        if should:
            maskedSampleMean[idx] = ma.masked
            maskedTarget[idx] = ma.masked
    """
    #print(np.max(sequenceSampleStd), sequenceTimes[np.argmax(sequenceSampleStd)])
    #print(sequenceSampleMean)
    f, ax = plt.subplots()
    f.set_figwidth(15)
    plt.plot(sequenceTimes, sequenceSampleMean, label="pred")
    plt.plot(sequenceTimes, sequenceTarget, label="target")
    plt.fill_between(sequenceTimes,np.array(sequenceSampleMean)-1.96*np.array(sequenceSampleStd), np.array(sequenceSampleMean)+1.96*np.array(sequenceSampleStd), alpha=0.5)
    plt.xticks(rotation=90)
    xfmt = md.DateFormatter('%Y-%m-%d %H:%M:%S')
    ax=plt.gca()
    ax.xaxis.set_major_formatter(xfmt)
    plt.legend()
    plt.ylabel("mile/h")
    plt.title("{} Hour Sample Prediction {}".format(N, dataset))
    yMin = np.min((np.min(sequenceSampleMean)-10, np.min(sequenceTarget)-10, 10))
    yMax = np.max((np.max(sequenceSampleMean)+10, np.max(sequenceTarget)+10, 70))
    #plt.ylim((yMin,yMax))