<a href="https://colab.research.google.com/github/NikNord174/Autoencoder/blob/main/Autoencoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Autoencoder

## Packages and data import

In [None]:
import os
import sys

from time import time

import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import Compose, ToTensor, Normalize
from torch.nn.modules.container import Sequential


In [None]:
BATCH_SIZE = 64
NUM_WORKERS = 2
if torch.cuda.is_available():
    DEVICE = torch.device('cuda')
else:
    DEVICE = torch.device('cpu')


In [None]:
transform=transforms.Compose([
                              ToTensor(),
                              Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                              ])
train = datasets.CIFAR10('../data',
                         train=True,
                         download=True,
                         transform=transform)
test = datasets.CIFAR10('../data',
                        train=False,
                        download=True,
                        transform=transform)
train_loader = torch.utils.data.DataLoader(train,
                                           batch_size=BATCH_SIZE,
                                           num_workers=NUM_WORKERS,
                                           shuffle=True)
test_loader = torch.utils.data.DataLoader(test,
                                          batch_size=BATCH_SIZE,
                                          num_workers=NUM_WORKERS,
                                          shuffle=True)


## Models

### Autoencoder_Initial

In [None]:
class Autoencoder_Initial(nn.Module):
    def __init__(self):
        super(Autoencoder_Initial, self).__init__()
        self.channels = [3, 60, 120, 240]
        self.hidden_state = 300
        self.encoder = self.encoder_layers()
        self.flatten = nn.Flatten()
        self.lin_neurons = [240, 4, 4]
        self.enc_neurons = np.prod(self.lin_neurons)
        self.linear = nn.Linear(self.enc_neurons,
                                self.hidden_state)
        self.rev_linear = nn.Linear(self.hidden_state,
                                    self.enc_neurons)
        self.decoder = self.decoder_layers()

    def simple_enc_block(self,
                         input_channels: int = 3,
                         output_channels: int = 3,
                         kernel_size: int = 4,
                         stride: int = 2,
                         padding: int = 1,
                         final_layer: bool = False) -> nn.Sequential:
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(in_channels=input_channels,
                          out_channels=output_channels,
                          kernel_size=kernel_size,
                          stride=stride,
                          padding=padding),
                nn.BatchNorm2d(output_channels),
                nn.Dropout(p=0.2),
                nn.LeakyReLU(0.2)
                )
        else:
            return nn.Sequential(
                nn.Conv2d(in_channels=input_channels,
                          out_channels=output_channels,
                          kernel_size=kernel_size,
                          stride=stride,
                          padding=padding)
                )

    def encoder_layers(self) -> nn.Sequential:
        layers = []
        for i in range(len(self.channels)-1):
            if self.channels[i] != self.channels[-2]:
                layers.append(self.simple_enc_block(
                    input_channels=self.channels[i],
                    output_channels=self.channels[i+1]
                ))
            else:
                layers.append(self.simple_enc_block(
                    input_channels=self.channels[i],
                    output_channels=self.channels[i+1],
                    final_layer=True
                    ))
        return nn.Sequential(*layers)

    def simple_dec_block(self,
                         input_channels: int = 3,
                         output_channels: int = 3,
                         kernel_size: int = 4,
                         stride: int = 2,
                         padding: int = 1,
                         final_layer: bool = False) -> nn.Sequential:
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels=input_channels,
                                   out_channels=output_channels,
                                   kernel_size=kernel_size,
                                   stride=stride,
                                   padding=padding),
                nn.BatchNorm2d(output_channels),
                nn.Dropout(p=0.2),
                nn.LeakyReLU(0.2)
                )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels=input_channels,
                                   out_channels=output_channels,
                                   kernel_size=kernel_size,
                                   stride=stride,
                                   padding=padding)
                )

    def decoder_layers(self) -> nn.Sequential:
        layers = []
        dec_channels = list(reversed(self.channels))
        for i in range(len(dec_channels)-1):
            if dec_channels[i] != dec_channels[-2]:
                layers.append(self.simple_dec_block(
                    input_channels=dec_channels[i],
                    output_channels=dec_channels[i+1]
                    ))
            else:
                layers.append(self.simple_dec_block(
                    input_channels=dec_channels[i],
                    output_channels=dec_channels[i+1],
                    final_layer=True
                    ))
        return nn.Sequential(*layers)

    def forward(self, x):
        encoded = self.encoder(x)
        hidden_layer = torch.flatten(encoded, start_dim=1)
        hidden_layer = self.linear(hidden_layer)
        hidden_layer = self.rev_linear(hidden_layer)
        hidden_layer = hidden_layer.view(-1, *self.lin_neurons)
        decoded = self.decoder(hidden_layer)
        return decoded


### Autoencoder_ConvTranspose

