## Summary

### Version 1: Baseline
- Basic data augmentation (crop, flip, perspective, rotation, color jitter)
- Established baseline performance

### Version 2: Enhanced Augmentation
- Stronger perspective warping (0.5 distortion)
- More aggressive color jitter
- Goal: Help distinguish similar characters (Geralt/Vesemir)
- Result: Too aggressive, hurt training stability

### Version 3: Balanced Approach
- Sweet spot between V1 and V2
- Moderate augmentation (0.45 perspective, reduced color jitter)
- Removed random grayscale (hurt performance)
- Added StepLR scheduler (drops LR at epochs 10, 20)
- Result: Better convergence

### Version 4: Dataset Improvement #1
- **Key Change**: Enhanced 'Other' class with more white-haired characters
- Same transforms as V3
- Reason: Confusion matrix showed 'Other' frequently misclassified as Vesemir/Geralt
- Result: Improved 'Other' class accuracy

### Version 5: Adaptive Learning Rate
- **Key Change**: Switched to ReduceLROnPlateau scheduler
- Further dataset refinement
- LR reduces when validation loss plateaus (patience=3)
- Result: More intelligent fine-tuning

### Version 6: Final Refinement
- **Key Changes**: Additional dataset improvements + 30 epochs
- Same architecture and scheduler as V5
- **Final Performance**: ~85% validation accuracy
- Best per-class results:
  - Yennefer: 95%
  - Ciri: 87%
  - Triss: 86%
  - Vesemir: 86%
  - Other: 85%
  - Geralt: 75%

In [None]:
import os
import splitfolders

# Update this to your actual desktop path
base_path = '/path/to/your/project'

input_folder = os.path.join(base_path, "dataset")
output_folder = os.path.join(base_path, "split_data")

# Check if input folder exists
if os.path.exists(input_folder):
    print(f"Found input folder at: {input_folder}")
    # Let's also see what classes we have
    classes = [d for d in os.listdir(input_folder) if os.path.isdir(os.path.join(input_folder, d))]
    print(f"Classes found: {classes}")
    
    # Show image counts per class
    for cls in classes:
        count = len(os.listdir(os.path.join(input_folder, cls)))
        print(f"  {cls}: {count} images")
else:
    print("Error: Could not find the input folder!")

In [None]:
# Split with 80/20 ratio
# seed=42 ensures reproducibility
splitfolders.ratio(input_folder, output=output_folder, seed=42, ratio=(.8, .2))

print("--- Split Complete ---")
print(f"Check: {output_folder} should now contain 'train' and 'val' folders.")

# Let's verify the split worked correctly
print("\n--- Verifying Split ---")
for split in ['train', 'val']:
    split_path = os.path.join(output_folder, split)
    if os.path.exists(split_path):
        print(f"\n{split.upper()} set:")
        classes = [d for d in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, d))]
        for cls in classes:
            count = len(os.listdir(os.path.join(split_path, cls)))
            print(f"  {cls}: {count} images")

In [None]:
import torch
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomPerspective(distortion_scale=0.4, p=0.5),
        transforms.RandomRotation(15),
        transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Point to the split folders
data_dir = os.path.join(base_path, "split_data")

# Load datasets
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x])
                  for x in ['train', 'val']}

# Create dataloaders (batch_size=16 like the binary model)
dataloaders = {x: DataLoader(image_datasets[x], batch_size=16, shuffle=True)
               for x in ['train', 'val']}

# Store class names and dataset sizes
class_names = image_datasets['train'].classes
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}

print(f"Classes found: {class_names}")
print(f"Training images: {dataset_sizes['train']}")
print(f"Validation images: {dataset_sizes['val']}")

In [None]:
import torch
import numpy as np

# Get the actual counts per class in the training set
train_counts = []
for class_name in class_names:
    class_folder = os.path.join(data_dir, 'train', class_name)
    count = len(os.listdir(class_folder))
    train_counts.append(count)
    print(f"{class_name}: {count} training images")

# Calculate weights inversely proportional to class frequency
# Classes with fewer samples get higher weight
train_counts = np.array(train_counts)
class_weights = 1.0 / train_counts
class_weights = class_weights / class_weights.sum() * len(class_names)  # Normalize

# Convert to tensor
class_weights = torch.FloatTensor(class_weights)

print("\n--- Class Weights ---")
for i, class_name in enumerate(class_names):
    print(f"{class_name}: {class_weights[i]:.4f}")

print("\nThese weights will be applied to the loss function.")
print("Higher weight = model penalized more for getting that class wrong")

## Version 1: Baseline Model
ResNet18 with all layers unfrozen, class weights, and basic augmentation

In [None]:
import torch.nn as nn
import torch.optim as optim

# Load pretrained ResNet18
model = models.resnet18(weights="DEFAULT")

