In [1]:
import os

import torch
from torch import nn

In [2]:
class VAE(nn.Module):
    def __init__(self, input_dim, hidden_dim=200, z_dim=20):
        super().__init__()
        # Encoder
        self.linear1 = nn.Linear(input_dim, hidden_dim)
        self.linear_mu = nn.Linear(hidden_dim, z_dim)
        self.linear_sigma = nn.Linear(hidden_dim, z_dim)
        
        # Decoder
        self.linear_2h = nn.Linear(z_dim, hidden_dim)
        self.linear_2img = nn.Linear(hidden_dim, input_dim)
        
        # Multipurpose
        self.relu = nn.ReLU() # LeakyReLU
        self.flat = nn.Flatten()
    
    def encode(self, x):
        #q_phi(z|x)
        #x = self.flat(x)
        z = self.linear1(x)
        z = self.relu(z)
        mu = self.linear_mu(z)
        sigma = self.linear_sigma(z)
        
        return mu, sigma
    
    def decode(self, z):
        # p_theta(x|z)
        h = self.linear_2h(z)
        h = self.relu(h)
        img = self.linear_2img(h)
        
        return torch.sigmoid(img)
    
    def forward(self, x):
        mu, sigma = self.encode(x)
        e = torch.randn_like(sigma)
        z_reparametrized = mu+sigma*e
        x_reconstructed = self.decode(z_reparametrized)
        return x_reconstructed, mu, sigma
    

In [3]:
### Test

x = torch.randn(4, 28*28)
vae = VAE(input_dim=28*28)
x_reconstructed, mu, sigma = vae(x)
print(x_reconstructed.size())
print(mu.size())
print(sigma.size())

torch.Size([4, 784])
torch.Size([4, 20])
torch.Size([4, 20])


In [4]:
import numpy as np

def split_indices(size, pct):
    n_val = int(pct*size)
    idxs = np.random.permutation(size)
    return idxs[n_val:], idxs[:n_val]

In [5]:
import torchvision.datasets as datasets
from tqdm import tqdm
from torchvision import transforms
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

device = 'cuda' if torch.cuda.is_available() else 'cpu'
batch_size = 128

dataset = datasets.MNIST(root="./data", train=True, transform=transforms.ToTensor(), download=True)
test_dataset = datasets.MNIST(root="./data", train=False, transform=transforms.ToTensor(), download=True)

# train_indices, val_indices = split_indices(len(dataset), 0.2)
# train_sampler = SubsetRandomSampler(train_indices)
# train_loader = DataLoader(dataset, batch_size, sampler=train_sampler)
# # Validation sampler and data loader
# val_sampler = SubsetRandomSampler(val_indices)
# val_loader = DataLoader(dataset, batch_size, sampler=val_sampler)

train_loader = DataLoader(dataset, batch_size, shuffle=True)


In [6]:
# Training
lr = 1e-4 # Karpathy constant
EPOCHS = 30
INPUT_DIM = 28*28
H_DIM = 200
Z_DIM = 20

model = VAE(INPUT_DIM, H_DIM, Z_DIM)
model = model.to(device)
opt = torch.optim.Adam(model.parameters(), lr=lr) # try probably smaller lr
loss_fn = nn.BCELoss(reduction='sum') # MSE ?
#loss_fn = nn.MSELoss(reduction='sum')

losses = []

for epoch in tqdm(range(EPOCHS)):
    for x, _ in train_loader:
        x = x.to(device)
        x = x.view(x.size(0), INPUT_DIM)
        opt.zero_grad()
        
        out, mu, sigma = model(x)
        reconstruction_loss = loss_fn(out, x)
        #KL_div = -0.5 * torch.sum(1 + torch.log(sigma**2) - mu**2 - sigma**2, dim=1)
        #KL_div = -0.5*torch.sum(1+torch.log(torch.pow(sigma, 2)) - torch.pow(mu, 2) - torch.pow(sigma, 2)) # Minimize KL_div
        KL_div = -0.5 * torch.sum(1 + torch.log(sigma**2) - mu**2 - sigma**2, dim=1).mean()
        
        # backward
        loss = reconstruction_loss + KL_div
        loss.backward()
        opt.step()
        
        losses.append(loss.item())
    
    print("Epoch {}/{} - Loss: {:.4f}".format(epoch+1, EPOCHS, loss.item()))

  3%|██▏                                                                | 1/30 [00:07<03:41,  7.63s/it]

