<h1>Table of Contents<span class="tocSkip"></span></h1>
<div class="toc" style="margin-top: 1em;"><ul class="toc-item"><li><span><a href="#Import-Section" data-toc-modified-id="Import-Section-1"><span class="toc-item-num">1&nbsp;&nbsp;</span>Import Section</a></span></li><li><span><a href="#VAE-training-section" data-toc-modified-id="VAE-training-section-2"><span class="toc-item-num">2&nbsp;&nbsp;</span>VAE training section</a></span></li><li><span><a href="#Stop-importing-here" data-toc-modified-id="Stop-importing-here-3"><span class="toc-item-num">3&nbsp;&nbsp;</span>Stop importing here</a></span></li><li><span><a href="#Core" data-toc-modified-id="Core-4"><span class="toc-item-num">4&nbsp;&nbsp;</span>Core</a></span></li><li><span><a href="#Joint-Training" data-toc-modified-id="Joint-Training-5"><span class="toc-item-num">5&nbsp;&nbsp;</span>Joint Training</a></span></li><li><span><a href="#Interactive-z-sampling" data-toc-modified-id="Interactive-z-sampling-6"><span class="toc-item-num">6&nbsp;&nbsp;</span>Interactive z sampling</a></span></li><li><span><a href="#Clustering" data-toc-modified-id="Clustering-7"><span class="toc-item-num">7&nbsp;&nbsp;</span>Clustering</a></span></li><li><span><a href="#Experimental" data-toc-modified-id="Experimental-8"><span class="toc-item-num">8&nbsp;&nbsp;</span>Experimental</a></span></li><li><span><a href="#Old-Code" data-toc-modified-id="Old-Code-9"><span class="toc-item-num">9&nbsp;&nbsp;</span>Old Code</a></span></li></ul></div>

# Import Section

In [1]:
%reload_ext autoreload
%autoreload 2

from IPython.display import display
from ipywidgets import interact, interactive, fixed, interact_manual, widgets, Layout
from IPython.display import Image as WImage
import sys
import os.path
import pathlib

import pickle
import random
import numpy as np
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import torch.utils.data as data
import torch.optim as optim
from torch.autograd import Variable
import torch.nn.functional as F

import matplotlib.pyplot as plt
import scipy.misc

from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances_argmin_min

In [2]:
if '/scratch/Jack/projects/Chamber' not in sys.path:
    sys.path.insert(0,'/scratch/Jack/projects/Chamber')
if '/scratch/Jack/projects/Explanations' not in sys.path:
    sys.path.insert(0,'/scratch/Jack/projects/Explanations')

from pythreejs import *

from datasets import CustomClusteringDataset
from models import VAE, ConvVAE, W2V, CVAE

from chamber import Chamber, Misc, Historian, MasterOracle, Oracle, Commander, CodeFolder

torch.multiprocessing.set_sharing_strategy('file_system')

historian = Historian()
commander = Commander(historian)
oracle = Oracle(historian)

In [3]:
model_types = ["Vanilla","Convolutional","Conditional"]
scheduler_types = {}

encoder_choice, decoder_choice, multipliers_choice,\
batchnorm_choice, z_choice, linear_choice, kernel_choice,\
batch_choice, sizes_choice, window_context_choice, scheduler_choice \
= {},{},{},{},{},{},{},{},{},{},{}


# Scheduler types

scheduler_types['Vanilla A'] = {
            'epochs' : 3000,
            'lr_epoch' : 30000,
            'lr_mod' : 1.0,
            'lr_change' : 0.1,
            'lr_base' : 1e-3,
            'wd' : 0}

scheduler_types['Conditional A'] = {
            'lr_base' : 1e-3,
            'wd' : 0}

# Convolutional model types

model_type = 'A'+' '+model_types[1]
encoder_choice[model_type] = ['CM','CM','C','CM']
decoder_choice[model_type] = list(reversed(encoder_choice[model_type]))
multipliers_choice[model_type] = [1,1,2,2,4,4,4,4]
batchnorm_choice[model_type] = False
z_choice[model_type] = 4
linear_choice[model_type] = 16
kernel_choice[model_type] = 16
batch_choice[model_type] = 32
sizes_choice[model_type] = (128,128)
scheduler_choice[model_type] = scheduler_types['Vanilla A']

# Let's try to increase size of parameters, until we can overfit the small subset
model_type = 'B'+' '+model_types[1]
encoder_choice[model_type] = ['CM','CM','C','CM']
decoder_choice[model_type] = list(reversed(encoder_choice[model_type]))
multipliers_choice[model_type] = [1,1,2,2,4,4,4,4]
batchnorm_choice[model_type] = False
z_choice[model_type] = 128
linear_choice[model_type] = 512
kernel_choice[model_type] = 16
batch_choice[model_type] = 32
sizes_choice[model_type] = (128,128)
scheduler_choice[model_type] = scheduler_types['Vanilla A']
# Good, that overfit the toyset

model_type = 'C'+' '+model_types[1]
encoder_choice[model_type] = ['CM','CM','C','CM']
decoder_choice[model_type] = list(reversed(encoder_choice[model_type]))
multipliers_choice[model_type] = [1,1,2,2,4,4,4,4]
batchnorm_choice[model_type] = False
z_choice[model_type] = 64
linear_choice[model_type] = 512
kernel_choice[model_type] = 16
batch_choice[model_type] = 32
sizes_choice[model_type] = (128,128)
scheduler_choice[model_type] = scheduler_types['Vanilla A']
# Excellent fit at 2k

model_type = 'D'+' '+model_types[1]
encoder_choice[model_type] = ['CM','CM','C','CM']
decoder_choice[model_type] = list(reversed(encoder_choice[model_type]))
multipliers_choice[model_type] = [1,1,2,2,4,4,4,4]
batchnorm_choice[model_type] = False
z_choice[model_type] = 32
linear_choice[model_type] = 512
kernel_choice[model_type] = 16
batch_choice[model_type] = 32
sizes_choice[model_type] = (128,128)
scheduler_choice[model_type] = scheduler_types['Vanilla A']
# Fantastic fit at 2k


model_type = 'E'+' '+model_types[1]
encoder_choice[model_type] = ['CM','CM','C','CM']
decoder_choice[model_type] = list(reversed(encoder_choice[model_type]))
multipliers_choice[model_type] = [1,1,2,2,4,4,4,4]
batchnorm_choice[model_type] = False
z_choice[model_type] = 8
linear_choice[model_type] = 512
kernel_choice[model_type] = 16
batch_choice[model_type] = 32
sizes_choice[model_type] = (128,128)
scheduler_choice[model_type] = scheduler_types['Vanilla A']


model_type = 'F'+' '+model_types[1]
encoder_choice[model_type] = ['CM','CM','CM','CM']
decoder_choice[model_type] = list(reversed(encoder_choice[model_type]))
multipliers_choice[model_type] = [1,1,2,2]
batchnorm_choice[model_type] = False
z_choice[model_type] = 32
linear_choice[model_type] = 512
kernel_choice[model_type] = 96
batch_choice[model_type] = 128
sizes_choice[model_type] = (64,64)
scheduler_choice[model_type] = scheduler_types['Vanilla A']
# Make sure we have enough kernels to not get visual artifacts, fantastic

model_type = 'F2'+' '+model_types[1]
encoder_choice[model_type] = ['CM','CM','CM','CM']
decoder_choice[model_type] = list(reversed(encoder_choice[model_type]))
multipliers_choice[model_type] = [2,2,4,4]
batchnorm_choice[model_type] = False
z_choice[model_type] = 64
linear_choice[model_type] = 756
kernel_choice[model_type] = 96
batch_choice[model_type] = 128
sizes_choice[model_type] = (64,64)
scheduler_choice[model_type] = scheduler_types['Vanilla A']
# Make sure we have enough kernels to not get visual artifacts, fantastic



model_type = 'Faster'+' '+model_types[1]
encoder_choice[model_type] = ['CM','CM','C','CM']
decoder_choice[model_type] = list(reversed(encoder_choice[model_type]))
multipliers_choice[model_type] = [1,1,2,2,4,4,4,4]
batchnorm_choice[model_type] = False
z_choice[model_type] = 128
linear_choice[model_type] = 512
kernel_choice[model_type] = 96
batch_choice[model_type] = 32
sizes_choice[model_type] = (128,128)
scheduler_choice[model_type] = scheduler_types['Vanilla A']
# Make sure we have enough kernels to not get visual artifacts, fantastic


model_type = 'Zen'+' '+model_types[1]
encoder_choice[model_type] = ['CM','CM','C','CM','C','CM']
decoder_choice[model_type] = list(reversed(encoder_choice[model_type]))
multipliers_choice[model_type] = [2,3,4,4,6,6]
batchnorm_choice[model_type] = False
z_choice[model_type] = 128
linear_choice[model_type] = 512
kernel_choice[model_type] = 32
batch_choice[model_type] = 32
sizes_choice[model_type] = (64,64)
scheduler_choice[model_type] = scheduler_types['Vanilla A']
# Guess best parameters to get perfect reconstruction with as few epochs/param dimensions as possible

model_type = 'Proof'+' '+model_types[1]
encoder_choice[model_type] = ['CM','CM','CM','CM']
decoder_choice[model_type] = list(reversed(encoder_choice[model_type]))
multipliers_choice[model_type] = [2,2,3,3]
batchnorm_choice[model_type] = True
z_choice[model_type] = 512
linear_choice[model_type] = 512
kernel_choice[model_type] = 32
batch_choice[model_type] = 32
sizes_choice[model_type] = (64,64)
scheduler_choice[model_type] = scheduler_types['Vanilla A']
# Guess best parameters to get perfect reconstruction with as few epochs/param dimensions as possible

model_type = 'ZenSimple'+' '+model_types[1]
encoder_choice[model_type] = ['C','CM','C','CM','C','CM','CM','CM']
decoder_choice[model_type] = list(reversed(encoder_choice[model_type]))
multipliers_choice[model_type] = [2,2,2,2,2,2,2,2]
batchnorm_choice[model_type] = False
z_choice[model_type] = 8
linear_choice[model_type] = 128
kernel_choice[model_type] = 32
batch_choice[model_type] = 32
sizes_choice[model_type] = (128,128)
scheduler_choice[model_type] = scheduler_types['Vanilla A']

# Vanilla model types

model_type = 'whole A'+' '+model_types[0]
z_choice[model_type] = 128
linear_choice[model_type] = 512
batch_choice[model_type] = 32
sizes_choice[model_type] = (128,128)
scheduler_choice[model_type] = scheduler_types['Vanilla A']

model_type = 'Grand'+' '+model_types[0]
z_choice[model_type] = 32
linear_choice[model_type] = 512
batch_choice[model_type] = 128
sizes_choice[model_type] = (64,64)
scheduler_choice[model_type] = scheduler_types['Vanilla A']

# Conditional model types

model_type = 'A'+' '+model_types[2]
z_choice[model_type] = 4 # latent z* of the CVAE
linear_choice[model_type] = 128
batch_choice[model_type] = 32 # sample this many samples from a window around the input
sizes_choice[model_type] = 128 # latent z dimension in input and output
window_context_choice[model_type] = (2,True) # size of the window, and whether it is bidirectional
scheduler_choice[model_type] = scheduler_types['Conditional A']

model_type = 'Z'+' '+model_types[2]
z_choice[model_type] = 128 # latent z* of the CVAE
linear_choice[model_type] = 512
batch_choice[model_type] = 4 # sample this many samples from a window around the input
sizes_choice[model_type] = 128 # latent z dimension in input and output
window_context_choice[model_type] = (2,True) # size of the window, and whether it is bidirectional
scheduler_choice[model_type] = scheduler_types['Conditional A']


model_type = 'Jumper'+' '+model_types[2]
z_choice[model_type] = 128 # latent z* of the CVAE
linear_choice[model_type] = 512
batch_choice[model_type] = 4 # sample this many samples from a window around the input
sizes_choice[model_type] = 128 # latent z dimension in input and output
window_context_choice[model_type] = (6,True) # size of the window, and whether it is bidirectional
scheduler_choice[model_type] = scheduler_types['Conditional A']


model_type = 'Jumper Slow'+' '+model_types[2]
z_choice[model_type] = 128 # latent z* of the CVAE
linear_choice[model_type] = 512
batch_choice[model_type] = 32 # sample this many samples from a window around the input
sizes_choice[model_type] = 128 # latent z dimension in input and output
window_context_choice[model_type] = (8,True) # size of the window, and whether it is bidirectional
scheduler_choice[model_type] = scheduler_types['Conditional A']


