In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import os
import numpy as np
import cv2

from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from sklearn.model_selection import train_test_split
from collections import Counter

# Configuration
IMG_SIZE = 64  # Image size for resizing (64x64 pixels)
DATA_DIR = "./archive/simpsons_dataset"  # Path to the dataset directory
MAX_IMAGES_PER_CLASS = 500  # Limit images per class to manage memory
BATCH_SIZE = 32
NUM_EPOCHS = 20
INITIAL_LR = 0.001  # Initial learning rate
LR_STEP_SIZE = 5  # Reduce learning rate every 5 epochs
LR_GAMMA = 0.1  # Multiply learning rate by 0.5
VALID_EXTENSIONS = {'.jpg', '.jpeg', '.png'}  # Valid image file extensions

# 1. Define data augmentation and preprocessing
train_transforms = transforms.Compose([
    transforms.ToPILImage(),  # Convert numpy array to PIL Image
    transforms.Resize((IMG_SIZE, IMG_SIZE)),  # Resize to 64x64
    transforms.RandomHorizontalFlip(p=0.5),  # 50% chance of horizontal flip
    transforms.RandomRotation(degrees=15),  # Random rotation up to 15 degrees
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # Color adjustments
    transforms.ToTensor(),  # Convert to tensor and reorder to [C, H, W]
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

test_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# 2. Custom Dataset class to handle images and labels with augmentation
class SimpsonsDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images  # Numpy array of images
        self.labels = labels  # Numpy array of labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# 3. Load and preprocess the data
# Get all character folders in the dataset directory
label_names = sorted([name for name in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR, name))])
label_map = {name: idx for idx, name in enumerate(label_names)}

# Restrict to top 10 characters with the highest number of images
image_counts = {label: len([f for f in os.listdir(os.path.join(DATA_DIR, label))
                           if os.path.splitext(f)[1].lower() in VALID_EXTENSIONS]) for label in label_names}
top_characters = [c for c, _ in Counter(image_counts).most_common(10)]

# Estimate total images to pre-allocate array
total_images = sum(min(image_counts[label], MAX_IMAGES_PER_CLASS) for label in top_characters)
images = []
labels = []

# Load images
for label in top_characters:
    folder_path = os.path.join(DATA_DIR, label)
    img_count = 0
    for img_name in os.listdir(folder_path):
        if img_count >= MAX_IMAGES_PER_CLASS:
            break
        # Check if file has a valid image extension
        if os.path.splitext(img_name)[1].lower() not in VALID_EXTENSIONS:
            print(f"Warning: Skipping non-image file {os.path.join(folder_path, img_name)}")
            continue
        img_path = os.path.join(folder_path, img_name)
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)  # Force RGB loading
        if img is None:
            print(f"Warning: Failed to load image {img_path}")
            continue
        # Convert BGR to RGB
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # Resize image
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)
        # Verify image shape
        if img.shape != (IMG_SIZE, IMG_SIZE, 3):
            print(f"Warning: Image {img_path} has invalid shape {img.shape}, expected ({IMG_SIZE}, {IMG_SIZE}, 3)")
            continue
        images.append(img)
        labels.append(label_map[label])
        img_count += 1

# Convert to numpy arrays
try:
    images = np.array(images, dtype=np.uint8)  # Keep as uint8 for PIL compatibility
    labels = np.array(labels, dtype=np.int64)
except ValueError as e:
    print(f"Error converting to NumPy array: {e}")
    print("Shapes of first few images:")
    for i, img in enumerate(images[:5]):
        print(f"Image {i}: shape {np.array(img).shape}")
    print("Total images collected:", len(images))
    raise

# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(
    images, labels, test_size=0.2, stratify=labels, random_state=42
)

# Create datasets with appropriate transforms
train_dataset = SimpsonsDataset(X_train, y_train, transform=train_transforms)
test_dataset = SimpsonsDataset(X_test, y_test, transform=test_transforms)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Print summary
print(f"Loaded {len(images)} images across {len(top_characters)} characters")
print(f"Training set: {len(X_train)} images, Test set: {len(X_test)} images")