Epoch 1/30 - Loss: 19696.8105


  7%|████▍                                                              | 2/30 [00:17<04:03,  8.69s/it]

Epoch 2/30 - Loss: 15251.9199


 10%|██████▋                                                            | 3/30 [00:26<04:03,  9.04s/it]

Epoch 3/30 - Loss: 13623.1084


 13%|████████▉                                                          | 4/30 [00:36<04:01,  9.27s/it]

Epoch 4/30 - Loss: 12637.5615


 17%|███████████▏                                                       | 5/30 [00:45<03:54,  9.38s/it]

Epoch 5/30 - Loss: 11897.9336


 20%|█████████████▍                                                     | 6/30 [00:55<03:47,  9.48s/it]

Epoch 6/30 - Loss: 11686.8984


 23%|███████████████▋                                                   | 7/30 [01:04<03:38,  9.49s/it]

Epoch 7/30 - Loss: 11085.6543


 27%|█████████████████▊                                                 | 8/30 [01:14<03:30,  9.56s/it]

Epoch 8/30 - Loss: 10304.7412


 30%|████████████████████                                               | 9/30 [01:24<03:21,  9.59s/it]

Epoch 9/30 - Loss: 9971.4238


 33%|██████████████████████                                            | 10/30 [01:34<03:14,  9.72s/it]

Epoch 10/30 - Loss: 10544.6338


 37%|████████████████████████▏                                         | 11/30 [01:44<03:04,  9.74s/it]

Epoch 11/30 - Loss: 10125.8955


 40%|██████████████████████████▍                                       | 12/30 [01:53<02:53,  9.66s/it]

Epoch 12/30 - Loss: 9576.2656


 43%|████████████████████████████▌                                     | 13/30 [02:03<02:44,  9.65s/it]

Epoch 13/30 - Loss: 9359.9150


 47%|██████████████████████████████▊                                   | 14/30 [02:12<02:34,  9.63s/it]

Epoch 14/30 - Loss: 9196.8633


 50%|█████████████████████████████████                                 | 15/30 [02:22<02:23,  9.59s/it]

Epoch 15/30 - Loss: 8872.4971


 53%|███████████████████████████████████▏                              | 16/30 [02:31<02:13,  9.55s/it]

Epoch 16/30 - Loss: 8789.1816


 57%|█████████████████████████████████████▍                            | 17/30 [02:41<02:04,  9.56s/it]

Epoch 17/30 - Loss: 8911.6074


 60%|███████████████████████████████████████▌                          | 18/30 [02:50<01:55,  9.60s/it]

Epoch 18/30 - Loss: 8612.1318


 63%|█████████████████████████████████████████▊                        | 19/30 [03:00<01:45,  9.61s/it]

Epoch 19/30 - Loss: 9237.5732


 67%|████████████████████████████████████████████                      | 20/30 [03:10<01:36,  9.68s/it]

Epoch 20/30 - Loss: 8643.4648


 70%|██████████████████████████████████████████████▏                   | 21/30 [03:20<01:27,  9.71s/it]

Epoch 21/30 - Loss: 8575.8271


 73%|████████████████████████████████████████████████▍                 | 22/30 [03:29<01:17,  9.70s/it]

Epoch 22/30 - Loss: 8474.5576


 77%|██████████████████████████████████████████████████▌               | 23/30 [03:39<01:07,  9.69s/it]

Epoch 23/30 - Loss: 8040.7783


 80%|████████████████████████████████████████████████████▊             | 24/30 [03:49<00:58,  9.67s/it]

Epoch 24/30 - Loss: 8634.2852


 83%|███████████████████████████████████████████████████████           | 25/30 [03:58<00:48,  9.65s/it]

Epoch 25/30 - Loss: 8824.1338


 87%|█████████████████████████████████████████████████████████▏        | 26/30 [04:08<00:38,  9.63s/it]

Epoch 26/30 - Loss: 8142.3940


 90%|███████████████████████████████████████████████████████████▍      | 27/30 [04:17<00:28,  9.60s/it]

Epoch 27/30 - Loss: 8269.7607


 93%|█████████████████████████████████████████████████████████████▌    | 28/30 [04:27<00:19,  9.60s/it]

