In [307]:
import json
import sys
sys.path.append("../batchedRNN")
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [343]:
# %load ../batchedRNN/model/SketchRNN.py
import torch
import torch.nn as nn
from torch.autograd import Variable
from torch import optim
import torch.nn.functional as F

class SketchRNNEncoder(nn.Module):
    def __init__(self):
        super(SketchRNNEncoder, self).__init__()
        if args.bidirectionalEncoder:
            self.directions = 2
        else:
            self.directions = 1
        # bidirectional lstm:
        self.lstm = nn.LSTM(args.x_dim * args.channels, args.encoder_h_dim, \
            args.n_layers, dropout=args.encoder_layer_dropout, bidirectional=args.bidirectionalEncoder)
        # create mu and sigma from lstm's last output:
        self.fc_mu = nn.Linear(args.n_layers * self.directions * args.encoder_h_dim, args.z_dim)
        self.fc_sigma = nn.Linear(args.n_layers * self.directions * args.encoder_h_dim, args.z_dim)
        

    def forward(self, input, hidden_cell=None):
        if hidden_cell is None:
            hidden_cell = self.init_hidden_cell()
        _, (hidden, cell) = self.lstm(input, hidden_cell)
        # convert hidden size from (n_layers * directions, batch_size, h_dim)
        #                       to (batch_size, n_layers * directions * h_dim)
        hiddenLayers = torch.split(hidden, 1, 0)
        if self.directions == 2 and args.n_layers == 2:
            assert len(hiddenLayers) == 4
        hidden_cat = torch.cat([h.squeeze(0) for h in hiddenLayers], 1)
        mu = self.fc_mu(hidden_cat)
        sigma_hat = self.fc_sigma(hidden_cat)
        sigma = torch.exp(sigma_hat / 2)
        z_size = mu.size()
        if args.cuda:
            N = Variable(torch.normal(torch.zeros(z_size),torch.ones(z_size)).cuda())
        else:
            N = Variable(torch.normal(torch.zeros(z_size),torch.ones(z_size)))
        z = mu + sigma*N
        return z, mu, sigma_hat



    def init_hidden_cell(self):
        hidden = Variable(torch.zeros(self.directions * args.n_layers, args.batch_size, args.encoder_h_dim))
        cell = Variable(torch.zeros(self.directions * args.n_layers, args.batch_size, args.encoder_h_dim))
        if args.cuda:
            return (hidden.cuda(), cell.cuda())
        else:
            return (hidden, cell)

class SketchRNNDecoder(nn.Module):
    def __init__(self):
        super(SketchRNNDecoder, self).__init__()
        # to init hidden and cell from z:
        self.fc_hc = nn.Linear(args.z_dim, 2 * args.n_layers * args.decoder_h_dim)
        # unidirectional lstm:
        self.lstm = nn.LSTM(args.z_dim + args.output_dim, args.decoder_h_dim, args.n_layers, dropout=args.decoder_layer_dropout)
        self.muLayer = nn.Linear(args.decoder_h_dim, args.output_dim * args.n_gaussians)
        self.sigmaLayer = nn.Linear(args.decoder_h_dim, args.output_dim * args.n_gaussians)
        self.piLayer = nn.Linear(args.decoder_h_dim, args.output_dim * args.n_gaussians)

    def forward(self, inputs, z, hidden_cell=None):
        if hidden_cell is None:
            layers = torch.split(torch.tanh(self.fc_hc(z)),args.decoder_h_dim,1)
            hidden = torch.stack(layers[:int(len(layers) / 2)], dim=0)
            cell = torch.stack(layers[int(len(layers) / 2): ], dim=0)
            hidden_cell = (hidden.contiguous(), cell.contiguous())
        outputs,(hidden,cell) = self.lstm(inputs, hidden_cell)
        # outputs size: (seq_len, batch, num_directions * hidden_size)
        # hidden size: (num_layers * num_directions, batch, hidden_size)
        # cell size: (num_layers * num_directions, batch, hidden_size)
        mu = self.muLayer(outputs).view(-1, args.batch_size, args.output_dim, args.n_gaussians)
        sigma = self.sigmaLayer(outputs).view(-1, args.batch_size, args.output_dim, args.n_gaussians)
        pi = self.piLayer(outputs).view(-1, args.batch_size, args.output_dim, args.n_gaussians)
        pi = F.softmax(pi, 3)
        sigma = torch.exp(sigma)
        return (pi, mu, sigma), (hidden, cell)