# UNFREEZE all layers (like V4 of the binary model)
for param in model.parameters():
    param.requires_grad = True

# Replace final layer
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 6)  # 6 classes now!

# Move to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)

print(f"Model loaded on: {device}")
print(f"Final layer output: {model.fc.out_features} classes")

In [None]:
# Loss function WITH class weights
criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))

# Optimizer: Low learning rate since we're fine-tuning the entire unfrozen network
# Using 1e-5 like V4
optimizer = optim.Adam(model.parameters(), lr=0.00001)

print("Loss function: CrossEntropyLoss with class weights")
print(f"Optimizer: Adam with lr=1e-5")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

In [None]:
import time
import copy

def train_model(model, criterion, optimizer, num_epochs=15):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data
            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Zero the parameter gradients
                optimizer.zero_grad()

                # Forward pass
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # Save the best model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [None]:
model_multiclass_v1 = train_model(model, criterion, optimizer, num_epochs=20)

In [None]:
# Save the trained model
models_dir = os.path.join(base_path, "models")
os.makedirs(models_dir, exist_ok=True)

model_path = os.path.join(models_dir, "witcher_multiclass_v1.pth")
torch.save(model_multiclass_v1.state_dict(), model_path)

print(f"Model saved to: {model_path}")

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns

def evaluate_model(model, dataloader, class_names):
    """
    Evaluate model and return predictions and true labels
    """
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return all_preds, all_labels

# Run evaluation on validation set
print("Evaluating on validation set...")
val_preds, val_labels = evaluate_model(model_multiclass_v1, dataloaders['val'], class_names)

# Calculate overall accuracy
from sklearn.metrics import accuracy_score
val_accuracy = accuracy_score(val_labels, val_preds)
print(f"\nOverall Validation Accuracy: {val_accuracy:.4f} ({val_accuracy*100:.2f}%)")

In [None]:
# 1. CONFUSION MATRIX
print("\n" + "="*60)
print("CONFUSION MATRIX")
print("="*60)

cm = confusion_matrix(val_labels, val_preds)

# Plot confusion matrix as heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names,
            cbar_kws={'label': 'Count'})
plt.title('Confusion Matrix - Witcher Character Classifier', fontsize=14, pad=20)
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.tight_layout()
plt.show()

# 2. PER-CLASS ACCURACY
print("\n" + "="*60)
print("PER-CLASS ACCURACY")
print("="*60)

for i, class_name in enumerate(class_names):
    class_correct = cm[i, i]
    class_total = cm[i, :].sum()
    class_acc = class_correct / class_total if class_total > 0 else 0
    print(f"{class_name:12s}: {class_correct:2d}/{class_total:2d} = {class_acc:.2%}")

# 3. DETAILED CLASSIFICATION REPORT
print("\n" + "="*60)
print("CLASSIFICATION REPORT (Precision, Recall, F1)")
print("="*60)
print(classification_report(val_labels, val_preds, target_names=class_names, digits=4))

## Version 2: Enhanced Augmentation
Stronger perspective warping and aggressive color jitter to help distinguish similar characters

In [None]:
# V2: ENHANCED TRANSFORMS
# Goal: Help distinguish similar characters (like Geralt and Vesemir)

data_transforms_v2 = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
        transforms.RandomHorizontalFlip(),
        
        # Stronger perspective warping
        transforms.RandomPerspective(distortion_scale=0.5, p=0.6),
        
        # More aggressive color jittering to force focus on facial features
        transforms.ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.1),
        
        # Random grayscale: forces model to learn structure, not just hair color
        transforms.RandomGrayscale(p=0.15),
        
        # Rotation and blur (keep from V1)
        transforms.RandomRotation(20),
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=5)], p=0.2),
        
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Reload datasets with new transforms
image_datasets_v2 = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms_v2[x])
                     for x in ['train', 'val']}

dataloaders_v2 = {x: DataLoader(image_datasets_v2[x], batch_size=16, shuffle=True)
                  for x in ['train', 'val']}

print("V2 Transforms loaded!")
print("Key changes:")
print("  - Stronger perspective warping (0.5 distortion)")
print("  - More color jittering (brightness/contrast/saturation 0.3)")
print("  - Random grayscale (15% chance)")
print("  - Increased rotation (20 degrees)")

In [None]:
# V2: Fresh model with learning rate scheduler
model_v2 = models.resnet18(weights="DEFAULT")

# Unfreeze all layers
for param in model_v2.parameters():
    param.requires_grad = True

num_ftrs = model_v2.fc.in_features
model_v2.fc = nn.Linear(num_ftrs, 6)
model_v2 = model_v2.to(device)

# Loss with class weights (same as before)
criterion_v2 = nn.CrossEntropyLoss(weight=class_weights.to(device))

# Optimizer
optimizer_v2 = optim.Adam(model_v2.parameters(), lr=0.00001)

