In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from torch.utils.data import Dataset, DataLoader


import numpy as np
from typing import List, Tuple
from collections import OrderedDict

from tqdm import tqdm

In [None]:
torch.cuda.is_available()

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
# The first part of this notebook will be dedicated to the basic VAE architecture

# The second part will be dedicated to the VAE with skip-connections

# both model will be tested by passing [64, 3, 64, 64] torch.randn tensor

# General VAEs

## Model

In [None]:
class Conv2dBlock(nn.Module):

    def __init__(self,
                conv_filters_in: int,
                conv_filters_out: int,
                conv_kernels: Tuple[int],
                conv_strides: Tuple[int],
                paddings: Tuple[int],
                dilations: Tuple[int],
                **kwargs):
        
        super(Conv2dBlock, self).__init__(**kwargs)

        self.conv_filters_in = conv_filters_in
        self.conv_filters_out = conv_filters_out
        self.conv_kernels = conv_kernels
        self.conv_strides = conv_strides
        self.paddings = paddings
        self.dilations = dilations


        self.conv2d = nn.Conv2d(in_channels=self.conv_filters_in, 
                                out_channels=self.conv_filters_out, 
                                kernel_size=self.conv_kernels,
                                stride=self.conv_strides,
                                padding=self.paddings,
                                dilation=self.dilations)

        self.maxpool2d = nn.MaxPool2d(kernel_size=self.conv_kernels, 
                                    stride=self.conv_strides,
                                    padding=self.paddings,
                                    dilation=self.dilations)

        self.batchnorm = nn.BatchNorm2d(num_features=self.conv_filters_out)


    def forward(self, x):

        x = self.conv2d(x)

        x = self.maxpool2d(x)

        x = self.batchnorm(x)

        x = F.relu(x)

        return x



In [None]:
class ConvTranspose2dBlock(nn.Module):

    def __init__(self, 
                conv_filters_in: int,
                conv_filters_out: int,
                conv_kernels: Tuple[int],
                conv_strides: Tuple[int],
                paddings: Tuple[int],
                output_paddings: Tuple[int],
                dilations: Tuple[int],
                **kwargs):

        super(ConvTranspose2dBlock, self).__init__(**kwargs)

        self.conv_filters_in = conv_filters_in
        self.conv_filters_out = conv_filters_out
        self.conv_kernels = conv_kernels
        self.conv_strides = conv_strides
        self.paddings = paddings
        self.output_paddings = output_paddings
        self.dilations = dilations

        self.convtranspose2d = nn.ConvTranspose2d(in_channels=self.conv_filters_in, 
                                                out_channels=self.conv_filters_out, 
                                                kernel_size=self.conv_kernels,
                                                stride=self.conv_strides,
                                                padding=self.paddings,
                                                output_padding=self.output_paddings,
                                                dilation=self.dilations)

        self.upsampling = nn.UpsamplingBilinear2d(scale_factor=self.conv_strides[0])


        self.batchnorm = nn.BatchNorm2d(num_features=self.conv_filters_out)

    
    def forward(self, x):

        x = self.convtranspose2d(x)

        x = self.upsampling(x)

        x = self.batchnorm(x)

        x = F.relu(x)

        return x


        

In [None]:
class Encoder(nn.Module):

    def __init__(self,
                input_shape: List[int],
                conv_filters:List[int], # must have 1 more element than others -> FIRST element must be 3 for colored images
                conv_kernels: List[Tuple[int]],
                conv_strides: List[Tuple[int]],
                paddings: List[Tuple[int]],
                dilations: List[Tuple[int]],
                latent_space_dim: int,
                **kwargs):
    
        super(Encoder, self).__init__(**kwargs)
        
        self.conv_filters = conv_filters # [2, 4, 8]
        self.conv_kernels = conv_kernels 
        self.conv_strides = conv_strides
        self.paddings = paddings
        self.dilations = dilations
        self.latent_space_dim = latent_space_dim

        
        # dim assertion

        assert len(self.conv_kernels) == len(self.conv_strides) == len(self.paddings)

        self.convblocks = nn.Sequential(
            OrderedDict(
                [
            (
            f"Convolution_Block_{i+1}",
            Conv2dBlock(conv_filters_in=self.conv_filters[i],
                        conv_filters_out=self.conv_filters[i+1],
                        conv_kernels=self.conv_kernels[i],
                        conv_strides=self.conv_strides[i],
                        paddings=self.paddings[i],
                        dilations=self.dilations[i]).float()
            )
            
            for i in range(len(self.conv_filters) - 1)
            
                ]
            )
        )   

        self.shape_before_bottleneck = self._calculate_shape_before_bottleneck(input_shape).shape

        self.flatten_shape = torch.numel(self._calculate_shape_before_bottleneck(input_shape))

        self.flatten = nn.Flatten()

        self.mu = nn.Linear(self.flatten_shape, self.latent_space_dim)

        self.log_sigma = nn.Linear(self.flatten_shape, self.latent_space_dim)

        

    def _calculate_shape_before_bottleneck(self, input_shape: List[int]):

        x = torch.ones(input_shape, dtype=torch.float32).to(device)

        x = torch.unsqueeze(x, 0) 

        for convblock in self.convblocks:

            x = convblock(x)

        return x   

    def _reparameterized(self, mu, log_sigma):
        
        eps = torch.randn(size=mu.shape, dtype=torch.float32).to(device)
        
        sample_point = mu + torch.exp(log_sigma / 2) * eps
        
        return sample_point    

    def forward(self, x):

        for convblock in self.convblocks:
            
            x = convblock(x)

        x = self.flatten(x)

        mu = self.mu(x)

        log_sigma= self.log_sigma(x)

        x = self._reparameterized(mu, log_sigma)

        return x, (mu, log_sigma)


