# Notebook 1: Data Preprocessing & Loading

## Imports

In [None]:

import os
import torch
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from PIL import Image
import matplotlib.pyplot as plt
from collections import Counter
import numpy as np


## Configuration & Paths

In [None]:

data_root = r"DATA/DATASET"

train_dir = os.path.join(data_root, 'train')
test_dir  = os.path.join(data_root, 'test')


## Data Cleaning & Validation

In [None]:

def clean_and_count(root_dir, split_name):
    valid_images = []
    class_counts = {}

    for root, dirs, files in os.walk(root_dir):
        for file in files:
            if file.lower().endswith(('jpg', 'jpeg', 'png')):
                img_path = os.path.join(root, file)
                try:
                    img = Image.open(img_path)
                    img.verify()
                    img.close()

                    valid_images.append(img_path)
                    class_name = os.path.basename(root)
                    class_counts[class_name] = class_counts.get(class_name, 0) + 1

                except Exception as e:
                    print(f'Corrupt image skipped: {img_path} | {e}')

    print(f'--- {split_name} Dataset Summary ---')
    print(f'Total valid images: {len(valid_images)}')
    print(f'Class distribution: {class_counts}\n')

    return valid_images, class_counts


## Run Cleaning

In [None]:

train_valid, train_counts = clean_and_count(train_dir, 'Train')
test_valid, test_counts   = clean_and_count(test_dir, 'Test')


## Data Transformations

In [None]:

train_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.1, contrast=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (0.5, 0.5, 0.5))
])

test_transform = transforms.Compose([
    transforms.Resize((128, 128)),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5),
                         (0.5, 0.5, 0.5))
])


## Dataset & DataLoader

In [None]:

train_dataset = ImageFolder(root=train_dir, transform=train_transform)
test_dataset  = ImageFolder(root=test_dir,  transform=test_transform)

batch_size = 32

train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    num_workers=4
)

test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    num_workers=4
)

print(f'Classes: {train_dataset.classes}')
print(f'Number of classes: {len(train_dataset.classes)}')


## Visualization

In [None]:

def show_sample(dataset, n=5):
    plt.figure(figsize=(15, 5))
    for i in range(n):
        img, label = dataset[i]
        img = img.permute(1, 2, 0) * 0.5 + 0.5
        plt.subplot(1, n, i + 1)
        plt.imshow(img)
        plt.title(dataset.classes[label])
        plt.axis('off')
    plt.show()

show_sample(train_dataset)