# LEARNING RATE SCHEDULER
# Reduces LR by factor of 0.5 if val loss doesn't improve for 3 epochs
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_v2, 
    mode='min',           # Monitor validation loss (minimize it)
    factor=0.5,           # Reduce LR by half
    patience=3,           # Wait 3 epochs before reducing
    verbose=True,         # Print when LR changes
    min_lr=1e-7           # Don't go below this
)

print("Model V2 initialized!")
print(f"Initial learning rate: {optimizer_v2.param_groups[0]['lr']}")
print("Scheduler: ReduceLROnPlateau (patience=3, factor=0.5)")

In [None]:
def train_model_v2(model, criterion, optimizer, scheduler, num_epochs=30):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    # Track metrics for analysis
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data
            for inputs, labels in dataloaders_v2[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                # Forward pass
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(image_datasets_v2[phase])
            epoch_acc = running_corrects.double() / len(image_datasets_v2[phase])

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            # Store metrics
            if phase == 'train':
                history['train_loss'].append(epoch_loss)
                history['train_acc'].append(epoch_acc.item())
            else:
                history['val_loss'].append(epoch_loss)
                history['val_acc'].append(epoch_acc.item())

            # Save best model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        
        # SCHEDULER STEP: Adjust learning rate based on validation loss
        scheduler.step(history['val_loss'][-1])
        
        # Print current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        print(f'Current LR: {current_lr:.2e}')
        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model, history

print("Training function V2 ready!")
print("New features:")
print("  - Tracks training history (loss & acc)")
print("  - Uses learning rate scheduler")
print("  - Prints current LR each epoch")

In [None]:
# Train V2 for 30 epochs
print("Starting V2 training with 30 epochs...")
print("Watch for LR reductions when validation stops improving!\n")

model_v2_trained, history_v2 = train_model_v2(
    model_v2, 
    criterion_v2, 
    optimizer_v2, 
    scheduler,
    num_epochs=30
)

## Version 3: Balanced Approach
Sweet spot between V1 and V2 - moderate augmentation with StepLR scheduler

In [None]:
# V3: BALANCED TRANSFORMS
# Sweet spot between V1 and V2

data_transforms_v3 = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
        transforms.RandomHorizontalFlip(),
        
        # Stronger perspective (like V2, helps with angles)
        transforms.RandomPerspective(distortion_scale=0.45, p=0.5),
        
        # Moderate color jittering (less than V2, more than V1)
        # Helps distinguish hair colors without being too extreme
        transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.2),
        
        # REMOVED: Random grayscale (it hurt more than helped)
        
        # Slightly more rotation than V1
        transforms.RandomRotation(15),
        
        # Keep blur
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=5)], p=0.2),
        
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Reload datasets with V3 transforms
image_datasets_v3 = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms_v3[x])
                     for x in ['train', 'val']}

dataloaders_v3 = {x: DataLoader(image_datasets_v3[x], batch_size=16, shuffle=True)
                  for x in ['train', 'val']}

print("V3 Transforms loaded!")
print("\nChanges from V2:")
print("  Kept: Stronger perspective (0.45 distortion)")
print("  Reduced: Color jitter (0.25 brightness/contrast, 0.2 saturation)")
print("  Removed: Random grayscale")
print("  Kept: Rotation at 15 degrees")
print("\nV3 should learn faster than V2 but be more robust than V1")

In [None]:
# V3: Fresh model 
model_v3 = models.resnet18(weights="DEFAULT")

# Unfreeze all layers
for param in model_v3.parameters():
    param.requires_grad = True

num_ftrs = model_v3.fc.in_features
model_v3.fc = nn.Linear(num_ftrs, 6)
model_v3 = model_v3.to(device)

# Loss with class weights
criterion_v3 = nn.CrossEntropyLoss(weight=class_weights.to(device))

# Optimizer - same LR as V1 and V2
optimizer_v3 = optim.Adam(model_v3.parameters(), lr=0.00001)

# StepLR scheduler instead of ReduceLROnPlateau
# Reduces LR every 10 epochs regardless of performance
scheduler_v3 = optim.lr_scheduler.StepLR(
    optimizer_v3,
    step_size=10,    # Reduce every 10 epochs
    gamma=0.5,       # Multiply LR by 0.5
    verbose=True
)

print("Model V3 initialized!")
print(f"Initial learning rate: {optimizer_v3.param_groups[0]['lr']}")
print("Scheduler: StepLR (reduces LR at epochs 10, 20)")
print("\nThis ensures LR actually decreases during training")

