### 0: IMPORTING LIBRARIES AND SETTING THE SEEDS

In [None]:
# Importing necessary libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms
from torchvision.datasets import ImageFolder
from pathlib import Path
import pytorch_lightning as pl
from typing import Tuple
import PIL
from PIL import Image
from pytorch_lightning.callbacks.progress import TQDMProgressBar
import csv
from torchmetrics.functional import accuracy
import numpy as np
import torchvision
from torchvision.transforms import ToPILImage
import matplotlib.pyplot as plt
import numpy as np
from torchvision import transforms
import cv2
from torchvision import datasets
from torchmetrics.classification import Accuracy, MulticlassF1Score, MulticlassPrecision, MulticlassRecall
from sklearn.metrics import classification_report
import matplotlib.pyplot as plt
import seaborn as sns
from torch.utils.data import DataLoader
from torch.utils.data import DataLoader, random_split
from pytorch_lightning.loggers import TensorBoardLogger, CSVLogger
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.utils.class_weight import compute_class_weight
from torch.nn import CrossEntropyLoss
from pytorch_lightning.callbacks import TQDMProgressBar, LearningRateMonitor, ModelCheckpoint
import math

In [None]:
# Fixing random state for reproducibility using NumPy
rng = np.random.RandomState(31)

# Setting the seed for NumPy to ensure reproducibility
np.random.seed(31)

# Setting the seed for PyTorch Lightning to ensure reproducibility
pl.seed_everything(31)

#torch.manual_seed(31)
#torch.cuda.manual_seed(31)
#torch.backends.cudnn.deterministic = True

### 1: DATA INSPECTION

#### 1.1: CREATION OF THE LABEL DICTIONARY

In [None]:
mapping_dict = {}

# Open the file in read mode.
with open('/kaggle/input/tiny-imagenet/tiny-imagenet-200/words.txt', 'r') as file:
    # Read each line from the file.
    for line in file:
        # Split the line into tokens based on whitespace.
        tokens = line.strip().split('\t')
        
        # Check if there are at least two tokens.
        if len(tokens) >= 2:
            # Extract the encoded label (left) and actual label (right).
            encoded_label, actual_label = tokens[0], tokens[1]
            
            # Add the mapping to the dictionary.
            mapping_dict[encoded_label] = actual_label

# Print the mapping dictionary.
#print(mapping_dict)


#### 1.2: DISPLAYING EXAMPLES OF THE DATASET

In [None]:
# Loading the dataset using ImageFolder.
dataset0 = datasets.ImageFolder(root="/kaggle/input/tiny-imagenet/tiny-imagenet-200/train/", transform=None)

# Extract class names and their counts.
class_names = dataset0.classes
class_counts = [dataset0.targets.count(i) for i in range(len(class_names))]
np.random.seed(31)

# Create a grid of 10 images with labels.
plt.figure(figsize=(15, 8))
for i in range(10):
    
    # Randomly select an image and its corresponding label.
    index = np.random.randint(len(dataset0))
    image, label = dataset0[index]

    # Display the image with its label
    plt.subplot(2, 5, i+1)
    plt.imshow(np.array(image))  # Convert the PIL Image to a numpy array
    plt.title(f"Label: {class_names[label]}")
    plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:
# Assuming you have already loaded the dataset using ImageFolder
dataset0 = datasets.ImageFolder(root="/kaggle/input/tiny-imagenet/tiny-imagenet-200/train/", transform=None)

# Extract class names and their counts
class_names = dataset0.classes
class_counts = [dataset0.targets.count(i) for i in range(len(class_names))]

np.random.seed(31)
# Create a grid of 10 images with labels
plt.figure(figsize=(15, 8))
for i in range(10):
    # Randomly select an image and its corresponding label
    index = np.random.randint(len(dataset0))
    image, encoded_label = dataset0[index]
    # Look up the actual label using the mapping dictionary
    actual_label = mapping_dict.get(class_names[encoded_label], "Unknown Label")
    
    # Trim the label if it exceeds the maximum length.
    actual_label_trimmed = actual_label[:15] + '...' if len(actual_label) > 15 else actual_label

    # Display the image with its label.
    plt.subplot(2, 5, i+1)
    plt.imshow(np.array(image))  # Convert the PIL Image to a numpy array
    plt.title(f"Label: {actual_label_trimmed}", wrap=True)
    plt.axis('off')

plt.tight_layout()
plt.show()

In [None]:
class CustomValidationDataset(Dataset):
    def __init__(self, root, transform=None):
        
        self.root = Path(root)
        self.transform = transform
        self.image_paths = sorted(list(self.root.glob("val_*.JPEG")))
        self.labels = self.load_labels()

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        label = self.labels[image_path.stem]

        return image, label

    def load_labels(self):
        
        label_path = "/kaggle/input/tiny-imagenet/tiny-imagenet-200/val/val_annotations.txt"
        labels = {}

        with open(label_path, "r") as f:
            lines = f.readlines()

        for line in lines:
            parts = line.split("\t")
            image_name, label = parts[0], parts[1]
            labels[image_name] = label

        return labels


In [None]:
class AViT_DataModule(pl.LightningDataModule):
    def __init__(self, train_data_dir, val_data_dir, batch_size, num_workers=4):
        super(HW2_DataModule, self).__init__()
        self.train_data_dir = train_data_dir
        self.val_data_dir = val_data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers

        self.transform = transforms.Compose([
            transforms.ToTensor()
        ])

    def setup(self, stage=None):
        # Load Train dataset.
        self.train_dataset = ImageFolder(self.train_data_dir, transform=self.transform)

        # Load Validation dataset.
        self.val_dataset = CustomValidationDataset(root=self.val_data_dir, transform=self.transform)

    def train_dataloader(self):
        # Return the DataLoader for the training dataset
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle=True)

    def val_dataloader(self):
        # Return the DataLoader for the validation dataset
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)


class HW2_DataModule(pl.LightningDataModule): 
    def __init__(self, train_data_dir, test_data_dir, batch_size, num_workers=4, val_split=0.2):
        
        super(HW2_DataModule, self).__init__()
        self.train_data_dir = train_data_dir
        self.test_data_dir = test_data_dir
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.val_split = val_split
        
        self.transform = transforms.Compose([
            #transforms.Lambda(lambda x: x.crop((12, 12, x.width - 12, x.height - 12))),
            transforms.ToTensor()
        ])

    def setup(self, stage=None):
        # Load Train dataset.
        self.train_dataset = ImageFolder(self.train_data_dir, transform=self.transform)

        # Load the test dataset with preprocessing
        #self.test_dataset = ImageFolder(self.test_data_dir, transform=self.transform)

    def train_dataloader(self):
        
        # Return the DataLoader for the training dataset
        return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=self.num_workers, shuffle = True)

    def val_dataloader(self):
        
        # Return the DataLoader for the validation dataset
        return DataLoader(self.val_dataset, batch_size=self.batch_size, num_workers=self.num_workers)
    
    def test_dataloader(self):
        # Return the DataLoader for the test dataset
        return DataLoader(self.test_dataset, batch_size=self.batch_size, num_workers=self.num_workers)

    