In [None]:
class Autoencoder_ConvTranspose(nn.Module):
    def __init__(self):
        super(Autoencoder_ConvTranspose, self).__init__()
        self.channels = [3, 100, 200, 300]
        self.encoder = self.encoder_layers()
        self.decoder = self.decoder_layers()

    def simple_enc_block(self,
                         input_channels: int = 3,
                         output_channels: int = 3,
                         kernel_size: int = 3,
                         final_layer: bool = False) -> nn.Sequential:
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(
                    in_channels=input_channels,
                    out_channels=output_channels,
                    kernel_size=kernel_size,
                ),
                nn.BatchNorm2d(output_channels),
                nn.MaxPool2d(kernel_size=3),
                nn.Dropout(p=0.2),
                nn.LeakyReLU(0.2)
                )
        else:
            return nn.Sequential(
                nn.Conv2d(
                    in_channels=input_channels,
                    out_channels=300,
                    kernel_size=2,
                ),
                )

    def encoder_layers(self) -> nn.Sequential:
        layers = []
        for i in range(len(self.channels)-1):
            if self.channels[i] != self.channels[-2]:
                layers.append(self.simple_enc_block(
                    input_channels=self.channels[i],
                    output_channels=self.channels[i+1])
                )
            else:
                layers.append(self.simple_enc_block(
                    input_channels=self.channels[i],
                    output_channels=self.channels[i+1],
                    final_layer=True)
                )
        return nn.Sequential(*layers)

    def simple_dec_block(self,
                         input_channels: int = 3,
                         output_channels: int = 3,
                         kernel_size: int = 3,
                         stride: int = 2,
                         dilation: int = 2,
                         output_padding: int = 1,
                         final_layer: bool = False) -> nn.Sequential:
        if not final_layer:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels=input_channels,
                                   out_channels=output_channels,
                                   kernel_size=kernel_size,
                                   stride=stride,
                                   dilation=dilation,
                                   output_padding=output_padding,),
                nn.BatchNorm2d(output_channels),
                nn.Dropout(p=0.2),
                nn.LeakyReLU(0.2)
                )
        else:
            return nn.Sequential(
                nn.ConvTranspose2d(in_channels=input_channels,
                                   out_channels=output_channels,
                                   kernel_size=kernel_size,
                                   stride=stride,
                                   dilation=1,
                                   padding=1,
                                   output_padding=1)
                )

    def decoder_layers(self) -> nn.Sequential:
        layers = []
        dec_channels = list(reversed(self.channels))
        for i in range(len(dec_channels)-1):
            if dec_channels[i] != dec_channels[-2]:
                layers.append(self.simple_dec_block(
                    input_channels=dec_channels[i],
                    output_channels=dec_channels[i+1])
                )
            else:
                layers.append(self.simple_dec_block(
                    input_channels=dec_channels[i],
                    output_channels=dec_channels[i+1],
                    final_layer=True)
                )
        return nn.Sequential(*layers)

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


### Autoencoder_Upsamling

In [None]:
class Autoencoder_Upsampling(nn.Module):
    def __init__(self):
        super(Autoencoder_Upsampling, self).__init__()
        self.channels = [3, 100, 200, 300]
        self.encoder = self.encoder_layers()
        self.decoder = self.decoder_layers()

    def simple_enc_block(self,
                         input_channels: int = 3,
                         output_channels: int = 3,
                         kernel_size: int = 3,
                         final_layer: bool = False) -> nn.Sequential:
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(
                    in_channels=input_channels,
                    out_channels=output_channels,
                    kernel_size=kernel_size,
                ),
                nn.BatchNorm2d(output_channels),
                nn.MaxPool2d(kernel_size=3),
                nn.Dropout(p=0.2),
                nn.LeakyReLU(0.2),)
        else:
            return nn.Sequential(
                nn.Conv2d(
                    in_channels=input_channels,
                    out_channels=output_channels,
                    kernel_size=2,
                ),
                )

    def encoder_layers(self) -> nn.Sequential:
        layers = []
        for i in range(len(self.channels)-1):
            if self.channels[i] != self.channels[-2]:
                layers.append(self.simple_enc_block(
                    input_channels=self.channels[i],
                    output_channels=self.channels[i+1])
                )
            else:
                layers.append(self.simple_enc_block(
                    input_channels=self.channels[i],
                    output_channels=self.channels[i+1],
                    final_layer=True)
                )
        return nn.Sequential(*layers)

    def simple_dec_block(self,
                         input_channels: int = 3,
                         output_channels: int = 3,
                         kernel_size: int = 1,
                         scale_factor: int = 4,
                         mode: str = 'nearest',
                         final_layer: bool = False) -> nn.Sequential:
        if not final_layer:
            return nn.Sequential(
                nn.Conv2d(
                    in_channels=input_channels,
                    out_channels=output_channels,
                    kernel_size=kernel_size,
                ),
                nn.Upsample(scale_factor=scale_factor, mode=mode),
                nn.BatchNorm2d(output_channels),
                nn.Dropout(p=0.2),
                nn.LeakyReLU(0.2)
                )
        else:
            return nn.Sequential(
                nn.Conv2d(
                    in_channels=input_channels,
                    out_channels=output_channels,
                    kernel_size=kernel_size,
                ),
                nn.Upsample(scale_factor=2, mode=mode),
                )

    def decoder_layers(self) -> nn.Sequential:
        layers = []
        dec_channels = list(reversed(self.channels))
        for i in range(len(dec_channels)-1):
            if dec_channels[i] != dec_channels[-2]:
                layers.append(self.simple_dec_block(
                    input_channels=dec_channels[i],
                    output_channels=dec_channels[i+1])
                )
            else:
                layers.append(self.simple_dec_block(
                    input_channels=dec_channels[i],
                    output_channels=dec_channels[i+1],
                    final_layer=True)
                )
        return nn.Sequential(*layers)

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded


