In [1]:
import scipy
import numpy as np
import torch 
import torch.nn as nn
data = np.load('data/bouncing_mnist_test.npy')

In [3]:
class VRNN(nn.Module):

    def __init__(self, input_size, hidden_size, latent_size, RNN_dim):

        super(VRNN, self).__init__()

        self.input_size = input_size
        self.hidden_size = hidden_size
        self.latent_size = latent_size
        self.rnn_dim = RNN_dim

        self.encoder = nn.Sequential(
            nn.Linear(input_size + RNN_dim, hidden_size),
            nn.Tanh()
        )
        self.enc_mean = nn.Linear(hidden_size, latent_size)
        self.enc_logvar = nn.Linear(hidden_size, latent_size)

        self.prior_encoder = nn.Sequential(
            nn.Linear(RNN_dim, hidden_size),
            nn.Tanh()
        )
        self.prior_mean = nn.Linear(hidden_size, latent_size)
        self.prior_logvar = nn.Linear(hidden_size, latent_size)

        self.decoder = nn.Sequential(
            nn.Linear(latent_size + RNN_dim, hidden_size),
            nn.Tanh()
        )
        self.dec_mean = nn.Linear(hidden_size, input_size)
        self.dec_logvar = nn.Linear(hidden_size, input_size)

        self.rnn = nn.RNN(input_size + latent_size, RNN_dim) # could try with a GRU or an LSTM

    def reparametrization(self, mean, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mean + std * eps
    

    def forward(self,x):

        seq_len, batch_size,_ = x.shape

        # create variable holder
        self.z_mean = torch.zeros(seq_len, batch_size, self.latent_size)
        self.z_logvar = torch.zeros(seq_len, batch_size, self.latent_size)
        self.z = torch.zeros(seq_len, batch_size, self.latent_size)

        h = torch.zeros((seq_len, batch_size, self.rnn_dim))
        z_t = torch.zeros((batch_size, self.latent_size))
        h_t = torch.zeros((batch_size, self.rnn_dim)) # the initial hidden state coudl be different from 0 ... (random ?)

        y = torch.zeros((seq_len, batch_size, self.y_dim)) # observation reconstruction
        
        for t in range(seq_len):

            # encode
            x_t = x[t,:,:]
            encoder_input = torch.cat((x_t,h_t),dim=1)
            encoder_output = self.encode(encoder_input)
            mean_zt, logvar_zt = self.enc_mean(encoder_output), self.enc_logvar(encoder_output)

            # sample the latent variable
            z_t = self.reparametrization(mean_zt, logvar_zt)

            # decode
            decoder_input = torch.cat((z_t,h_t),dim=1)
            decoder_output = self.decode(decoder_input)
            mean_xt, logvar_xt = self.dec_mean(decoder_output), self.dec_logvar(decoder_output)
            y_t = mean_xt # for now we will not sample from the decoder

            # update the hidden state
            rnn_input = torch.cat((x_t,z_t),dim=1)
            h_t = self.rnn(rnn_input.unsqueeze(0),h_t.unsqueeze(0))[1].squeeze(0) # looooool

            # save variable
            y[t,:,:] = y_t
            self.z_mean[t,:,:] = mean_zt
            self.z_logvar[t,:,:] = logvar_zt
            self.z[t,:,:] = z_t
            h[t,:,:] = h_t

        # generation of the latent variable z prior (for the KL divergence)
        prior_encoder_output = self.prior_encoder(h)
        self.z_prior_mean = self.prior_mean(prior_encoder_output)
        self.z_prior_logvar = self.prior_logvar(prior_encoder_output)

        return y