In [None]:
def train_model_v3(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders_v3[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(image_datasets_v3[phase])
            epoch_acc = running_corrects.double() / len(image_datasets_v3[phase])

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            if phase == 'train':
                history['train_loss'].append(epoch_loss)
                history['train_acc'].append(epoch_acc.item())
            else:
                history['val_loss'].append(epoch_loss)
                history['val_acc'].append(epoch_acc.item())

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        
        # Step the scheduler AFTER both train and val
        scheduler.step()
        
        current_lr = optimizer.param_groups[0]['lr']
        print(f'Current LR: {current_lr:.2e}')
        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')

    model.load_state_dict(best_model_wts)
    return model, history

print("Training function V3 ready!")

In [None]:
# Train V3 for 25 epochs
print("Starting V3 training with 25 epochs...")
print("LR will drop at epochs 10 and 20\n")

model_v3_trained, history_v3 = train_model_v3(
    model_v3,
    criterion_v3,
    optimizer_v3,
    scheduler_v3,
    num_epochs=25
)

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns

def evaluate_model(model, dataloader, class_names):
    """
    Evaluate model and return predictions and true labels
    """
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return all_preds, all_labels

# Run evaluation on validation set
print("Evaluating on validation set...")
val_preds, val_labels = evaluate_model(model_v3, dataloaders_v3['val'], class_names)

# Calculate overall accuracy
from sklearn.metrics import accuracy_score
val_accuracy = accuracy_score(val_labels, val_preds)
print(f"\nOverall Validation Accuracy: {val_accuracy:.4f} ({val_accuracy*100:.2f}%)")

In [None]:
# 1. CONFUSION MATRIX
print("\n" + "="*60)
print("CONFUSION MATRIX")
print("="*60)

cm = confusion_matrix(val_labels, val_preds)

# Plot confusion matrix as heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names,
            cbar_kws={'label': 'Count'})
plt.title('Confusion Matrix - Witcher Character Classifier', fontsize=14, pad=20)
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.tight_layout()
plt.show()

# 2. PER-CLASS ACCURACY
print("\n" + "="*60)
print("PER-CLASS ACCURACY")
print("="*60)

for i, class_name in enumerate(class_names):
    class_correct = cm[i, i]
    class_total = cm[i, :].sum()
    class_acc = class_correct / class_total if class_total > 0 else 0
    print(f"{class_name:12s}: {class_correct:2d}/{class_total:2d} = {class_acc:.2%}")

# 3. DETAILED CLASSIFICATION REPORT
print("\n" + "="*60)
print("CLASSIFICATION REPORT (Precision, Recall, F1)")
print("="*60)
print(classification_report(val_labels, val_preds, target_names=class_names, digits=4))

## Version 4: Updated Dataset
Same as V3 but with additional images in 'other' class to reduce confusion with Geralt/Vesemir (This will require us to re-split the dataset again)

In [None]:
import os
import splitfolders

# Update this to your actual desktop path
base_path = '/path/to/your/project'

input_folder = os.path.join(base_path, "dataset")
output_folder = os.path.join(base_path, "split_data")

# Check if input folder exists
if os.path.exists(input_folder):
    print(f"Found input folder at: {input_folder}")
    # Let's also see what classes we have
    classes = [d for d in os.listdir(input_folder) if os.path.isdir(os.path.join(input_folder, d))]
    print(f"Classes found: {classes}")
    
    # Show image counts per class
    for cls in classes:
        count = len(os.listdir(os.path.join(input_folder, cls)))
        print(f"  {cls}: {count} images")
else:
    print("Error: Could not find the input folder!")

In [None]:
# Split with 80/20 ratio
# seed=42 ensures reproducibility
splitfolders.ratio(input_folder, output=output_folder, seed=42, ratio=(.8, .2))

print("--- Split Complete ---")
print(f"Check: {output_folder} should now contain 'train' and 'val' folders.")

# Let's verify the split worked correctly
print("\n--- Verifying Split ---")
for split in ['train', 'val']:
    split_path = os.path.join(output_folder, split)
    if os.path.exists(split_path):
        print(f"\n{split.upper()} set:")
        classes = [d for d in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, d))]
        for cls in classes:
            count = len(os.listdir(os.path.join(split_path, cls)))
            print(f"  {cls}: {count} images")

In [None]:
# V4: BALANCED TRANSFORMS, SAME AS V3
# Sweet spot between V1 and V2

data_transforms_v4 = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
        transforms.RandomHorizontalFlip(),
        
        # Stronger perspective (like V2, helps with angles)
        transforms.RandomPerspective(distortion_scale=0.45, p=0.5),
        
        # Moderate color jittering (less than V2, more than V1)
        # Helps distinguish hair colors without being too extreme
        transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.2),
        
        # REMOVED: Random grayscale (it hurt more than helped)
        
        # Slightly more rotation than V1
        transforms.RandomRotation(15),
        
        # Keep blur
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=5)], p=0.2),
        
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Reload datasets with V4 transforms
image_datasets_v4 = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms_v4[x])
                     for x in ['train', 'val']}