## Losses

In [None]:
def MSE(fake: torch.Tensor,
        image: torch.Tensor) -> torch.Tensor:
    return torch.mean((fake - image)**2)

In [None]:
def SSIM(fake: torch.Tensor, real: torch.Tensor) -> torch.Tensor:
    k1: float = 0.01
    k2: float = 0.03
    mean_fake = torch.mean(fake, dim=[1, 2, 3])
    mean_real = torch.mean(real, dim=[1, 2, 3])
    var_fake = torch.var(fake, dim=[1, 2, 3])
    var_real = torch.var(real, dim=[1, 2, 3])
    c1 = (k1) ** 2  # L=1, because of normalised images
    c2 = (k2) ** 2  # L=1, because of normalised images
    fake_dif = torch.sub(fake, mean_fake[:, None, None, None])
    real_dif = torch.sub(real, mean_real[:, None, None, None])
    covariance = (
        torch.sum(fake_dif * real_dif, dim=[1, 2, 3])
        / real.size()[-1] / real.size()[-2] / real.size()[-3]
    )
    ssim_numerator = (
        (2 * mean_fake * mean_real + c1) * (2 * covariance + c2)
    )
    ssim_denominator = (
        ((mean_fake) ** 2 + (mean_real) ** 2 + c1)
        * ((var_fake) ** 2 + (var_real) ** 2 + c2)
    )
    ssim = ssim_numerator / ssim_denominator
    return torch.mean(ssim)


## Train and test functions

In [None]:
def train(model: nn.modules,
          train_loader: DataLoader,
          optimizer: torch.optim,
          device: torch.device) -> None:
    for images, _ in train_loader:
        images = images.to(DEVICE)
        optimizer.zero_grad()
        outputs = model(images)
        batch_loss = MSE(outputs, images)
        batch_loss.backward()
        optimizer.step()

In [None]:
def test(model: nn.modules,
         test_loader: DataLoader,
         device: torch.device) -> float:
    model.eval()
    test_loss = 0
    acc = 0
    with torch.no_grad():
        for images, _ in test_loader:
            images = images.to(DEVICE)
            output = model(images)
            loss = MSE(output, images)
            test_loss += loss.item()
    return test_loss / len(test_loader)

## Training

In [None]:
epochs = 100
lr = 0.001
test_loss_list = []
eps = 1e-3

model = Autoencoder_Upsampling().to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

t0 = time()

for epoch in range(epochs):
    loss = 0.0
    train(model, train_loader, optimizer, DEVICE)
    test_loss = test(model, test_loader, DEVICE)
    test_loss_list.append(test_loss)
    t1 = (time() - t0) / 60
    print('Epoch: {}, test loss: {:.5f}, '.format(epoch+1, test_loss) + 
            'time: {:.2f} min'.format(t1))
    if epoch > 2:
        if max(test_loss_list[-5:]) - min(test_loss_list[-5:]) > eps:
            continue
        else:
            break
print('Finish!')


In [None]:
plt.plot(range(1, len(test_loss_list) + 1), test_loss_list)
plt.show()

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
from torchvision.utils import save_image

def imshow(img, im_name):
    img = img.detach().cpu()
    save_image(img, im_name)
    npimg = img.numpy()
    plt.axis('off')
    npimg = np.transpose(npimg, (1, 2, 0))
    plt.imshow(npimg)
    plt.show()


def get_file_size_in_bytes(file_path):
   size = os.path.getsize(file_path)
   return size


for i, (im, lab) in enumerate(train_loader):
    print(im.shape)
    print(lab[0])
    imshow(im[0], '/content/image.jpeg')
    fake = model(im.to(DEVICE))
    print(fake.shape)
    imshow(fake[0], '/content/fake.jpeg')
    real_size = get_file_size_in_bytes('/content/image.jpeg')
    compressed_state=model.encoder_layers()
    jpeg_comp_rate = 3*32*32 / real_size
    #ae_comp_rate = sys.getsizeof(fake[0]) / sys.getsizeof(compressed_state(im)[0])
    ae_comp_rate = 3*32*32 / sys.getsizeof(compressed_state(im)[0])
    print('jpeg ompression rate : {:.2f}'.format(jpeg_comp_rate))
    print('AE ompression rate : {:.2f}'.format(ae_comp_rate))
    break

