In [1]:
import pandas as pd
import numpy as np
import torch
import torchvision 
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch import optim
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

This is a set of convenience functions to use instead of using the layers with thier hyperparamters.

In [2]:
def conv(channels_in, channels_out): 
    return nn.Conv2d(channels_in, channels_out, kernel_size = 3, stride = 1, padding = 'same', bias = False)

def pool(): 
    return nn.MaxPool2d(kernel_size = 2, stride = 2)

def conv1x1(channels_in, channels_out): 
    return nn.Conv2d(channels_in, channels_out, kernel_size = 1, stride = 1, padding = 'same')

def bn(channels_in): 
    return nn.BatchNorm2d(channels_in)

def relu():
    return nn.ReLU(inplace = True)

def up():
    return nn.Upsample(scale_factor = 2)

def convUp(channels_in, channels_out): 
    return nn.Conv2d(channels_in, channels_out, kernel_size = 3, stride = 1, padding = 'same', bias = False)

This a Sequential Module for a single block in the encoder.

In [3]:
class down_Sample_Block(nn.Sequential):
    def __init__(self, channels_in, channels_out):
        super(down_Sample_Block, self).__init__()
        self.add_module('conv1', conv(channels_in, channels_out))
        self.add_module('conv2', conv(channels_out, channels_out))
        self.add_module('conv3', conv(channels_out, channels_out))
        self.add_module('conv4', conv(channels_out, channels_out))
        self.add_module('norm', bn(channels_out))
        self.add_module('relu', relu())
        self.add_module('pool', pool())

This a Sequential Module for a single block in the decoder.

In [4]:
class up_Sample_Block(nn.Sequential):
    def __init__(self, channels_in, channels_out):
        super(up_Sample_Block, self).__init__()
        self.add_module('upSample', up())
        self.add_module('conv1', convUp(channels_in, channels_out))
        self.add_module('conv2', convUp(channels_out, channels_out))
        self.add_module('conv3', convUp(channels_out, channels_out))
        self.add_module('conv4', convUp(channels_out, channels_out))
        self.add_module('norm', bn(channels_out))
        self.add_module('relu', relu())

Sequential module for the encoder block as a whole, it is designed to halve the height and width of the image, and double the number of channels, until it reaches (batch, channels, 1, 1).

It is then passed through a conv_1x1 to reduce channels to 128, then it is flattened to remove the height and width dimensions (last two).

In [5]:
class Encoder(nn.Sequential):
    def __init__(self, encoder_output_length = 128):
        super(Encoder, self).__init__()
        self.add_module('layer_1', down_Sample_Block(3, 8)) # 64
        self.add_module('layer_2', down_Sample_Block(8, 16)) # 32
        self.add_module('layer_3', down_Sample_Block(16, 32)) # 16
        self.add_module('layer_4', down_Sample_Block(32, 64)) # 8
        self.add_module('layer_5', down_Sample_Block(64, 128)) # 4
        self.add_module('layer_6', down_Sample_Block(128, 256)) # 2
        self.add_module('layer_7', down_Sample_Block(256, 256)) # 1

        self.add_module('conv1x1', conv1x1(256, encoder_output_length))
        self.add_module('flatten', nn.Flatten())

Module for the bottleneck block, it is the part where the mean and standard deviation layers are found, and where the reparameterization occurs.

The output of the reconstruction has dimensions of (batch_size, channels), so inorder to make it suitable for the decoder, it is unsqueezed at dimensions 2 and 3 (the height and width dimensions are added again just like they were removed at the end of the encoder)