dataloaders_v4 = {x: DataLoader(image_datasets_v4[x], batch_size=16, shuffle=True)
                  for x in ['train', 'val']}

print("V4 Transforms loaded!")

In [None]:
# V4: Fresh model 
model_v4 = models.resnet18(weights="DEFAULT")

# Unfreeze all layers
for param in model_v4.parameters():
    param.requires_grad = True

num_ftrs = model_v4.fc.in_features
model_v4.fc = nn.Linear(num_ftrs, 6)
model_v4 = model_v4.to(device)

# Loss with class weights
criterion_v4 = nn.CrossEntropyLoss(weight=class_weights.to(device))

# Optimizer - same LR as V1 and V2
optimizer_v4 = optim.Adam(model_v4.parameters(), lr=0.00001)

# StepLR scheduler instead of ReduceLROnPlateau
# Reduces LR every 10 epochs regardless of performance
scheduler_v4 = optim.lr_scheduler.StepLR(
    optimizer_v4,
    step_size=10,    # Reduce every 10 epochs
    gamma=0.5,       # Multiply LR by 0.5
    verbose=True
)

print("Model V4 initialized!")
print(f"Initial learning rate: {optimizer_v4.param_groups[0]['lr']}")
print("Scheduler: StepLR (reduces LR at epochs 10, 20)")
print("\nThis ensures LR actually decreases during training")

In [None]:
def train_model_v4(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders_v4[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(image_datasets_v4[phase])
            epoch_acc = running_corrects.double() / len(image_datasets_v4[phase])

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            if phase == 'train':
                history['train_loss'].append(epoch_loss)
                history['train_acc'].append(epoch_acc.item())
            else:
                history['val_loss'].append(epoch_loss)
                history['val_acc'].append(epoch_acc.item())

            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        
        # Step the scheduler AFTER both train and val
        scheduler.step()
        
        current_lr = optimizer.param_groups[0]['lr']
        print(f'Current LR: {current_lr:.2e}')
        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')

    model.load_state_dict(best_model_wts)
    return model, history

print("Training function V4 ready!")

In [None]:
# Train V4 for 25 epochs
print("Starting V4 training with 25 epochs...")
print("LR will drop at epochs 10 and 20\n")

model_v4_trained, history_v4 = train_model_v4(
    model_v4,
    criterion_v4,
    optimizer_v4,
    scheduler_v4,
    num_epochs=25
)

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns

def evaluate_model(model, dataloader, class_names):
    """
    Evaluate model and return predictions and true labels
    """
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return all_preds, all_labels

# Run evaluation on validation set
print("Evaluating on validation set...")
val_preds, val_labels = evaluate_model(model_v4, dataloaders_v4['val'], class_names)

# Calculate overall accuracy
from sklearn.metrics import accuracy_score
val_accuracy = accuracy_score(val_labels, val_preds)
print(f"\nOverall Validation Accuracy: {val_accuracy:.4f} ({val_accuracy*100:.2f}%)")

In [None]:
# 1. CONFUSION MATRIX
print("\n" + "="*60)
print("CONFUSION MATRIX")
print("="*60)

cm = confusion_matrix(val_labels, val_preds)

# Plot confusion matrix as heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names,
            cbar_kws={'label': 'Count'})
plt.title('Confusion Matrix - Witcher Character Classifier', fontsize=14, pad=20)
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.tight_layout()
plt.show()

# 2. PER-CLASS ACCURACY
print("\n" + "="*60)
print("PER-CLASS ACCURACY")
print("="*60)

for i, class_name in enumerate(class_names):
    class_correct = cm[i, i]
    class_total = cm[i, :].sum()
    class_acc = class_correct / class_total if class_total > 0 else 0
    print(f"{class_name:12s}: {class_correct:2d}/{class_total:2d} = {class_acc:.2%}")

# 3. DETAILED CLASSIFICATION REPORT
print("\n" + "="*60)
print("CLASSIFICATION REPORT (Precision, Recall, F1)")
print("="*60)
print(classification_report(val_labels, val_preds, target_names=class_names, digits=4))

## Version 5: ReduceLROnPlateau
Further dataset refinement (requiring a re-split) with adaptive learning rate that decreases when validation loss plateaus

In [None]:
import os
import splitfolders

# Update this to your actual desktop path
base_path = '/path/to/your/project'

input_folder = os.path.join(base_path, "dataset")
output_folder = os.path.join(base_path, "split_data")

# Check if input folder exists
if os.path.exists(input_folder):
    print(f"Found input folder at: {input_folder}")
    # Let's also see what classes we have
    classes = [d for d in os.listdir(input_folder) if os.path.isdir(os.path.join(input_folder, d))]
    print(f"Classes found: {classes}")
    
    # Show image counts per class
    for cls in classes:
        count = len(os.listdir(os.path.join(input_folder, cls)))
        print(f"  {cls}: {count} images")
else:
    print("Error: Could not find the input folder!")

In [None]:
# Split with 80/20 ratio
# seed=42 ensures reproducibility
splitfolders.ratio(input_folder, output=output_folder, seed=42, ratio=(.8, .2))

print("--- Split Complete ---")
print(f"Check: {output_folder} should now contain 'train' and 'val' folders.")

# Let's verify the split worked correctly
print("\n--- Verifying Split ---")
for split in ['train', 'val']:
    split_path = os.path.join(output_folder, split)
    if os.path.exists(split_path):
        print(f"\n{split.upper()} set:")
        classes = [d for d in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, d))]
        for cls in classes:
            count = len(os.listdir(os.path.join(split_path, cls)))
            print(f"  {cls}: {count} images")

