<a href="https://colab.research.google.com/github/DobryVecher1/dl-phys-vsu/blob/main/lectures/08_autoencoders.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torchmetrics

In [2]:
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import torch
from torch.utils.data.dataset import Dataset
from torchvision import transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torchmetrics
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split

In [3]:
class DatasetCIFAR(Dataset):

    def __init__(self, x_data, y_data, transform=None):
        self.x_data = x_data
        self.y_data = y_data
        self.transform = transform

    def __getitem__(self, index):
        """Load and return a sample from the dataset at the given index."""
        img = self.x_data[index]

        # augmentations
        if self.transform is not None:
            img = self.transform(img)

        label = torch.from_numpy(self.y_data[index])

        return img, label

    def __len__(self):
        """Return the number of samples in dataset."""
        return len(self.x_data)

In [4]:
class DatamoduleCIFAR():
    """Create dataset and loaders, apply transforms."""

    def __init__(self):
        # load data
        (self.x_train, self.y_train), (self.x_val, self.y_val) = tf.keras.datasets.cifar10.load_data()

        # make dataset smaller if needed
        # self.x_train = self.x_train[:1000]
        # self.y_train = self.y_train[:1000]
        # self.x_val = self.x_val[:1000]
        # self.y_val = self.y_val[:1000]


    def create_loaders(self):
        """Create loaders both for train and test/validation datasets."""

        # train dataset
        dset_train = DatasetCIFAR(self.x_train, self.y_train, transform=transforms.ToTensor())
        # val dataset
        dset_val = DatasetCIFAR(self.x_val, self.y_val, transform=transforms.ToTensor())

        # Train and val dataloaders
        train_loader = DataLoader(dset_train, batch_size=100, shuffle=True)
        val_loader = DataLoader(dset_val, batch_size=100, shuffle=False)

        return train_loader, val_loader

In [5]:
class Autoencoder(nn.Module):

    def __init__(self):
        super().__init__()

        # FC Autoencoder
        # self.encoder = nn.Sequential(
        #     nn.Flatten(),
        #     nn.Linear(in_features=3072, out_features=512),
        #     nn.ReLU(),
        #     nn.Linear(in_features=512, out_features=128),
        #     nn.ReLU(),
        # )

        # self.decoder = nn.Sequential(
        #     nn.Linear(in_features=128, out_features=512),
        #     nn.ReLU(),
        #     nn.Linear(in_features=512, out_features=3072),
        #     nn.Sigmoid(),
        #     nn.Unflatten(1, (3, 32, 32))
        # )

        # Conv Autoencoder
        self.encoder = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=8, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),

            nn.Conv2d(in_channels=8, out_channels=4, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),

        )

        self.decoder = nn.Sequential(

            nn.Conv2d(in_channels=4, out_channels=8, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest'),

            nn.Conv2d(in_channels=8, out_channels=3, kernel_size=3, padding='same'),
            nn.ReLU(),
            nn.Upsample(scale_factor=2, mode='nearest')
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)

        return x


In [6]:
class ModelCIFAR():

    def __init__(self):

        self.autoencoder = Autoencoder().cuda()
        self.loss_mse = nn.MSELoss()
        self.optimizer = torch.optim.AdamW(self.autoencoder.parameters(), lr=1e-3)


    def fit(self, train_loader, val_loader, num_epoch=50):

        for ii in range(num_epoch):

            loss_batches = []
            # train
            for step, (images, labels) in enumerate(train_loader):
                # to cuda
                images = images.cuda()

                self.autoencoder.train()
                # make prediction
                img_out = self.autoencoder(images)

                # loss
                loss = self.loss_mse(images, img_out)

                # update weights
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()

                # save loss
                loss_batches.append(loss.item())

            print(f"Epoch: {ii}")
            print(f"TRAIN | Loss: {np.mean(loss_batches): .4f}")

            # val
            with torch.no_grad():
                loss_batches_val = []
                for step, (images, labels) in enumerate(val_loader):

                    images = images.cuda()

                    self.autoencoder.eval()
                    img_out = self.autoencoder(images)

                    loss = self.loss_mse(images, img_out)

                    # save loss
                    loss_batches_val.append(loss.item())

                print(f"val | Loss: {np.mean(loss_batches_val): .3f}")


    def predict(self, val_loader):

        with torch.no_grad():
            for step, (images, labels) in enumerate(val_loader):

                images = images.cuda()

                self.autoencoder.eval()
                img_out = self.autoencoder(images)

                break

            return images.cpu(), img_out.cpu()



In [None]:
model = ModelCIFAR()

train_loader, val_loader = DatamoduleCIFAR().create_loaders()

model.fit(train_loader, val_loader, num_epoch=50)

In [8]:
img_orig, img_out = model.predict(val_loader)

In [9]:
def plot_img(img_orig, img_out, num_img=5):

    fig, ax = plt.subplots(2, num_img, figsize=(10, 5))

    for ii in range(num_img):
        ax[0, ii].imshow(img_orig.permute(0, 2, 3, 1)[ii, :, :, :])
        ax[1, ii].imshow(img_out.permute(0, 2, 3, 1)[ii, :, :, :])

In [None]:
plot_img(img_orig, img_out, num_img=5)