<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc" style="margin-top: 1em;"><ul class="toc-item"></ul></div>

In [None]:
from esper.prelude import *
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
import pyro.optim as optim
import pyro.infer as infer
from torch.utils.data import DataLoader
from custom_mlp import MLP, Exp
from transcript_utils import *
from timeit import default_timer as now

In [None]:
# compute_vectors(video_list(), vocabulary, SEGMENT_SIZE, SEGMENT_STRIDE)

In [None]:
# define a PyTorch module for the VAE
class VAE(nn.Module):
    # by default our latent space is 50-dimensional
    # and we use 400 hidden units
    def __init__(self, z_dim=50, hidden_layers=(500,), categories=2, use_cuda=False):
        super(VAE, self).__init__()
        # create the encoder and decoder networks
        
        self.input_size = vocab_size
        self.output_size = categories
        self.use_cuda = use_cuda
        self.z_dim = z_dim
        self.hidden_layers = list(hidden_layers)
        
        self.encoder_y = MLP([self.input_size] + self.hidden_layers + [self.output_size],
                             activation=nn.Softplus,
                             output_activation=nn.Softmax,
                             use_cuda=self.use_cuda)

        self.encoder_z = MLP([self.input_size + self.output_size] +
                             self.hidden_layers + [[self.z_dim, self.z_dim]],
                             activation=nn.Softplus,
                             output_activation=[None, Exp],
                             use_cuda=self.use_cuda)
        
        self.decoder = MLP([self.z_dim + self.output_size] +
                           self.hidden_layers + [self.input_size],
                           activation=nn.Softplus,
                           output_activation=nn.Sigmoid,
                           use_cuda=self.use_cuda)

        if use_cuda:
            # calling cuda() here will put all the parameters of
            # the encoder and decoder networks into gpu memory
            self.cuda()


    # define the model p(x|z)p(z)
    def model(self, xs, ys=None):
        # register PyTorch module `decoder` with Pyro
        pyro.module("ss_vae", self)
        batch_size = xs.shape[0]
        with pyro.iarange("data", batch_size):
            # setup hyperparameters for prior p(z)
            z_loc = xs.new_zeros(torch.Size((batch_size, self.z_dim)))
            z_scale = xs.new_ones(torch.Size((batch_size, self.z_dim)))
            # sample from prior (value will be sampled by guide when computing the ELBO)
            zs = pyro.sample("latent", dist.Normal(z_loc, z_scale).independent(1))
            
            alpha_prior = xs.new_ones([batch_size, self.output_size]) / (1.0 * self.output_size)
            ys = pyro.sample("y", dist.OneHotCategorical(alpha_prior), obs=ys)
            
            # decode the latent code z
            loc = self.decoder.forward([zs, ys])
            # score against actual images
            pyro.sample("obs", dist.Bernoulli(loc).independent(1), obs=xs)
            # return the loc so we can visualize it later
            return loc

    # define the guide (i.e. variational distribution) q(z|x)
    def guide(self, xs, ys=None):
        with pyro.iarange("data", xs.size(0)):
            
            if ys is None:
                alpha = self.encoder_y.forward(xs)
                ys = pyro.sample("y", dist.OneHotCategorical(alpha))                        
            
            # use the encoder to get the parameters used to define q(z|x)
            z_loc, z_scale = self.encoder_z.forward([xs, ys])
            # sample the latent code z
            pyro.sample("latent", dist.Normal(z_loc, z_scale).independent(1))
            
    def classifier(self, xs):
        alpha = self.encoder_y.forward(xs)
        res, ind = torch.topk(alpha, 1)
        ys = xs.new_zeros(alpha.size())
        ys = ys.scatter_(1, ind, 1.0)
        return ys

In [None]:
unsup_dataset = SegmentVectorDataset(video_list(), vocab_size=vocab_size)
sup_dataset = LabeledSegmentDataset(unsup_dataset, pcache.get('labeled_segments'), categories=2)
loader_params = {'shuffle': True}
unsup_loader = DataLoader(unsup_dataset, batch_size=8, **loader_params)
sup_loader = DataLoader(sup_dataset, batch_size=8, **loader_params)
data_loaders = {"unsup": unsup_loader, "sup": sup_loader}

In [None]:
pyro.clear_param_store()
vae = VAE()
optimizer = optim.ClippedAdam({"lr": 0.001, "betas": [0.9, 0.999]})
svi = infer.SVI(vae.model, infer.config_enumerate(vae.guide), optimizer, loss=infer.Trace_ELBO())
loss_history = []