In [None]:
# V5: BALANCED TRANSFORMS, SAME AS V3 and V4
# Sweet spot between V1 and V2

data_transforms_v5 = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
        transforms.RandomHorizontalFlip(),
        
        # Stronger perspective (like V2, helps with angles)
        transforms.RandomPerspective(distortion_scale=0.45, p=0.5),
        
        # Moderate color jittering (less than V2, more than V1)
        # Helps distinguish hair colors without being too extreme
        transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.2),
        
        # Slightly more rotation than V1
        transforms.RandomRotation(15),
        
        # Keep blur
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=5)], p=0.2),
        
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Reload datasets with V5 transforms
image_datasets_v5 = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms_v5[x])
                     for x in ['train', 'val']}

dataloaders_v5 = {x: DataLoader(image_datasets_v5[x], batch_size=16, shuffle=True)
                  for x in ['train', 'val']}

print("V5 Transforms loaded!")

In [None]:
# V5: Fresh model 
model_v5 = models.resnet18(weights="DEFAULT")

# Unfreeze all layers
for param in model_v5.parameters():
    param.requires_grad = True

num_ftrs = model_v5.fc.in_features
model_v5.fc = nn.Linear(num_ftrs, 6)
model_v5 = model_v5.to(device)

# Loss with updated V5 class weights
criterion_v5 = nn.CrossEntropyLoss(weight=class_weights.to(device))

# Optimizer
optimizer_v5 = optim.Adam(model_v5.parameters(), lr=0.00001)

# Using ReduceLROnPlateau for smarter fine-tuning
scheduler_v5 = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_v5, 
    mode='min',    # Monitor when validation loss stops decreasing
    factor=0.5,    # Reduce LR by half 
    patience=3     # Wait 3 epochs of no improvement before dropping
)

print("Model V5 initialized!")
print(f"Initial learning rate: {optimizer_v5.param_groups[0]['lr']}")