class SketchRNN(nn.Module):
    def __init__(self):
        super(SketchRNN, self).__init__()
        if args.cuda:
            self.encoder = SketchRNNEncoder().cuda()
            self.decoder = SketchRNNDecoder().cuda()
        else:
            self.encoder = SketchRNNEncoder()
            self.decoder = SketchRNNDecoder()

    def scheduleSample(self, epoch):
        eps = max(args.scheduling_start - 
            (args.scheduling_start - args.scheduling_end)* epoch / args.args.n_epochs,
            args.scheduling_end)
        return np.random.binomial(1, eps)

    def generatePred(self, pi, mu, sigma):
        if args.cuda:
            N = Variable(torch.normal(torch.zeros(pi.size()),torch.ones(pi.size())).cuda())
        else:
            N = Variable(torch.normal(torch.zeros(pi.size()),torch.ones(pi.size())))
        clusterPredictions = mu + sigma * N
        weightedClusterPredictions = clusterPredictions * pi
        pred = torch.sum(weightedClusterPredictions, dim=3)
        return pred

    def allSteps(self, target, z):
        sos = self.getStartOfSequence()
        batch_init = torch.cat([sos, target[:-1,...]], 0)
        z_stack = torch.stack([z]*(args.sequence_len))
        inp = torch.cat([batch_init, z_stack], 2)
        (pi, mu, sigma), (hidden, cell) = self.decoder(inp, z)
        return (pi, mu, sigma)

    def oneStepAtATime(self, z):
        sos = self.getStartOfSequence()
        inp = torch.cat([sos, z.unsqueeze(0)], 2)
        piList, muList, sigmaList = [], [], []
        for timeStep in range(args.sequence_len):
            (pi, mu, sigma), (hidden, cell) = self.decoder(inp, z)
            pred = self.generatePred(pi, mu, sigma)
            inp = torch.cat([pred, z.unsqueeze(0)], 2)
            piList.append(pi)
            muList.append(mu)
            sigmaList.append(sigma)
        Pi = torch.cat(piList, 0)
        Mu = torch.cat(muList, 0)
        Sigma = torch.cat(sigmaList, 0)
        return (Pi, Mu, Sigma)

    def getStartOfSequence(self):
        if args.cuda:
            return Variable(torch.zeros(1, args.batch_size, args.output_dim).cuda())
        else:
            return Variable(torch.zeros(1, args.batch_size, args.output_dim))

    def doEncoding(self,batch):
        # convert input from [sequence_len, batch_size, channels, x_dim]
        #                 to [sequence_len, batch_size, channels * x_dim]
        embedded = batch.contiguous().view(-1, args.batch_size, args.x_dim * args.channels)
        z, mu, sigma_hat = self.encoder(embedded)
        return z, mu, sigma_hat, embedded

    def forward(self, batch, target):
        z, latentMean, latentStd, embedded = self.doEncoding(embedded)
        if self.training:
            (Pi, Mu, Sigma) = self.allSteps(target, z)
        else:
            (Pi, Mu, Sigma) = self.oneStepAtATime(z)
        return Pi, Mu, Sigma, latentMean, latentStd


In [344]:
# %load ../batchedRNN/utils.py
import logging, sys
import torch
import h5py
import os
import numpy as np
import torch.utils.data as torchUtils
import torch.optim as optim
from functools import partial
import torch.nn as nn
import json
from shutil import copy2, copyfile, copytree
import argparse
import matplotlib
# matplotlib.use('Agg')
import matplotlib.pyplot as plt

