In [17]:
# -*- coding: utf-8 -*-
"""
Created on Tue Apr 16 19:05:29 2019

@author: karm2204
"""

"""
References:
"""
#%%

# https://towardsdatascience.com/model-summary-in-pytorch-b5a1e4b64d25
# https://github.com/pytorch/examples/blob/master/mnist/main.py
# https://discuss.pytorch.org/t/text-autoencoder-nan-loss-after-first-batch/22730
# https://discuss.pytorch.org/t/understanding-output-padding-cnn-autoencoder-input-output-not-the-same/22743
# https://github.com/rtqichen/beta-tcvae/blob/master/vae_quant.py
# https://github.com/pytorch/examples/blob/master/vae/main.py
# https://www.groundai.com/project/isolating-sources-of-disentanglement-in-variational-autoencoders/
# https://arogozhnikov.github.io/einops/pytorch-examples.html

# Note   :    https://www.cs.toronto.edu/~lczhang/360/lec/w03/convnet.html
#%%

import argparse
import torch
import torch.nn as nn 
from torch import cuda
import torch.utils.dataF
from torch import optim, autograd
from torch.nn import functional as F
import torchvision
from torch.utils.data import dataset
from torchvision import datasets, transforms
from torchvision.utils import save_image
import numpy as np
import matplotlib.pyplot as plt
from GAN_q3 import Generator


batch_size = 64

#%%    
image_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((.5, .5, .5),
                         (.5, .5, .5))
])
    
def get_data_loader(dataset_location, batch_size):
    trainvalid = torchvision.datasets.SVHN(
        dataset_location, split='train',
        download=True,
        transform=image_transform
    )

    trainset_size = int(len(trainvalid) * 0.9)
    trainset, validset = dataset.random_split(
        trainvalid,
        [trainset_size, len(trainvalid) - trainset_size]
    )

    train = torch.utils.data.DataLoader(
        trainset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2
    )

    valid = torch.utils.data.DataLoader(
        validset,
        batch_size=batch_size,
    )

    test = torch.utils.data.DataLoader(
        torchvision.datasets.SVHN(
            dataset_location, split='test',
            download=True,
            transform=image_transform
        ),
        batch_size=batch_size,
    )

    return train, valid, test

#%%

def get_data_loader_1(dataset_location, batch_size):
    trainvalid = torchvision.datasets.SVHN(
        dataset_location, split='train',
        download=True,
        transform=image_transform
    )

    trainset_size = int(len(trainvalid) * 0.9)
    trainset, validset = dataset.random_split(
        trainvalid,
        [trainset_size, len(trainvalid) - trainset_size]
    )

    train = torch.utils.data.DataLoader(
        trainset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=2
    )

    valid = torch.utils.data.DataLoader(
        validset,
        batch_size=batch_size,
    )

    test = torch.utils.data.DataLoader(
        torchvision.datasets.SVHN(
            dataset_location, split='test',
            download=True,
            transform=image_transform_1
        ),
        batch_size=batch_size,
    )

    return train, valid, test

image_transform_1 = transforms.Compose([
    transforms.ToTensor()
])

#%%
def imshow(img):
    img = 0.5*(img + 1)
    npimg = img.numpy()
    # npimg = (255*npimg).astype(np.uint8) # to be a int in (0,...,255)
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()                 
                    
#%% 
class View(nn.Module):
    def __init__(self, shape, *shape_):
        super().__init__()
        if isinstance(shape, list):
            self.shape = shape
        else:
            self.shape = (shape,) + shape_      
            
def forward(self, x):
        return x.view(self.shape)
                    
#%%                
# https://discuss.pytorch.org/t/text-autoencoder-nan-loss-after-first-batch/22730
                    
