In [2]:
import os
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import pytorch_lightning as pl
import scipy.stats as ss

import numpy as np

from tqdm.notebook import trange

In [149]:
class PUVAE(pl.LightningModule):
    def __init__(self,input_dim, h_dim, z_dim,EMSTEPS=10):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(input_dim, h_dim),nn.ReLU(),
                                     nn.Linear(h_dim, h_dim),nn.ReLU(),
                                     nn.Linear(h_dim, z_dim * 2))
        self.decoder = nn.Sequential(nn.Linear(z_dim, h_dim), nn.ReLU(),
                                     nn.Linear(h_dim, h_dim), nn.ReLU(),
                                     nn.Linear(h_dim, input_dim))
        self.pi = torch.ones(1) * .5
        self.mu = torch.eye(z_dim) 
        self.z_dim = z_dim
        self.EMSTEPS = EMSTEPS
        
        
    def forward(self, x):
        # in lightning, forward defines the prediction/inference actions
        embedding = self.encoder(x)
        return embedding

    def training_step(self, batch, batch_idx):
        # training_step defined the train loop.
        # It is independent of forward
        x, y = batch
        N = x.shape[0]
        x = x.view(x.size(0), -1)
        z_1_2 = self.encoder(x)
        z_1 = z_1_2[...,:self.z_dim] + torch.normal(torch.zeros((N,self.z_dim)),1)
        z_2 = z_1_2[...,self.z_dim:] + torch.normal(torch.zeros((N,self.z_dim)),1)
        z = self.pi * z_1 + (1 - self.pi) * z_2
        zd = z.detach().numpy()
        for em_step in trange(self.EMSTEPS,leave=False):
            #################### E-STEP #######################################
            gamma_1 = self.pi * ss.multivariate_normal.pdf(zd, mean=self.mu[0], cov=np.eye(self.z_dim))
            gamma_2 = (1 - self.pi) * ss.multivariate_normal.pdf(zd, mean=self.mu[1], cov=np.eye(self.z_dim))
            gamma_1_2 = gamma_1 + gamma_2
            gamma_n_1 = gamma_1 / gamma_1_2
            gamma_n_2 = gamma_2 / gamma_1_2
            ################### M-STEP #########################################
            n1 = gamma_n_1.sum()
            n2 = gamma_n_2.sum()
            self.mu[0] = np.dot(gamma_n_1.float(), zd.squeeze()) / n1
            self.mu[1] = np.dot(gamma_n_2.float(), zd.squeeze()) / n2
            self.pi = n1 / (n1 + n2)
        
            
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [151]:
puvae = PUVAE(16,32,2)

In [152]:
puvae.training_step((torch.Tensor(np.random.multivariate_normal(np.zeros(16),np.eye(16),size=100)), None),0)

HBox(children=(FloatProgress(value=0.0, max=10.0), HTML(value='')))