model_type = 'F'+' '+model_types[2]
z_choice[model_type] = 8 # latent z* of the CVAE
linear_choice[model_type] = 256
batch_choice[model_type] = 128 # sample this many samples from a window around the input
sizes_choice[model_type] = 32 # latent z dimension in input and output
window_context_choice[model_type] = (64,True) # size of the window, and whether it is bidirectional
scheduler_choice[model_type] = scheduler_types['Conditional A']

model_type = 'local'+' '+model_types[2]
z_choice[model_type] = 8 # latent z* of the CVAE
linear_choice[model_type] = 256
batch_choice[model_type] = 128 # sample this many samples from a window around the input
sizes_choice[model_type] = 32 # latent z dimension in input and output
window_context_choice[model_type] = (2,True) # size of the window, and whether it is bidirectional
scheduler_choice[model_type] = scheduler_types['Conditional A']

# VAE training section

In [4]:

class VAE_Adventures():
        
    def load_model(self,epoch=1000):
        self.VAE.load_state_dict(torch.load(self.directory_models+"model_"+str(epoch)+".model"))
        self.inference = True         
        
    def save_model(self, epoch=0):
        torch.save(self.VAE.state_dict(), self.directory_models+"model_"+str(epoch)+".model")
        
    def __init__(self, log=False, model_version="X", model_choice="X",  save_models=True, whole_dataset=False, model_type = "X", resource_path=None, result_path=None):
        
        ## Class definitions
        
        codeFolder = CodeFolder(name = "Class Parameters")
        with codeFolder:
            self.save_models = save_models
            self.whole_dataset = whole_dataset

            self.model_version = model_version
            self.model_type = model_type
            self.model_choice = model_choice +" "+model_type        

            if self.model_type == model_types[1]:
                self.cfg_encoder = encoder_choice[self.model_choice]
                self.cfg_decoder = decoder_choice[self.model_choice]
                self.kernel_base = kernel_choice[self.model_choice]
                self.kernel_multipliers = multipliers_choice[self.model_choice]
                self.batch_norm = batchnorm_choice[self.model_choice]
                self.main_directory = "Convolutional VAE results/"
            elif self.model_type == model_types[0]:
                self.main_directory = "Vanilla VAE results/"
            elif self.model_type == model_types[2]:
                self.window_context = window_context_choice[self.model_choice]
                self.main_directory = "Conditional VAE results/"

            self.z_spatial_dim = z_choice[self.model_choice]
            self.linear_nodes = linear_choice[self.model_choice]

            self.sizes = sizes_choice[self.model_choice]
            self.batch_size = batch_choice[self.model_choice] 

            self.datasplit = (8,10) # train on 8 out of 10 batches
            self.vis_size = 8
            
            self.resource_path = resource_path
            self.result_path = result_path
            self.scheduler = scheduler_choice[self.model_choice]
        
        ### Dataset definition
        
        if self.model_type != model_types[2]:
            self.trainset = CustomClusteringDataset(sizes=self.sizes, path=self.resource_path,subset=1.0 if self.whole_dataset else 0.07,datasplit = self.datasplit, testing=False)
            self.trainloader = torch.utils.data.DataLoader(dataset=self.trainset, batch_size=self.batch_size, shuffle=True, num_workers=16)
            self.trainloader_prediction = torch.utils.data.DataLoader(dataset=self.trainset, batch_size=self.batch_size, shuffle=False, num_workers=16)
            self.testset = CustomClusteringDataset(sizes=self.sizes, path=self.resource_path,subset=1.0 if self.whole_dataset else 0.07,datasplit = self.datasplit, testing=True)
            self.testloader_prediction = torch.utils.data.DataLoader(dataset=self.testset, batch_size=self.batch_size, shuffle=False, num_workers=16)
        
        ### Directories and Logging Definition 
        
        self.root_directory = self.result_path
        self.version_directory = self.model_version+" "+str(self.model_choice)
        if self.whole_dataset:
            self.version_directory += " entire dataset"
            
        self.directory = self.root_directory+self.main_directory+self.version_directory
        print("Using directory : {}".format(self.directory))
        
        overwrite = None
        # If we turn on logging, then we want to train. Thus we need to ask the user for permission to overwrite directory
        # If logging is turned off because we are in inference mode, then self.run should be False as well, so that we cannot use the model's train method
        
        overwrite = 'y'
        if pathlib.Path(self.directory).exists() and log == True:
            overwrite = input("Overwrite directory? It already exists...")
            
        if overwrite in ['no','n','False','false',''] or not (overwrite in ['y','yes','True','true']):
            self.run = False
            return
        self.run = log           
            
        pathlib.Path(self.directory).mkdir(parents=True, exist_ok=True) 
        
        codefolder = CodeFolder(name="directories")
        with codefolder:
            self.directory_models = self.directory+"/models/"
            self.directory_visuals = self.directory+"/visuals/"
            self.directory_visuals_test = self.directory+"/visuals_test/"
            self.directory_samples = self.directory+"/samples/"
            self.directory_samples_varied = self.directory+"/samples_varied/"

            pathlib.Path(self.directory_models).mkdir(parents=True, exist_ok=True) 
            pathlib.Path(self.directory_visuals).mkdir(parents=True, exist_ok=True)
            pathlib.Path(self.directory_visuals_test).mkdir(parents=True, exist_ok=True)
            pathlib.Path(self.directory_samples).mkdir(parents=True, exist_ok=True)
            pathlib.Path(self.directory_samples_varied).mkdir(parents=True, exist_ok=True)
        
        self.oracle = Oracle()
        self.historian = None
        if log:
            self.historian = Historian()
            self.historian.logger(self.directory+"/version_metadata.txt")    
        
        if self.model_type == model_types[1]:
            self.VAE = ConvVAE(cfg_encoder=self.cfg_encoder, cfg_decoder=self.cfg_decoder, sizes=self.sizes,z_spatial_dim=self.z_spatial_dim, linear_nodes=self.linear_nodes, kernel_base=self.kernel_base, kernel_multipliers = self.kernel_multipliers, batch_norm=self.batch_norm).cuda()
        elif self.model_type == model_types[0]:
            self.VAE = VAE(sizes=self.sizes,z_spatial_dim=self.z_spatial_dim, linear_nodes=self.linear_nodes).cuda()       
        elif self.model_type == model_types[2]:
            self.VAE = CVAE(sizes=self.sizes,z_spatial_dim=self.z_spatial_dim, linear_nodes=self.linear_nodes).cuda()   
            
        if not self.historian is None:
            if self.model_type != model_types[2]:
                self.historian.log("Loaded training dataset with {} sequences of attention maps".format(len(self.trainloader)))
                self.historian.log("Loaded training prediction dataset with {} sequences of attention maps".format(len(self.trainloader_prediction)))
                self.historian.log("Loaded testing prediction dataset with {} sequences of attention maps".format(len(self.testloader_prediction)))
            
            ### Model definition

            self.historian.log("\n\n----------------- MODEL ATTRIBUTE DEFINITIONS -----------------\n\n")

            for key, val in vars(self).items():
                self.historian.log("self.{}  ===   {}".format(key,val))

            self.historian.log("\n\n----------------- MODEL ATTRIBUTE DEFINITIONS -----------------\n\n")    
                 
    def steptrain(self, mycondition,  myinput, myconcat, params=None): # batch of output context and batch of replicated input
        
        assert self.model_type==model_types[2], "Cannot train by steps on conditional inputs on non-Conditional VAE model"
        assert self.run," Experiment was halted, either trying to run without overwriting directory or inference mode is using a logfile..."
        
        lr_base = self.scheduler['lr_base']
        wd = self.scheduler['wd']
        self.optimizer = optim.Adam(params, weight_decay=wd, lr=lr_base)
        #print(list(params[1]['params'])[0].grad)
        self.VAE.inference = False
        if params == None:
            condition_datum = Variable(mycondition) # mycondition[1] should be a batch of replicated word z
            in_datum = Variable(myinput) # myinput[1] should be a batch of contextual z samples
            concat_datum = Variable(myconcat)
        else:
            condition_datum = mycondition
            in_datum = myinput
            concat_datum = myconcat

        Z, z_mu, z_var, out_datum = self.VAE(condition_datum, concat_datum) 

        loss = self.loss_function(out_datum, in_datum, z_mu, z_var)

        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        return loss.data.cpu()
    
    def epoch_train(self):
        assert self.run
        codefolder = CodeFolder(name="Hyper-parameters")
        with codefolder:
            self.VAE.inference = False
            
            lr_epoch = self.scheduler['lr_epoch']
            lr_mod = self.scheduler['lr_mod']
            lr_change = self.scheduler['lr_change']
            lr_base = self.scheduler['lr_base']
            wd = self.scheduler['wd']
            epochs = 1

            samples_vary_creation = False
            starting_epoch = 1
            
        self.optimizer = optim.Adam(self.VAE.parameters(), weight_decay=wd, lr=lr_base*lr_mod)
        
        for i, datum in enumerate(self.trainloader):
            print(len(self.trainloader))
                    
            #orig_datum = Variable(datum[0].cuda())
            in_datum = Variable(datum[1].cuda())

            Z, z_mu, z_var, out_datum = self.VAE(in_datum)

            loss = self.loss_function(out_datum, in_datum,  z_mu, z_var)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            if i == 0 and not self.historian is None:
                self.historian.log("loss: {} ,lr= {}".format(loss.data.cpu().numpy()[0], lr_mod*lr_base))
                
            if i > 32:
                break
                
    def conditional_step_train(self, v, params=None):
        
        assert self.run
        
        self.VAE.inference = False
        v.VAE.inference = True
        
        #print(v.VAE.z_mu.weight)
        for i, datum in enumerate(v.trainloader_prediction):
            # Get z's for entire sequence
            in_datum_batch = Variable(datum[1].cuda())
            Z, z_mu, z_var, out_datum = v.VAE(in_datum_batch)

            tensor_conditional = None
            # Randomly pick a z pair
            pair_list = []
            while tensor_conditional == None or tensor_conditional[0].size()[0] < self.batch_size:
                z_choice_index = random.randint(1,datum[1].size()[0])
                z_window = [0,min(z_choice_index+self.window_context[0], datum[1].size()[0])]
                if self.window_context[1]:
                    z_window[0] = max(0,z_choice_index-self.window_context[0])
                possible_vals = [x for x in range(z_window[0], z_window[1]+1) if x != z_choice_index]
                random_pair = random.choice(possible_vals)
                Zc = Z[z_choice_index-1].unsqueeze(0)
                Zn = Z[random_pair-1].unsqueeze(0)
                Zplus = torch.cat([Zc, Zn],1)
                if tensor_conditional is None:
                    tensor_conditional = [Zc,Zn , Zplus ]
                else:
                    tensor_conditional = [torch.cat([tensor_conditional[0],Zc]), torch.cat([tensor_conditional[1],Zn]), torch.cat([tensor_conditional[2],Zplus])]

            myloss = self.steptrain(tensor_conditional[0], tensor_conditional[1], tensor_conditional[2], params=params)
        
            if i > 32:
                break
        #print(v.VAE.z_mu.weight)
            
    def train(self):
        
        assert self.run," Experiment was halted, either trying to run without overwriting directory or inference mode is using a logfile..."
        
        # Training hyperparameters
        codefolder = CodeFolder(name="Hyper-parameters")
        with codefolder:
            self.VAE.inference = False
            
            lr_epoch = self.scheduler['lr_epoch']
            lr_mod = self.scheduler['lr_mod']
            lr_change = self.scheduler['lr_change']
            lr_base = self.scheduler['lr_base']
            wd = self.scheduler['wd']
            epochs = self.scheduler['epochs']

            samples_vary_creation = False
            starting_epoch = 1
            epoch_save_all = 100
        
        if self.whole_dataset:
            epoch_save_all = 25
        
        if starting_epoch != 1:
            directory_models = self.directory+"/models/"
            self.VAE.load_state_dict(torch.load(self.directory_models+"model_"+str(starting_epoch-1)+".model"))
            
        self.optimizer = optim.Adam(self.VAE.parameters(), weight_decay=wd, lr=lr_base*lr_mod)
        
        
        # Start training
        for epoch in range(starting_epoch,epochs+1):

            print("Epoch {}".format(epoch))
            if epoch % lr_epoch == 0:
                lr_mod *= lr_change
                self.optimizer = optim.Adam(self.VAE.parameters(), weight_decay=wd, lr=lr_base*lr_mod)
                
            if epoch % epoch_save_all != 0:
                for i, datum in enumerate(self.trainloader):
                    
                    orig_datum = Variable(datum[0])
                    in_datum = Variable(datum[1].cuda())

                    Z, z_mu, z_var, out_datum = self.VAE(in_datum)

                    loss = self.loss_function(out_datum, in_datum,  z_mu, z_var)

                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                    
                    if i == 0 and not self.historian is None:
                        self.historian.log("{} / {} loss: {} ,lr= {}".format(epoch,  epochs, loss.data.cpu().numpy()[0], lr_mod*lr_base))
                        
            else: # Save models and visualization of data
                                    
                if self.save_models:
                    self.save_model(epoch=epoch)
                    
                # Reconstruction samples
                for i, datum in enumerate(self.trainloader_prediction):
                    
                    if i > self.vis_size:
                        break
                        
                    orig_datum = Variable(datum[0])
                    in_datum = Variable(datum[1].cuda())

                    Z, z_mu, z_var, out_datum = self.VAE(in_datum)
                    
                    orig , inp, outp = orig_datum.data, in_datum.data, out_datum.data
                    if self.batch_size != 1:
                        orig, inp, outp = orig.squeeze(0), inp.squeeze(0), outp
                    orig, inp, outp = orig.cpu(), inp.cpu(), outp.cpu()
                        
                    self.oracle.visualize_tensors([orig, inp, outp],file=self.directory_visuals+"/viewpoint_result_e"+str(epoch)+"_b"+str(i).zfill(2))
                
                for i, datum in enumerate(self.testloader_prediction):
                    
                    if i > self.vis_size:
                        break
                        
                    orig_datum = Variable(datum[0])
                    in_datum = Variable(datum[1].cuda())

                    Z, z_mu, z_var, out_datum = self.VAE(in_datum)
                    
                    orig , inp, outp = [orig_datum.data, in_datum.data, out_datum.data]
                    if self.batch_size != 1:
                        orig, inp, outp = orig.squeeze(0), inp.squeeze(0), outp
                    orig, inp, outp = [orig.cpu(), inp.cpu(), outp.cpu()]
                    
                    self.oracle.visualize_tensors([orig, inp, outp],file=self.directory_visuals_test+"/viewpoint_result_e"+str(epoch)+"_b"+str(i).zfill(2)) 
                        
                # Random z samples
                self.VAE.inference = True
                
                for i in range(6):
                    sample = Variable(torch.randn(12,self.z_spatial_dim)).cuda()
                    sample = self.VAE.decode(sample)
                    self.oracle.visualize_tensors([sample.data.cpu()],file=self.directory_samples+"/viewpoint_result_e"+str(epoch)+"_b"+str(i).zfill(2))
                
                if samples_vary_creation:
                    sample = Variable(torch.randn(12,self.z_spatial_dim)).cuda()

                    for i in range(1+int(self.z_spatial_dim/8)):
                       output = self.VAE.decode(sample)
                       for j in range(12):
                           sample[j,i] = sample[j,i] + random.randint(1,30)*1.0 
                       self.oracle.visualize_tensors([output.data.cpu()],file=self.directory_samples_varied+"/viewpoint_result_e"+str(epoch)+"_b"+str(i).zfill(2))
                
                self.VAE.inference = False

    def loss_function(self, recon_x, x, mu, logvar):
        if self.model_type != model_types[2]:
            # Our inputs and outputs are bound to [0,1] since they are raw attention images.
            # Input is already in [0,1] and output has a sigmoid activation
            # We are trying to find the likelihood of every pixel location being active
            BCE = F.binary_cross_entropy(recon_x.view(-1,self.sizes[0]*self.sizes[1]), x.view(-1, self.sizes[0]*self.sizes[1]))

            # see Appendix B from VAE paper:
            # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
            # https://arxiv.org/abs/1312.6114
            # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
            KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            # Normalise by same number of elements as in reconstruction
            KLD /= x.size()[0] * self.sizes[0]*self.sizes[1]
            KLD = torch.clamp(KLD,0,1000) # tiny floating point errors with values like -5e-14
            if random.random() > 0.9:
                print(BCE.data.cpu().numpy()[0], KLD.data.cpu().numpy()[0])
            return BCE + KLD
        else:
            # We have latent space be our input, so the loss function should be a regressor instead
            # Our input and output samples are in (-inf, +inf) and follow a gaussian
            
            MSE = F.mse_loss(recon_x.view(-1,self.sizes), x.view(-1, self.sizes))
            KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
            KLD /= recon_x.size()[0] * self.sizes
            KLD = torch.clamp(KLD,0,1000) # tiny floating point errors with values like -5e-14
            
            if random.random() > 0.9:
                print(MSE.data.cpu().numpy()[0], KLD.data.cpu().numpy()[0])
            return MSE + KLD
        

