In [None]:
import os
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"  # Make CUDA errors synchronous
os.environ["TORCH_USE_CUDA_DSA"] = "1"    # Enable device-side assertions

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
import cv2
from sklearn.model_selection import train_test_split
from collections import Counter
from torchvision import transforms

# Configuration
IMG_SIZE = 64  # Image size for resizing (64x64 pixels)
DATA_DIR = "./archive/simpsons_dataset"  # Path to the dataset directory
MAX_IMAGES_PER_CLASS = 500  # Limit images per class to manage memory
BATCH_SIZE = 16  # Reduced to avoid memory issues
NUM_EPOCHS = 20
INITIAL_LR = 0.001  # Fixed learning rate
VALID_EXTENSIONS = {'.jpg', '.jpeg', '.png'}  # Valid image file extensions

# 1. Define data augmentation and preprocessing
train_transforms = transforms.Compose([
    transforms.ToPILImage(),  # Convert numpy array to PIL Image
    transforms.Resize((IMG_SIZE, IMG_SIZE)),  # Resize to 64x64
    transforms.RandomHorizontalFlip(p=0.5),  # 50% chance of horizontal flip
    transforms.RandomRotation(degrees=15),  # Random rotation up to 15 degrees
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),  # Color adjustments
    transforms.ToTensor(),  # Convert to tensor and reorder to [C, H, W]
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # Normalize to [-1, 1]
])

test_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# 2. Custom Dataset class to handle images and labels with augmentation
class SimpsonsDataset(Dataset):
    def __init__(self, images, labels, transform=None):
        self.images = images  # Numpy array of images
        self.labels = labels  # Numpy array of labels
        self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        if self.transform:
            image = self.transform(image)
        return image, label

# 3. Load and preprocess the data
# Get all character folders in the dataset directory
label_names = sorted([name for name in os.listdir(DATA_DIR) if os.path.isdir(os.path.join(DATA_DIR, name))])

# Restrict to top 10 characters with the highest number of images
image_counts = {label: len([f for f in os.listdir(os.path.join(DATA_DIR, label))
                           if os.path.splitext(f)[1].lower() in VALID_EXTENSIONS]) for label in label_names}
top_characters = [c for c, _ in Counter(image_counts).most_common(10)]

# Create label map based only on top_characters
label_map = {name: idx for idx, name in enumerate(top_characters)}

# Initialize lists for images and labels
images = []
labels = []

# Load images
for label in top_characters:
    folder_path = os.path.join(DATA_DIR, label)
    img_count = 0
    for img_name in os.listdir(folder_path):
        if img_count >= MAX_IMAGES_PER_CLASS:
            break
        # Check if file has a valid image extension
        if os.path.splitext(img_name)[1].lower() not in VALID_EXTENSIONS:
            print(f"Warning: Skipping non-image file {os.path.join(folder_path, img_name)}")
            continue
        img_path = os.path.join(folder_path, img_name)
        img = cv2.imread(img_path, cv2.IMREAD_COLOR)  # Force RGB loading
        if img is None:
            print(f"Warning: Failed to load image {img_path}")
            continue
        # Convert BGR to RGB
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        # Handle grayscale images
        if img.ndim == 2:
            img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
        # Resize image
        img = cv2.resize(img, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)
        # Verify image shape
        if img.shape != (IMG_SIZE, IMG_SIZE, 3):
            print(f"Warning: Image {img_path} has invalid shape {img.shape}, expected ({IMG_SIZE}, {IMG_SIZE}, 3)")
            continue
        images.append(img)
        labels.append(label_map[label])
        img_count += 1

# Convert to numpy arrays
try:
    images = np.array(images, dtype=np.uint8)  # Keep as uint8 for PIL compatibility
    labels = np.array(labels, dtype=np.int64)
except ValueError as e:
    print(f"Error converting to NumPy array: {e}")
    print("Shapes of first few images:")
    for i, img in enumerate(images[:5]):
        print(f"Image {i}: shape {np.array(img).shape if isinstance(img, np.ndarray) else 'Not an array'}")
    print("Total images collected:", len(images))
    raise

# Validate labels
num_classes = len(top_characters)
if labels.max() >= num_classes or labels.min() < 0:
    print(f"Error: Labels contain invalid indices. Max label: {labels.max()}, Min label: {labels.min()}, Expected range: [0, {num_classes-1}]")
    raise ValueError("Invalid label indices detected")

