In [None]:
from model.Data import DataLoader
import numpy as np
import torch
from utils import kld_gauss, unNormalize, load_dataset
from model.vrnn.model import VRNN
from joblib import Parallel, delayed

In [None]:
batchSize = 10
dataDict = load_dataset("./data/reformattedTraffic/", batchSize)

In [None]:
def getEpochLoss(model, dataLoader, args, dataDict, limit=100):
    epoch = 0
    totalKLDLoss = 0.0
    totalReconLoss = 0.0
    for batch_idx, (data, target, dataT, targetT) in enumerate(dataLoader.get_iterator()):
        if batch_idx == limit:
            break
        if batch_idx % 100 == 0:
            print("batch", batch_idx)
        data = torch.as_tensor(data, dtype=torch.float, device="cpu").transpose(0,1)
        target = torch.as_tensor(target, dtype=torch.float, device="cpu").transpose(0,1)
        output = model(data, target, epoch)
        del data
        encoder_means, encoder_stds, decoder_means, decoder_stds, prior_means, prior_stds, all_samples = output
        # Calculate KLDivergence part
        totalKLDLoss = 0.0
        for enc_mean_t, enc_std_t, decoder_mean_t, decoder_std_t, prior_mean_t, prior_std_t, sample in\
        zip(encoder_means, encoder_stds, decoder_means, decoder_stds, prior_means, prior_stds, all_samples):
            kldLoss = kld_gauss(enc_mean_t, enc_std_t, prior_mean_t, prior_std_t)
            totalKLDLoss += args.kld_weight * kldLoss
        #Calculate Prediction Loss
        pred = torch.cat([torch.unsqueeze(y, dim=0) for y in all_samples])
        unNPred = unNormalize(pred.detach(), dataDict["train_mean"], dataDict["train_std"])
        unNTarget = unNormalize(target.detach(), dataDict["train_mean"], dataDict["train_std"])
        assert pred.size() == target.size()
        if args.criterion == "RMSE":
            predLoss = torch.sqrt(torch.mean((pred - target)**2))    
            unNormalizedLoss = torch.sqrt(torch.mean((unNPred - unNTarget)**2))
        elif args.criterion == "L1Loss":
            predLoss = torch.mean(torch.abs(pred - target))
            unNormalizedLoss = torch.mean(torch.abs(unNPred - unNTarget))
        totalKLDLoss += ((totalKLDLoss / args.sequence_len))
        totalReconLoss += unNormalizedLoss
    return (totalKLDLoss / min(dataLoader.num_batch, limit)).data.item(), (totalReconLoss / min(dataLoader.num_batch, limit)).data.item()

In [None]:
modelOld = torch.load("./save/models/model513/vrnn_full_model.pth")

In [None]:
class Bunch(object):
  def __init__(self, adict):
    self.__dict__.update(adict)
import json

In [None]:
argsStr = '{"h_dim": 512, "z_dim": 128, "no_cuda": true, "no_attn": true, "n_epochs": 500, "batch_size": 10, "n_layers": 2, "initial_lr": 0.001, "no_lr_decay": true, "lr_decay_ratio": 0.1, "lr_decay_beginning": 20, "lr_decay_every": 10, "print_every": 20, "plot_every": 1, "criterion": "RMSE", "save_freq": 10, "down_sample": 0, "data_dir": "./data", "model": "vrnn", "weight_decay": 5e-05, "no_schedule_sampling": false, "scheduling_start": 1.0, "scheduling_end": 0.0, "tries": 10, "kld_weight": 0.1, "save_dir": "./save/models/model513/", "cuda": false, "_device": "cpu", "use_attn": false, "x_dim": 207, "sequence_len": 12, "use_schedule_sampling": true}'

In [None]:
argsD = json.loads(argsStr)

In [None]:
args = Bunch(argsD)

In [None]:
modelNew = VRNN(args)

In [None]:
modelNew.load_state_dict(modelOld.state_dict())

In [None]:
avgkldLoss, avgreconLoss = getEpochLoss(modelNew, dataDict["train_loader"], args, dataDict, limit=100)
print(avgkldLoss, avgreconLoss)

In [None]:
avgkldLoss, avgreconLoss = getEpochLoss(modelNew, dataDict["val_loader"], args, dataDict, limit=100)
print(avgkldLoss, avgreconLoss)