# Stop importing here

In [None]:
if __name__ != "__main__":
    raise ImportError

# Core

In [None]:
resource_path = "/scratch/Jack/resources/Attention Maps/view maps/"
result_path =  "/scratch/Jack/research lab/NFL Viewpoint VAEs/"
v = VAE_Adventures(log=True, model_version='Serenity',model_choice='Grand', model_type='Vanilla', whole_dataset=True, resource_path=resource_path, result_path = result_path)
v.train()

In [None]:
resource_path = "/scratch/Jack/resources/Attention Maps/view maps/"
result_path =  "/scratch/Jack/research lab/NFL Viewpoint VAEs/"
v = VAE_Adventures(log=True, model_version='Serenity',model_choice='F', model_type='Convolutional',whole_dataset=True, resource_path=resource_path, result_path = result_path)
v.train()

# Joint Training

In [14]:

starting_cycle = 0

model_choice = 'F'
model_version = "Serenity Joint Su"
model_type = "Convolutional"
resource_path = "/scratch/Jack/resources/Attention Maps/playtype maps/"
result_path =  "/scratch/Jack/research lab/NFL Playtype VAEs/"
v = VAE_Adventures(log=True, model_version=model_version,model_choice=model_choice, model_type=model_type,whole_dataset=False, resource_path=resource_path, result_path = result_path)
#v.load_model(epoch = starting_cycle)

model_choice = 'local'
model_version = "Serenity Joint Su"
model_type = "Conditional"

cVAE = VAE_Adventures(log=True, model_version=model_version, model_choice=model_choice, model_type=model_type, resource_path=resource_path, result_path=result_path)
#cVAE.load_model(epoch = starting_cycle)
cVAE.trainloader_prediction = v.trainloader_prediction

if starting_cycle != 0:
    #v.load_model(epoch = starting_cycle)
    cVAE.load_model(epoch = starting_cycle) 
params = [{'params': cVAE.VAE.parameters(), 'lr':1e-3}, {'params': v.VAE.parameters(), 'lr': 1e-4} ]
cycles = 4000
cycle_save = 100 

for cycle in range(starting_cycle+1, cycles+1):
    historian.log("My cycle is {}".format(cycle))
    historian.log("One train cycle for convolutional VAE")
    v.epoch_train()
    historian.log("One train cycle for conditional VAE")
    cVAE.conditional_step_train(v, params)
    
    if cycle % cycle_save == 0:
        cVAE.save_model(epoch=cycle)
        v.save_model(epoch=cycle)

Using directory : /scratch/Jack/research lab/NFL Playtype VAEs/Convolutional VAE results/Serenity Joint Su F Convolutional
Using directory : /scratch/Jack/research lab/NFL Playtype VAEs/Conditional VAE results/Serenity Joint Su local Conditional


KeyboardInterrupt: 

In [10]:
 
#starting_cycle = 0
starting_cycle = 0
model_choice = 'F'
#model_choice = 'F2'
#model_version = "Azimuth Joint"
model_version = "Zalamity Joint"
model_type = "Convolutional"
resource_path = "/scratch/Jack/resources/Attention Maps/playtype maps/"
result_path =  "/scratch/Jack/research lab/NFL Playtype VAEs/"
v = VAE_Adventures(log=True, model_version=model_version,model_choice=model_choice, model_type=model_type,whole_dataset=False, resource_path=resource_path, result_path = result_path)

v.load_model(epoch = 1000)
model_choice = 'local'
model_version = "Zalamity Joint"
model_type = "Conditional"

cVAE = VAE_Adventures(log=True, model_version=model_version, model_choice=model_choice, model_type=model_type, resource_path=resource_path, result_path=result_path)
cVAE.trainloader_prediction = v.trainloader_prediction
#{'params': cVAE.VAE.parameters(), 'lr':1e-3}
params = [{'params': cVAE.VAE.parameters(), 'lr': 1e-4} ]
cycles = 4000
cycle_save = 20  

if starting_cycle != 0:
    #v.load_model(epoch = starting_cycle)
    cVAE.load_model(epoch = starting_cycle)  

for cycle in range(starting_cycle+1, cycles+1):
    historian.log("My cycle is {}".format(cycle))
    historian.log("One train cycle for convolutional VAE")
    #v.epoch_train()
    #historian.log("One train cycle for conditional VAE")
    cVAE.conditional_step_train(v, params)
    
    if cycle % cycle_save == 0:
        cVAE.save_model(epoch=cycle)
        #v.save_model(epoch=cycle)

Using directory : /scratch/Jack/research lab/NFL Playtype VAEs/Convolutional VAE results/Zalamity Joint F Convolutional
Overwrite directory? It already exists...y
Using directory : /scratch/Jack/research lab/NFL Playtype VAEs/Conditional VAE results/Zalamity Joint local Conditional
Overwrite directory? It already exists...y
Historian -> My cycle is 1
Historian -> One train cycle for convolutional VAE
0.0527128 0.0
0.0750555 0.0
0.0642181 0.0
Historian -> My cycle is 2
Historian -> One train cycle for convolutional VAE
0.0350283 0.0
0.0867339 1.25146e-09
Historian -> My cycle is 3
Historian -> One train cycle for convolutional VAE
0.0582405 2.18279e-11
0.0910049 0.0
0.0439656 0.0
Historian -> My cycle is 4
Historian -> One train cycle for convolutional VAE
0.0627255 3.07773e-09
Historian -> My cycle is 5
Historian -> One train cycle for convolutional VAE
0.102757 1.45374e-08
Historian -> My cycle is 6
Historian -> One train cycle for convolutional VAE
0.0328821 1.32131e-08
0.0658872 2.0

Historian -> My cycle is 72
Historian -> One train cycle for convolutional VAE
0.0423049 8.45457e-06
Historian -> My cycle is 73
Historian -> One train cycle for convolutional VAE
0.0410874 9.85241e-06
Historian -> My cycle is 74
Historian -> One train cycle for convolutional VAE
0.019489 1.82395e-05
Historian -> My cycle is 75
Historian -> One train cycle for convolutional VAE
Historian -> My cycle is 76
Historian -> One train cycle for convolutional VAE
0.0190209 2.28844e-05
Historian -> My cycle is 77
Historian -> One train cycle for convolutional VAE
Historian -> My cycle is 78
Historian -> One train cycle for convolutional VAE
0.042572 2.06639e-05
0.0420443 2.54902e-05
0.0265437 2.50148e-05
Historian -> My cycle is 79
Historian -> One train cycle for convolutional VAE
0.0377066 1.63372e-05
0.01254 1.24944e-05
0.010912 8.39349e-06
0.0552256 3.50147e-05
0.0219331 1.97195e-05
Historian -> My cycle is 80
Historian -> One train cycle for convolutional VAE
Historian -> My cycle is 81
Hi

RuntimeError: Traceback (most recent call last):
  File "/scratch/Jack/projects/Environments/torch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 40, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/scratch/Jack/projects/Environments/torch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 109, in default_collate
    return [default_collate(samples) for samples in transposed]
  File "/scratch/Jack/projects/Environments/torch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 109, in <listcomp>
    return [default_collate(samples) for samples in transposed]
  File "/scratch/Jack/projects/Environments/torch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 89, in default_collate
    storage = batch[0].storage()._new_shared(numel)
  File "/scratch/Jack/projects/Environments/torch/lib/python3.6/site-packages/torch/storage.py", line 111, in _new_shared
    return cls._new_using_filename(size)
RuntimeError: unable to write to file </torch_94115_3263109217> at /pytorch/torch/lib/TH/THAllocator.c:271


