In [1]:
import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.optim import Adam
import pyro.distributions as dist
import pyro.poutine as poutine

from RAVDESS_dataset_util import *
from EmoClassCNN import *

torch.set_default_dtype(torch.float64)

In [2]:
folder_path = '/home/studenti/ballerini/datasets/RAVDESS_frames'

In [3]:
NUM_CLASSES = len(emocat)
IMG_SIZE = 128
BATCH_SIZE = 8
DEFAULT_Z_DIM = 50

face_dataset = FaceEmotionDataset(root_dir=folder_path,
                                    transform=transforms.Compose([
                                        Rescale(IMG_SIZE), 
                                        CenterCrop(IMG_SIZE), 
                                        ToTensor()
                                    ]))        

trainingset_len = len(face_dataset) // 100 * 10
testset_len = len(face_dataset) - trainingset_len

print('training set size: ', trainingset_len)
print('test set size: ', testset_len)

train_set, test_set = torch.utils.data.random_split(face_dataset, 
                                                    [trainingset_len, testset_len], 
                                                    generator=torch.Generator().manual_seed(42)
                                                   )

trainset_loader = DataLoader(train_set, batch_size=BATCH_SIZE,
                        shuffle=True, num_workers=4)

testset_loader = DataLoader(test_set, batch_size=BATCH_SIZE,
                        shuffle=True, num_workers=4)

dataset_loader = (trainset_loader, testset_loader)

training set size:  720
test set size:  6480


In [4]:
def emotion_rating_conversion(cat):
    ratings = torch.zeros(NUM_CLASSES)
    ratings[cat] = 1
    return ratings
    
#torch.argmax(emotion_rating_conversion(3))

In [5]:
# helper functions
class Swish(nn.Module):
    """https://arxiv.org/abs/1710.05941"""
    def forward(self, x):
        return x * torch.sigmoid(x)

def swish(x):
    return x * torch.sigmoid(x)

In [6]:
class ProductOfExperts(nn.Module):
    """
    Return parameters for product of independent experts.
    See https://arxiv.org/pdf/1410.7827.pdf for equations.

    @param loc: M x D for M experts
    @param scale: M x D for M experts
    """
    def forward(self, loc, scale, eps=1e-8):
        scale = scale + eps # numerical constant for stability
        # precision of i-th Gaussian expert (T = 1/sigma^2)
        T = 1. / scale
        product_loc = torch.sum(loc * T, dim=0) / torch.sum(T, dim=0)
        product_scale = 1. / torch.sum(T, dim=0)
        return product_loc, product_scale