# parser = argparse.ArgumentParser(description='Batched Sequence to Sequence')
# parser.add_argument('--h_dim', type=int, default=256)
# parser.add_argument("--z_dim", type=int, default=128)
# parser.add_argument('--no_cuda', action='store_true', default=False,
#                                         help='disables CUDA training')
# parser.add_argument("--no_attn", action="store_true", default=True, help="Do not use AttnDecoder")
# parser.add_argument("--n_epochs", type=int, default=200)
# parser.add_argument("--batch_size", type=int, default= 64)
# parser.add_argument("--n_layers", type=int, default=2)
# parser.add_argument("--initial_lr", type=float, default=1e-4)
# parser.add_argument("--lr_decay_every", type=int, default=10)
# parser.add_argument("--lr_decay_factor", type=float, default=.10)
# parser.add_argument("--lr_decay_beginning", type=int, default=20)
# parser.add_argument("--print_every", type=int, default = 200)
# parser.add_argument("--criterion", type=str, default="L1Loss")
# parser.add_argument("--save_freq", type=int, default=10)
# parser.add_argument("--down_sample", type=float, default=0.0, help="Keep this fraction of the training data")
# # parser.add_argument("--data_dir", type=str, default="./data/reformattedTraffic/")
# parser.add_argument("--model", type=str, default="sketch-rnn")
# parser.add_argument("--lambda_l1", type=float, default=0)
# parser.add_argument("--lambda_l2", type=float, default=5e-4)
# parser.add_argument("--no_schedule_sampling", action="store_true", default=False)
# parser.add_argument("--scheduling_start", type=float, default=1.0)
# parser.add_argument("--scheduling_end", type=float, default=0.0)
# parser.add_argument("--tries", type=int, default=12)
# parser.add_argument("--kld_warmup_until", type=int, default=5)
# parser.add_argument("--kld_weight_max", type=float, default=0.10)
# parser.add_argument("--no_shuffle_after_epoch", action="store_true", default=False)
# parser.add_argument("--clip", type=int, default=10)
# parser.add_argument("--dataset", type=str, default="traffic")
# parser.add_argument("--predictOnTest", action="store_true", default=True)
# parser.add_argument("--encoder_input_dropout", type=float, default=0.5)
# parser.add_argument("--encoder_layer_dropout", type=float, default=0.5)
# parser.add_argument("--decoder_input_dropout", type=float, default=0.5)
# parser.add_argument("--decoder_layer_dropout", type=float, default=0.5)
# parser.add_argument("--noEarlyStopping", action="store_true", default=False)
# parser.add_argument("--earlyStoppingPatients", type=int, default=3)
# parser.add_argument("--earlyStoppingMinDelta", type=float, default=0.0001)
# parser.add_argument("--bidirectionalEncoder", type=bool, default=True)
# parser.add_argument("--local", action="store_true", default=False)
# parser.add_argument("--debugDataset", action="store_true", default=False)
# parser.add_argument("--encoder_h_dim", type=int, default=256)
# parser.add_argument("--decoder_h_dim", type=int, default=512)
# parser.add_argument("--num_mixtures", type=int, default=20)
# args = parser.parse_args()
logging.basicConfig(stream=sys.stderr,level=logging.DEBUG)

def plotLosses(trainLosses, valLosses, trainKLDLosses=None, valKLDLosses=None):
    torch.save(trainLosses, args.save_dir+"plot_train_recon_losses")
    torch.save(valLosses, args.save_dir+"plot_val_recon_losses")
    if trainKLDLosses and valKLDLosses:
        torch.save(trainKLDLosses, args.save_dir+"plot_train_KLD_losses")
        torch.save(valKLDLosses, args.save_dir+"plot_val_KLD_losses")
    plt.rcParams.update({'font.size': 8})
    fig, ax1 = plt.subplots()
    ax1.set_xlabel("Epoch")
    ax1.set_ylabel(args.criterion, color="r")
    ax1.tick_params('y', colors='r')
    ax1.plot(np.arange(1, len(trainLosses)+1), trainLosses, "r--", label="train reconstruction loss")
    ax1.plot(np.arange(1, len(valLosses)+1), valLosses, color="red", label="validation reconstruction loss")
    ax1.legend(loc="upper left")
    ax1.grid()
    plt.title("Losses for {}".format(args.model))
    plt.savefig(args.save_dir + "train_val_loss_plot.png")

def getSaveDir():
    if args.local:
        saveDir = '../save/local/models/model0/'
    else:
        saveDir = '../save/models/model0/'
    while os.path.isdir(saveDir):
        numStart = saveDir.rfind("model")+5
        numEnd = saveDir.rfind("/")
        saveDir = saveDir[:numStart] + str(int(saveDir[numStart:numEnd])+1) + "/"
    os.mkdir(saveDir)
    return saveDir

def saveUsefulData():
    argsFile = args.save_dir + "args.txt"
    with open(argsFile, "w") as f:
        f.write(json.dumps(vars(args)))
    copy2("./train.py", args.save_dir+"train.py")
    copy2("./utils.py", args.save_dir+"utils.py")
    copy2("./gridSearchOptimize.py", args.save_dir+"gridsearchOptimize.py")
    copytree("./model", args.save_dir+"model/")