Process Process-2283:
Process Process-2287:
Process Process-2281:
Process Process-2276:
Process Process-2275:
Process Process-2278:
Process Process-2277:
Process Process-2279:
Process Process-2284:
Process Process-2274:
Process Process-2286:
Process Process-2285:
Process Process-2280:
Traceback (most recent call last):
Process Process-2273:
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
Traceback (most recent call last):
  File "/scratch/Jack/projects/Python/myPython3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Traceback (most recent call last):
Traceback (most recent call last):
  File "/scratch/Jack/projects/Python/myPython3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/scratch/Jack/projects/Python/myPython3/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
  

  File "/scratch/Jack/projects/Python/myPython3/lib/python3.6/multiprocessing/synchronize.py", line 96, in __enter__
    return self._semlock.__enter__()
KeyboardInterrupt
  File "/scratch/Jack/projects/Python/myPython3/lib/python3.6/multiprocessing/queues.py", line 334, in get
    with self._rlock:
  File "/scratch/Jack/projects/Python/myPython3/lib/python3.6/multiprocessing/synchronize.py", line 96, in __enter__
    return self._semlock.__enter__()
  File "/scratch/Jack/projects/Python/myPython3/lib/python3.6/multiprocessing/queues.py", line 334, in get
    with self._rlock:
KeyboardInterrupt
  File "/scratch/Jack/projects/Environments/torch/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 34, in _worker_loop
    r = index_queue.get()
  File "/scratch/Jack/projects/Python/myPython3/lib/python3.6/multiprocessing/connection.py", line 407, in _recv_bytes
    buf = self._recv(4)
  File "/scratch/Jack/projects/Python/myPython3/lib/python3.6/multiprocessing/synchronize.py"

In [None]:
# Conditional VAE


In [23]:
# Run it on sequences of batch 32 but only extract an input and a window context
# Run the CVAE on the window context

### Model definition

model_choice = 'F'
model_version = "Serenity"
model_type = "Convolutional"

### ------------------------------------------------------------------------------------

resource_path = "/scratch/Jack/resources/Attention Maps/view maps/"
result_path =  "/scratch/Jack/research lab/NFL Viewpoint VAEs/"
v = VAE_Adventures(log=False, model_version=model_version, model_choice=model_choice, whole_dataset=True, model_type=model_type, resource_path=resource_path, result_path=result_path)
v.load_model(epoch=1000)

### ------------------------------------------------------------------------------------

Using directory : /scratch/Jack/research lab/NFL Viewpoint VAEs/Convolutional VAE results/Serenity F Convolutional entire dataset


In [None]:
model_choice = 'local'
model_version = "SerenityLarge"
model_type = "Conditional"

resource_path = "/scratch/Jack/resources/Attention Maps/view maps/"
result_path =  "/scratch/Jack/research lab/NFL Viewpoint VAEs/"
cVAE = VAE_Adventures(log=True, model_version=model_version, model_choice=model_choice, model_type=model_type, resource_path=resource_path, result_path=result_path)

epochs = 4000
for epoch in range(1,epochs+1):
    print("epoch ",epoch)
    average_loss = 0
    batches = 0.0
    for i, datum in enumerate(v.trainloader_prediction):
        # Get z's for entire sequence
        in_datum_batch = Variable(datum[1].cuda(), volatile=True)
        Z, z_mu, z_var, out_datum = v.VAE(in_datum_batch)

        tensor_conditional = None
        # Randomly pick a z pair
        pair_list = []
        while tensor_conditional == None or tensor_conditional[0].size()[0] < cVAE.batch_size:
            z_choice_index = random.randint(1,datum[1].size()[0])
            z_window = [0,min(z_choice_index+cVAE.window_context[0], datum[1].size()[0])]
            if cVAE.window_context[1]:
                z_window[0] = max(0,z_choice_index-cVAE.window_context[0])
            possible_vals = [x for x in range(z_window[0], z_window[1]+1) if x != z_choice_index]
            random_pair = random.choice(possible_vals)
            Zc = Z[z_choice_index-1].unsqueeze(0)
            Zn = Z[random_pair-1].unsqueeze(0)
            Zplus = torch.cat([Zc, Zn],1)
            if tensor_conditional is None:
                tensor_conditional = [Zc,Zn , Zplus ]
            else:
                tensor_conditional = [torch.cat([tensor_conditional[0],Zc]), torch.cat([tensor_conditional[1],Zn]), torch.cat([tensor_conditional[2],Zplus])]

        myloss = cVAE.steptrain(tensor_conditional[0].data, tensor_conditional[1].data, tensor_conditional[2].data)
        average_loss += myloss.numpy()[0]
        batches += cVAE.batch_size
    if epoch % 100 == 0:
        print("Loss is : {}".format(average_loss/batches))
        cVAE.save_model(epoch=epoch)
        
        

# Interactive z sampling

In [None]:
### Model definition

model_choice = 'F'
model_version = "Serenity"
model_type = "Convolutional"
resource_path = "/scratch/Jack/resources/Attention Maps/view maps/"
result_path =  "/scratch/Jack/research lab/NFL Viewpoint VAEs/"
z_sub = 32 # how many z dimensions are interactive
z_spatial_dim = 32
### ------------------------------------------------------------------------------------

z_sub = min(z_spatial_dim,z_sub)

v = VAE_Adventures(log=False,whole_dataset=True, model_version=model_version, model_choice=model_choice, model_type=model_type,resource_path=resource_path, result_path=result_path)
v.load_model(epoch=1000)

### ------------------------------------------------------------------------------------



def visualize_z_sample(**z):
    sample_z = torch.zeros([1,z_spatial_dim])
    for i, (arg, val) in enumerate(z.items()):
        if arg != "wimage":
            sample_z[0,i] = val

    v.inference = True
    sample_var = Variable(sample_z).cuda()
    sample = v.VAE.decode(sample_var)
    
    to_pil = transforms.ToPILImage()
    sample_cpu = sample.data.cpu()
    oracle.visualize_tensors([sample_cpu],file="./temp_image")
    wimage.value = open("./temp_image.png", "rb").read()
    display(wimage)
    

def widget_slider(name):
    mystyle = {'handle_color': 'white'}
    w = widgets.FloatSlider(
        value=0,
        min=-3.0,
        max=3.0,
        step=0.1,
        disabled=False,
        continuous_update=False,
        orientation='vertical',
        readout=True,
        style=mystyle)
    
    return w

z_interactive = {'z-' + str(k): widget_slider('z-' + str(k)) for k in range(z_sub)}

wimage =widgets.Image(
    format='png',
    width=512,
    height=512,
)

hlay = Layout(display='flex',
                    flex_flow='row',
                    align_items='stretch',
                    border='solid',
                    width='150%')

w = interactive(visualize_z_sample, **z_interactive)
w.layout = hlay

### This acts as a predictor for the z-mu values that produce every frame

# for i, datum in enumerate(v.trainloader_prediction):
#     orig_datum = Variable(datum[0])
#     in_datum = Variable(datum[1].cuda())
#     v.inference = True
#     Z, z_mu, z_var, out_datum = v.VAE(in_datum)
#     print(Z)

display(w)

In [12]:
### Which model to visualize, sets include {[4,64,512]} and epochs in [100,200,...,1000]

model_choice = 'F'
model_version = "Zalamity Joint"
model_type = "Convolutional"
resource_path = "/scratch/Jack/resources/Attention Maps/playtype maps/"
result_path =  "/scratch/Jack/research lab/NFL Playtype VAEs/"

v = VAE_Adventures(log=False,whole_dataset=False, model_version=model_version, model_choice=model_choice, model_type=model_type, resource_path=resource_path, result_path=result_path)
v.load_model(epoch=1000)
v.VAE.cpu()
v.VAE.cpu_mode_sampling()

model_choice = 'local'
model_version = "Zalamity Joint"
model_type = "Conditional"

cVAE = VAE_Adventures(log=False, model_version=model_version, model_choice=model_choice, model_type=model_type, resource_path=resource_path, result_path=result_path)
cVAE.load_model(epoch=120)
cVAE.VAE.cpu()
cVAE.VAE.cpu_mode_sampling()

'''
model_choice = 'F'
model_version = "Serenity Joint"
model_type = "Convolutional"
resource_path = "/scratch/Jack/resources/Attention Maps/view maps/"
result_path =  "/scratch/Jack/research lab/NFL Viewpoint VAEs/"

v = VAE_Adventures(log=False,whole_dataset=True, model_version=model_version, model_choice=model_choice, model_type=model_type, resource_path=resource_path, result_path=result_path)
v.load_model(epoch=1200)
v.VAE.cpu()
v.VAE.cpu_mode_sampling()

model_choice = 'local'
model_version = "Serenity Joint"
model_type = "Conditional"

cVAE = VAE_Adventures(log=False, model_version=model_version, model_choice=model_choice, model_type=model_type, resource_path=resource_path, result_path=result_path)
cVAE.load_model(epoch=1200)
cVAE.VAE.cpu()
cVAE.VAE.cpu_mode_sampling()

'''
'''
model_choice = 'F'
model_version = "Serenity"
model_type = "Convolutional"
resource_path = "/scratch/Jack/resources/Attention Maps/view maps/"
result_path =  "/scratch/Jack/research lab/NFL Viewpoint VAEs/"

v = VAE_Adventures(log=False,whole_dataset=True, model_version=model_version, model_choice=model_choice, model_type=model_type, resource_path=resource_path, result_path=result_path)
v.load_model(epoch=1000)
v.VAE.cpu()
v.VAE.cpu_mode_sampling()

model_choice = 'local'
model_version = "SerenityLarge"
model_type = "Conditional"

cVAE = VAE_Adventures(log=False, model_version=model_version, model_choice=model_choice, model_type=model_type, resource_path=resource_path, result_path=result_path)
cVAE.load_model(epoch=100)
cVAE.VAE.cpu()
cVAE.VAE.cpu_mode_sampling()

'''
z_star_dim = 8
z_sub = 8 # how many z dimensions are interactive)
z_sub = min(z_sub, z_star_dim)
### ------------------------------------------------------------------------------------

from bokeh.io import  push_notebook, show, output_notebook, output_file
from bokeh.layouts import column, layout, row, widgetbox
from bokeh.plotting import figure, curdoc
from bokeh.models import CustomJS
from bokeh.events import ButtonClick
from bokeh.models.widgets import Tabs, Panel, Slider, Button


output_notebook()

def grab_random_image_context():
    for i, datum_batch in enumerate(v.trainloader_prediction):
        if random.random() < 1.0/len(v.trainloader_prediction) or i == len(v.trainloader_prediction)-2:
            print(i, len(v.trainloader_prediction))
            return datum_batch
    
sample_datum_batch = grab_random_image_context() 
v.VAE.inference = True
cVAE.VAE.inference = True    

bokeh_handle = None
imgs_contextual = []

def prepare_contextual_sample():
    sampled_context=np.zeros((1, 8),np.uint8)
    attention_near=np.ones((128, 128),np.float32)
    imgs_contextual.append(figure(plot_width=128*2, plot_height=12*2, x_range=[0, 128], y_range=[128, 0],  x_axis_location="above",  title=""))
    imgs_contextual.append(figure(plot_width=128*2, plot_height=128*2, x_range=[0, 128], y_range=[128, 0],  x_axis_location="above",  title=""))
    imgs_contextual[0].image([sampled_context], x=0, y=128, dw=128, dh=128)
    imgs_contextual[0].axis.visible = False
    imgs_contextual[0].toolbar_location = None
    imgs_contextual[1].image([attention_near], x=0, y=128, dw=128, dh=128)
    imgs_contextual[1].axis.visible = False
    imgs_contextual[1].toolbar_location = None

def update_contextual_sample(name, value, random):
    
    z_star = Variable(torch.from_numpy(np.random.normal(0, 1.0, z_star_dim)).float().unsqueeze(0))
    if random == "notrandom":
        for k in range(z_sub):
            if z_interactive[k].title == name:
                z_interactive[k].value = float(value)

    
        for k in range(z_sub):
            z_star[0,k] = z_interactive[k].value 
    
    
    
    sample_z1 = Variable(sample_datum_batch[1][25].unsqueeze(0)) # condition z (input frame z)
    
    z_condition,_,_,_ = v.VAE(sample_z1)
    
    
    z_sampled_near = cVAE.VAE.decode(z_star, z_condition)
    
    attention_near = v.VAE.decode(z_sampled_near)
    attention_near = np.flip(attention_near[0][0].data.cpu().numpy(),0)
    
    sampled_context = np.flip(z_sampled_near.data.cpu().numpy(),0)
    #sampled_context[sampled_context < 0] = -1
    #sampled_context[sampled_context > 0] = 1
    imgs_contextual[0].image([sampled_context], x=0, y=128, dw=128, dh=128)
    imgs_contextual[1].image([attention_near], x=0, y=128, dw=128, dh=128)
    push_notebook(handle=bokeh_handle)
    

