In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader

# import SummaryWriter
from torch.utils.tensorboard import SummaryWriter

# import the MNIST dataset
import torchvision
import torchvision.datasets as datasets
from torchvision import transforms

import torchsummary

from sklearn.model_selection import train_test_split

# from torchdiffeq import odeint, odeint_adjoint

import cv2
import numpy as np

from matplotlib.animation import FuncAnimation
import matplotlib.pyplot as plt

from tqdm.notebook import trange

# use the ball dataset
from utils import create_gaussian_dataset, add_spatial_encoding, stack_dataset

In [3]:
dataset = create_gaussian_dataset(r_min=3.0, r_max=3.0, n_samples=50000, size=28, margin=2, n_balls=1)
dataset = add_spatial_encoding(dataset)
# dataset = stack_dataset(dataset)

train_dataset, test_dataset = train_test_split(dataset, test_size=0.3)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)

In [2]:
# Download the MNIST dataset if it is not present
mnist_train = datasets.MNIST('./data', train=True, download=True, transform=transforms.ToTensor())
mnist_test = datasets.MNIST('./data', train=False, download=True, transform=transforms.ToTensor())

train_dataset = stack_dataset(mnist_train)
test_dataset = stack_dataset(mnist_test)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=True)


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


In [4]:
dec_in_channels = 1
n_latent = 7
reshaped_dim = [-1, 7, 7, dec_in_channels]
inputs_decoder = int(49 * dec_in_channels / 2)
print(inputs_decoder)

class encoder(nn.Module):
    def __init__(self, keep_prob):
        super(encoder, self).__init__()
        activation = nn.LeakyReLU(0.3)
        # remove the classification layer (last layer)
        self.encoder_model = nn.Sequential(*list(torchvision.models.resnet18(pretrained=True).children())[:-1], 
                                            nn.Flatten())
        self.mean = nn.Linear(512, n_latent)
        self.sd = nn.Linear(512, n_latent)

    def forward(self, X_in):
        latent = self.encoder_model(X_in)
        # print(latent.shape)
        mn = self.mean(latent)
        sd = self.sd(latent)*0.5
        epsilon = torch.randn(mn.shape)
        z = mn + epsilon * torch.exp(sd)
        return z, mn, sd

class decoder(nn.Module):
    def __init__(self):
        super(decoder, self).__init__()
        self.linear = nn.Sequential(
            nn.Linear(n_latent, inputs_decoder),
            nn.LeakyReLU(0.3),
            nn.Linear(inputs_decoder, inputs_decoder * 2 + 1),
            nn.LeakyReLU(0.3))
        self.decoder_model = nn.Sequential(
            nn.ConvTranspose2d(1, 64, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 64, 4, 2, 1),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 3, 3, 1, 1),
            nn.ReLU(),
            nn.Flatten()
        )
        self.last_linear = nn.Sequential(
            nn.Linear(3*28*28, 3*28*28),
            nn.Sigmoid()
        )
    def forward(self, sampled_z):
        decoded = self.linear(sampled_z)
        # print(decoded.shape)
        decoded = decoded.view(-1, 1, 7, 7)
        # print(decoded.shape)
        img = self.decoder_model(decoded)
        # print("out decoder", img.shape)
        img = self.last_linear(img)
        # print(img.shape)
        img = img.view(-1, 3, 28, 28)
        return img