def getTrafficDataset(dataDir, category):
    f = np.load(os.path.join(dataDir, category + '.npz'))
    my_dataset = torchUtils.TensorDataset(torch.Tensor(f["inputs"]),torch.Tensor(f["targets"])) # create your datset
    scaler = getScaler(f["inputs"])
    sequence_len = f['inputs'].shape[1]
    x_dim = f['inputs'].shape[2]
    channels = f["inputs"].shape[3]
    return my_dataset, scaler, sequence_len, sequence_len, x_dim, channels

def getHumanDataset(dataDir, category):
    f = h5py.File(os.path.join(dataDir, category+".h5"), "r")
    my_dataset = torchUtils.TensorDataset(torch.Tensor(f["input2d"]), torch.Tensor(f["target2d"]))
    scaler = getScaler(f["input2d"])
    input_sequence_len = f["input2d"].shape[1]
    target_sequence_len = f["target2d"].shape[1]
    x_dim = f["input2d"].shape[2]
    channels = f["input2d"].shape[3]
    return my_dataset, scaler, input_sequence_len, target_sequence_len, x_dim, channels

def getLoaderAndScaler(dataDir, category):
    logging.info("Getting {} loader".format(category))
    if args.dataset == "traffic":
        my_dataset, scaler, input_sequence_len, target_sequence_len, x_dim, channels = getTrafficDataset(dataDir, category)
    else:
        my_dataset, scaler, input_sequence_len, target_sequence_len, x_dim, channels = getHumanDataset(dataDir, category)
    shf = False
    if category == "train":
        shf = True
    loader = torchUtils.DataLoader(
        my_dataset,
        batch_size=args.batch_size,
        shuffle=shf,
        num_workers=0,
        pin_memory=False,
        drop_last=True
        )
    return loader, scaler, input_sequence_len, target_sequence_len, x_dim, channels # create your dataloader

def getDataLoaders(dataDir, debug=False):
    loaders = {}
    logging.info("Getting data from {}".format(dataDir))
    if debug:
        categories = ["test"]
        scalerSet = "test"
    else:
        categories = ["train", "val", "test"]
        scalerSet = "train"
    for category in categories:
        loader, scaler, input_sequence_len, target_sequence_len, x_dim, channels = getLoaderAndScaler(dataDir, category)
        if category == scalerSet:
            loaders["scaler"] = scaler
            loaders["input_sequence_len"] = input_sequence_len
            loaders["target_sequence_len"] = target_sequence_len
            loaders["x_dim"] = x_dim
            loaders["channels"] = channels
        loaders[category] = loader
    return loaders

class StandardScaler:
    """
    Standard the input
    """

    def __init__(self, mean0, std0, mean1, std1):
        self.mean0 = mean0
        self.std0 = std0
        self.mean1 = mean1
        self.std1 = std1

    def transform(self, data):
        mean = torch.zeros(data.size())
        mean[...,0] = self.mean0
        mean[...,1] = self.mean1
        std = torch.ones(data.size())
        std[...,0] = self.std0
        std[...,1] = self.std1
        return torch.div(torch.sub(data,mean),std)

class StandardScalerTraffic(StandardScaler):
    def __init__(self, mean0, std0):
        super(StandardScalerTraffic, self).__init__(mean0, std0, 0.0, 1.0)

    def inverse_transform(self, data):
        """
        Inverse transform is applied to output and target.
        These are only the speeds, so only use the first 
        """
        mean = torch.ones(data.size()) * self.mean0
        std = torch.ones(data.size()) * self.std0
        if args.cuda:
            mean = mean.cuda()
            std = std.cuda()
        transformed = torch.add(torch.mul(data, std), mean)
        del mean, std
        return transformed.permute(1,0,2)

    def transformBatchForEpoch(self, batch):
        x = self.transform(batch[0]).permute(1,0,3,2)
        y = self.transform(batch[1])[...,0].permute(1,0,2)
        if args.cuda:
            return x.cuda(), y.cuda()
        return x, y

