In [None]:
!pip install pytorch_lightning

In [None]:
import numpy as np
import torch
import torch.nn as nn

import pytorch_lightning as pl
import torchvision
from torchvision.datasets import MNIST, CIFAR10, ImageNet
import os

from torchvision.transforms import ToTensor
from matplotlib import pyplot as plt

from torch.utils.data import DataLoader, random_split, Dataset
from torch.optim.lr_scheduler import StepLR

from PIL import Image


In [None]:
class Encoder(nn.Module):
    def __init__(self, input_channels, output_channels, linear_dim):
        super().__init__()
        self.linear_dim = linear_dim
        self.output_channels = output_channels
        self.conv = nn.Conv2d(input_channels, output_channels, kernel_size=3, padding=1)
        self.relu = nn.ReLU()
        self.pool = nn.MaxPool2d(2,2)    
        self.linear = nn.Linear(768, linear_dim)
        
    def forward(self, x):
        x = self.conv(x)
        x = self.relu(x)
        x = self.conv(x)
        x = self.pool(x)
        x = nn.Flatten()(x)
        x = self.linear(x)
        return x


In [None]:
class Decoder(nn.Module):
  # TODO: add the linear dimension to the argument
    def __init__(self, input_channels, output_channels, linear_dim):
        super().__init__()
        self.linear = nn.Linear(linear_dim, linear_dim)
        #self.unpool = nn.MaxUnpool2d(2)
        self.relu = nn.ReLU()
        self.deconv = nn.ConvTranspose2d(output_channels, input_channels, kernel_size=3, padding=1)
   
    def forward(self, z):
        # TODO: reshape z before using cnn, knowing that the output of the encoder now is a vector after the linear operator
        x = self.linear(z)
        x = x.view(x.size(0), 3, 32, 32)
        x = self.deconv(x)
        x = self.relu(x)
        x = self.deconv(x)
        x = self.relu(x)
        return x

In [None]:

# Create a PyTorch Lightning class
class AutoEncoder(pl.LightningModule):
    def __init__(self, input_channels, output_channels, linear_dim):
        super().__init__()
        self.input_shape = input_channels
        self.num_hidden = output_channels
        self.latent_dim = linear_dim
        self.encoder = Encoder(input_channels, output_channels, linear_dim)
        self.decoder = Decoder(input_channels, output_channels, linear_dim)

    def forward(self, x):
        # Forward pass
        x = self.encoder(x)
        x = self.decoder(x)
        return x
    
    def configure_optimizers(self):
        # TODO: add StepLR learning rate scheduler
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        schedduler = StepLR(optimizer, step_size=30, gamma=0.1)
        return [optimizer], [schedduler]

    def add_noise(self, x):
        noise = torch.randn_like(x)
        x_noisy = x + noise
        return x_noisy

    def training_step(self, batch, batch_idx):
        # Training step
        x = batch
        x_noisy = self.add_noise(x)
        z = self.encoder(x_noisy)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        return loss
    
    def validation_step(self, batch, batch_idx):
        # Validation step
        x = batch
        x_noisy = self.add_noise(x) 
        z = self.encoder(x_noisy)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        self.log("val_loss", loss, prog_bar=True, logger=True)
        return loss
    
    def test_step(self, batch, batch_idx):
        # Testing step
        x = batch
        x_noisy = self.add_noise(x)
        z = self.encoder(x_noisy)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        self.log("test_loss", loss, prog_bar=True, logger=True) 
        # Plot the first 8 images in the batch
        if batch_idx == 0:
            self.original = x[:8]
            self.images_noisy = x_noisy[:8]
            self.reconstructions = x_hat[:8]       
        return loss
    
    # final step after training to print original images and the results in the latent space 
    def on_test_end(self):
        for i, (im, recon, origin) in enumerate(zip(self.images_noisy, self.reconstructions, self.original)):
            im = im.permute(1, 2, 0)  
            recon = recon.permute(1, 2, 0)
            origin = origin.permute(1, 2, 0)
            #w = self.latent_dim // 8
            #z = z.reshape(8, w)

            plt.subplot(8, 3, i * 3 + 1)
            if i == 0:
                plt.title("Original Image")
            plt.imshow(origin.detach().cpu().numpy())  
            plt.axis("off")

            plt.subplot(8, 3, i * 3 + 2)
            if i == 0:
                plt.title("Image with noise")
            plt.imshow(im.detach().cpu().numpy())  
            plt.axis("off")

            plt.subplot(8, 3, i * 3 + 3)
            if i == 0:
                plt.title("Reconstruction")
            plt.imshow(recon.detach().cpu().numpy()) 
            plt.axis("off")

        plt.show() 




In [None]:

class TinyImageNetDataModule(pl.LightningDataModule):
    def __init__(self, data_dir, batch_size=32):
        super().__init__()
        self.data_dir = data_dir
        self.batch_size = batch_size

    def setup(self, stage=None):
        self.transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(),
         torchvision.transforms.Normalize((0.1307,), (0.3081,)),torchvision.transforms.Resize((32, 32))])

        self.train_dataset = TinyImageNetDataset(
            data_dir=(self.data_dir+'/train'),
            transform=self.transform
        )
        
        self.val_dataset = TinyImageNetDataset(
            data_dir=os.path.join(self.data_dir, 'val'),
            transform=self.transform
        )
        self.test_dataset = TinyImageNetDataset(
            data_dir=os.path.join(self.data_dir, 'test'),
            transform=self.transform
        )
        
        
    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            TinyImageNetDataset(data_dir=self.data_dir+'/train', transform=self.transform),
            batch_size=self.batch_size,
        )

    def val_dataloader(self):
        return torch.utils.data.DataLoader(
            TinyImageNetDataset(data_dir=self.data_dir+'/val', transform=self.transform),
            batch_size=self.batch_size,
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            TinyImageNetDataset(data_dir=self.data_dir+'/test', transform=self.transform),
            batch_size=self.batch_size,
        )
    

In [None]:

class TinyImageNetDataset(Dataset):
    def __init__(self, data_dir, transform=None):
        self.data_dir = data_dir
        self.transform = transform
        self.image_paths = self.load_data()

    def load_data(self):
        image_paths = []
 
        class_folders = os.listdir(self.data_dir)
        class_folders.sort()
        if(self.data_dir=="/kaggle/input/tiny-imagenet/tiny-imagenet-200/train"):
            for i, folder in enumerate(class_folders):        
                class_path = os.path.join(self.data_dir, folder, 'images')
                if(os.path.isdir(class_path)):
                    image_names = os.listdir(class_path)

                    for image_name in image_names:
                        image_path = os.path.join(class_path, image_name)
                        image_paths.append(image_path)
                
        else:
            class_path = self.data_dir+"/images"
            if(os.path.isdir(class_path)):
                image_names = os.listdir(class_path)

                for image_name in image_names:
                    image_path = os.path.join(class_path, image_name)
                    image_paths.append(image_path)

        return image_paths

    
    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        if(os.path.isfile(image_path)):
            image = Image.open(image_path).convert('RGB')
            if self.transform is not None:
                image = self.transform(image)
            return image


In [None]:
# Create the model
model = AutoEncoder(input_channels=3, output_channels=3 , linear_dim=3072)
dataset = TinyImageNetDataModule(data_dir = "/kaggle/input/tiny-imagenet/tiny-imagenet-200", batch_size=32)
trainer = pl.Trainer(max_epochs=3 , accelerator='gpu', devices=-1)
trainer.fit(model, dataset)

In [None]:
dataset.setup(stage="test")
trainer.test(datamodule=dataset)