In [1]:
import numpy as np
import pandas as pd
import matplotlib.image as img
import warnings
warnings.filterwarnings('ignore')

import os
import glob
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
random_seed = 42
learning_rate = 0.001
batch_size = 32
epochs = 50

cuda


In [4]:
from tqdm import tqdm
X = []       
labels = [] # this is not used in the training
for path in ['/kaggle/input/cifar10/cifar10/train/*', '/kaggle/input/cifar10/cifar10/test/*']:
    for (i, dirname) in enumerate(glob.glob(path)):
        for filename in tqdm(glob.glob(os.path.join(dirname, '*'))):
            file = img.imread(filename)
            X.append(file)
            labels.append(i)

In [3]:
X = np.array(X)

print(X.shape)
plt.imshow(X[11451,:,:,:])

(0,)


IndexError: too many indices for array: array is 1-dimensional, but 4 were indexed

In [None]:
from torch.utils.data import DataLoader
X = torch.tensor(X)
X = X.permute(0, 3,1,2)
image_loader = DataLoader(X, batch_size = batch_size, shuffle = True)
# We only care about reconstruction, so we don't need to compute accuracy

In [None]:
!pip install torchsummary

In [None]:
from torchsummary import summary
import torch.nn.functional as F

class Reshape(nn.Module):
    def __init__(self, *args):
        super(Reshape, self).__init__()
        self.shape = args
        
    def forward(self, x):
        return x.view(self.shape)
    
class Trim(nn.Module):
    def __init__(self):
        super(Trim, self).__init__()
    
    def forward(self, x):
        return x[:, :, :32, :32]
    
class VAE(nn.Module):
    def __init__(self):
        
        # [(W−K+2P)/S]+1
        super(VAE, self).__init__()
        # Unfortunately, we cannot have maxunpool2d in nn.Sequential, so we have to write them out

        self.conv1 = nn.Conv2d(3, 32, 3, stride = 2, padding = 1)
        self.conv2 = nn.Conv2d(32, 64, 3, stride = 2, padding = 1)
        self.conv3 = nn.Conv2d(64, 128, 3, stride = 2, padding = 1)
        self.flatten = nn.Flatten()
        
        self.linear_mean = nn.Linear(2048, 100)
        self.linear_logvar = nn.Linear(2048, 100)

        self.linear = nn.Linear(100, 2048)
        self.reshape = Reshape(-1, 128, 4, 4)
        self.deconv1 = nn.ConvTranspose2d(128, 64, 3, stride = 2, padding=0)
        self.deconv2 = nn.ConvTranspose2d(64, 32, 3, stride = 2, padding=1)
        self.deconv3 = nn.ConvTranspose2d(32, 3, 3, stride = 2, padding=1)
        self.trim = Trim()
    
    def reparameterized(self, mean, var):
        eps = torch.randn(mean.size(0), mean.size(1)).to(device)
        z = mean + eps * torch.exp(var / 2.)
        return z

    def encode(self, x): # Using silu instead of relu here
        x = F.silu(self.conv1(x))
        x = F.silu(self.conv2(x))
        x = self.conv3(x)
        x = self.flatten(x)
        mean = self.linear_mean(x)
        var = self.linear_logvar(x)
        z = self.reparameterized(mean, var)
        return mean, var, z

    def decode(self, z):
        z = self.linear(z)
        z = self.reshape(z)
        z = F.silu(self.deconv1(z))
        z = F.silu(self.deconv2(z))
        z = F.silu(self.deconv3(z))
        z = self.trim(z)
        z = F.sigmoid(z)
        return z
    
    def forward(self, x):
        mean, var, z = self.encode(x)
        z = self.decode(z)
        return mean, var, z

model = VAE().to(device)
model2 = VAE().to(device)
# print(X.shape)

x = torch.randn(1, 3, 32, 32).to(device)
model(x)
summary(model, (3, 32, 32))

In [None]:
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr = learning_rate)
KL_weight = 0.000075 # The KL divergence is much larger than MAE loss, scaling it down