In [None]:
def train_model_v5(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  
            else:
                model.eval()   

            running_loss = 0.0
            running_corrects = 0

            # Use dataloaders_v5
            for inputs, labels in dataloaders_v5[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                # Track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
            epoch_loss = running_loss / len(image_datasets_v5[phase])
            epoch_acc = running_corrects.double() / len(image_datasets_v5[phase])

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            if phase == 'train':
                history['train_loss'].append(epoch_loss)
                history['train_acc'].append(epoch_acc.item())
            else:
                history['val_loss'].append(epoch_loss)
                history['val_acc'].append(epoch_acc.item())
                
                # Step the ReduceLROnPlateau scheduler based on val_loss
                scheduler.step(epoch_loss)

            # Deep copy the model if it's the best accuracy
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        
        # Print current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        print(f'Current LR: {current_lr:.2e}')
        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model, history

print("Training function V5 (with Plateau Scheduler) ready!")

In [None]:
# Start V5 training
print("Starting V5 training with 25 epochs...")
print("LR will drop if validation loss plateaus for 3 epochs\n")

model_v5_trained, history_v5 = train_model_v5(
    model_v5,
    criterion_v5,
    optimizer_v5,
    scheduler_v5,
    num_epochs=25
)

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns

def evaluate_model(model, dataloader, class_names):
    """
    Evaluate model and return predictions and true labels
    """
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return all_preds, all_labels

# Run evaluation on validation set
print("Evaluating on validation set...")
val_preds, val_labels = evaluate_model(model_v5, dataloaders_v5['val'], class_names)

# Calculate overall accuracy
from sklearn.metrics import accuracy_score
val_accuracy = accuracy_score(val_labels, val_preds)
print(f"\nOverall Validation Accuracy: {val_accuracy:.4f} ({val_accuracy*100:.2f}%)")

In [None]:
# 1. CONFUSION MATRIX
print("\n" + "="*60)
print("CONFUSION MATRIX")
print("="*60)

cm = confusion_matrix(val_labels, val_preds)

# Plot confusion matrix as heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names,
            cbar_kws={'label': 'Count'})
plt.title('Confusion Matrix - Witcher Character Classifier', fontsize=14, pad=20)
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.tight_layout()
plt.show()

# 2. PER-CLASS ACCURACY
print("\n" + "="*60)
print("PER-CLASS ACCURACY")
print("="*60)

for i, class_name in enumerate(class_names):
    class_correct = cm[i, i]
    class_total = cm[i, :].sum()
    class_acc = class_correct / class_total if class_total > 0 else 0
    print(f"{class_name:12s}: {class_correct:2d}/{class_total:2d} = {class_acc:.2%}")

# 3. DETAILED CLASSIFICATION REPORT
print("\n" + "="*60)
print("CLASSIFICATION REPORT (Precision, Recall, F1)")
print("="*60)
print(classification_report(val_labels, val_preds, target_names=class_names, digits=4))

### This is a function to test your model with

In [None]:
from PIL import Image

def predict_image(model, image_path, class_names):
    # Load and transform the image
    img = Image.open(image_path).convert('RGB')
    img_t = data_transforms_v5['val'](img).unsqueeze(0).to(device)
    
    model.eval()
    with torch.no_grad():
        outputs = model(img_t)
        # Apply softmax to get probabilities
        probs = torch.nn.functional.softmax(outputs, dim=1)
        conf, preds = torch.max(probs, 1)
        
    print(f"Prediction: {class_names[preds[0]]}")
    print(f"Confidence: {conf.item()*100:.2f}%")
    
    # Display the image
    plt.imshow(img)
    plt.axis('off')
    plt.show()

# Example usage (update path to your test image):
# predict_image(model_v5_trained, 'path/to/your/test/image.jpg', image_datasets_v5['train'].classes)

In [None]:
# Save the trained model
models_dir = os.path.join(base_path, "models")
os.makedirs(models_dir, exist_ok=True)

model_path = os.path.join(models_dir, "witcher_multiclass_v5.pth")
torch.save(model_v5.state_dict(), model_path)

print(f"Model saved to: {model_path}")

## Version 6: Final Refinement
Additional dataset improvements and 30 training epochs for optimal performance

In [None]:
import os
import splitfolders

# Update this to your actual desktop path
base_path = '/path/to/your/project'

input_folder = os.path.join(base_path, "dataset")
output_folder = os.path.join(base_path, "split_data")

# Check if input folder exists
if os.path.exists(input_folder):
    print(f"Found input folder at: {input_folder}")
    # Let's also see what classes we have
    classes = [d for d in os.listdir(input_folder) if os.path.isdir(os.path.join(input_folder, d))]
    print(f"Classes found: {classes}")
    
    # Show image counts per class
    for cls in classes:
        count = len(os.listdir(os.path.join(input_folder, cls)))
        print(f"  {cls}: {count} images")
else:
    print("Error: Could not find the input folder!")

In [None]:
# Split with 80/20 ratio
# seed=42 ensures reproducibility
splitfolders.ratio(input_folder, output=output_folder, seed=42, ratio=(.8, .2))

print("--- Split Complete ---")
print(f"Check: {output_folder} should now contain 'train' and 'val' folders.")

# Let's verify the split worked correctly
print("\n--- Verifying Split ---")
for split in ['train', 'val']:
    split_path = os.path.join(output_folder, split)
    if os.path.exists(split_path):
        print(f"\n{split.upper()} set:")
        classes = [d for d in os.listdir(split_path) if os.path.isdir(os.path.join(split_path, d))]
        for cls in classes:
            count = len(os.listdir(os.path.join(split_path, cls)))
            print(f"  {cls}: {count} images")

In [None]:
# V6: BALANCED TRANSFORMS, SAME AS V3 and V4
# Sweet spot between V1 and V2

data_transforms_v6 = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.08, 1.0)),
        transforms.RandomHorizontalFlip(),
        
        # Stronger perspective (like V2, helps with angles)
        transforms.RandomPerspective(distortion_scale=0.45, p=0.5),
        
        # Moderate color jittering (less than V2, more than V1)
        # Helps distinguish hair colors without being too extreme
        transforms.ColorJitter(brightness=0.25, contrast=0.25, saturation=0.2),
        
        # Slightly more rotation than V1
        transforms.RandomRotation(15),
        
        # Keep blur
        transforms.RandomApply([transforms.GaussianBlur(kernel_size=5)], p=0.2),
        
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}

# Reload datasets with V6 transforms
image_datasets_v6 = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms_v6[x])
                     for x in ['train', 'val']}

dataloaders_v6 = {x: DataLoader(image_datasets_v6[x], batch_size=16, shuffle=True)
                  for x in ['train', 'val']}

print("V6 Transforms loaded!")

In [None]:
# V6: Fresh model 
model_v6 = models.resnet18(weights="DEFAULT")