Epoch 28/30 - Loss: 8251.9854


 97%|███████████████████████████████████████████████████████████████▊  | 29/30 [04:37<00:09,  9.64s/it]

Epoch 29/30 - Loss: 8432.2617


100%|██████████████████████████████████████████████████████████████████| 30/30 [04:46<00:00,  9.57s/it]

Epoch 30/30 - Loss: 7863.8105





In [20]:
count = 1
from torchvision.utils import save_image
from torchvision.models import inception_v3
import torch.nn.functional as F
import numpy as np
from scipy.linalg import sqrtm

original_images = []
generated_images = []

transform = transforms.Compose([
    #transforms.Grayscale(3)
    transforms.Lambda(lambda x: x.repeat(3, 1, 1))
])

model.eval()
with torch.no_grad():
    for x, y in dataset:
        # print(y)
        original = x
        original_images.append(transform(original.squeeze()))
        x = x.view(1, 784)
        x = x.to(device)
        mu, sigma = model.encode(x)
        # Reparametrization
        z = mu + sigma*torch.randn_like(sigma)
        reconstruction = model.decode(z)
        reconstruction = reconstruction.view(-1, 28, 28).squeeze()
        generated_images.append(transform(reconstruction))
        reconstruction = reconstruction.detach().cpu()
#         plt.figure()
#         fig, ax = plt.subplots(1,2)
#         ax[0].imshow(original.squeeze(), cmap='gray')
#         ax[0].set_title('Original')
#         ax[1].imshow(reconstruction, cmap='gray')
#         ax[1].set_title('Reconstructed')
#         plt.show()
        
        # Save the image
        # save_image(reconstruction, f"./data/MNIST_generated_VAE/BCE_1e4_30/generated_{count}.png")
        # save_image(original, f"./data/MNIST_original/original_{count}.png")
        
        count+=1
#         if count == 20:
#             break  

original_images = torch.stack(original_images)
generated_images = torch.stack(generated_images)



In [53]:
from inception import InceptionV3

def calculate_fid(mu1, sigma1, mu2, sigma2):
    diff = mu1-mu2
    covmean = sqrtm(sigma1.dot(sigma2))
    if np.iscomplexobj(covmean):
        covmean = covmean.real
    fid = diff.dot(diff) + np.trace(sigma1+sigma2- 2*covmean)
    return fid

def get_activations(images, model, batch_size=32, dims=2048, device='cuda'):
    model.eval()
    pred_arr = np.empty((len(images), dims))
    with torch.no_grad():
        for i in range(0, len(images), batch_size):
            start = i
            end = i+batch_size
            batch = images[start:end].to(device)
            #print(batch.size())
            #raise
            pred = model(batch)[0]
            pred = F.adaptive_avg_pool2d(pred, output_size=(1,1))
            pred = pred.squeeze(3).squeeze(2).cpu().numpy()
            pred_arr[start:end] = pred
    return pred_arr

#inception_model = inception_v3(pretrained=True, transform_input=False).to(device)
#inception_model.fc = torch.nn.Identity()
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[2048]
inception_model = InceptionV3([block_idx]).to(device)

In [42]:
# Transform to match the inception v3

# transform = transforms.Compose([
#     #transforms.Resize((299,299)),
#     transforms.Lambda(lambda x: x.repeat(3, 1, 1))
# ])

# original_transformed = []
# generated_transformed = []

# for i in range(len(generated_images)):
#     original_transformed.append(transform(original_images[i].squeeze()))
#     generated_transformed.append(transform(generated_images[i].squeeze()))


In [43]:
# original_transformed = torch.stack(original_transformed)
# generated_transformed = torch.stack(generated_transformed)

In [54]:
original_activations = get_activations(original_images, inception_model, device=device)
generated_activations = get_activations(generated_images, inception_model, device=device)

# Calculate mean and covariance of the activations
mu1 = np.mean(original_activations, axis=0)
sigma1 = np.cov(original_activations, rowvar=False)
mu2 = np.mean(generated_activations, axis=0)
sigma2 = np.cov(generated_activations, rowvar=False)

# Calculate FID score
fid_score = calculate_fid(mu1, sigma1, mu2, sigma2)
print('FID score:', fid_score)

FID score: 33.00564486414716