class VAE(nn.Module):
    def __init__(self, keep_prob):
        super(VAE, self).__init__()
        self.encoder = encoder(keep_prob)
        self.decoder = decoder()

    def forward(self, X_in):
        z, mn, sd = self.encoder(X_in)
        img = self.decoder(z)
        return img, mn, sd

    def fit(self, dataloader_train, dataloader_test, optimizer, scheduler, criterion, epochs=10, display_step=1, n_plot=10):
        iterator = trange(epochs)
        losses_train = []
        losses_test = []
        for _ in iterator:
            self.plot(n_samples=n_plot)
            self.reconstruction(dataloader_test)
            
            loss_epoch = 0
            self.train()
            for i, data in enumerate(dataloader_train):
                input_image, _ = data
                input_image = input_image.float()
                optimizer.zero_grad()
                img, mn, sd = self(input_image)
                loss = criterion(input_image, img, mn, sd)
                loss.backward()
                optimizer.step()
                loss_epoch += loss.item()
                if i % display_step == 0:
                    iterator.set_description(f'Batch: {i}/{len(dataloader_train)}, Loss: {loss_epoch/(i+1):.6f}')
                

            losses_train.append(loss_epoch/len(dataloader_train))
            self.eval()
            with torch.no_grad():
                loss_epoch = 0
                for i, data in enumerate(dataloader_test):
                    input_image, _ = data
                    input_image = input_image.float()
                    img, mn, sd = self(input_image)
                    loss = criterion(input_image, img, mn, sd)
                    loss_epoch += loss.item()
                    if i % display_step == 0:
                        iterator.set_postfix_str(f'Test Batch: {i}/{len(dataloader_test)}, Loss: {loss_epoch/(i+1):.6f}')

                losses_test.append(loss_epoch/len(dataloader_test))
            
            scheduler.step()

        return losses_train, losses_test

    def generate(self, n_samples=1):
        self.eval()
        with torch.no_grad():
            z = torch.randn(n_samples, n_latent)
            img = self.decoder(z)
            return img

    def plot(self, n_samples=1):
        self.eval()
        imgs = self.generate(n_samples)
        imgs = imgs[:,0].cpu().detach().numpy()
        imgs = np.reshape(imgs, (n_samples, 28, 28))
        fig, ax = plt.subplots(figsize=(10,5))
        for i in range(n_samples):
            plt.subplot(1, n_samples, i+1)
            plt.imshow(imgs[i], cmap='gray')
            plt.axis('off')
        plt.show()

    def reconstruction(self, plot_loader):
        self.eval()
        for i, data in enumerate(plot_loader):
            input_image, _ = data
            input_image = input_image.float()
            batch_size = input_image.shape[0]
            img, _, _ = self(input_image)
            img = img[:,0].cpu().detach().numpy()
            img = np.reshape(img, (batch_size, 28, 28))
            input_image = input_image[:,0].cpu().detach().numpy()
            input_image = np.reshape(input_image, (batch_size, 28, 28))

            fig, ax = plt.subplots(figsize=(20,2))
            for i in range(batch_size):
                plt.subplot(2, batch_size, i+1)
                plt.imshow(img[i], cmap='gray')
                plt.axis('off')
                
            for i in range(batch_size):
                plt.subplot(2, batch_size, i+1+batch_size)
                plt.imshow(input_image[i], cmap='gray')
                plt.axis('off')
            plt.show()
            break
                

class custom_loss(nn.Module):
    def __init__(self):
        super(custom_loss, self).__init__()
    
    def forward(self, x, dec, mu, sd):
        # dec = dec[:,0]
        # x = x[:,0]
        unreshaped = torch.reshape(dec, [-1, 28*28])
        x = torch.reshape(x, [-1, 28*28])
        # print(unreshaped.shape)
        img_loss = torch.mean(torch.sum((unreshaped - x)**2, dim=1))
        # print(img_loss.shape)
        latent_loss = -0.5 * torch.mean(1 + 2*sd - mu**2 - torch.exp(2*sd), dim=1)
        # print(latent_loss.shape)
        loss = torch.mean(img_loss + latent_loss)
        # print(loss.shape)
        return loss

loss_fn = custom_loss()

24


In [5]:
vae = VAE(keep_prob=0.2)

In [6]:
# test_input = torch.zeros(3, 3, 28, 28)
# print(len(vae(test_input)), vae(test_input)[0].shape, vae(test_input)[1].shape, vae(test_input)[2].shape)  


In [7]:
optimizer = torch.optim.Adam(vae.parameters(), lr=1e-2)
# optimmizer = torch.optim.RMSprop(vae.parameters(), lr=1e-5)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.9999)

In [1]:
# print(vae.generate(1).shape)
vae.fit(train_loader, test_loader, optimizer, scheduler, loss_fn, epochs=100, display_step=10, n_plot=2)

NameError: name 'vae' is not defined

In [71]:
# torch.save(vae.state_dict(), 'models/VAE/vae_resnet18_1_Ball_latent_4.pth')