# Unfreeze all layers
for param in model_v6.parameters():
    param.requires_grad = True

num_ftrs = model_v6.fc.in_features
model_v6.fc = nn.Linear(num_ftrs, 6)
model_v6 = model_v6.to(device)

# Loss with updated V6 class weights
criterion_v6 = nn.CrossEntropyLoss(weight=class_weights.to(device))

# Optimizer
optimizer_v6 = optim.Adam(model_v6.parameters(), lr=0.00001)

# NEW: Using ReduceLROnPlateau for smarter fine-tuning
scheduler_v6 = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer_v6, 
    mode='min',    # Monitor when validation loss stops decreasing
    factor=0.5,    # Reduce LR by half 
    patience=3     # Wait 3 epochs of no improvement before dropping
)

print("Model V6 initialized!")
print(f"Initial learning rate: {optimizer_v6.param_groups[0]['lr']}")

In [None]:
def train_model_v6(model, criterion, optimizer, scheduler, num_epochs=25):
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    
    history = {'train_loss': [], 'val_loss': [], 'train_acc': [], 'val_acc': []}

    for epoch in range(num_epochs):
        print(f'Epoch {epoch}/{num_epochs - 1}')
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
            else:
                model.eval()

            running_loss = 0.0
            running_corrects = 0
            for inputs, labels in dataloaders_v6[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device)

                optimizer.zero_grad()

                # Track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
                
            epoch_loss = running_loss / len(image_datasets_v6[phase])
            epoch_acc = running_corrects.double() / len(image_datasets_v6[phase])

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')
            
            if phase == 'train':
                history['train_loss'].append(epoch_loss)
                history['train_acc'].append(epoch_acc.item())
            else:
                history['val_loss'].append(epoch_loss)
                history['val_acc'].append(epoch_acc.item())
                
                # Step the ReduceLROnPlateau scheduler based on val_loss
                scheduler.step(epoch_loss)

            # Deep copy the model if it's the best accuracy
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
        
        # Print current learning rate
        current_lr = optimizer.param_groups[0]['lr']
        print(f'Current LR: {current_lr:.2e}')
        print()

    time_elapsed = time.time() - since
    print(f'Training complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model, history

print("Training function V6 (with Plateau Scheduler) ready!")

In [None]:
# Start V6 training
print("Starting V6 training with 30 epochs...")
print("LR will drop if validation loss plateaus for 3 epochs\n")

model_v6_trained, history_v6 = train_model_v6(
    model_v6,
    criterion_v6,
    optimizer_v6,
    scheduler_v6,
    num_epochs=30
)

In [None]:
from sklearn.metrics import confusion_matrix, classification_report
import matplotlib.pyplot as plt
import seaborn as sns

def evaluate_model(model, dataloader, class_names):
    """
    Evaluate model and return predictions and true labels
    """
    model.eval()
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs = inputs.to(device)
            labels = labels.to(device)
            
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    return all_preds, all_labels

# Run evaluation on validation set
print("Evaluating on validation set...")
val_preds, val_labels = evaluate_model(model_v6, dataloaders_v6['val'], class_names)

# Calculate overall accuracy
from sklearn.metrics import accuracy_score
val_accuracy = accuracy_score(val_labels, val_preds)
print(f"\nOverall Validation Accuracy: {val_accuracy:.4f} ({val_accuracy*100:.2f}%)")

In [None]:
# 1. CONFUSION MATRIX
print("\n" + "="*60)
print("CONFUSION MATRIX")
print("="*60)

cm = confusion_matrix(val_labels, val_preds)

# Plot confusion matrix as heatmap
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=class_names, yticklabels=class_names,
            cbar_kws={'label': 'Count'})
plt.title('Confusion Matrix - Witcher Character Classifier', fontsize=14, pad=20)
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.tight_layout()
plt.show()

# 2. PER-CLASS ACCURACY
print("\n" + "="*60)
print("PER-CLASS ACCURACY")
print("="*60)

for i, class_name in enumerate(class_names):
    class_correct = cm[i, i]
    class_total = cm[i, :].sum()
    class_acc = class_correct / class_total if class_total > 0 else 0
    print(f"{class_name:12s}: {class_correct:2d}/{class_total:2d} = {class_acc:.2%}")

# 3. DETAILED CLASSIFICATION REPORT
print("\n" + "="*60)
print("CLASSIFICATION REPORT (Precision, Recall, F1)")
print("="*60)
print(classification_report(val_labels, val_preds, target_names=class_names, digits=4))

In [None]:
# Save the trained model
models_dir = os.path.join(base_path, "models")
os.makedirs(models_dir, exist_ok=True)

model_path = os.path.join(models_dir, "witcher_multiclass_v6.pth")
torch.save(model_v6.state_dict(), model_path)

print(f"Model saved to: {model_path}")