train_losses_total, train_losses_total_avg = [], []
train_losses_reconstruction, train_losses_reconstruction_avg = [], []
train_losses_KL, train_losses_KL_avg = [], []
for epoch in range(epochs):
    train_loss_total = 0.0
    train_loss_reconstruction = 0.0
    train_loss_KL = 0.0
    for images in tqdm(image_loader):
        images = images.to(device)
        mean, var, outputs = model(images)

        optimizer.zero_grad()
        loss1 = criterion(outputs, images)
        loss2 = torch.mean(-0.5 * torch.sum(1 + var - mean**2 - torch.exp(var),axis=1),axis=0) # sum over latent dimension, not batch
        # -0.5 * torch.sum(1 + var - mean**2 - torch.exp(var), axis = 1).mean()
        # loss2 = torch.atan(loss2) / (np.pi / 2) # scaling it to [0,1]        
        loss = loss1  + KL_weight * loss2
        loss.backward()
        optimizer.step()
        with torch.no_grad():
            train_losses_reconstruction.append(loss1.item())
            train_losses_KL.append(loss2.item())
            train_losses_total.append(loss.item())
            train_loss_reconstruction += loss1.item() * images.size(0)
            train_loss_KL += loss2.item() * images.size(0)
            train_loss_total += loss.item() * images.size(0)
    
    train_loss_total /= len(image_loader)
    train_loss_reconstruction /= len(image_loader)
    train_loss_KL /= len(image_loader)
    
    train_losses_total_avg.append(train_loss_total)
    train_losses_reconstruction_avg.append(train_loss_reconstruction)
    train_losses_KL_avg.append(train_loss_KL)
    
    print('-----------------------------------------------------')
    print(f'Epoch{epoch + 1}')
    print(f'Total loss = {train_loss_total:.3f}')
    print(f'Reconstruction loss = {train_loss_reconstruction:.3f}, KL = {train_loss_KL:.3f}')

    log = {
        "epoch": epoch,
        "model_state": model.state_dict(),
        "optimizer_state": optimizer.state_dict()
    }

    torch.save(log, f'model1_log_{epoch + 1}.pth')

In [None]:
fig, axs = plt.subplots(3, 1, figsize = (6, 15))

train_losses_reconstruction_avg = [np.mean(train_losses_reconstruction[i:i+25]) for i in range(3750)]
axs[0].plot(np.arange(epochs * 1875), train_losses_reconstruction, color = 'b', label = 'mini-batches')
axs[0].plot(np.arange(0, epochs * 1875, 25), train_losses_reconstruction_avg, color = 'r', label = 'running average')
axs[0].set_ylim(0,0.1)
axs[0].set_title('Reconstruction losses')

train_losses_KL_avg = [np.mean(train_losses_KL[i:i+25]) for i in range(3750)]
axs[1].plot(np.arange(epochs * 1875), train_losses_KL, color = 'b', label = 'mini-batches')
axs[1].plot(np.arange(0, epochs * 1875, 25), train_losses_KL_avg, color = 'r', label = 'running average')
axs[1].set_title('KL losses')

train_losses_total_avg = [np.mean(train_losses_total[i:i+20]) for i in range(3750)]
axs[2].plot(np.arange(epochs * 1875), train_losses_total, color = 'b', label = 'mini-batches')
axs[2].plot(np.arange(0, epochs * 1875, 25), train_losses_total_avg, color = 'r', label = 'running average')
axs[2].set_ylim(0,0.1)
axs[2].set_title('Scaled losses')

plt.show()

In [None]:
fig, axs = plt.subplots(2, 6, figsize = (20, 5))
for i in range(6):
    axs[0][i].imshow(X[i+100,:,:,:].permute(1,2,0))
_, _, new_fig = model(X[100:106, :, :, :].to(device))
for i in range(6):
    axs[1][i].imshow(new_fig[i].permute(1,2,0).cpu().detach().numpy())

In [None]:
# Generating 60 images using model

fig, axs = plt.subplots(10,6, figsize = (50, 30))

names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']
sixty_images_latent = torch.randn(60, 100).to(device)
sixty_images = model.decode(sixty_images_latent)
print(sixty_images.shape)
sixty_images = sixty_images.permute(0, 3, 2, 1).cpu().detach().numpy()

for i in range(10):
    for j in range(6):
        idx = (6) * (i-1) + (j)
        axs[i][j].imshow(sixty_images[idx])