# Libraries

In [38]:
import sys

sys.path.append("../")

from tqdm.notebook import tqdm

import numpy as np

from src.model import VisionModel
from src.noises import add_noise

import matplotlib.pyplot as plt

import cv2
from PIL import Image

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision.datasets import MNIST

# 1. Transform function

In [43]:
class AddNoise(torch.nn.Module):
    def forward(self, img):
        img = np.array(img)
        noisy_image = add_noise(img, noise_type='exponential', scale=100,sigma=150)
        noisy_image = add_noise(noisy_image, noise_type='gaussian', scale=100,sigma=150)

        return Image.fromarray(noisy_image)

In [44]:
transform = transforms.Compose([
    transforms.Resize((28, 28)),       # Resize the image to 256x256 pixels
    transforms.ToTensor(),            # Convert the image to a PyTorch tensor
    #transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Normalize the image
    transforms.Normalize(mean=[0.485], std=[0.229])  # Normalize the image
])

# 2. Load Train and Test Dataset

## 2.1 MNIST Dataset

In [74]:
class MNISTCustomDataset(MNIST):
    def __init__(self,root_dir,train,transform_function,noise_function):
        super().__init__(root_dir,download=True,train=train)
        self.transform_function = transform_function
        self.noise_function = noise_function
    def __getitem__(self, index: int):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (image, target) where target is index of the target class.
        """
        img, target = self.data[index], int(self.targets[index])

        # doing this so that it is consistent with all other datasets
        # to return a PIL Image
        original_img = Image.fromarray(img.numpy(), mode="L")
        noise_img = original_img.copy()

        original_img = self.transform_function(original_img)
        
        noise_img = self.noise_function(noise_img)
        noise_img = self.transform_function(noise_img)

        return original_img, noise_img

In [82]:
mnist_dataset_train = MNISTCustomDataset("./",True,transform,AddNoise())
mnist_dataset_test = MNISTCustomDataset("./",False,transform,AddNoise())

In [83]:
total = mnist_dataset_train.__len__() + mnist_dataset_test.__len__()

print(f"Train: {round(mnist_dataset_train.__len__()/total,2)}\nTest: {round(mnist_dataset_test.__len__()/total,2)}")

Train: 0.86
Test: 0.14


## 2.2 MNIST Dataloader

In [113]:
batch_size = 16
num_workers = 1

mnist_dataloader_train = DataLoader(
    mnist_dataset_train,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=True)

mnist_dataloader_test = DataLoader(
    mnist_dataset_test,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=False)

# 3. Vision Transform Model

## 3.1 Model

In [114]:
model = VisionModel(
    img_size=(batch_size,1,28,28),
    patch_size=4,
    token_len=512)

model = model.cuda()

## 3.2 Loss Function

In [115]:
loss_fn = nn.MSELoss()

## 3.3 Optimizer

In [116]:
lr = 0.001
optimizer = torch.optim.Adam(model.parameters(),lr=lr)

# 3. Train model

In [121]:
epochs = 64

model = model.train()

for epoch in tqdm(range(1,epochs+1)):
    total_loss = 0
    for ori_img, noi_img in (pbar := tqdm(mnist_dataloader_train)):
        noi_img = noi_img.cuda()
        denoised_img = model(noi_img)
        denoised_img = denoised_img.cpu()
        
        loss = loss_fn(denoised_img, ori_img)

        optimizer.zero_grad()    
        loss.backward()
        optimizer.step()

        total_loss += loss.item()/batch_size
        pbar.set_description(f"Loss: {total_loss}")
    print(total_loss)
    

  0%|          | 0/64 [00:00<?, ?it/s]

  0%|          | 0/3750 [00:00<?, ?it/s]

KeyboardInterrupt: 