In [None]:
class Decoder(nn.Module):

    def __init__(self,
                latent_space_dim :int,
                shape_before_bottleneck : torch.Size,
                conv_filters : List[int], # must have 1 more element than others -> LAST element must be 3 for colored images
                conv_kernels : List[Tuple[int]],
                conv_strides : List[Tuple[int]],
                paddings : List[Tuple[int]],
                output_paddings : List[Tuple[int]],
                dilations : List[Tuple[int]],
                out_channel : int,
                **kwargs):

        super(Decoder, self).__init__(**kwargs)

        self.conv_filters = conv_filters
        self.conv_kernels = conv_kernels
        self.conv_strides = conv_strides
        self.paddings = paddings
        self.output_paddings = output_paddings
        self.dilations = dilations
        self.out_channel = out_channel
        self.shape_before_bottleneck = shape_before_bottleneck

        self.flatten_shape = torch.numel(torch.ones(self.shape_before_bottleneck, dtype=torch.float32, device=device))
        self.fc = nn.Linear(latent_space_dim, self.flatten_shape)

        # dim assertion 
        assert len(self.conv_kernels) == len(self.conv_strides) == len(self.paddings) == len(self.output_paddings) == len(self.dilations)

        self.convtransposes = nn.Sequential(
            OrderedDict(
                [
            (
            f"Convolution_Transpose_Block{i+1}",          
            ConvTranspose2dBlock(conv_filters_in=self.conv_filters[i],
                                conv_filters_out=self.conv_filters[i+1],
                                conv_kernels=self.conv_kernels[i],
                                conv_strides=self.conv_strides[i],
                                paddings=self.paddings[i],
                                output_paddings=self.output_paddings[i],
                                dilations=self.dilations[i])
            )
            
            for i in range(len(self.conv_filters) - 1)

                ]
            )
        )

        self.output_convolution = nn.ConvTranspose2d(in_channels=self.conv_filters[-1],
                                                    out_channels=3, # colored images
                                                    kernel_size=self.conv_kernels[0],
                                                    stride=(1, 1),
                                                    padding=(1, 1),
                                                    output_padding=(0, 0),
                                                    dilation=(2, 2)

        )

    
    def forward(self, x):

        x = self.fc(x)
        
        x = x.view(self.shape_before_bottleneck)

        for convtransposeblock in self.convtransposes:

            x = convtransposeblock(x)

        x = self.output_convolution(x)

        x = torch.tanh(x)

        return x