button_callback = CustomJS(code="""
if (IPython.notebook.kernel !== undefined) {
    var kernel = IPython.notebook.kernel;
    cmd = "update_contextual_sample(0,0,random)";
    kernel.execute(cmd, {}, {});}
""")

callback = CustomJS(code="""
    if (IPython.notebook.kernel !== undefined) {
    var kernel = IPython.notebook.kernel;
    cmd = "update_contextual_sample('"+cb_obj.title+"','"+cb_obj.value+"', 'notrandom')";
    kernel.execute(cmd, {}, {});}
""")

                           
prepare_contextual_sample()
z_interactive = [Slider(start=-3.0, end=3.0, value=0.0, step=0.01, title="z_{}".format(k), callback=callback, callback_policy="mouseup") for k in range(z_sub)]

#update_contextual_sample("z_3",3, 'notrandom')
#z_interactive[0].value = 0.1
#update_contextual_sample(3)

button = Button(label="Random sampling")
button.js_on_event(ButtonClick, button_callback)

# Visualize original images and their attention maps
imgs = []
imgs_next = []
imgs_z_recon = []
imgs_samples = []
imgs_samples_nouvo = []

for i in range(3):
    
    sample_z1 = Variable(sample_datum_batch[1][25*i].unsqueeze(0)) # condition z (input frame z)
    
    z,_,_,_ = v.VAE(sample_z1)
    r = z.data
    # why is z scaling invariant?
    #r[0,0:128] *= 0.2
    # and why is it so dependent on only a few values
    #r[0,0:1] *= 0
    z = Variable(r)
    recon_z1 = v.VAE.decode(z)
    orig = torch.div(sample_datum_batch[0][25*i],torch.max(sample_datum_batch[0][25*i]))
    orig = torch.clamp(orig,0,1).numpy()
    orig = np.swapaxes(np.swapaxes(np.multiply(orig,255), 0, 2), 1,0).astype(np.uint8)[::-1]
    orig = np.dstack([orig, 255*np.ones(orig.shape[:2], np.uint8)])
    imgs.append(figure(plot_width=128*2, plot_height=128*2, x_range=[0, 128], y_range=[128, 0],  x_axis_location="above",  title=""))
    imgs[i].image_rgba([orig], x=0, y=128, dw=128, dh=128)
    imgs[i].axis.visible = False
    imgs[i].toolbar_location = None
    
    att = np.flip(sample_datum_batch[1][25*i][0].numpy(),0)
    att[0][0] = 0.001
    imgs_next.append(figure(plot_width=128*2, plot_height=128*2, x_range=[0, 128], y_range=[128, 0],  x_axis_location="above",  title=""))
    imgs_next[i].image([att], x=0, y=128, dw=128, dh=128)
    imgs_next[i].axis.visible = False
    imgs_next[i].toolbar_location = None
    
    attrr = np.flip(recon_z1[0][0].data.cpu().numpy(),0)
    imgs_z_recon.append(figure(plot_width=128*2, plot_height=128*2, x_range=[0, 128], y_range=[128, 0],  x_axis_location="above",  title=""))
    imgs_z_recon[i].image([attrr], x=0, y=128, dw=128, dh=128)
    imgs_z_recon[i].axis.visible = False
    imgs_z_recon[i].toolbar_location = None
    
    sample_zz = np.flip(z.data.cpu().numpy(),0)
    #sample_zz[sample_zz < 0] = -1
    #sample_zz[sample_zz > 0] = 1
        
    imgs_samples.append(figure(plot_width=128*2, plot_height=12*2, x_range=[0, 128], y_range=[128, 0],  x_axis_location="above",  title=""))
    imgs_samples[i].image([sample_zz], x=0, y=128, dw=128, dh=128)
    imgs_samples[i].axis.visible = False
    imgs_samples[i].toolbar_location = None
    
    
    imgs_samples_nouvo.append(figure(plot_width=128*2, plot_height=12*2, x_range=[0, 128], y_range=[128, 0],  x_axis_location="above",  title=""))
    imgs_samples_nouvo[i].image([sample_zz], x=0, y=128, dw=128, dh=128)
    imgs_samples_nouvo[i].axis.visible = False
    imgs_samples_nouvo[i].toolbar_location = None
    
    

w_tab_images = Panel(child=column(row(imgs),row(imgs_next),row(imgs_z_recon), row(imgs_samples)), title="image comparisons")
w_tab = Panel(child=row(column(z_interactive), button,  column(column(imgs_samples_nouvo[0], imgs_samples_nouvo[2], imgs_contextual[0],  imgs_contextual[1]))), title="z sampling")

tabs = Tabs(tabs=[w_tab, w_tab_images])
# show results

bokeh_handle = show(tabs, notebook_handle=True)

#display(w0)
#display(w)

### This acts as a predictor for the z-mu values that produce every frame

# for i, datum in enumerate(v.trainloader_prediction):
#     orig_datum = Variable(datum[0])
#     in_datum = Variable(datum[1].cuda())
#     v.inference = True
#     Z, z_mu, z_var, out_datum = v.VAE(in_datum)
#     print(Z)

#   Validation set

# validate = True
#
# if validate:
#     for i, datum in enumerate(v.testloader_prediction):
#
#         orig_datum = Variable(datum[0])
#         in_datum = Variable(datum[1].cuda())
#         v.inference = True
#         Z, z_mu, z_var, out_datum = v.VAE(in_datum)
#         loss = v.loss_function(out_datum, in_datum,  z_mu, z_var)
#         print(Z, loss)


Using directory : /scratch/Jack/research lab/NFL Playtype VAEs/Convolutional VAE results/Zalamity Joint F Convolutional
Using directory : /scratch/Jack/research lab/NFL Playtype VAEs/Conditional VAE results/Zalamity Joint local Conditional


9 48


# Clustering

In [52]:
%matplotlib inline

import matplotlib.pyplot as plt
import skfuzzy as fuzz

load_models = True
cluster_type = "raw"
cluster_version = "play"
cluster_append = True

if load_models:
    model_choice = 'F'
    model_version = "Serenity Joint Su"
    #model_version = "Zalamity Joint"
    model_type = "Convolutional"
    resource_path = "/scratch/Jack/resources/Attention Maps/view maps/"
    result_path =  "/scratch/Jack/research lab/NFL Viewpoint VAEs/"

    v = VAE_Adventures(log=False, whole_dataset=False, model_version=model_version, model_choice=model_choice, model_type=model_type, resource_path=resource_path, result_path=result_path)
    v.load_model(epoch=1000)
    v.VAE.cpu()
    v.VAE.cpu_mode_sampling()

    model_choice = 'local'
    model_type = "Conditional"

    cVAE = VAE_Adventures(log=False,  model_version=model_version, model_choice=model_choice, model_type=model_type, resource_path=resource_path, result_path=result_path)
    cVAE.load_model(epoch=120)
    cVAE.VAE.cpu()
    cVAE.VAE.cpu_mode_sampling()

    z_star_dim = 8


    v.VAE.inference = True
    cVAE.VAE.inference = True

    window_context_choice[model_type] = (2,True) 

if cluster_append:
    clustering_data = {0 : [], 1: [], 2: []}
    if cluster_type == "raw":
        for i, datum in enumerate(v.trainloader_prediction):
            for j in range(datum[0].size()[0]):
                
                clustering_data[0].append(datum[0][j])
                clustering_data[1].append(datum[1][j])
                clustering_data[2].append(datum[1][j].view(1,64*64))   
            if i == 30:
                break
    elif cluster_type == "spatial":
        for i, datum in enumerate(v.trainloader_prediction):
            print(i)
            for j in range(datum[0].size()[0]):
                print("datum size", j, datum[1].size()[0], datum[1].size())
                
                z, _, _, _ = v.VAE(Variable(datum[1][j].unsqueeze(0)))      
                z = z.squeeze(0)
                z = z.unsqueeze(1)
                
                if z.size()[0] == 32:
                    clustering_data[0].append(datum[0][j])
                    clustering_data[1].append(datum[1][j])
                    clustering_data[2].append(z)  
                    print(z.size())
            if i == 30:
                break
            print(i)
    else:
        for i, datum in enumerate(v.trainloader_prediction):
            print(i)
            for j in range(datum[0].size()[0]):
                #print("datum size", j, datum[1].size()[0], datum[1].size())
            
                
                z_stars = None
                for k in range(-2,3):
                    if k != 0:
                        if j+k >= 0 and j+k < datum[1].size()[0]:
                            z, _, _, _ = v.VAE(Variable(datum[1][j].unsqueeze(0)))
                            z_near, _, _, _ = v.VAE(Variable(datum[1][j+k].unsqueeze(0)))
                            z_c = torch.cat([z, z_near],1)

                            z_star,_,_,_ = cVAE.VAE(z, z_c)
                            z_star = z_star.squeeze(0)
                            z_star = z_star.unsqueeze(1)
                            if z_stars is None:
                                z_stars = z_star
                            else:
                                z_stars = torch.cat([z_stars, z_star], 0)
                z = z.squeeze(0)
                z = z.unsqueeze(1)
                z_stars = torch.cat([z_stars, z], 0)
                
                if z_stars.size()[0] == 64:
                    clustering_data[0].append(datum[0][j])
                    clustering_data[1].append(datum[1][j])
                    clustering_data[2].append(z_stars)  
            if i == 10:
                break
            print(i)

# Done appending clusters

print("Done appending")
if cluster_type == "raw":
    numpy_cluster_data = torch.cat(clustering_data[2], 0).numpy()
else:
    numpy_cluster_data = torch.cat(clustering_data[2], 1).data.numpy()

print(numpy_cluster_data.shape)
fpcs = []

colors = ['b', 'orange', 'g', 'r', 'c', 'm', 'y', 'k', 'Brown', 'ForestGreen']
for ncenters in range(5,35,5):
    cntr, u, u0, d, jm, p, fpc = fuzz.cluster.cmeans(numpy_cluster_data, ncenters, 2, error=0.005, maxiter=1000, init=None)

    # Store fpc values for later
    fpcs.append(fpc)

    # Plot assigned clusters, for each data point in training set
    cluster_membership = np.argmax(u, axis=0)
    print(cluster_membership)

    # Mark the center of each fuzzy cluster

    directory = "/scratch/Jack/research lab/clusters_final/"+cluster_version+"/"+cluster_type+"/centers_"+str(ncenters)+"/"
    for k in range(ncenters):
        if not os.path.exists(directory+str(k)):
            os.makedirs(directory+str(k))

    # Directory finished
    print("Preparing to save images")
    
    my_clustering_data = [cntr, u, u0, d, jm, p, fpc]
    pickle.dump(my_clustering_data, open(directory+"/my_clustering_data.pkl","wb"))

    oracle = Oracle()
    for i in range(len(clustering_data[0])):

        datum_orig = clustering_data[0][i]
        datum_attention = clustering_data[1][i]
        cluster_folder = cluster_membership[i]
        dir_save = directory+str(cluster_folder)+"/"
        oracle.visualize_tensors([datum_orig.unsqueeze(0)],file=dir_save+str(i).zfill(3))
        oracle.visualize_tensors([datum_attention.unsqueeze(0)],file=dir_save+str(i).zfill(3)+str("att"), normalize=True)




Using directory : /scratch/Jack/research lab/NFL Viewpoint VAEs/Convolutional VAE results/Serenity Joint Su F Convolutional


FileNotFoundError: [Errno 2] No such file or directory: '/scratch/Jack/research lab/NFL Viewpoint VAEs/Convolutional VAE results/Serenity Joint Su F Convolutional/models/model_1000.model'

In [20]:
%matplotlib inline

import matplotlib.pyplot as plt
import skfuzzy as fuzz

load_models = True
cluster_type = "temporal"
cluster_version = "play"
cluster_append = True

