In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets,transforms
from torch.distributions import Normal, Multinomial

In [None]:
# define a FCNN to map W_DIM to Z_DIM*2*num_clusters, with one hidden layer with 500 neurons

class FCNN(nn.Module):
    def __init__(self, W_DIM, Z_DIM, num_clusters):
        super(FCNN, self).__init__()
        self.fc1 = nn.Linear(W_DIM, 500)
        self.fc2 = nn.Linear(500, Z_DIM*2*num_clusters)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x


In [None]:
class GMVAE(nn.Module):
    def __init__(self, INPUT_DIM, H_DIM, Z_DIM, W_DIM, num_clusters):
        super(GMVAE, self).__init__()

        self.INPUT_DIM = INPUT_DIM # 28*28
        self.H_DIM = H_DIM  # 200
        self.Z_DIM = Z_DIM  # 20
        self.W_DIM = W_DIM  # 150    
        self.num_clusters = num_clusters # 10

        FCNN = FCNN(W_DIM = self.W_DIM, Z_DIM = self.Z_DIM, num_clusters=self.num_clusters)

        # encoder
        self.img_2hid = nn.Linear(INPUT_DIM, H_DIM)
        self.w_2hid = nn.Linear(INPUT_DIM, H_DIM)
        self.hid_2mu = nn.Linear(H_DIM, Z_DIM)
        self.hid_2sigma = nn.Linear(H_DIM, Z_DIM)
        self.hid_2w1 = nn.Linear(H_DIM, W_DIM)
        self.hid_2w2 = nn.Linear(H_DIM, W_DIM)


        # decoder
        self.z_2hid = nn.Linear(Z_DIM, H_DIM)
        self.hid_2img = nn.Linear(H_DIM, INPUT_DIM)

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()
    
    def encode(self, x, w):
        h_x = self.relu(self.img_2hid(x))
        h_w = self.relu(self.w_2hid(w))
        mu = self.relu(self.hid_2mu(h_x))
        sigma = self.relu(self.hid_2sigma(h_x))
        w1 = self.relu(self.hid_2w1(h_w))
        w2 = self.relu(self.hid_2w2(h_w))

        return mu, sigma, w1, w2

    def decode(self, z):
        h = self.relu(self.z_2hid(z))
        return torch.sigmoid(self.hid_2img(h))

    def forward(self, x):
        # Sample w from N(0, I), where I is W_DIMxW_DIM identity matrix
        w = torch.randn((x.size(0), self.W_DIM), device=x.device)
        mu, sigma, w1, w2 = self.encode(x, w)
        # Sample z from Multinomial(1, [1/num_clusters]*num_clusters)
        z = Multinomial(1, torch.tensor([1/self.num_clusters]*self.num_clusters)).sample().to(x.device)
        MUs = FCNN(w1)
        SIGMAs = FCNN(w2)
        


        z, w = self.reparameterize(mu, sigma, w)
        x_reconstructed = self.decode(z)

        # Compute q_phi_x, q_phi_w, p_beta_z
        q_phi_x = Normal(mu, sigma)
        q_phi_w = Normal(w, torch.ones_like(w))
        p_beta_z = Multinomial(1, [1.0/self.num_clusters]*self.num_clusters)
        
        return x_reconstructed, mu, sigma, w, z, q_phi_x, q_phi_w, p_beta_z
    
    
    def reparameterize(self, mu, sigma, w):
        eps_z = torch.randn_like(sigma)
        eps_w = torch.randn_like(w)
        z = mu + eps_z * torch.sqrt(sigma)
        w = w + eps_w * torch.sqrt(torch.ones_like(w))
        return z, w