# Verify array shape
print(f"Images array shape: {images.shape}, Labels array shape: {labels.shape}")

# Split the data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(
    images, labels, test_size=0.2, stratify=labels, random_state=42
)

# Create datasets with appropriate transforms
train_dataset = SimpsonsDataset(X_train, y_train, transform=train_transforms)
test_dataset = SimpsonsDataset(X_test, y_test, transform=test_transforms)

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

# Print summary
print(f"Loaded {len(images)} images across {len(top_characters)} characters")
print(f"Training set: {len(X_train)} images, Test set: {len(X_test)} images")

# 4. Define the CNN model
class SimpsonsCNN(nn.Module):
    def __init__(self, num_classes):
        super(SimpsonsCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128 * 8 * 8, 512)  # For 64x64 input images
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(512, num_classes)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.view(-1, 128 * 8 * 8)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

# 5. Initialize model, loss function, and optimizer with robust device handling
# Debug device availability
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA version: {torch.version.cuda}")
    try:
        print(f"GPU device: {torch.cuda.get_device_name(0)}")
    except Exception as e:
        print(f"Error getting GPU device name: {e}")
print(f"MPS available: {torch.backends.mps.is_available()}")

# Select device with fallback
device = torch.device("cpu")  # Default to CPU
if torch.cuda.is_available():
    try:
        torch.cuda.init()
        # Test CUDA with a small tensor operation
        test_tensor = torch.ones(1, device="cuda")
        device = torch.device("cuda")
        print("CUDA initialized successfully, using GPU.")
    except Exception as e:
        print(f"CUDA initialization failed: {e}. Falling back to CPU.")
else:
    print("No GPU available or CUDA initialization failed, using CPU.")

# Initialize model
try:
    model = SimpsonsCNN(num_classes=len(top_characters)).to(device)
    # Test model with a small input to ensure device compatibility
    test_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE, device=device)
    model(test_input)
    print(f"Model successfully moved to {device} and tested.")
except Exception as e:
    print(f"Error moving model to device {device}: {e}")
    print("Falling back to CPU.")
    device = torch.device("cpu")
    model = SimpsonsCNN(num_classes=len(top_characters)).to(device)
    test_input = torch.randn(1, 3, IMG_SIZE, IMG_SIZE, device=device)
    model(test_input)
    print("Model successfully moved to CPU and tested.")

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=INITIAL_LR)
print(f"Using {device} device for training")

# 6. Training loop with tensor device checks and gradient clipping
for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = 0.0
    for batch_idx, (inputs, labels) in enumerate(train_loader):
        try:
            # Ensure all tensors are on the same device
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            # Verify tensor devices and label values
            if inputs.device != device or labels.device != device:
                raise RuntimeError(f"Tensor device mismatch: inputs on {inputs.device}, labels on {labels.device}, expected {device}")
            if labels.max() >= num_classes or labels.min() < 0:
                raise RuntimeError(f"Invalid label values in batch: max {labels.max()}, min {labels.min()}, expected [0, {num_classes-1}]")

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)

            # Debug tensor devices and values before backward pass
            if loss.device != device:
                raise RuntimeError(f"Loss on incorrect device: {loss.device}, expected {device}")
            if torch.isnan(loss) or torch.isinf(loss):
                raise RuntimeError(f"Invalid loss value: {loss.item()}")

            # Clip gradients to prevent explosions
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        except Exception as e:
            print(f"Error in batch {batch_idx+1}, epoch {epoch+1}: {e}")
            print(f"Inputs device: {inputs.device}, Labels device: {labels.device}, Model device: {next(model.parameters()).device}")
            print(f"Input shape: {inputs.shape}, Label shape: {labels.shape}, Label values: {labels.tolist()}")
            print("Falling back to CPU for this batch.")
            # Move model and tensors to CPU for this batch
            model.to("cpu")
            inputs = inputs.to("cpu")
            labels = labels.to("cpu")
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            # Move model back to original device
            model.to(device)

    print(f"Epoch {epoch+1}/{NUM_EPOCHS}, Loss: {running_loss/len(train_loader):.4f}, Learning Rate: {INITIAL_LR:.6f}")

    # 7. Evaluate on test set
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    print(f"Test Accuracy: {100 * correct / total:.2f}%")

# 8. Save the trained model
torch.save(model.state_dict(), "simpsons_cnn.pth")
print("Model saved as simpsons_cnn.pth")