In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import torch.backends.cudnn as cudnn
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
import matplotlib.pyplot as plt
import time
import os
from PIL import Image
from tempfile import TemporaryDirectory
from torch.utils.data import DataLoader

### Data Loading

In [None]:
data_transforms = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomRotation(degrees=15),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'test': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])
}

data_dir = 'Dataset'

# Load the full training dataset
full_train_dataset = datasets.ImageFolder(os.path.join(data_dir, 'Training'))

# Split training data into train and validation (80% train, 20% val)
generator = torch.Generator().manual_seed(42)
train_size = int(0.8 * len(full_train_dataset))
val_size = len(full_train_dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(full_train_dataset, [train_size, val_size], generator=generator)

# Create a copy of the validation dataset with val transforms
val_dataset = torch.utils.data.Subset(
    datasets.ImageFolder(os.path.join(data_dir, 'Training'), transform=data_transforms['val']),
    val_dataset.indices
)

# Apply different transforms to train and val sets
# For train set: keep the original transforms with augmentations
# For val set: apply validation transforms (no augmentations)
train_dataset.dataset.transform = data_transforms['train']

image_datasets = {
    'train': train_dataset,
    'val': val_dataset,
    'test': datasets.ImageFolder(os.path.join(data_dir, 'Testing'), transform=data_transforms['test'])
}

print(full_train_dataset.classes)
print('training images:', len(image_datasets['train']))
print('validation images:', len(image_datasets['val']))
print('testing images:', len(image_datasets['test']))


dataloaders = {
    'train': DataLoader(image_datasets['train'], batch_size=32, shuffle=True, num_workers=0),
    'val': DataLoader(image_datasets['val'], batch_size=32, shuffle=False, num_workers=0),
    'test': DataLoader(image_datasets['test'], batch_size=32, shuffle=False, num_workers=0)
}

### Show Sample Images

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Helper to unnormalize images (since you applied ImageNet normalization)
def imshow(inp, title=None):
    inp = inp.numpy().transpose((1, 2, 0))   # C x H x W -> H x W x C
    mean = np.array([0.485, 0.456, 0.406])
    std  = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean   # unnormalize
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.axis("off")

# Show a batch of images with augmentation
def show_random_batch(dataloader, class_names, num_images=6):
    inputs, classes = next(iter(dataloader))
    out = torchvision.utils.make_grid(inputs[:num_images])  # take first N images
    imshow(out, title=[class_names[x] for x in classes[:num_images]])

# Example usage:
show_random_batch(dataloaders['train'], image_datasets['train'].dataset.classes, num_images=5)