class StandardScalerHuman(StandardScaler):
    """docstring for StandardScalerHuman"""
    def __init__(self, mean0, std0, mean1, std1):
        super(StandardScalerHuman, self).__init__(mean0, std0, mean1, std1)

    def inverse_transform(self, data):
        """
        applied to output and target
        """
        transed = self.restoreDim(data)
        mean = torch.zeros(transed.size())
        std = torch.ones(transed.size())
        if args.cuda:
            mean = mean.cuda()
            std = std.cuda()
        mean[...,0] = self.mean0
        mean[...,1] = self.mean1
        std[...,0] = self.std0
        std[...,1] = self.std1
        transformed =  torch.add(torch.mul(transed, std), mean)
        del mean, std
        return transformed.permute(1,0,3,2)

    def restoreDim(self, data):
        l1, l2 = torch.split(data, int(data.size(2) / 2), 2)
        return torch.cat((l1.unsqueeze(3), l2.unsqueeze(3)), dim=3)

    def removeDim(self, data):
        layer0, layer1 = torch.split(data, 1, dim=3)
        return torch.cat((layer0.squeeze(3), layer1.squeeze(3)), dim=2)

    def transformBatchForEpoch(self, batch):
        x = self.transform(batch[0]).permute(1,0,3,2)
        y = self.transform(batch[1])
        wideY = self.removeDim(y).permute(1,0,2)
        if args.cuda:
            return x.cuda(), wideY.cuda()
        return x, wideY

def getScaler(trainX):
    mean0 = np.mean(trainX[...,0])
    std0 = np.std(trainX[...,0])
    mean1 = np.mean(trainX[...,1])
    std1 = np.std(trainX[...,1])
    if args.dataset == "traffic":
        return StandardScalerTraffic(mean0, std0)
    elif args.dataset == "human":
        return StandardScalerHuman(mean0, std0, mean1, std1)
    else:
        assert False, "bad dataset"

def getReconLoss(output, target, scaler):
    output = scaler.inverse_transform(output)
    target = scaler.inverse_transform(target)
    assert output.size() == target.size(), "output size: {}, target size: {}".format(output.size(), target.size())
    if args.criterion == "RMSE":
        criterion = nn.MSELoss()
        return torch.sqrt(criterion(output, target))
    elif args.criterion == "L1Loss":
        criterion = nn.L1Loss()
        return criterion(output, target)
    else:
        assert False, "bad loss function"

def getKLDWeight(epoch):
    # kldLossWeight = args.kld_weight_max * min((epoch / (args.kld_warmup_until)), 1.0)
    kldLossWeight = args.kld_weight_max
    return kldLossWeight

def kld_gauss(mean_1, std_1, mean_2, std_2):
    """Using std to compute KLD"""

    kld_element = (2 * torch.log(std_2) - 2 * torch.log(std_1) +
                   (std_1.pow(2) + (mean_1 - mean_2).pow(2)) /
                   std_2.pow(2) - 1)
    return 0.5 * torch.sum(kld_element)

def sketchRNNKLD(latentMean, latentStd):
    m2 = torch.zeros_like(latentMean)
    s2 = torch.ones_like(latentStd)
    return kld_gauss(latentMean, latentStd, m2, s2)

def getLoss(model, output, target, scaler, epoch):
    if args.model == "rnn":
        reconLoss = getReconLoss(output, target, scaler)
        return reconLoss, 0
    else:
        latentMean, latentStd, z, predOut, predMeanOut, predStdOut = output
        reconLoss = getReconLoss(predOut, target, scaler)
        kldLoss = sketchRNNKLD(latentMean, latentStd)
        return reconLoss, kldLoss

def saveModel(modelWeights, epoch):
    fn = args.save_dir+'{}_state_dict_'.format(args.model)+str(epoch)+'.pth'
    torch.save(modelWeights, fn)
    logging.info('Saved model to '+fn)

class EarlyStoppingObject(object):
    """docstring for EarlyStoppingObject"""
    def __init__(self):
        super(EarlyStoppingObject, self).__init__()
        self.bestLoss = None
        self.bestEpoch = None
        self.counter = 0
        self.epochCounter = 0

    def checkStop(self, previousLoss):
        self.epochCounter += 1
        if not args.noEarlyStopping:
            if self.bestLoss is not None and previousLoss + args.earlyStoppingMinDelta >= self.bestLoss:
                self.counter += 1
                if self.counter >= args.earlyStoppingPatients:
                    logging.info("Stopping Early, haven't beaten best loss {:.4f} @ Epoch {} in {} epochs".format(
                        self.bestLoss,
                        self.bestEpoch,
                        args.earlyStoppingPatients))
                    return True
            else:
                self.bestLoss = previousLoss
                self.bestEpoch = self.epochCounter
                self.counter = 0
                return False

        else:
            return False