In [None]:
def run_inference_for_epoch(data_loaders, losses, sup_batches, unsup_batches, periodic_interval_batches):
    """
    runs the inference algorithm for an epoch
    returns the values of all losses separately on supervised and unsupervised parts
    """
    num_losses = len(losses)

    # compute number of batches for an epoch
    batches_per_epoch = sup_batches + unsup_batches

    # initialize variables to store loss values
    epoch_losses_sup = [0.] * num_losses
    epoch_losses_unsup = [0.] * num_losses

    # setup the iterators for training data loaders
    sup_iter = iter(data_loaders["sup"])
    unsup_iter = iter(data_loaders["unsup"])

    # count the number of supervised batches seen in this epoch
    ctr_sup = 0
    for i in range(batches_per_epoch):

        # whether this batch is supervised or not
        is_supervised = (i % periodic_interval_batches == 1) and ctr_sup < sup_batches

        # extract the corresponding batch
        start = now()
        if is_supervised:
            (xs, ys, _) = next(sup_iter)
            ctr_sup += 1
        else:
            xs, _ = next(unsup_iter)
        #print('load: {:.04f}'.format(now() - start))

        # run the inference for each loss with supervised or un-supervised
        # data as arguments
        start = now()
        for loss_id in range(num_losses):
            if is_supervised:
                new_loss = losses[loss_id].step(xs, ys)
                epoch_losses_sup[loss_id] += new_loss
            else:
                new_loss = losses[loss_id].step(xs)
                epoch_losses_unsup[loss_id] += new_loss
        #print('loss: {:.04f}'.format(now() - start))

    # return the values of all losses
    return epoch_losses_sup, epoch_losses_unsup

In [None]:
def get_accuracy(data_loader, classifier_fn):
    """
    compute the accuracy over the supervised training set or the testing set
    """
    predictions, actuals = [], []

    # use the appropriate data loader
    for (xs, ys, _) in data_loader:
        # use classification function to compute all predictions for each batch
        predictions.append(classifier_fn(xs))
        actuals.append(ys)
        
    # compute the number of accurate predictions
    accurate_preds = 0
    for pred, act in zip(predictions, actuals):
        for i in range(pred.size(0)):
            v = torch.sum(pred[i] == act[i])
            accurate_preds += (v.item() == pred[i].shape[0])

    # calculate the accuracy between 0 and 1
    accuracy = accurate_preds / (len(predictions) * len(predictions[0]))
    return accuracy

In [None]:
losses = [svi]

for epoch in range(1000):
    sup_batches = len(sup_loader)
    unsup_batches = len(sup_loader) * 100
    epoch_losses_sup, epoch_losses_unsup = run_inference_for_epoch(
        data_loaders, 
        losses, 
        sup_batches=sup_batches,
        unsup_batches=unsup_batches,
        periodic_interval_batches=100)
    
    # compute average epoch losses i.e. losses per example
    avg_epoch_losses_sup = list(map(lambda v: v / sup_batches, epoch_losses_sup))
    avg_epoch_losses_unsup = list(map(lambda v: v / unsup_batches, epoch_losses_unsup))
    loss_history.append((avg_epoch_losses_sup, avg_epoch_losses_unsup))
    
    if epoch % 10 == 0:
        # store the loss and validation/testing accuracies in the logfile
        str_loss_sup = " ".join(map(lambda s: '{:.04f}'.format(s), avg_epoch_losses_sup))
        str_loss_unsup = " ".join(map(lambda s: '{:.04f}'.format(s), avg_epoch_losses_unsup))

        str_print = "{} epoch: avg losses {}".format(epoch, "{} {}".format(str_loss_sup, str_loss_unsup))

        test_accuracy = get_accuracy(data_loaders["sup"], vae.classifier)
        str_print += ", sup accuracy {:.04f}".format(test_accuracy)

        print(str_print)
        pyro.get_param_store().save(
            '/app/data/models/transcript_ssvae_weights_epoch{:05d}.pt'.format(epoch))

In [None]:
vae.state_dict()

In [None]:
plt.plot([l[0] for l in loss_history[10:]])

In [None]:
plt.plot([l[1] for l in loss_history[10:]])

In [None]:
get_accuracy(data_loaders['sup'], vae.classifier)