In [6]:
class Bottleneck(nn.Module):
    def __init__(
        self, 
        latent_vec_len = 32,
        encoder_output_length = 128,
        decoder_input_length = 256
    ):
        super().__init__()
        
        self.add_module('mean_layer', nn.Linear(encoder_output_length, latent_vec_len)),
        self.add_module('standard_deviation_layer', nn.Linear(encoder_output_length, latent_vec_len)),
        self.add_module('output_linear_layer', nn.Linear(latent_vec_len, decoder_input_length))
        
    def forward(self, x):
        mean = self.mean_layer(x)
        standard_deviation = self.standard_deviation_layer(x)
        epsilon = torch.randn_like(standard_deviation)
        
        x_reparameterized = mean + standard_deviation*epsilon
        x_reconstructed = self.output_linear_layer(x_reparameterized)
        
        x_reconstructed = torch.unsqueeze(x_reconstructed, 2)
        x_reconstructed = torch.unsqueeze(x_reconstructed, 3)
        
        return x_reconstructed, mean, standard_deviation

Sequential module for the decoder block as a whole, works in opposite fashion to the encoder.

It is designed to take a tensor of dimensions (batch, channels, 1, 1) and double its height and width, and halve the number of channels, until it reaches (batch, 3, 128, 128).

In [7]:
class Decoder(nn.Sequential):
    def __init__(self):
        super(Decoder, self).__init__()
        self.add_module('layer_1', up_Sample_Block(256, 256)) # 2
        self.add_module('layer_2', up_Sample_Block(256, 128)) # 4
        self.add_module('layer_3', up_Sample_Block(128, 64)) # 8
        self.add_module('layer_4', up_Sample_Block(64, 32)) # 16
        self.add_module('layer_5', up_Sample_Block(32, 16)) # 32
        self.add_module('layer_6', up_Sample_Block(16, 8)) # 64
        self.add_module('layer_7', up_Sample_Block(8, 3)) # 128

Module for the VAE itself.

In [8]:
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.add_module('encoder', Encoder())
        self.add_module('bottleneck', Bottleneck())
        self.add_module('decoder', Decoder())
    
    def forward(self, x):
        x = self.encoder(x)
        x, mean, standard_deviation = self.bottleneck(x)
        x = self.decoder(x) 
        return x, mean, standard_deviation
    
    def generate(self, device):
        with torch.no_grad():
            x = torch.randn(1, 128)
            x = x.to(device)
            x, _, _ = self.bottleneck(x)
            x = self.decoder(x)
        return x

In [9]:
model = VAE()

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [11]:
from torchinfo import summary

model = model
batch_size = 128
summary(model, input_size=(batch_size, 3, 128, 128))

Layer (type:depth-idx)                   Output Shape              Param #
VAE                                      [128, 3, 128, 128]        --
├─Encoder: 1-1                           [128, 128]                --
│    └─down_Sample_Block: 2-1            [128, 8, 64, 64]          --
│    │    └─Conv2d: 3-1                  [128, 8, 128, 128]        216
│    │    └─Conv2d: 3-2                  [128, 8, 128, 128]        576
│    │    └─Conv2d: 3-3                  [128, 8, 128, 128]        576
│    │    └─Conv2d: 3-4                  [128, 8, 128, 128]        576
│    │    └─BatchNorm2d: 3-5             [128, 8, 128, 128]        16
│    │    └─ReLU: 3-6                    [128, 8, 128, 128]        --
│    │    └─MaxPool2d: 3-7               [128, 8, 64, 64]          --
│    └─down_Sample_Block: 2-2            [128, 16, 32, 32]         --
│    │    └─Conv2d: 3-8                  [128, 16, 64, 64]         1,152
│    │    └─Conv2d: 3-9                  [128, 16, 64, 64]         2,304
│    

In [12]:
import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset

class dataset(Dataset):
    def __init__(
        self, img_dir,
        transform = transforms.Compose([
            transforms.Resize(size=(128, 128)),
            transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
        ]),
        target_transform = transforms.Resize(size=(128, 128))
    ):
        self.img_dir = img_dir
        self.transform = transform
        self.target_transform = target_transform

    def __len__(self):
        return len(os.listdir(self.img_dir))

    def __getitem__(self, idx):
        os.chdir(self.img_dir)
        img_path = self.img_dir + "\\\\" + str(idx+1) + ".jpg"
        # img_path = os.path.join(self.img_dir, str(idx) + ".JPEG")
        image = read_image(img_path)
        image = image.to(torch.float32)
        image = self.transform(image)
        label = image        
        return image, label