In [345]:
%run GetLossObj.ipynb

In [346]:
baseDir = "../save/local/models/model21/"
with open(baseDir + "args.txt") as f:
    args = f.read()
args = Bunch(json.loads(args))
args.encoder_h_dim = args.h_dim
args.decoder_h_dim = args.h_dim
args.z_dim = 128
args.cuda= False
args.sequence_len = 12
args.n_gaussians = 20

In [347]:
print(args)

h_dim: 256, z_dim: 128, no_cuda: False, no_attn: True, n_epochs: 20, batch_size: 64, n_layers: 2, initial_lr: 0.0001, lr_decay_every: 10, lr_decay_factor: 0.1, lr_decay_beginning: 20, print_every: 10, criterion: L1Loss, save_freq: 10, down_sample: 0.0, model: rnn, lambda_l1: 0, lambda_l2: 0.0005, no_schedule_sampling: True, scheduling_start: 1.0, scheduling_end: 0.0, tries: 12, kld_warmup_until: 5, kld_weight_max: 0.1, no_shuffle_after_epoch: False, clip: 10, dataset: traffic, predictOnTest: True, encoder_input_dropout: 0.5, encoder_layer_dropout: 0.5, decoder_input_dropout: 0.5, decoder_layer_dropout: 0.5, noEarlyStopping: True, earlyStoppingPatients: 3, earlyStoppingMinDelta: 0.0001, bidirectionalEncoder: True, local: True, debugDataset: True, cuda: False, _device: cpu, use_attn: False, use_schedule_sampling: False, x_dim: 207, input_sequence_len: 12, target_sequence_len: 12, channels: 2, output_dim: 207, save_dir: ../save/local/models/model21/, encoder_h_dim: 256, decoder_h_dim: 256

In [348]:
data = getDataLoaders("/Users/danielzeiberg/Documents/Data/Traffic/Processed/trafficWithTime/down_sample_0.1/")

INFO:root:Getting data from /Users/danielzeiberg/Documents/Data/Traffic/Processed/trafficWithTime/down_sample_0.1/
INFO:root:Getting train loader
INFO:root:Getting val loader
INFO:root:Getting test loader


In [349]:
sketch = SketchRNN()

In [350]:
(inputData, target) = next(map(data["scaler"].transformBatchForEpoch, data["train"]), 1)

In [351]:
inputData.shape, target.shape

(torch.Size([12, 64, 2, 207]), torch.Size([12, 64, 207]))

In [352]:
z, latentMean, latentStd, embedded = sketch.doEncoding(inputData)

In [353]:
z.shape, latentMean.shape, latentStd.shape, embedded.shape

(torch.Size([64, 128]),
 torch.Size([64, 128]),
 torch.Size([64, 128]),
 torch.Size([12, 64, 414]))

In [354]:
sketch.getStartOfSequence().size()

torch.Size([1, 64, 207])

In [355]:
Pi, Mu, Sigma = sketch.allSteps(target, z)

In [356]:
Pi.shape, Mu.shape, Sigma.shape

(torch.Size([12, 64, 207, 20]),
 torch.Size([12, 64, 207, 20]),
 torch.Size([12, 64, 207, 20]))

In [357]:
PiOne, MuOne, SigmaOne = sketch.oneStepAtATime(z)

In [358]:
PiOne.shape, MuOne.shape, SigmaOne.shape

(torch.Size([12, 64, 207, 20]),
 torch.Size([12, 64, 207, 20]),
 torch.Size([12, 64, 207, 20]))

In [375]:
m = torch.distributions.Normal(loc=Mu, scale=Sigma)


In [376]:
stackedTarget = torch.stack([target]*Mu.size(3), dim=3)
loss = torch.exp(m.log_prob(stackedTarget))
loss.size()

torch.Size([12, 64, 207, 20])

In [377]:
weightedLoss = loss * Pi
print(weightedLoss.size())
loss = torch.sum(weightedLoss, dim=3)
print(loss.size())

torch.Size([12, 64, 207, 20])
torch.Size([12, 64, 207])


In [378]:
loss= -torch.log(loss)
loss.size()

torch.Size([12, 64, 207])

In [379]:
loss.mean()

tensor(1.3814, grad_fn=<MeanBackward1>)