In [None]:
class VAE(nn.Module):

    def __init__(self,
                input_shape : List[int],
                conv_filters : List[Tuple[int]],
                conv_kernels : List[Tuple[int]],
                conv_strides : List[Tuple[int]],
                paddings : List[Tuple[int]],
                output_paddings : List[Tuple[int]],
                dilations: List[Tuple[int]],
                latent_space_dim : int,
                **kwargs):

        super(VAE, self).__init__(**kwargs)

        self.input_shape = input_shape
        
        self.latent_space_dim = latent_space_dim
        
        self.encoder = Encoder(input_shape=input_shape,
                                conv_filters=conv_filters,
                                conv_kernels=conv_kernels,
                                conv_strides=conv_strides,
                                paddings=paddings,
                                dilations=dilations,
                                latent_space_dim=latent_space_dim
                                )

        self.shape_before_bottleneck = self.encoder.shape_before_bottleneck

        self.decoder = Decoder(latent_space_dim=latent_space_dim,
                                shape_before_bottleneck=self.encoder.shape_before_bottleneck,
                                conv_filters=conv_filters[::-1],
                                conv_kernels=conv_kernels[::-1],
                                conv_strides=conv_strides[::-1],
                                paddings=paddings[::-1],
                                output_paddings=output_paddings,
                                dilations=dilations[::-1],
                                out_channel=3
                                )


    def forward(self, x):
        z, (mu, log_sigma) = self.encoder(x)
            
        x_prime = self.decoder(z)

        return z, mu, log_sigma, x_prime

    def sample(self, eps=None):

        if eps is None:
            eps = torch.randn([1, self.latent_space_dim])
            return self.decoder(eps)

        else:
            return self.decoder(eps)

    def reconstruct(self, images):
        latent_representations = self.encoder(images)
        reconstructed_images = self.decoder(latent_representations)

        return reconstructed_images, latent_representations

    @staticmethod
    def kl_div(mu, log_sigma):
        loss = -0.5 * torch.sum(1 + log_sigma + torch.square(mu) -torch.exp(log_sigma), 1)

        return loss
    
    def loss_fn(self, x, x_prime, mu, log_sigma):

        kld_loss = self.kl_div
        recon_loss = nn.MSELoss()

        kld = kld_loss(mu, log_sigma)
        recon = recon_loss(x, x_prime)

        loss = kld + recon

        return loss, kld, recon


    
                  

In [None]:
a ="""
VAE(input_shape=[3, 64, 64],
    conv_filters=[3, 16, 32, 64 , 128],
    conv_kernels=[(5, 5), (3, 3), (3, 3), (3, 3)],
    conv_strides=[(1, 1), (1, 1), (1, 1), (1, 1)],
    paddings=[(1, 1), (1, 1), (1, 1), (1, 1)],
    output_paddings=[(0, 0), (0, 0), (0, 0), (0, 0)],
    dilations=[1, 1, 1, 1],
    latent_space_dim=1024)
"""

In [None]:
vae = VAE(input_shape=[3, 64, 64],
    conv_filters=[3, 32, 64, 128 , 256],
    conv_kernels=[(5, 5), (3, 3), (3, 3), (3, 3)],
    conv_strides=[(1, 1), (1, 1), (1, 1), (1, 1)],
    paddings=[(1, 1), (1, 1), (1, 1), (1, 1)],
    output_paddings=[(0, 0), (0, 0), (0, 0), (0, 0)],
    dilations=[(1, 1), (1, 1), (1, 1), (1, 1)],
    latent_space_dim=1024)

In [None]:
class TanksDataset(Dataset):

    def __init__(self, transform=None):
        
        path = "alltanks.npy"

        images_data = np.load(path)

        data = np.swapaxes(images_data, 3, 1)

        self.data = data

        self.transform = transform
    
    def __getitem__(self, index):

        if self.transform:

            return self.transform(self.data[index])

    def __len__(self):
        
        return self.data.shape[0]

In [None]:
class ToTensor:
    # Convert ndarrays to Tensors
    def __call__(self, sample):
        x = sample
        return torch.from_numpy(x)

In [None]:
train_loader = DataLoader(TanksDataset(transform=ToTensor()), batch_size=1, shuffle=True)

In [None]:
iterdata = iter(train_loader)
print(iterdata.next())

In [None]:
def train(vae, dataloader, epochs=1, device=torch.device("cpu")):
        vae = vae.to(device)
        vae = vae.double()
        #transform = T.ConvertImageDtype(dtype=torch.double)
        optimizer = torch.optim.Adam(vae.parameters(), lr=0.001)
        reported_loss = []
        for epoch in range(epochs):

            collective_loss = []
            for _, x in tqdm(enumerate(dataloader)):

                x.to(device)
                
                #x = transform(images)

                #assert x.dtype == torch.double

                _, mu, log_sigma, x_prime = vae.forward(x.double())

                loss, recon, kld = vae.loss_fn(x, x_prime, mu, log_sigma)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                collective_loss.append([recon.item(), kld.item()])
            
            np_collective_loss = np.array(collective_loss)
            
            average_loss = np.mean(np_collective_loss, axis=1)

            reported_loss.append(average_loss)

            print(f"Epoch {epoch+1} finished!", f"reconstruction_loss = {average_loss[0]} || KL-Divergence = {average_loss[1]}", sep="\n")

            if (epoch+1) % 10 == 0:

                with torch.no_grad():
                    
                    to_img = T.ToPILImage()
                    
                    example = vae.sample()
                    
                    img_example = to_img(example)

                    img_example.save(f"result_at_epoch_{epoch+1}.png")
                    
        
        print("Training Finished!")

        return np.array(list(zip(range(epochs), average_loss)))

In [None]:
train(vae, train_loader, epochs=100, device=torch.device("cuda"))