In [7]:
class ImageEncoder(nn.Module):
    """
    define the PyTorch module that parametrizes q(z|image).
    This goes from images to the latent z
    
    This is the standard DCGAN architecture.

    @param z_dim: integer
                  size of the tensor representing the latent random variable z
    """
    def __init__(self, z_dim, img_size):
        super(ImageEncoder, self).__init__()
        #torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, 
        #                padding=0, dilation=1, groups=1, bias=True)
        # H_out = floor( (H_in + 2*padding - dilation(kernel_size-1) -1) / stride    +1)
        self.img_size = img_size
        
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, 1, 1, bias=False),
            Swish(),
            
            nn.Conv2d(32, 64, 3, 1, 1, bias=False),
            Swish(),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2,2),
            
            nn.Conv2d(64, 128, 3, 1, 1, bias=False),
            Swish(),
            nn.BatchNorm2d(128),
            
            nn.Conv2d(128, 256, 3, 1, 1, bias=False),
            Swish(),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2,2)
        )
        
        # Here, we define two layers, one to give z_loc and one to give z_scale
        self.z_loc_layer = nn.Sequential(
            nn.Linear(256 * (self.img_size // 4)**2, 512),
            Swish(),
            nn.Dropout(p=0.1),
            nn.Linear(512, z_dim))
        
        self.z_scale_layer = nn.Sequential(
            nn.Linear(256 * (self.img_size // 4)**2, 512),
            Swish(),
            nn.Dropout(p=0.1),
            nn.Linear(512, z_dim))
        self.z_dim = z_dim

    def forward(self, image):
        hidden = self.features(image)
        hidden = hidden.view(-1, 256 * (self.img_size // 4)**2)
        z_loc = self.z_loc_layer(hidden)
        z_scale = torch.exp(self.z_scale_layer(hidden)) #add exp so it's always positive
        return z_loc, z_scale
    
class ImageDecoder(nn.Module):
    """
    define the PyTorch module that parametrizes p(image|z).
    This goes from the latent z to the images
    
    This is the standard DCGAN architecture.

    @param z_dim: integer
                  size of the tensor representing the latent random variable z
    """
    def __init__(self, z_dim, img_size):
        super(ImageDecoder, self).__init__()
        self.img_size = img_size
        
        self.upsample = nn.Sequential(
            nn.Linear(z_dim, 256 * (self.img_size**2)),
            Swish())
        
        self.hallucinate = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 3, 1, 1, bias=False),
            nn.BatchNorm2d(128),
            Swish(),
            nn.ConvTranspose2d(128, 64, 3, 1, 1, bias=False),
            nn.BatchNorm2d(64),
            Swish(),
            nn.ConvTranspose2d(64, 32, 3, 1, 1, bias=False),
            nn.BatchNorm2d(32),
            Swish(),
            nn.ConvTranspose2d(32, 3, 3, 1, 1, bias=False))

    def forward(self, z):
        # the input will be a vector of size |z_dim|
        z = self.upsample(z)
        z = z.view(-1, 256, self.img_size, self.img_size)
        
        # but if 100x100, the output image size is 96x96
        image = self.hallucinate(z) # this is the image
        return image  # NOTE: no sigmoid here. See train.py

In [8]:
class EmotionEncoder(nn.Module):
    """
    define the PyTorch module that parametrizes q(z|emotion category).
    This goes from ratings to the latent z

    @param z_dim: integer
                  size of the tensor representing the latent random variable z
    """
    def __init__(self, z_dim):
        super(EmotionEncoder, self).__init__()
        self.net = nn.Linear(NUM_CLASSES, 512)
        
        self.z_loc_layer = nn.Sequential(
            nn.Linear(512, 512),
            Swish(),
            nn.Linear(512, z_dim))
        
        self.z_scale_layer = nn.Sequential(
            nn.Linear(512, 512),
            Swish(),
            nn.Linear(512, z_dim))
        self.z_dim = z_dim

    def forward(self, emocat):
        hidden = self.net(emocat)
        z_loc = self.z_loc_layer(hidden)
        z_scale = torch.exp(self.z_scale_layer(hidden))
        return z_loc, z_scale


class EmotionDecoder(nn.Module):
    """
    define the PyTorch module that parametrizes p(emotion category|z).
    This goes from the latent z to the ratings

    @param z_dim: integer
                  size of the tensor representing the latent random variable z
    """
    def __init__(self, z_dim):
        super(EmotionDecoder, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(z_dim, 512),
            Swish())
        
        self.emotion_loc_layer = nn.Sequential(
            nn.Linear(512, 512),
            Swish(),
            nn.Linear(512, len(emocat)))
        
        self.emotion_scale_layer = nn.Sequential(
            nn.Linear(512, 512),
            Swish(),
            nn.Linear(512, NUM_CLASSES))

    def forward(self, z):
        #batch_size = z.size(0)
        hidden = self.net(z)
        emotion_loc = self.emotion_loc_layer(hidden)
        emotion_scale = torch.exp(self.emotion_scale_layer(hidden))
        # rating is going to be a |emotions| * 9 levels
        #rating = h.view(batch_size, EMOTION_VAR_DIM, 9)
        return emotion_loc, emotion_scale  # NOTE: no softmax here. See train.py

In [9]:
class MVAE(nn.Module):
    """
    This class encapsulates the parameters (neural networks), models & guides needed to train a
    multimodal variational auto-encoder.
    Modified from https://github.com/mhw32/multimodal-vae-public
    Multimodal Variational Autoencoder.

    @param z_dim: integer
                  size of the tensor representing the latent random variable z
                  
    Currently all the neural network dimensions are hard-coded; 
    in a future version will make them be inputs into the constructor
    """
    def __init__(self, z_dim, img_size=128, use_cuda=True):
        super(MVAE, self).__init__()
        self.z_dim = z_dim
        self.img_size = img_size
        self.experts = ProductOfExperts()
        self.image_encoder = ImageEncoder(z_dim, img_size)
        self.image_decoder = ImageDecoder(z_dim, img_size)
        self.emotion_encoder = EmotionEncoder(z_dim)
        self.emotion_decoder =EmotionDecoder(z_dim)
        
        self.use_cuda = use_cuda
        # relative weights of losses in the different modalities
        self.LAMBDA_IMAGES = 1.0
        self.LAMBDA_RATINGS = 50.0
        self.LAMBDA_OUTCOMES = 100.0
        
        # using GPUs for faster training of the networks
        if self.use_cuda:
            self.cuda()
            
    def model(self, images=None, emotions=None, annealing_beta=1.0):
        # register this pytorch module and all of its sub-modules with pyro
        pyro.module("mvae", self)
        
        batch_size = 0
        if images is not None:
            batch_size = images.size(0)
        elif emotions is not None:
            batch_size = emotions.size(0)
        
        with pyro.plate("data"):      
            
            # sample the latent z from the (constant) prior, z ~ Normal(0,I)
            z_prior_loc  = torch.zeros(size=[BATCH_SIZE, self.z_dim])
            z_prior_scale = torch.exp(torch.zeros(size=[BATCH_SIZE, self.z_dim]))                
            
            # sample from prior (value will be sampled by guide when computing the ELBO)
            with poutine.scale(scale=annealing_beta):
                z = pyro.sample("z", dist.Normal(z_prior_loc, z_prior_scale))

            # decode the latent code z (image decoder)
            img_loc = self.image_decoder.forward(z)
            
            # score against actual images
            if images is not None:
                with poutine.scale(scale=self.LAMBDA_IMAGES):
                    print('image loc: ', img_loc.shape)
                    print('image shape: ', images.shape)
                    pyro.sample("obs_img", dist.Bernoulli(img_loc), obs=images)
            
            # decode the latent code z (emotion decoder)
            emotion_loc, emotion_scale = self.emotion_decoder.forward(z)
            if emotions is not None:
                with poutine.scale(scale=self.LAMBDA_RATINGS):
                    pyro.sample("obs_emotion", 
                                dist.Normal(emotion_loc, emotion_scale), 
                                obs=emotions)

            # return the loc so we can visualize it later
            return img_loc, emotion_loc
        
    def guide(self, images=None, emotions=None, annealing_beta=1.0):
        # register this pytorch module and all of its sub-modules with pyro
        pyro.module("mvae", self)
        
        batch_size = 0
        if images is not None:
            batch_size = images.size(0)
        elif emotions is not None:
            batch_size = emotions.size(0)
            
        with pyro.plate("data"):
            # use the encoder to get the parameters used to define q(z|x)
                        
            # initialize the prior expert.
            # we initalize an additional dimension, along which we concatenate all the 
            #   different experts.
            # self.experts() then combines the information from these different modalities
            #   by multiplying the gaussians together
            
            z_loc = torch.zeros(torch.Size((1, batch_size, self.z_dim))) + 0.5
            z_scale = torch.ones(torch.Size((1, batch_size, self.z_dim))) * 0.1
            if self.use_cuda:
                z_loc, z_scale = z_loc.cuda(), z_scale.cuda()
                
            if images is not None:
                image_z_loc, image_z_scale = self.image_encoder.forward(images)
                z_loc = torch.cat((z_loc, image_z_loc.unsqueeze(0)), dim=0)
                z_scale = torch.cat((z_scale, image_z_scale.unsqueeze(0)), dim=0)
            
            if emotions is not None:
                emotion_z_loc, emotion_z_scale = self.emotion_encoder.forward(emotions)
                z_loc = torch.cat((z_loc, emotion_z_loc.unsqueeze(0)), dim=0)
                z_scale = torch.cat((z_scale, emotion_z_scale.unsqueeze(0)), dim=0)
            
            z_loc, z_scale = self.experts(z_loc, z_scale)
            # sample the latent z
            with poutine.scale(scale=annealing_beta):
                pyro.sample("latent", dist.Normal(z_loc, z_scale))
                
                
    def forward(self, image=None, emotion=None):
        z_loc, z_scale  = self.infer(image, emotion)
        z = pyro.sample("latent", dist.Normal(z_loc, z_scale).independent(1))
        # reconstruct inputs based on that gaussian
        image_recon = self.image_decoder(z)
        rating_recon = self.emotion_decoder(z)
        return image_recon, rating_recon, z_loc, z_scale
    
    
    def infer(self, images=None, emotions=None):
        batch_size = 0
        if images is not None:
            batch_size = images.size(0)
        elif emotions is not None:
            batch_size = emotions.size(0)
            
        # initialize the prior expert
        # we initalize an additional dimension, along which we concatenate all the 
        #   different experts.
        # self.experts() then combines the information from these different modalities
        #   by multiplying the gaussians together
        z_loc = torch.zeros(torch.Size((1, BATCH_SIZE, self.z_dim))) + 0.5
        z_scale = torch.ones(torch.Size((1, BATCH_SIXE, self.z_dim))) * 0.1
        if self.use_cuda:
            z_loc, z_scale = z_loc.cuda(), z_scale.cuda()

        if images is not None:
            image_z_loc, image_z_scale = self.image_encoder.forward(images)
            z_loc = torch.cat((z_loc, image_z_loc.unsqueeze(0)), dim=0)
            z_scale = torch.cat((z_scale, image_z_scale.unsqueeze(0)), dim=0)

        if emotions is not None:
            emotion_z_loc, emotion_z_scale = self.emotion_encoder.forward(emotions)
            z_loc = torch.cat((z_loc, emotion_z_loc.unsqueeze(0)), dim=0)
            z_scale = torch.cat((z_scale, emotion_z_scale.unsqueeze(0)), dim=0)

        z_loc, z_scale = self.experts(z_loc, z_scale)
        return z_loc, z_scale

    
    # define a helper function for reconstructing images
    def reconstruct_img(self, images):
        # encode image x
        z_loc, z_scale = self.image_encoder(images)
        # sample in latent space
        z = dist.Normal(z_loc, z_scale).sample()
        # decode the image (note we don't sample in image space)
        img_loc = self.image_decoder.forward(z)
        return img_loc

    
    # define a helper function for reconstructing images without sampling
    def reconstruct_img_nosample(self, images):
        # encode image x
        z_loc, z_scale = self.image_encoder(images)
        ## sample in latent space
        #z = dist.Normal(z_loc, z_scale).sample()
        # decode the image (note we don't sample in image space)
        img_loc = self.image_decoder.forward(z_loc)
        return img_loc

In [10]:
pyro.clear_param_store()

class Args:
    learning_rate = 5e-6
    num_epochs = 2 #500
    z_dim = DEFAULT_Z_DIM
    img_size = IMG_SIZE
    seed = 30
    cuda = False
    
args = Args()

# setup the VAE
mvae = MVAE(z_dim=args.z_dim, img_size=args.img_size, use_cuda=args.cuda)

# setup the optimizer
adam_args = {"lr": args.learning_rate}
optimizer = Adam(adam_args)

# setup the inference algorithm
svi = SVI(mvae.model, mvae.guide, optimizer, loss=Trace_ELBO())

In [11]:
'''sample = next(iter(testset_loader))
images = sample['image']
emotions = torch.stack([emotion_rating_conversion(emo) for emo in sample['cat']])
pyro.render_model(mvae.model, model_args=(images, emotions))'''

"sample = next(iter(testset_loader))\nimages = sample['image']\nemotions = torch.stack([emotion_rating_conversion(emo) for emo in sample['cat']])\npyro.render_model(mvae.model, model_args=(images, emotions))"

In [12]:
import time
from tqdm import tqdm

train_elbo = []
trainingTimes = [time.time()]
# training loop
for epoch in range(args.num_epochs):
    # initialize loss accumulator
    epoch_loss = 0.
    # do a training epoch over each mini-batch returned
    # by the data loader
    for batch_num, sample in tqdm(enumerate(trainset_loader)):
        faces, emotions = sample['image'], sample['cat']
        
        emotions = torch.stack([emotion_rating_conversion(emo) for emo in emotions])
        
        # if on GPU put mini-batch into CUDA memory
        if args.cuda:
            faces, ratings, outcomes = faces.cuda(), ratings.cuda(), outcomes.cuda()
        
        # do ELBO gradient and accumulate loss
        #print("Batch: ", batch_num, "out of", len(train_loader))
        epoch_loss += svi.step(images=faces, emotions=emotions)
        epoch_loss += svi.step(images=faces, emotions=None)
        epoch_loss += svi.step(images=None, emotions=emotions)
        epoch_loss += svi.step(images=None, emotions=None)

    # report training diagnostics
    normalizer_train = len(trainset_loader)
    total_epoch_loss_train = epoch_loss / normalizer_train
    train_elbo.append(total_epoch_loss_train)
    
    # report training diagnostics
    trainingTimes.append(time.time())
    epoch_time = trainingTimes[-1] - trainingTimes[-2]
    print("[epoch %03d]  time: %.2f, average training loss: %.4f" % (epoch, epoch_time, total_epoch_loss_train))
    #if ((epoch+1) % 50 == 0):
        #pyro.get_param_store().save('trained_models/checkpoints/tutorial_mvae_pretrained_' + str(epoch) + '.save')
        

0it [00:08, ?it/s]

image loc:  torch.Size([8, 3, 128, 128])
image shape:  torch.Size([8, 3, 128, 128])





ValueError: Expected parameter probs (Tensor of shape (8, 3, 128, 128)) of distribution Bernoulli(probs: torch.Size([8, 3, 128, 128])) to satisfy the constraint Interval(lower_bound=0.0, upper_bound=1.0), but found invalid values:
tensor([[[[-3.9017e-02, -7.2188e-01,  2.9513e-01,  ...,  7.0698e-01,
           -4.0400e-01,  1.1426e+00],
          [ 9.4293e-01, -4.2566e-01, -1.4508e+00,  ...,  4.9947e-01,
           -2.5893e-02,  2.5175e-01],
          [-2.8492e-02,  8.6708e-01, -8.8784e-01,  ..., -3.9668e-01,
           -3.1076e-01, -7.0688e-01],
          ...,
          [ 1.3940e+00, -3.6874e-01, -4.3644e-01,  ...,  5.4321e-01,
            1.7384e-01, -2.0410e-01],
          [-1.0094e+00,  5.3915e-01, -5.9737e-02,  ...,  1.2016e+00,
            3.0995e-01, -2.8569e-01],
          [ 4.6490e-01, -7.2065e-01, -1.7153e+00,  ...,  7.9064e-02,
            3.3914e-01,  2.3666e-01]],

         [[-2.7222e-01,  2.8789e-01,  2.0775e+00,  ...,  2.8436e-01,
            3.6272e-01,  2.2622e-01],
          [ 8.0380e-01, -1.9356e-01,  1.0928e+00,  ...,  2.8926e-01,
            4.8483e-01, -5.4913e-02],
          [-4.5173e-01,  1.3452e-01,  1.2322e+00,  ...,  2.3791e+00,
           -8.5602e-02,  8.2655e-02],
          ...,
          [ 8.3299e-01, -2.5927e-01, -1.6366e+00,  ..., -8.3680e-02,
           -4.8813e-01,  1.5064e+00],
          [ 4.6731e-01, -2.6187e-01, -4.0458e-01,  ...,  1.1724e-01,
           -1.0972e+00,  1.8863e-02],
          [ 4.0596e-01, -2.4951e-01, -1.5650e+00,  ..., -4.2743e-01,
           -5.7526e-02, -9.2211e-01]],

         [[ 9.8827e-02,  3.6070e-01, -7.6037e-01,  ..., -7.8843e-01,
            1.2897e+00,  1.6489e-01],
          [ 3.5030e-01,  6.3007e-01,  3.2396e-01,  ...,  1.2788e-01,
            6.0136e-02,  7.2336e-01],
          [-4.6332e-01,  4.1395e-01, -4.8604e-01,  ..., -5.4677e-02,
           -1.0098e+00,  4.1479e-01],
          ...,
          [-4.2682e-01,  8.7209e-01,  3.5469e-01,  ..., -5.6740e-01,
           -5.8024e-01, -1.1075e+00],
          [ 7.4761e-02,  1.1156e+00,  6.0063e-01,  ..., -4.3437e-01,
           -5.6222e-01,  7.8487e-02],
          [-9.6210e-02,  2.1367e+00,  2.4944e-01,  ...,  5.7961e-01,
           -8.0608e-01,  4.8292e-01]]],


        [[[ 2.5746e-02, -6.8964e-01,  3.0840e-01,  ...,  1.6439e-01,
            5.7820e-01,  8.0057e-01],
          [-1.6396e-01, -4.4783e-01,  7.2353e-01,  ...,  1.2401e-01,
            1.4810e+00, -3.7449e-01],
          [-3.9761e-01,  9.5291e-01, -5.1215e-02,  ..., -6.3679e-01,
            1.1281e+00,  8.2581e-01],
          ...,
          [-1.9110e-02,  7.2256e-01,  3.0355e-01,  ..., -3.8227e-01,
           -3.8176e-01, -5.8665e-01],
          [ 2.3904e-01,  3.2932e-01,  2.6372e-01,  ...,  5.6213e-01,
            8.7643e-01,  3.2533e-01],
          [ 7.6803e-01,  6.2389e-01, -3.6977e-01,  ..., -3.1811e-03,
            6.8925e-01, -1.6380e-01]],

         [[ 4.5422e-01, -1.2126e-01,  1.7723e-03,  ...,  1.4019e-01,
           -2.3005e-01,  8.7588e-02],
          [-1.1691e-01,  6.0980e-01,  6.2077e-01,  ...,  3.9957e-02,
            5.8375e-01,  5.4512e-01],
          [ 9.6048e-02,  3.3861e-01, -1.2552e-01,  ..., -8.0926e-01,
            1.1439e+00,  8.8580e-01],
          ...,
          [ 4.9670e-01, -5.3180e-01,  1.0598e+00,  ..., -4.0450e-01,
            9.5854e-01,  2.3169e-01],
          [-4.2756e-01,  1.3666e-01, -3.2273e-01,  ...,  4.3074e-01,
            5.4341e-01, -1.0114e-01],
          [ 5.4346e-02,  1.0331e+00, -1.1682e-01,  ..., -6.9097e-01,
           -2.8787e-02, -2.3072e-02]],

         [[-2.0442e-01,  2.5657e-01, -3.3531e-01,  ...,  1.2308e-01,
           -6.5697e-01,  2.2010e-02],
          [-8.1277e-01, -1.3558e+00, -2.1419e-01,  ..., -9.4393e-01,
           -8.5543e-01, -9.2960e-01],
          [ 4.6743e-01,  2.6072e-01,  2.9905e-01,  ..., -1.7642e-02,
           -2.4371e+00,  2.9656e-02],
          ...,
          [-6.4899e-01,  1.1591e+00,  1.5636e-01,  ..., -1.1162e+00,
           -6.4924e-01,  9.2095e-02],
          [ 3.9833e-01,  3.8702e-01, -1.1339e-01,  ..., -1.6524e-01,
           -1.2188e+00,  7.1567e-01],
          [-1.3767e-01, -1.6588e-01, -4.0546e-01,  ...,  1.0771e-01,
            7.1223e-01, -4.5778e-01]]],


        [[[ 9.9965e-01, -3.0301e-01,  8.8343e-02,  ..., -3.1927e-01,
            5.9005e-01,  1.8413e-01],
          [ 6.5223e-01, -4.7682e-01, -1.1774e+00,  ..., -3.5193e-02,
           -1.1047e+00, -1.2156e+00],
          [-9.2792e-01,  1.3876e-01,  1.2122e-01,  ...,  8.2273e-01,
           -3.3880e-01,  1.3593e-01],
          ...,
          [ 8.6760e-02,  1.7300e+00, -4.2076e-01,  ..., -2.6218e+00,
            1.0819e-01,  6.8241e-01],
          [ 1.0043e+00,  8.6387e-01,  1.7983e+00,  ..., -8.7246e-01,
            2.0311e+00,  1.0811e+00],
          [ 6.1507e-01, -2.8476e-02, -1.0656e+00,  ..., -5.1055e-01,
           -7.3141e-01, -5.9831e-01]],

         [[ 4.7994e-01, -7.9971e-02,  3.6339e-01,  ...,  1.3162e-02,
            1.8769e-01,  3.3493e-02],
          [-5.1917e-01,  4.4698e-01, -6.3489e-01,  ...,  5.7594e-01,
            2.7531e-01,  7.4175e-01],
          [-1.5960e-01, -1.2524e+00, -5.0536e-01,  ...,  7.3972e-01,
            1.3050e+00,  1.4190e+00],
          ...,
          [ 5.7906e-01, -3.1778e-01, -7.4224e-01,  ...,  7.5666e-01,
           -4.8276e-01,  1.7860e+00],
          [ 9.1892e-03,  2.0764e-01,  8.6349e-03,  ..., -1.0164e-01,
            8.2697e-01,  3.9829e-01],
          [ 1.7038e-01,  4.6167e-01, -2.0694e-01,  ..., -7.7880e-01,
            5.9491e-01,  3.4319e-01]],

         [[ 1.0492e-01, -1.8227e-01, -1.8404e-02,  ...,  6.7217e-01,
           -7.3851e-01, -2.3027e-01],
          [-1.1986e-01, -1.8996e-01,  1.9871e-01,  ...,  8.3738e-01,
           -9.8994e-01, -6.1400e-01],
          [-1.0485e+00,  8.9330e-01,  2.2865e-01,  ...,  3.4802e-01,
           -8.6168e-01,  1.1340e+00],
          ...,
          [-2.3495e-01,  1.9848e-01,  8.0362e-01,  ...,  1.1575e+00,
           -1.2087e+00, -2.8792e-01],
          [ 9.2640e-01,  1.1055e+00,  2.6067e-01,  ..., -1.1524e+00,
            1.1371e+00, -8.4982e-01],
          [ 6.1988e-01,  3.4533e-02,  3.1822e-01,  ...,  7.8782e-01,
           -6.0970e-01, -4.8516e-01]]],


        ...,


        [[[-1.4784e-01,  2.7171e-01, -3.8212e-01,  ...,  1.0399e-01,
           -2.5261e-01,  9.2444e-02],
          [ 1.7680e-01, -6.6754e-01, -7.7115e-02,  ...,  2.1930e+00,
            2.3621e-01,  1.1916e+00],
          [-2.6279e-02, -4.1784e-01, -8.3798e-01,  ...,  8.0678e-02,
            7.1473e-01,  2.9554e-01],
          ...,
          [ 9.0389e-01,  1.9255e-01, -1.5749e-01,  ..., -5.9989e-01,
            1.2586e+00, -6.7517e-01],
          [ 2.7228e-01,  1.6775e+00,  7.5060e-01,  ...,  1.6830e+00,
           -3.8423e-01,  1.4338e-01],
          [ 2.4646e-01, -3.8633e-01, -7.3344e-02,  ...,  1.1645e+00,
            4.7968e-01, -1.2715e-01]],

         [[ 4.3851e-02,  2.6447e-01,  1.1777e-01,  ..., -3.8761e-02,
           -2.0988e-01, -4.7920e-02],
          [ 3.9321e-01,  2.6125e-01,  1.8839e-01,  ...,  5.7798e-01,
            3.9192e-01,  6.7977e-01],
          [ 4.0889e-01,  3.9318e-01,  2.7429e-01,  ..., -1.9893e-01,
            1.2501e+00, -4.4598e-01],
          ...,
          [-7.0468e-01,  1.1383e+00, -9.6469e-01,  ...,  1.2142e+00,
           -5.5559e-01,  6.0136e-01],
          [-4.2415e-01,  7.2701e-01,  2.7348e-01,  ...,  9.7252e-01,
            5.3545e-02,  8.9845e-01],
          [-3.5599e-02, -9.9147e-01,  1.0962e-01,  ..., -4.2500e-01,
            5.9879e-03,  6.6595e-02]],

         [[ 1.3414e-01,  1.5417e-01,  8.2356e-01,  ..., -6.4451e-01,
           -8.2499e-01, -1.1627e-01],
          [ 4.8968e-01,  6.5762e-01,  1.0189e-01,  ..., -5.5968e-01,
           -1.2620e+00, -5.7268e-01],
          [-5.3173e-01, -4.4683e-02,  1.6379e-01,  ..., -1.4492e+00,
           -3.4923e-01,  1.3164e-01],
          ...,
          [-3.4171e-01, -6.5994e-01, -2.3982e-01,  ...,  8.3372e-01,
           -1.5697e-02, -1.0999e+00],
          [-8.5438e-01,  7.9687e-01,  1.7074e-01,  ...,  4.1932e-01,
            4.3002e-01, -5.2091e-01],
          [ 1.2794e-01,  7.3050e-01,  4.6818e-01,  ..., -2.1569e-01,
           -1.1805e-01, -6.0279e-02]]],


        [[[-2.6010e-01,  3.3723e-01, -2.1189e-01,  ...,  5.4050e-01,
            1.1406e+00, -2.3828e-02],
          [ 3.8747e-01, -3.9276e-01, -6.5514e-01,  ...,  4.9900e-02,
           -2.9840e-02,  1.8531e-02],
          [ 7.1674e-01,  2.7109e-01, -1.4508e+00,  ..., -2.6001e+00,
           -1.2457e-01, -3.4741e-01],
          ...,
          [ 6.7617e-01,  7.4334e-01,  8.9434e-01,  ...,  9.9241e-01,
            5.7569e-01,  3.0964e-01],
          [ 3.0716e-01,  9.5590e-01,  1.5069e+00,  ...,  1.3535e+00,
           -5.4177e-01, -3.1954e-01],
          [-5.0877e-01, -1.0374e+00,  1.2553e-01,  ..., -3.2210e-01,
            5.3504e-01, -4.4945e-01]],

         [[-6.6355e-01,  1.5975e-02, -1.1028e+00,  ...,  4.2119e-01,
            1.4300e-01, -1.0500e+00],
          [ 2.9872e-02,  8.8631e-01, -6.1298e-01,  ...,  7.4856e-01,
           -2.5026e-01, -6.9316e-01],
          [-3.7613e-01, -1.6566e+00, -4.7379e-01,  ...,  1.3844e+00,
           -2.1917e-01, -3.4775e-02],
          ...,
          [ 6.4191e-03, -6.4329e-01, -3.4590e-01,  ...,  7.1099e-01,
           -8.6027e-01,  1.9707e+00],
          [ 4.9977e-01,  8.7028e-01, -1.2418e+00,  ...,  4.7148e-02,
           -2.1899e+00,  5.0688e-02],
          [-4.3150e-01, -9.5305e-01,  6.0477e-03,  ...,  4.6361e-02,
            1.2867e+00, -5.2287e-01]],

         [[ 2.1083e-01, -3.4744e-01, -8.8930e-01,  ..., -1.5877e+00,
            7.6081e-01,  8.1642e-01],
          [-7.2559e-01,  1.0759e-03,  5.0080e-01,  ..., -9.0819e-01,
           -1.3664e+00, -5.0009e-01],
          [ 9.8828e-01,  2.0458e-01, -3.2092e-01,  ..., -1.0505e+00,
            7.6977e-01, -1.0954e+00],
          ...,
          [ 3.1315e-01, -1.8013e-01,  6.6712e-03,  ..., -3.4387e-01,
            6.7799e-02, -7.9891e-01],
          [ 9.3795e-01,  8.4810e-01, -3.3065e-02,  ..., -1.4951e+00,
            1.8487e-01, -7.2444e-01],
          [-2.9354e-01,  4.9934e-01,  1.5927e+00,  ...,  6.8559e-01,
           -5.8828e-02,  7.8406e-02]]],


        [[[-2.7788e-01, -5.9287e-01, -1.0850e+00,  ...,  1.1521e-01,
           -1.0350e-01, -1.0359e-01],
          [ 9.7773e-02, -6.6024e-01,  1.3205e+00,  ...,  5.3065e-01,
            1.4022e+00, -1.3554e+00],
          [ 1.2476e-01, -3.5420e-01, -7.5356e-01,  ..., -1.4845e+00,
           -8.2950e-01, -3.1953e-01],
          ...,
          [-2.2498e-01,  1.1432e+00, -6.3102e-02,  ...,  2.2542e+00,
           -4.3919e-02, -4.5381e-01],
          [-1.1767e-01,  3.8476e-01,  2.8890e-01,  ...,  5.0568e-01,
            1.9877e+00, -4.0198e-01],
          [ 8.3461e-01, -3.8473e-01, -6.2184e-01,  ...,  1.2079e-01,
            1.2926e+00, -9.2583e-01]],

         [[ 2.2536e-01, -3.0062e-01,  1.0106e-01,  ..., -1.1121e+00,
            2.6827e-01, -5.5333e-02],
          [ 7.0521e-02, -3.1558e-01,  2.6445e-01,  ...,  1.1511e+00,
            5.9738e-01, -1.5190e-01],
          [-9.1370e-02,  1.0820e+00,  7.4802e-01,  ...,  1.9068e-01,
           -4.5860e-01,  7.4459e-01],
          ...,
          [ 1.8985e-02,  1.2807e+00, -1.5447e+00,  ..., -6.0090e-01,
            7.2441e-01,  8.2941e-01],
          [ 3.5783e-01, -9.7884e-02,  1.9831e+00,  ..., -8.2288e-01,
            3.5600e-01,  1.3519e-01],
          [-3.1872e-01,  5.2968e-01, -1.1851e+00,  ..., -2.3451e-01,
            6.3058e-01,  5.7916e-01]],

         [[-1.9151e-01, -8.0938e-01, -7.4878e-01,  ...,  8.7690e-02,
           -6.5998e-02, -8.7660e-02],
          [-1.4635e-02,  4.4938e-01, -1.3227e+00,  ..., -1.3625e+00,
           -1.4243e+00, -8.3266e-01],
          [ 1.1510e+00,  5.9788e-01, -1.3344e+00,  ..., -2.9452e-01,
           -1.1426e+00,  4.3165e-04],
          ...,
          [ 3.5482e-01,  2.9180e-01, -4.7351e-01,  ..., -1.0602e+00,
           -9.5300e-01, -1.4889e-01],
          [ 1.1904e+00,  3.3002e-01,  7.0450e-01,  ...,  5.8851e-01,
           -1.3167e+00,  1.1598e+00],
          [ 3.0639e-01,  2.4975e-01,  2.0160e-01,  ..., -6.6445e-01,
            7.7924e-01,  3.1040e-01]]]], grad_fn=<ConvolutionBackward0>)
                                      Trace Shapes:                        
                                       Param Sites:                        
             mvae$$$image_encoder.features.0.weight  32   3       3       3
             mvae$$$image_encoder.features.2.weight  64  32       3       3
             mvae$$$image_encoder.features.4.weight                      64
               mvae$$$image_encoder.features.4.bias                      64
             mvae$$$image_encoder.features.6.weight 128  64       3       3
             mvae$$$image_encoder.features.8.weight                     128
               mvae$$$image_encoder.features.8.bias                     128
             mvae$$$image_encoder.features.9.weight 256 128       3       3
            mvae$$$image_encoder.features.11.weight                     256
              mvae$$$image_encoder.features.11.bias                     256
          mvae$$$image_encoder.z_loc_layer.0.weight             512  262144
            mvae$$$image_encoder.z_loc_layer.0.bias                     512
          mvae$$$image_encoder.z_loc_layer.3.weight              50     512
            mvae$$$image_encoder.z_loc_layer.3.bias                      50
        mvae$$$image_encoder.z_scale_layer.0.weight             512  262144
          mvae$$$image_encoder.z_scale_layer.0.bias                     512
        mvae$$$image_encoder.z_scale_layer.3.weight              50     512
          mvae$$$image_encoder.z_scale_layer.3.bias                      50
             mvae$$$image_decoder.upsample.0.weight         4194304      50
               mvae$$$image_decoder.upsample.0.bias                 4194304
          mvae$$$image_decoder.hallucinate.0.weight 256 128       3       3
          mvae$$$image_decoder.hallucinate.1.weight                     128
            mvae$$$image_decoder.hallucinate.1.bias                     128
          mvae$$$image_decoder.hallucinate.3.weight 128  64       3       3
          mvae$$$image_decoder.hallucinate.4.weight                      64
            mvae$$$image_decoder.hallucinate.4.bias                      64
          mvae$$$image_decoder.hallucinate.6.weight  64  32       3       3
          mvae$$$image_decoder.hallucinate.7.weight                      32
            mvae$$$image_decoder.hallucinate.7.bias                      32
          mvae$$$image_decoder.hallucinate.9.weight  32   3       3       3
                  mvae$$$emotion_encoder.net.weight             512       8
                    mvae$$$emotion_encoder.net.bias                     512
        mvae$$$emotion_encoder.z_loc_layer.0.weight             512     512
          mvae$$$emotion_encoder.z_loc_layer.0.bias                     512
        mvae$$$emotion_encoder.z_loc_layer.2.weight              50     512
          mvae$$$emotion_encoder.z_loc_layer.2.bias                      50
      mvae$$$emotion_encoder.z_scale_layer.0.weight             512     512
        mvae$$$emotion_encoder.z_scale_layer.0.bias                     512
      mvae$$$emotion_encoder.z_scale_layer.2.weight              50     512
        mvae$$$emotion_encoder.z_scale_layer.2.bias                      50
                mvae$$$emotion_decoder.net.0.weight             512      50
                  mvae$$$emotion_decoder.net.0.bias                     512
  mvae$$$emotion_decoder.emotion_loc_layer.0.weight             512     512
    mvae$$$emotion_decoder.emotion_loc_layer.0.bias                     512
  mvae$$$emotion_decoder.emotion_loc_layer.2.weight               8     512
    mvae$$$emotion_decoder.emotion_loc_layer.2.bias                       8
mvae$$$emotion_decoder.emotion_scale_layer.0.weight             512     512
  mvae$$$emotion_decoder.emotion_scale_layer.0.bias                     512
mvae$$$emotion_decoder.emotion_scale_layer.2.weight               8     512
  mvae$$$emotion_decoder.emotion_scale_layer.2.bias                       8
                                      Sample Sites:                        
                                             z dist       8      50       |
                                              value       8      50       |