In [13]:
# Hyperparameters
learning_rate = 3e-7
batch_size = 1
num_epochs = 30

In [14]:
path = r"E:\\College\\FCAI-4th Year\\First Term\\Generative Adversarial Networks\\Assginments\\Assignment 4\\dataset\\Train"

train_set = dataset(img_dir = path)

train_loader = DataLoader(dataset = train_set, batch_size = batch_size, shuffle = True)

In [15]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [16]:
model = model.to(device)

In [17]:
optimizer = optim.Adam(model.parameters(), lr = learning_rate)
criterion = nn.CrossEntropyLoss()

In [18]:
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(tqdm(train_loader)):
        data = data.to(device)
        targets = targets.to(device)
        scores, mean, standard_deviation = model(data)
        
        reconstruction_loss = criterion(targets, scores)
        kl_divergence = -torch.sum(1 + torch.log(standard_deviation.pow(2)) - mean.pow(2) - standard_deviation.pow(2)) 
        loss = reconstruction_loss + kl_divergence
  
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    print(f'Epoch:{epoch+1}, Loss:{loss.item():f}')

100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 12.09it/s]


Epoch:1, Loss:104.572815


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:14<00:00,  7.66it/s]


Epoch:2, Loss:130.820389


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:12<00:00,  9.05it/s]


Epoch:3, Loss:99.222008


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.15it/s]


Epoch:4, Loss:126.529678


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:16<00:00,  6.54it/s]


Epoch:5, Loss:103.035629


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:14<00:00,  7.54it/s]


Epoch:6, Loss:67.514786


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:11<00:00,  9.30it/s]


Epoch:7, Loss:62.583881


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:16<00:00,  6.54it/s]


Epoch:8, Loss:94.166489


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:16<00:00,  6.54it/s]


Epoch:9, Loss:78.799835


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:16<00:00,  6.54it/s]


Epoch:10, Loss:56.705467


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:13<00:00,  8.44it/s]


Epoch:11, Loss:57.971195


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.14it/s]


Epoch:12, Loss:85.381927


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.12it/s]


Epoch:13, Loss:50.849785


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.04it/s]


Epoch:14, Loss:46.941994


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.12it/s]


Epoch:15, Loss:81.164429


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.11it/s]


Epoch:16, Loss:90.624046


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.08it/s]


Epoch:17, Loss:48.740295


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.08it/s]


Epoch:18, Loss:48.263474


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.10it/s]


Epoch:19, Loss:43.105839


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.12it/s]


Epoch:20, Loss:60.623535


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.08it/s]


Epoch:21, Loss:45.806664


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.13it/s]


Epoch:22, Loss:52.595612


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.15it/s]


Epoch:23, Loss:59.139187


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.05it/s]


Epoch:24, Loss:45.814358


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.11it/s]


Epoch:25, Loss:32.878094


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.14it/s]


Epoch:26, Loss:42.180889


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.08it/s]


Epoch:27, Loss:45.391171


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.12it/s]


Epoch:28, Loss:37.836121


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.14it/s]


Epoch:29, Loss:29.179348


100%|████████████████████████████████████████████████████████████████████████████████| 110/110 [00:09<00:00, 11.15it/s]

Epoch:30, Loss:37.972500





In [19]:
PATH = "E:\\College\\FCAI-4th Year\\First Term\\Generative Adversarial Networks\\Assginments\\Assignment 4\\VAE_weights.pth"
torch.save(model.state_dict(), PATH)

Generate Images from test data

In [32]:
import torchvision.transforms as T

for i in range(10):
    new_image = model.generate(device)
    new_image = torch.squeeze(new_image, 0)
    transform = T.ToPILImage()
    img = transform(new_image)
    img.save("E:\\College\\FCAI-4th Year\\First Term\\Generative Adversarial Networks\\Assginments\\Assignment 4\\Generated Images\\img"+ str(i+1)+".png")