#%% 
class convencoder(nn.Module):
    def __init__(self, latent_dim=100):
        super(convencoder, self).__init__()

        self.latent_dim = latent_dim

        self.conv = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3),
            nn.ELU(),
            nn.AvgPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3),
            nn.ELU(),
            nn.AvgPool2d(kernel_size = 2, stride = 2),
            nn.Conv2d(64, 256, kernel_size=6),
            nn.ELU()
        )

        self.final = nn.Linear(256, self.latent_dim * 2)

    def forward(self, x):
        x = self.conv(x)
        x = x.view(x.size(0), -1)
        x = self.final(x)
        # Return mu and logvar
        return x[..., self.latent_dim:], x[..., :self.latent_dim]         
  
#%% 
class convdecoder(Generator): # call generator from GAN

    def forward(self, x):
        x = x.view(x.size(0), -1, 1, 1)
        x = self.main(x)
        return x
#%%
class VAE(nn.Module):
    def __init__(self):
        super(VAE, self).__init__()

        self.encoder = convencoder()
        self.decoder = convdecoder()

    def forward(self, x):
        mu, logvar = self.encoder(x)
#        mu, logvar = convencoder[:, :self.latent_dim], convencoder[:, self.latent_dim:]
        std = torch.exp(logvar / 2)
        z = mu + std * torch.randn_like(std)
        return self.decoder(z), mu, logvar

def ELBO(output, target, mu, logvar):
    elbo = -torch.nn.functional.mse_loss(output, target, reduction='sum')
    elbo += 0.5 * torch.sum(1 + logvar - mu.pow(2) - torch.exp(logvar))
    return elbo / output.size(0)

# https://www.groundai.com/project/isolating-sources-of-disentanglement-in-variational-autoencoders/
#%%
def visual_samples(vae, dimensions, device, svhn_loader):
    z = torch.randn(64, dimensions, device = device)
    generated = vae.decoder(z)
    torchvision.utils.save_image(generated, 'images/vae/3.1vae-generated.png', normalize=False)
    
#%%
def disentangled_representation(vae, dimensions, device, epsilon = 3):
    z = torch.randn(dimensions, device = device)
    z = z.repeat(dimensions+1, 1)
    for i, sample in enumerate(z[1:]):
        sample[i] += epsilon

    generated = vae.decoder(z)
    torchvision.utils.save_image(generated, 'images/vae/3_2positive_eps.png', normalize=False)
    epsilon = -2*epsilon
    for i, sample in enumerate(z[1:]):
        sample[i] += epsilon

    generated = vae.decoder(z)
    torchvision.utils.save_image(generated, 'images/vae/3_2negative_eps.png', normalize=False)

#%%
    
def interpolation(vae, dimensions, device):
    # Interpolate in the latent space between z_0 and z_1
    z_0 = torch.randn(1,dimensions, device=device)
    z_1 = torch.randn(1,dimensions, device=device)
    z_a = torch.zeros([11,dimensions], device=device)

    for i in range(11):
        a = i/10
        z_a[i] = a*z_0 + (1-a)*z_1

    generated = vae.decoder(z_a)
    torchvision.utils.save_image(generated, 'images/vae/3_3latent.png', normalize=False)
    
    # Interpolate in the data space between x_0 and x_1
    x_0 = vae.decoder(z_0)
    x_1 = vae.decoder(z_1)
    x_a = torch.zeros(11,x_0.size()[1],x_0.size()[2],x_0.size()[3], device = device)

    for i in range(11):
        a = i/10
        x_a[i] = a*x_0 + (1-a)*x_1

    torchvision.utils.save_image(x_a, 'images/vae/3_3data.png', normalize=False)   
    
    
#%%
def save_images(img_dir: str):
    import os
    vae = VAE()
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    vae.load_state_dict(torch.load('VAE_q3_save.pth', map_location=device))
    vae = vae.to(device)
    vae.eval()
    
    for p in vae.parameters():
        p.requires_grad = False
        os.makedirs(f"{img_dir}/img/", exist_ok=True)
    for i in range(10):
        print(i)
        latents = torch.randn(100, 100, device=device)
        images = vae.decoder(latents)
        for j, image in enumerate(images):
            filename = f"images/vae/fid/img/{i * 100 + j:03d}.png"
            torchvision.utils.save_image(image, filename, normalize=True)
            
