In [1]:
import torch
from torch import nn
from MSE_VAE2 import Encoder
from MSE_VAE2 import Decoder
from MSE_VAE2 import posteriors
import sklearn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

device = 'cuda'
path = '/home/adarsh/ADRL/datasets/img_align_celeba'
model_path = '/home/adarsh/ADRL/assignment_1/VAE/model_celeba2318.pt'

model = nn.Sequential(
    Encoder(),
    Decoder()
)

model.to(device)
model.load_state_dict(torch.load(model_path))

celebA_transform = transforms.Compose(
        [transforms.Resize(128), transforms.CenterCrop(128) , transforms.ToTensor()])
celebA_dataset = datasets.ImageFolder(
    root=path, transform=celebA_transform)
dataloader = DataLoader(celebA_dataset, batch_size=4096)


  from .autonotebook import tqdm as notebook_tqdm


AttributeError: 'ImageFolder' object has no attribute 'shape'

In [2]:
Z = torch.zeros((0, 128))
x = torch.zeros((0, 49152))

for X,_ in dataloader:
    X = X.to(device)
    mean, std = model[0](X)
    z = posteriors(mean, std, 1)
    bs = X.shape[0]
    Z = torch.cat((Z, z.detach().cpu().view(bs, -1)))
    x = torch.cat((x, X.detach().cpu().view(bs,-1)))

print(Z.shape, x.shape)

torch.Size([202599, 128]) torch.Size([202599, 49152])


In [3]:
from sklearn.decomposition import PCA

z_PCA = PCA(2)
x_PCA = PCA(4)

z_PCA = z_PCA.fit(Z.numpy())
x_PCA = x_PCA.fit(x.numpy())

all_z_PCA = z_PCA.transform(Z.numpy())
print(all_z_PCA.shape)

(202599, 2)


In [4]:
import numpy as np
index = np.random.choice(Z.shape[0], Z.shape[0]//10, replace=False) 

z = all_z_PCA[index]

from sklearn.mixture import GaussianMixture as GMM
gmm = GMM(n_components=10)
gmm.fit(z)
labels = gmm.predict(z)
print(labels.shape)

(20259,)


In [5]:
import matplotlib.pyplot as plt 

# z1 = gmm.sample()[0].squeeze()
# Y = model[1](torch.tensor(z1).float().to(device))
# plt.imshow(Y.squeeze().detach().cpu().numpy().T)
# plt.show()

In [12]:
from torch.distributions.multivariate_normal import MultivariateNormal as MN
import math

def q(z):
    z = z.reshape(1,-1).cpu().detach().numpy()
    z = z_PCA.transform(z)
    out = gmm.score(z)
    return np.exp(out)

def p_theta(z):
    z = z.cpu().detach().numpy()
    z = z_PCA.transform(z)
    p_theta = MN(torch.zeros(z.shape[1]), torch.eye(z.shape[1]))
    out = p_theta.log_prob(torch.tensor(z))

    return np.exp(out)

def p_theta_given_z(x, z):
    z = z.to(device)
    mean = model[1](z.float()).cpu().detach().numpy()

    mean = mean.flatten().reshape(1, -1)
    x = x.cpu().detach().numpy().flatten().reshape(1, -1)
    mean = x_PCA.transform(mean)
    x = x_PCA.transform(x)

    p_theta = MN(torch.tensor(mean), torch.eye(mean.shape[1]))
    log_prob = p_theta.log_prob(torch.tensor(x))

    return np.exp(log_prob)

def get_marginal_likelihood(x_i):
    L = 1000
    sum = 0
    mean, std = model[0](x_i.to(device))
    for i in range(L):
        z = posteriors(mean, std, 1)
        a = q(z)
        b = p_theta(z)
        c = p_theta_given_z(x_i, z)
        sum += a/(b*c)

    return L/sum

for X,_ in dataloader:
    for i in range(10):
        input = X[i].reshape(1, 3, 128, 128)
        print(get_marginal_likelihood(input))
    break


tensor([1.4381e-23])
tensor([2.9066e-24])
tensor([6.0882e-28])
tensor([6.6265e-24])
tensor([2.0169e-25])
tensor([1.3046e-29])
tensor([1.0164e-21])
tensor([5.5556e-22])
tensor([6.4115e-29])
tensor([7.4224e-28])


In [10]:
z = torch.randn(1, 128)
p_theta = MN(torch.zeros(128), torch.eye(128))
out = p_theta.log_prob(z.float())
print("p_theta", out, torch.exp(out))

p_theta tensor([-187.2930]) tensor([0.])