if load_models:
    model_choice = 'F'
    model_version = "Zalamity Joint"
    model_type = "Convolutional"
    resource_path = "/scratch/Jack/resources/Attention Maps/playtype maps/"
    result_path =  "/scratch/Jack/research lab/NFL Playtype VAEs/"

    v = VAE_Adventures(log=False, whole_dataset=False, model_version=model_version, model_choice=model_choice, model_type=model_type, resource_path=resource_path, result_path=result_path)
    v.load_model(epoch=1000)
    v.VAE.cpu()
    v.VAE.cpu_mode_sampling()

    model_choice = 'local'
    model_version = "Zalamity Joint"
    model_type = "Conditional"

    cVAE = VAE_Adventures(log=False,  model_version=model_version, model_choice=model_choice, model_type=model_type, resource_path=resource_path, result_path=result_path)
    cVAE.load_model(epoch=120)
    cVAE.VAE.cpu()
    cVAE.VAE.cpu_mode_sampling()

    z_star_dim = 8


    v.VAE.inference = True
    cVAE.VAE.inference = True

    window_context_choice[model_type] = (2,True) 

if cluster_append:
    clustering_data = {0 : [], 1: [], 2: []}
    if cluster_type == "raw":
        for i, datum in enumerate(v.trainloader_prediction):
            for j in range(datum[0].size()[0]):
                
                clustering_data[0].append(datum[0][j])
                clustering_data[1].append(datum[1][j])
                clustering_data[2].append(datum[1][j].view(1,64*64))   
            if i == 32:
                break
    elif cluster_type == "spatial":
        for i, datum in enumerate(v.trainloader_prediction):
            print(i)
            for j in range(datum[0].size()[0]):
                #print("datum size", j, datum[1].size()[0], datum[1].size())
                
                z, _, _, _ = v.VAE(Variable(datum[1][j].unsqueeze(0)))      
                z = z.squeeze(0)
                z = z.unsqueeze(1)
                
                if z.size()[0] == 32:
                    clustering_data[0].append(datum[0][j])
                    clustering_data[1].append(datum[1][j])
                    clustering_data[2].append(z)  
                    #print(z.size())
            if i == 32:
                break
            print(i)
    else:
        for i, datum in enumerate(v.trainloader_prediction):
            print(i)
            for j in range(datum[0].size()[0]):
                print(j)
                #print("datum size", j, datum[1].size()[0], datum[1].size())
            
                
                z_stars = None
                for k in range(-2,3):
                    print(k)
                    if k != 0:
                        if j+k >= 0 and j+k < datum[1].size()[0]:
                            z, _, _, _ = v.VAE(Variable(datum[1][j].unsqueeze(0)))
                            z_near, _, _, _ = v.VAE(Variable(datum[1][j+k].unsqueeze(0)))
                            z_c = torch.cat([z, z_near],1)

                            z_star,_,_,_ = cVAE.VAE(z, z_c)
                            z_star = z_star.squeeze(0)
                            z_star = z_star.unsqueeze(1)
                            if z_stars is None:
                                z_stars = z_star
                            else:
                                z_stars = torch.cat([z_stars, z_star], 0)
                z = z.squeeze(0)
                z = z.unsqueeze(1)
                z_stars = torch.cat([z_stars, z], 0)
                
                if z_stars.size()[0] == 64:
                    clustering_data[0].append(datum[0][j])
                    clustering_data[1].append(datum[1][j])
                    clustering_data[2].append(z_stars)  
            if i == 10:
                break
            print(i)

# Done appending clusters

print("Done appending")
if cluster_type == "raw":
    numpy_cluster_data = torch.cat(clustering_data[2], 0).numpy()
else:
    numpy_cluster_data = torch.cat(clustering_data[2], 1).data.numpy()

print(numpy_cluster_data.shape)
fpcs = []

colors = ['b', 'orange', 'g', 'r', 'c', 'm', 'y', 'k', 'Brown', 'ForestGreen']
for ncenters in range(30,35,5):
    cntr, u, u0, d, jm, p, fpc = fuzz.cluster.cmeans(numpy_cluster_data, ncenters, 2, error=0.005, maxiter=1000, init=None)

    # Store fpc values for later
    fpcs.append(fpc)

    # Plot assigned clusters, for each data point in training set
    cluster_membership = np.argmax(u, axis=0)
    print(cluster_membership)

    # Mark the center of each fuzzy cluster

    directory = "/scratch/Jack/research lab/clusters_final2/"+cluster_version+"/"+cluster_type+"/centers_"+str(ncenters)+"/"
    for k in range(ncenters):
        if not os.path.exists(directory+str(k)):
            os.makedirs(directory+str(k))

    # Directory finished
    print("Preparing to save images")
    
    my_clustering_data = [cntr, u, u0, d, jm, p, fpc]
    pickle.dump(my_clustering_data, open(directory+"/my_clustering_data.pkl","wb"))

    oracle = Oracle()
    
    for i in range(len(cluster_membership)):

        datum_orig = clustering_data[0][i]
        datum_attention = clustering_data[1][i]
        cluster_folder = cluster_membership[i]
        dir_save = directory+str(cluster_folder)+"/"
        oracle.visualize_tensors([datum_orig.unsqueeze(0)],file=dir_save+str(i).zfill(3), normalize=True)
        oracle.visualize_tensors([datum_attention.unsqueeze(0)],file=dir_save+str(i).zfill(3)+str("_att"))



Using directory : /scratch/Jack/research lab/NFL Playtype VAEs/Convolutional VAE results/Zalamity Joint F Convolutional
Using directory : /scratch/Jack/research lab/NFL Playtype VAEs/Conditional VAE results/Zalamity Joint local Conditional
0
0
-2
-1
0
1
2
1
-2
-1
0
1
2
2
-2
-1
0
1
2
3
-2
-1
0
1
2
4
-2
-1
0
1
2
5
-2
-1
0
1
2
6
-2
-1
0
1
2
7
-2
-1
0
1
2
8
-2
-1
0
1
2
9
-2
-1
0
1
2
10
-2
-1
0
1
2
11
-2
-1
0
1
2
12
-2
-1
0
1
2
13
-2
-1
0
1
2
14
-2
-1
0
1
2
15
-2
-1
0
1
2
16
-2
-1
0
1
2
17
-2
-1
0
1
2
18
-2
-1
0
1
2
19
-2
-1
0
1
2
20
-2
-1
0
1
2
21
-2
-1
0
1
2
22
-2
-1
0
1
2
23
-2
-1
0
1
2
24
-2
-1
0
1
2
25
-2
-1
0
1
2
26
-2
-1
0
1
2
27
-2
-1
0
1
2
28
-2
-1
0
1
2
29
-2
-1
0
1
2
30
-2
-1
0
1
2
31
-2
-1
0
1
2
32
-2
-1
0
1
2
33
-2
-1
0
1
2
34
-2
-1
0
1
2
35
-2
-1
0
1
2
36
-2
-1
0
1
2
37
-2
-1
0
1
2
38
-2
-1
0
1
2
39
-2
-1
0
1
2
40
-2
-1
0
1
2
41
-2
-1
0
1
2
42
-2
-1
0
1
2
43
-2
-1
0
1
2
44
-2
-1
0
1
2
45
-2
-1
0
1
2
46
-2
-1
0
1
2
47
-2
-1
0
1
2
48
-2
-1
0
1
2
49
-2
-1
0
1
2
50
-2
-1
0
1
2
51


2
14
-2
-1
0
1
2
15
-2
-1
0
1
2
16
-2
-1
0
1
2
17
-2
-1
0
1
2
18
-2
-1
0
1
2
19
-2
-1
0
1
2
20
-2
-1
0
1
2
21
-2
-1
0
1
2
22
-2
-1
0
1
2
23
-2
-1
0
1
2
24
-2
-1
0
1
2
25
-2
-1
0
1
2
26
-2
-1
0
1
2
27
-2
-1
0
1
2
28
-2
-1
0
1
2
29
-2
-1
0
1
2
30
-2
-1
0
1
2
31
-2
-1
0
1
2
32
-2
-1
0
1
2
33
-2
-1
0
1
2
34
-2
-1
0
1
2
35
-2
-1
0
1
2
36
-2
-1
0
1
2
37
-2
-1
0
1
2
38
-2
-1
0
1
2
39
-2
-1
0
1
2
40
-2
-1
0
1
2
41
-2
-1
0
1
2
42
-2
-1
0
1
2
43
-2
-1
0
1
2
44
-2
-1
0
1
2
45
-2
-1
0
1
2
46
-2
-1
0
1
2
47
-2
-1
0
1
2
48
-2
-1
0
1
2
49
-2
-1
0
1
2
50
-2
-1
0
1
2
51
-2
-1
0
1
2
52
-2
-1
0
1
2
53
-2
-1
0
1
2
54
-2
-1
0
1
2
55
-2
-1
0
1
2
56
-2
-1
0
1
2
57
-2
-1
0
1
2
58
-2
-1
0
1
2
59
-2
-1
0
1
2
60
-2
-1
0
1
2
61
-2
-1
0
1
2
62
-2
-1
0
1
2
63
-2
-1
0
1
2
64
-2
-1
0
1
2
65
-2
-1
0
1
2
66
-2
-1
0
1
2
67
-2
-1
0
1
2
68
-2
-1
0
1
2
69
-2
-1
0
1
2
70
-2
-1
0
1
2
71
-2
-1
0
1
2
72
-2
-1
0
1
2
73
-2
-1
0
1
2
74
-2
-1
0
1
2
75
-2
-1
0
1
2
76
-2
-1
0
1
2
77
-2
-1
0
1
2
78
-2
-1
0
1
2
79
-2
-1
0
1
2
80
-2
-1

2
43
-2
-1
0
1
2
44
-2
-1
0
1
2
45
-2
-1
0
1
2
46
-2
-1
0
1
2
47
-2
-1
0
1
2
48
-2
-1
0
1
2
49
-2
-1
0
1
2
50
-2
-1
0
1
2
51
-2
-1
0
1
2
52
-2
-1
0
1
2
53
-2
-1
0
1
2
54
-2
-1
0
1
2
55
-2
-1
0
1
2
56
-2
-1
0
1
2
57
-2
-1
0
1
2
58
-2
-1
0
1
2
59
-2
-1
0
1
2
60
-2
-1
0
1
2
61
-2
-1
0
1
2
62
-2
-1
0
1
2
63
-2
-1
0
1
2
64
-2
-1
0
1
2
65
-2
-1
0
1
2
66
-2
-1
0
1
2
67
-2
-1
0
1
2
68
-2
-1
0
1
2
69
-2
-1
0
1
2
70
-2
-1
0
1
2
71
-2
-1
0
1
2
72
-2
-1
0
1
2
73
-2
-1
0
1
2
74
-2
-1
0
1
2
75
-2
-1
0
1
2
76
-2
-1
0
1
2
77
-2
-1
0
1
2
78
-2
-1
0
1
2
79
-2
-1
0
1
2
80
-2
-1
0
1
2
81
-2
-1
0
1
2
82
-2
-1
0
1
2
83
-2
-1
0
1
2
84
-2
-1
0
1
2
85
-2
-1
0
1
2
86
-2
-1
0
1
2
87
-2
-1
0
1
2
88
-2
-1
0
1
2
89
-2
-1
0
1
2
90
-2
-1
0
1
2
91
-2
-1
0
1
2
92
-2
-1
0
1
2
93
-2
-1
0
1
2
94
-2
-1
0
1
2
95
-2
-1
0
1
2
96
-2
-1
0
1
2
97
-2
-1
0
1
2
98
-2
-1
0
1
2
99
-2
-1
0
1
2
100
-2
-1
0
1
2
101
-2
-1
0
1
2
102
-2
-1
0
1
2
103
-2
-1
0
1
2
104
-2
-1
0
1
2
105
-2
-1
0
1
2
106
-2
-1
0
1
2
107
-2
-1
0
1
2
108
-2
-1
0
1
2

In [50]:
load_models = True

if load_models:
    model_choice = 'F'
    model_version = "Zalamity Joint"
    #model_version = "Serenity Joint"
    model_type = "Convolutional"
    resource_path = "/scratch/Jack/resources/Attention Maps/playtype maps/"
    result_path =  "/scratch/Jack/research lab/NFL Playtype VAEs/"

    v = VAE_Adventures(log=False, whole_dataset=False, model_version=model_version, model_choice=model_choice, model_type=model_type, resource_path=resource_path, result_path=result_path)
    v.load_model(epoch=1000)
    v.VAE.cpu()
    v.VAE.cpu_mode_sampling()

    model_choice = 'local'
    model_type = "Conditional"

    cVAE = VAE_Adventures(log=False,  model_version=model_version, model_choice=model_choice, model_type=model_type, resource_path=resource_path, result_path=result_path)
    cVAE.load_model(epoch=120)
    cVAE.VAE.cpu()
    cVAE.VAE.cpu_mode_sampling()

    z_star_dim = 8


    v.VAE.inference = True
    cVAE.VAE.inference = True

    window_context_choice[model_type] = (2,True)
    