#%%        
if __name__ == "__main__":
    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    print("Let's use {}".format(device))
    vae = VAE()
    vae = vae.to(device)
    running_loss = 0
    optimizer = optim.Adam(vae.parameters(), lr=3e-4)
    train, valid, test = get_data_loader("svhn", batch_size = 64)
    try: 
        vae.load_state_dict(torch.load('VAE_q3_save.pth', map_location=device))
        print('----Using saved model----')
    except FileNotFoundError:
        for epoch in range(20):
            print(f"------- EPOCH {epoch} --------")
            for i, (x, _) in enumerate(train):
                vae.train()
                optimizer.zero_grad()
                x = x.to(device)
                y, mu, logvar = vae(x)
                loss = -ELBO(y, x, mu, logvar)
                running_loss += loss
                loss.backward()
                optimizer.step()
                if(i%10 == 0):
                    visual_samples(vae, 100, device, test)

                if (i + 1) % 100 == 0:
                    print(f"Training example {i + 1} / {len(train)}. Loss: {running_loss}")
                    running_loss = 0

        torch.save(vae.state_dict(), 'VAE_q3_save.pth')

    dimensions = 100
    
    
    visual_samples(vae, dimensions, device, test)
    disentangled_representation(vae, dimensions, device, epsilon=10)
    interpolation(vae, dimensions, device)
    
    img_dir = "images/vae/fid"
    save_images(img_dir)


Let's use cuda:0
Using downloaded and verified file: svhn/train_32x32.mat
Using downloaded and verified file: svhn/test_32x32.mat
------- EPOCH 0 --------
Training example 100 / 1031. Loss: 38834.5625
Training example 200 / 1031. Loss: 17966.80078125
Training example 300 / 1031. Loss: 15262.626953125
Training example 400 / 1031. Loss: 13869.8291015625
Training example 500 / 1031. Loss: 13133.416015625
Training example 600 / 1031. Loss: 12274.4892578125
Training example 700 / 1031. Loss: 12025.474609375
Training example 800 / 1031. Loss: 11558.083984375
Training example 900 / 1031. Loss: 11418.11328125
Training example 1000 / 1031. Loss: 11212.109375
------- EPOCH 1 --------
Training example 100 / 1031. Loss: 14517.662109375
Training example 200 / 1031. Loss: 10933.8408203125
Training example 300 / 1031. Loss: 10979.2841796875
Training example 400 / 1031. Loss: 10837.2236328125
Training example 500 / 1031. Loss: 10604.6171875
Training example 600 / 1031. Loss: 10609.8515625
Training exa

Training example 400 / 1031. Loss: 8472.234375
Training example 500 / 1031. Loss: 8408.7021484375
Training example 600 / 1031. Loss: 8401.4755859375
Training example 700 / 1031. Loss: 8391.625
Training example 800 / 1031. Loss: 8531.9501953125
Training example 900 / 1031. Loss: 8502.255859375
Training example 1000 / 1031. Loss: 8483.0361328125
------- EPOCH 16 --------
Training example 100 / 1031. Loss: 10969.6826171875
Training example 200 / 1031. Loss: 8355.63671875
Training example 300 / 1031. Loss: 8398.3671875
Training example 400 / 1031. Loss: 8413.916015625
Training example 500 / 1031. Loss: 8413.9521484375
Training example 600 / 1031. Loss: 8429.25390625
Training example 700 / 1031. Loss: 8469.0859375
Training example 800 / 1031. Loss: 8436.5732421875
Training example 900 / 1031. Loss: 8329.845703125
Training example 1000 / 1031. Loss: 8384.6884765625
------- EPOCH 17 --------
Training example 100 / 1031. Loss: 11096.28515625
Training example 200 / 1031. Loss: 8315.9072265625
T