cluster_type = "temporal"
cluster_version = "play"
ncenters=30

directory = "/scratch/Jack/research lab/clusters_final/"+cluster_version+"/"+cluster_type+"/centers_"+str(ncenters)+"/"
    
my_clustering_data=pickle.load(open(directory+"/my_clustering_data.pkl","rb"))
first_seq = my_clustering_data[1]

item_list = []
for i in range(32):
    print(np.argmax(first_seq[:,i]))
    item_list.append(int(np.argmax(first_seq[:,i])))
print(item_list)
print(len(v.trainloader_prediction))
for i, datum in enumerate(v.trainloader_prediction):
    print(i)
    dirr = "/scratch/Jack/VISAVIS/"
    
    oracle.visualize_tensors([datum[0][:]],file=dirr+"/image_seq2"+str(i)+"_"+str(j).zfill(2), normalize=True) 
                        
        
    break

Using directory : /scratch/Jack/research lab/NFL Playtype VAEs/Convolutional VAE results/Zalamity Joint F Convolutional
Using directory : /scratch/Jack/research lab/NFL Playtype VAEs/Conditional VAE results/Zalamity Joint local Conditional
11
18
18
18
18
14
14
14
14
14
14
0
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
12
[11, 18, 18, 18, 18, 14, 14, 14, 14, 14, 14, 0, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12]
48
0


In [None]:
# membership for a sample sequence over the spatiotemporal codes
# load from my_clustering_data.pkl for memberships over inference video

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage as ndi
from bokeh.plotting import figure, show, output_file
from skimage.morphology import watershed
from skimage.feature import peak_local_max
import torch
from skimage.filters import sobel
from skimage.morphology import watershed

import random
from skimage.feature import canny
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage as ndi

from skimage.morphology import watershed
from skimage.feature import peak_local_max
import sys

while True:
    i = random.randint(0,108)
    if cluster_membership[i] != 1:
        continue
    else:
        datum = [clustering_data[0][i], clustering_data[1][i]]
    
    attention_image = datum[1][0].numpy()
    original_image = datum[0]
    
    markers = np.zeros_like(attention_image)
    thres = 0.0001
    markers[attention_image < thres] = 1
    markers[attention_image >= thres] = 2
    
    elevation_map = sobel(attention_image)
    
    segmentation = watershed(elevation_map, markers)
    segmentation = ndi.binary_fill_holes(segmentation - 1)
    
    segmentor, _ = ndi.label(segmentation)
    
    
    
    from bokeh.io import output_file, show
    from bokeh.layouts import column
    from bokeh.plotting import figure
    from bokeh.layouts import row
    orig = torch.clamp(original_image,0,1).numpy()
    orig = np.swapaxes(np.swapaxes(np.multiply(orig,255), 0, 2), 1,0).astype(np.uint8)[::-1]
    orig = np.dstack([orig, 255*np.ones(orig.shape[:2], np.uint8)])
    p = []
    
    p.append(figure(plot_width=124*2, plot_height=128*2, x_range=[0, 128], y_range=[128, 0],  x_axis_location="above",  title=""))
    p[0].image_rgba([orig], x=0, y=128, dw=128, dh=128)
    p.append(figure(plot_width=124*2, plot_height=128*2, x_range=[0, 128], y_range=[128, 0],  x_axis_location="above",  title=""))
    attention_image = attention_image[::-1]
    p[1].image([attention_image], x=0, y=128, dw=128, dh=128)
    segmentor = segmentor[::-1]
    p.append(figure(plot_width=124*2, plot_height=128*2, x_range=[0, 128], y_range=[128, 0],  x_axis_location="above",  title=""))
    p[2].image([segmentor], x=0, y=128, dw=128, dh=128)
    
    seg_threshold = 20
    p_attention_clusters = []
    p_attention_boxes = []
    thres_vis = 0.1
    for i in range(1,_+1):#_):
        seg = np.copy(segmentor)
        att = np.copy(attention_image)
        att[seg != i] = 0
        
        im = np.copy(attention_image)
        im[attention_image == 0] = 0
        im[att == 0] = 0
        
        x_ranges = [128, 0]
        y_ranges = [128, 0]
        for x in range(128):
            for y in range(128):
                if np.sum(im[y,x]) > 0:
                    if x > x_ranges[1]:
                        x_ranges[1] = x
                    if 128-y > y_ranges[1]:
                        y_ranges[1] = 128-y
                    if x < x_ranges[0]:
                        x_ranges[0] = x
                    if 128-y < y_ranges[0]:
                        y_ranges[0] = 128-y
        
        p_attention_clusters.append((figure(plot_width=128*2, plot_height=128*2, x_range=[0, 128], y_range=[128, 0],  x_axis_location="above",  title=""),(x_ranges, y_ranges)))
        dd = np.copy(orig)
        dd[im < thres_vis] = 0
        p_attention_clusters[i-1][0].image([im], x=0, y=128, dw=128, dh=128, alpha=0.2) 
        p_attention_clusters[i-1][0].image_rgba([dd], x=0, y=128, dw=128, dh=128, alpha=0.2) 
        
        p_attention_clusters[i-1][0].quad(x_ranges[0], x_ranges[1], y_ranges[0], y_ranges[1], fill_alpha=0.5)
        if int(np.sum(im)) < seg_threshold:
            p_attention_clusters[i-1] = None
            
    p_attention_cluster_figures = [x[0] for x in p_attention_clusters if x != None]
    p_attention_cluster_ranges = [x[1] for x in p_attention_clusters if x != None]
    
    
    orig_picture_final = figure(plot_width=128*2, plot_height=128*2, x_range=[0, 128], y_range=[128, 0],  x_axis_location="above",  title="")
    orig_picture_final.image_rgba([orig], x=0, y=128, dw=128, dh=128)
    for i in range(len(p_attention_cluster_ranges)):
        x_ranges, y_ranges = p_attention_cluster_ranges[i]
        orig_picture_final.quad(x_ranges[0], x_ranges[1], y_ranges[0], y_ranges[1],line_color="black",fill_color=(random.randint(0,255), random.randint(0,255), random.randint(0,255), 0.3))
        orig_picture_final.axis.visible = False
    

    orig_vanish = np.copy(orig)
    orig_vanish[attention_image < thres_vis] = 0
    orig_vanish[segmentor == 0] = 0
    
    orig_picture_final_extracted = figure(plot_width=128*2, plot_height=128*2, x_range=[0, 128], y_range=[128, 0],  x_axis_location="above",  title="")
    orig_picture_final_extracted.image_rgba([orig_vanish], x=0, y=128, dw=128, dh=128)
    for i in range(len(p_attention_cluster_ranges)):
        x_ranges, y_ranges = p_attention_cluster_ranges[i]
        orig_picture_final_extracted.quad(x_ranges[0], x_ranges[1], y_ranges[0], y_ranges[1],line_color="black",fill_color=(random.randint(0,255), random.randint(0,255), random.randint(0,255), 0.3))
    
    show(column(row(*p),row(*p_attention_cluster_figures),row([orig_picture_final]), row([orig_picture_final_extracted])))
    break

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage as ndi
from bokeh.plotting import figure, show, output_file
from skimage.morphology import watershed
from skimage.feature import peak_local_max
import torch
from skimage.filters import sobel
from skimage.morphology import watershed

import random
from skimage.feature import canny
import numpy as np
import matplotlib.pyplot as plt
from scipy import ndimage as ndi

from skimage.morphology import watershed
from skimage.feature import peak_local_max
import sys


for i in range(14+18,14+32, 3):
    datum = [clustering_data[0][i], clustering_data[1][i]]
    z_star = Variable(torch.from_numpy(np.random.normal(0, 2.0, 128)).float().unsqueeze(0))
    z, _,_,_ = v.VAE(Variable(clustering_data[1][i].unsqueeze(0)))
    z_out = cVAE.VAE.decode(z_star,z)
    print(z_out.size())
    map_out = v.VAE.decode(z_out).squeeze(0).squeeze(0)
    
    
    attention_image = datum[1][0].numpy()
    original_image = datum[0]
    
    # BRILLIANT
    attention_image = map_out.data.numpy()
    
    markers = np.zeros_like(attention_image)
    thres = 0.0001
    markers[attention_image < thres] = 1
    markers[attention_image >= thres] = 2
    
    elevation_map = sobel(attention_image)
    
    segmentation = watershed(elevation_map, markers)
    segmentation = ndi.binary_fill_holes(segmentation - 1)
    
    segmentor, _ = ndi.label(segmentation)
    
    
    
    from bokeh.io import output_file, show
    from bokeh.layouts import column
    from bokeh.plotting import figure
    from bokeh.layouts import row
    orig = torch.clamp(original_image,0,1).numpy()
    orig = np.swapaxes(np.swapaxes(np.multiply(orig,255), 0, 2), 1,0).astype(np.uint8)[::-1]
    orig = np.dstack([orig, 255*np.ones(orig.shape[:2], np.uint8)])
    p = []
    
    p.append(figure(plot_width=124*2, plot_height=128*2, x_range=[0, 128], y_range=[128, 0],  x_axis_location="above",  title=""))
    p[0].image_rgba([orig], x=0, y=128, dw=128, dh=128)
    p[0].axis.visible = False
    p[0].toolbar.logo = None
    #p[0].toolbar_location = None
    p.append(figure(plot_width=124*2, plot_height=128*2, x_range=[0, 128], y_range=[128, 0],  x_axis_location="above",  title=""))
    attention_image = attention_image[::-1]
    p[1].image([attention_image], x=0, y=128, dw=128, dh=128)
    p[1].axis.visible = False
    p[1].toolbar.logo = None
    #p[1].toolbar_location = None
    segmentor = segmentor[::-1]
    p.append(figure(plot_width=124*2, plot_height=128*2, x_range=[0, 128], y_range=[128, 0],  x_axis_location="above",  title=""))
    p[2].image([segmentor], x=0, y=128, dw=128, dh=128)
    
    seg_threshold = 20
    p_attention_clusters = []
    p_attention_boxes = []
    thres_vis = 0.1
    for i in range(1,_+1):#_):
        seg = np.copy(segmentor)
        att = np.copy(attention_image)
        att[seg != i] = 0
        
        im = np.copy(attention_image)
        im[attention_image == 0] = 0
        im[att == 0] = 0
        
        x_ranges = [128, 0]
        y_ranges = [128, 0]
        for x in range(128):
            for y in range(128):
                if np.sum(im[y,x]) > 0:
                    if x > x_ranges[1]:
                        x_ranges[1] = x
                    if 128-y > y_ranges[1]:
                        y_ranges[1] = 128-y
                    if x < x_ranges[0]:
                        x_ranges[0] = x
                    if 128-y < y_ranges[0]:
                        y_ranges[0] = 128-y
        
        p_attention_clusters.append((figure(plot_width=128*2, plot_height=128*2, x_range=[0, 128], y_range=[128, 0],  x_axis_location="above",  title=""),(x_ranges, y_ranges)))
        dd = np.copy(orig)
        dd[im < thres_vis] = 0
        p_attention_clusters[i-1][0].image([im], x=0, y=128, dw=128, dh=128, alpha=0.2) 
        p_attention_clusters[i-1][0].image_rgba([dd], x=0, y=128, dw=128, dh=128, alpha=0.2) 
        p_attention_clusters[i-1][0].axis.visible = False
        p_attention_clusters[i-1][0].toolbar.logo = None
        
        p_attention_clusters[i-1][0].quad(x_ranges[0], x_ranges[1], y_ranges[0], y_ranges[1], fill_alpha=0.5)
        if int(np.sum(im)) < seg_threshold:
            p_attention_clusters[i-1] = None
            
    p_attention_cluster_figures = [x[0] for x in p_attention_clusters if x != None]
    p_attention_cluster_ranges = [x[1] for x in p_attention_clusters if x != None]
    
    
    orig_picture_final = figure(plot_width=128*2, plot_height=128*2, x_range=[0, 128], y_range=[128, 0],  x_axis_location="above",  title="")
    orig_picture_final.image_rgba([orig], x=0, y=128, dw=128, dh=128)
    for i in range(len(p_attention_cluster_ranges)):
        x_ranges, y_ranges = p_attention_cluster_ranges[i]
        orig_picture_final.quad(x_ranges[0], x_ranges[1], y_ranges[0], y_ranges[1],line_color="black",fill_color=(random.randint(0,255), random.randint(0,255), random.randint(0,255), 0.3))
        orig_picture_final.axis.visible = False
        orig_picture_final.toolbar.logo = None
        #orig_picture_final.toolbar_location = None

    orig_vanish = np.copy(orig)
    orig_vanish[attention_image < thres_vis] = 0
    orig_vanish[segmentor == 0] = 0
    
    orig_picture_final_extracted = figure(plot_width=128*2, plot_height=128*2, x_range=[0, 128], y_range=[128, 0],  x_axis_location="above",  title="")
    orig_picture_final_extracted.image_rgba([orig_vanish], x=0, y=128, dw=128, dh=128)
    for i in range(len(p_attention_cluster_ranges)):
        x_ranges, y_ranges = p_attention_cluster_ranges[i]
        orig_picture_final_extracted.quad(x_ranges[0], x_ranges[1], y_ranges[0], y_ranges[1],line_color="black",fill_color=(random.randint(0,255), random.randint(0,255), random.randint(0,255), 0.3))
    
    show(column(row([p[0], p[1], orig_picture_final]),row(*p_attention_cluster_figures)))
    break

# Experimental

In [53]:

import numpy as np

from bokeh.io import curdoc
from bokeh.layouts import row, widgetbox
from bokeh.models.widgets import Slider, TextInput
from bokeh.plotting import figure
from bokeh.io import  push_notebook, show, output_notebook, output_file
from bokeh.layouts import column, layout, row, widgetbox
from bokeh.plotting import figure, curdoc
from bokeh.models import CustomJS
from bokeh.events import ButtonClick
from bokeh.models.widgets import Tabs, Panel, Slider, Button
from bokeh.io import show, output_notebook
from bokeh.application import Application
from bokeh.application.handlers import FunctionHandler, NotebookHandler
from bokeh.models.callbacks import CustomJS
from bokeh.models.sources import ColumnDataSource
output_notebook()

A1 = "0"
s = dict()
source = ColumnDataSource(data=s)

def update_contextual_sample(name, val):
    for k in range(5):
        if zs[k].title == name:
            zs[k].value = float(val)
    return

callback = CustomJS(code="""
    if (IPython.notebook.kernel !== undefined) {
    var kernel = IPython.notebook.kernel;
    var text = ''
    for (i = 0; i < 5; i++) { 
        text += cb_obj.title+"_";
        
    }
    console.log("'"+cb_obj.title+"','"+cb_obj.value+"'");
    cmd = "update_contextual_sample('"+cb_obj.title+"','"+cb_obj.value+"')";
    kernel.execute(cmd, {}, {});}
""")

zs = [Slider(start=-3.0, end=3.0, value=0.0, step=0.1, title="z{}".format(k), callback=callback, callback_policy="mouseup") for k in range(5)]
    
ss = column(zs)
bokeh_handle = show(ss, notebook_handle=True)


# Old Code

In [None]:
class train_VAE():
    
    def load_model(self,epoch=1000):
        
        directory_models =self.directory+"/models/"
        self.VAE.load_state_dict(torch.load(directory_models+"model_"+str(epoch)+".model"))
        self.inference = True
        
    def __init__(self, log=False, model_version="X", model_choice="X",  save_models=True, whole_dataset=False):

        self.sizes = sizes
        self.save_models = save_models
        self.batch_size = batch_size
        self.z_spatial_dim = z_spatial_dim
        self.linear_nodes = linear_nodes
        self.whole_dataset = whole_dataset
        
        self.datasplit = (8,10) # train on 8 out of 10 batches
        self.vis_size = 8
        
        self.trainset = CustomClusteringDataset(sizes=self.sizes, path="/scratch/datasets/NFLsegment/experiments/vpEB_image_dataset/all_images/",subset=1.0 if whole_dataset else 0.01,datasplit = self.datasplit, testing=False)
        self.trainloader = torch.utils.data.DataLoader(dataset=self.trainset, batch_size=batch_size, shuffle=True, num_workers=16)
        self.trainloader_prediction = torch.utils.data.DataLoader(dataset=self.trainset, batch_size=12, shuffle=False, num_workers=8)
                    
        '''
        self.trainset = CustomClusteringDataset(sizes=self.sizes, path="/scratch/datasets/NFLsegment/experiments/vpEB_image_dataset/all_images/", subset=1.0 if whole_dataset else 0.01)
        self.trainloader = torch.utils.data.DataLoader(dataset=self.trainset, batch_size=batch_size, shuffle=True, num_workers=16)
        self.trainloader_prediction = torch.utils.data.DataLoader(dataset=self.trainset, batch_size=12, shuffle=False, num_workers=8)
        '''
        
        historian.log("Loaded dataset with {} sequences of attention maps".format(len(self.trainloader)))

        self.VAE = VAE(sizes=self.sizes,z_spatial_dim=self.z_spatial_dim, linear_nodes=self.linear_nodes).cuda()
        self.VAE.inference = False
        
        self.oracle = Oracle()
        self.epochs = 1000

    def run(self):
        lr_mod = 1.0
        lr_change = 1.0 #0.99#0.995
        lr_base = 1e-3
        wd = 0
        epoch_save_all = 100
        starting_epoch = 1 # should be 1 to train a model from scratch
        
        samples_vary_creation = False
        
        if starting_epoch != 1:
            directory = str(self.z_spatial_dim)+"_"+str(self.linear_nodes)
            
            if self.whole_dataset:
                directory += "_whole_dataset"
            directory_models = "../Research Results/Vanilla VAE results/"+directory+"/models/"
            self.VAE.load_state_dict(torch.load(directory_models+"model_"+str(starting_epoch-1)+".model"))
            
        for epoch in range(starting_epoch,self.epochs+1):

            self.optimizer = optim.Adam(self.VAE.parameters(), weight_decay=wd, lr=lr_base*lr_mod)
            lr_mod *= lr_change
            if epoch % epoch_save_all != 0:
                for i, datum in enumerate(self.trainloader):

                    orig_datum = Variable(datum[0])
                    in_datum = Variable(datum[1].cuda())

                    Z, z_mu, z_var, out_datum = self.VAE(in_datum)

                    loss = self.loss_function(out_datum, in_datum,  z_mu, z_var)

                    self.optimizer.zero_grad()
                    loss.backward()
                    self.optimizer.step()
                    if i == 0:
                        print(epoch, "/", self.epochs, "loss: ",loss.data.cpu().numpy()[0], ", lr=",lr_mod*lr_base)

            else:
                directory = str(self.z_spatial_dim)+"_"+str(self.linear_nodes)
                
                if self.whole_dataset:
                    directory += "_whole_dataset"
                directory_models = "../Research Results/Vanilla VAE results/"+directory+"/models/"
                directory_visuals = "../Research Results/Vanilla VAE results/"+directory+"/visuals/"
                directory_samples = "../Research Results/Vanilla VAE results/"+directory+"/samples/"
                directory_samples_varied = "../Research Results/Vanilla VAE results/"+directory+"/samples_varied/"
                
                pathlib.Path(directory_models).mkdir(parents=True, exist_ok=True) 
                pathlib.Path(directory_visuals).mkdir(parents=True, exist_ok=True)
                pathlib.Path(directory_samples).mkdir(parents=True, exist_ok=True)
                
                if samples_vary_creation:
                    pathlib.Path(directory_samples_varied).mkdir(parents=True, exist_ok=True)
                                    
                if self.save_models:
                    torch.save(self.VAE.state_dict(), directory_models+"model_"+str(epoch)+".model")
                    self.VAE.load_state_dict(torch.load(directory_models+"model_"+str(epoch)+".model"))
                for i, datum in enumerate(self.trainloader_prediction):
                    
                    orig_datum = Variable(datum[0])
                    in_datum = Variable(datum[1].cuda())

                    Z, z_mu, z_var, out_datum = self.VAE(in_datum)
                    
                    if self.batch_size != 1:
                        self.oracle.visualize_tensors([orig_datum.squeeze(0).data.cpu(), in_datum.squeeze(0).data.cpu(), out_datum.data.cpu()],file=directory_visuals+"/viewpoint_result_e"+str(epoch)+"_b"+str(i).zfill(2)) 
                    else:
                        self.oracle.visualize_tensors([orig_datum.data.cpu(), in_datum.data.cpu(), out_datum.data.cpu()],file=directory_visuals+"/viewpoint_result_e"+str(epoch)+"_b"+str(i).zfill(2))
                self.VAE.inference = True
                for i in range(6):
                    sample = Variable(torch.randn(12,self.z_spatial_dim)).cuda()
                    sample = self.VAE.decode(sample)
                    self.oracle.visualize_tensors([sample.data.cpu()],file=directory_samples+"/viewpoint_result_e"+str(epoch)+"_b"+str(i).zfill(2))
                
                if samples_vary_creation:
                    sample = Variable(torch.randn(12,self.z_spatial_dim)).cuda()

                    for i in range(1+int(self.z_spatial_dim/8)):
                       output = self.VAE.decode(sample)
                       for j in range(12):
                           sample[j,i] = sample[j,i] + random.randint(1,30)*1.0 
                       self.oracle.visualize_tensors([output.data.cpu()],file=directory_samples_varied+"/viewpoint_result_e"+str(epoch)+"_b"+str(i).zfill(2))
                
                self.VAE.inference = False

    def loss_function(self, recon_x, x, mu, logvar):
        BCE = F.binary_cross_entropy(recon_x.view(-1,self.sizes[0]*self.sizes[1]), x.view(-1, self.sizes[0]*self.sizes[1]))

        # see Appendix B from VAE paper:
        # Kingma and Welling. Auto-Encoding Variational Bayes. ICLR, 2014
        # https://arxiv.org/abs/1312.6114
        # 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
        KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        # Normalise by same number of elements as in reconstruction
        KLD /= x.size()[0] * self.sizes[0]*self.sizes[1]
        
        historian.log("BCE: {}, KLD: {}".format(BCE.data.cpu().numpy()[0], KLD.data.cpu().numpy()[0]))

        return BCE + KLD
    
### Which model to visualize, sets include {[4,64,512]} and epochs in [100,200,...,1000]

model_choice = 'A'
model_version = "Gilgamesh"
model_type = "Vanilla"
log = False
z_sub = 12 # how many z dimensions are interactive

### ------------------------------------------------------------------------------------
z_sub = min(4,12)

v = VAE_Adventures(log=log, model_version=model_version,model_choice=model_choice, model_type=model_type)
v.load_model()

def visualize_z_sample(**z):
    sample_z = torch.zeros([1,4])
    for i, (arg, val) in enumerate(z.items()):
        if arg != "wimage":
            sample_z[0,i] = val

    v.inference = True
    sample_var = Variable(sample_z).cuda()
    sample = v.VAE.decode(sample_var)
    
    to_pil = transforms.ToPILImage()
    sample_cpu = sample.data.cpu()
    oracle.visualize_tensors([sample_cpu],file="./temp_image")
    wimage.value = open("./temp_image.png", "rb").read()
    display(wimage)
    

def widget_slider(name):
    mystyle = {'handle_color': 'white'}
    w = widgets.FloatSlider(
        value=0,
        min=-3.0,
        max=3.0,
        step=0.1,
        disabled=False,
        continuous_update=False,
        orientation='vertical',
        readout=True,
        style=mystyle)
    
    return w

z_interactive = {'z-' + str(k): widget_slider('z-' + str(k)) for k in range(z_sub)}

wimage =widgets.Image(
    format='png',
    width=512,
    height=512,
)

hlay = Layout(display='flex',
                    flex_flow='row',
                    align_items='stretch',
                    border='solid',
                    width='50%')

w = interactive(visualize_z_sample, **z_interactive)
w.layout = hlay

### This acts as a predictor for the z-mu values that produce every frame

# for i, datum in enumerate(v.trainloader_prediction):
#     orig_datum = Variable(datum[0])
#     in_datum = Variable(datum[1].cuda())
#     v.inference = True
#     Z, z_mu, z_var, out_datum = v.VAE(in_datum)
#     print(Z)

#   Validation set

# validate = True
#
# if validate:
#     for i, datum in enumerate(v.testloader_prediction):
#
#         orig_datum = Variable(datum[0])
#         in_datum = Variable(datum[1].cuda())
#         v.inference = True
#         Z, z_mu, z_var, out_datum = v.VAE(in_datum)
#         loss = v.loss_function(out_datum, in_datum,  z_mu, z_var)
#         print(Z, loss)
        
display(w)