# Experiment 3: Train WikiArt Art Style Classifier

This notebook trains a ResNet-18 based classifier to recognize 27 art styles from WikiArt.
The classifier is used to evaluate how well generated images match the prompted art style.

**Architecture:**
- ResNet-18 backbone (pretrained on ImageNet)
- Custom 512-dim embedding layer
- Classification head for 27 art styles

**Training:**
- HuggingFace WikiArt dataset
- Train/validation split (80/20)
- Data augmentation (flip, rotation, color jitter)
- 20 epochs, learning rate scheduling

## 1. Setup and Configuration

In [None]:
# Project configuration - use absolute paths
from pathlib import Path
import sys

PROJECT_ROOT = Path("/home/doshlom4/work/final_project")
sys.path.insert(0, str(PROJECT_ROOT))

print(f"Project root: {PROJECT_ROOT}")

In [None]:
# Import configuration and classifier
from config import (
    WIKIART_STYLES,
    DATASET_CACHE_DIR,
    CHECKPOINTS_DIR,
)

from wikiart_classifier import (
    WikiArtClassifier,
    train_wikiart_classifier,
    evaluate_wikiart_classifier,
    compute_per_class_accuracy,
    get_wikiart_classifier_checkpoint_path,
    get_wikiart_transforms,
)

# Deep learning frameworks
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset, random_split
from tqdm import tqdm

# Standard libraries
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

# HuggingFace datasets
from datasets import load_dataset
import torchvision.transforms as transforms

print("Libraries imported successfully")

In [None]:
# Training configuration
TRAINING_CONFIG = {
    "num_epochs": 20,
    "batch_size": 32,
    "learning_rate": 0.001,
    "image_size": 128,
    "train_split": 0.8,
    "num_workers": 4,
}

print("Training Configuration:")
for key, value in TRAINING_CONFIG.items():
    print(f"  {key}: {value}")
print(f"\nNumber of art styles: {len(WIKIART_STYLES)}")

## 2. Load WikiArt Dataset from HuggingFace

In [None]:
# Load WikiArt dataset
print("Loading WikiArt dataset from HuggingFace...")
wikiart_hf = load_dataset(
    "huggan/wikiart",
    split="train",
    cache_dir=str(DATASET_CACHE_DIR / "huggingface")
)

print(f"\nDataset loaded: {len(wikiart_hf)} images")
print(f"Features: {wikiart_hf.features}")

In [None]:
# Detect style column
sample = wikiart_hf[0]
style_column = 'style' if 'style' in sample else 'label'
print(f"Style column: {style_column}")

In [None]:
# Count images per style
style_counts = {i: 0 for i in range(len(WIKIART_STYLES))}

for item in tqdm(wikiart_hf, desc="Counting styles"):
    style_idx = item[style_column]
    if style_idx < len(WIKIART_STYLES):
        style_counts[style_idx] += 1

print("\nImages per style:")
for style_idx in sorted(style_counts.keys()):
    print(f"  {WIKIART_STYLES[style_idx]}: {style_counts[style_idx]}")

total_valid = sum(style_counts.values())
print(f"\nTotal valid images: {total_valid}")

In [None]:
# Visualize class distribution
plt.figure(figsize=(14, 6))
style_names = [s.replace('_', ' ')[:15] for s in WIKIART_STYLES]
counts = [style_counts[i] for i in range(len(WIKIART_STYLES))]

plt.bar(range(len(WIKIART_STYLES)), counts, color='steelblue')
plt.xticks(range(len(WIKIART_STYLES)), style_names, rotation=45, ha='right', fontsize=8)
plt.xlabel('Art Style')
plt.ylabel('Number of Images')
plt.title('WikiArt Dataset: Images per Style')
plt.tight_layout()
plt.show()

## 3. Create PyTorch Dataset

In [None]:
# Define transforms
train_transform = transforms.Compose([
    transforms.Resize((TRAINING_CONFIG["image_size"], TRAINING_CONFIG["image_size"])),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

test_transform = transforms.Compose([
    transforms.Resize((TRAINING_CONFIG["image_size"], TRAINING_CONFIG["image_size"])),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

print("Transforms defined")

In [None]:
# Create custom dataset wrapper for HuggingFace dataset
class WikiArtHFDataset(Dataset):
    """
    PyTorch Dataset wrapper for HuggingFace WikiArt dataset.
    """
    def __init__(self, hf_dataset, indices, transform=None):
        self.hf_dataset = hf_dataset
        self.indices = indices
        self.transform = transform
        
        # Detect style column
        sample = hf_dataset[0]
        self.style_column = 'style' if 'style' in sample else 'label'
    
    def __len__(self):
        return len(self.indices)
    
    def __getitem__(self, idx):
        real_idx = self.indices[idx]
        item = self.hf_dataset[int(real_idx)]
        
        image = item['image']
        if image.mode != 'RGB':
            image = image.convert('RGB')
        
        label = item[self.style_column]
        
        # Skip if label out of range
        if label >= len(WIKIART_STYLES):
            label = label % len(WIKIART_STYLES)
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

print("Dataset class defined")

In [None]:
# Filter to valid indices (only styles 0-26)
valid_indices = []
for i, item in enumerate(tqdm(wikiart_hf, desc="Filtering valid samples")):
    style_idx = item[style_column]
    if style_idx < len(WIKIART_STYLES):
        valid_indices.append(i)

print(f"\nValid samples: {len(valid_indices)}")

# Shuffle and split
np.random.seed(42)
np.random.shuffle(valid_indices)

split_idx = int(len(valid_indices) * TRAINING_CONFIG["train_split"])
train_indices = valid_indices[:split_idx]
test_indices = valid_indices[split_idx:]

print(f"Training samples: {len(train_indices)}")
print(f"Test samples: {len(test_indices)}")

In [None]:
# Create datasets
train_dataset = WikiArtHFDataset(wikiart_hf, train_indices, transform=train_transform)
test_dataset = WikiArtHFDataset(wikiart_hf, test_indices, transform=test_transform)

print(f"Train dataset: {len(train_dataset)} samples")
print(f"Test dataset: {len(test_dataset)} samples")

In [None]:
# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=TRAINING_CONFIG["batch_size"],
    shuffle=True,
    num_workers=TRAINING_CONFIG["num_workers"],
    pin_memory=True,
)

test_loader = DataLoader(
    test_dataset,
    batch_size=TRAINING_CONFIG["batch_size"],
    shuffle=False,
    num_workers=TRAINING_CONFIG["num_workers"],
    pin_memory=True,
)

print(f"Train batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")

In [None]:
# Visualize some training samples
fig, axes = plt.subplots(2, 5, figsize=(15, 6))

# Get a batch
images, labels = next(iter(train_loader))

# Denormalize for display
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)

for i, ax in enumerate(axes.flat):
    img = images[i] * std + mean
    img = img.permute(1, 2, 0).numpy().clip(0, 1)
    ax.imshow(img)
    ax.set_title(WIKIART_STYLES[labels[i].item()][:20], fontsize=8)
    ax.axis('off')

plt.suptitle('Training Samples', fontsize=14)
plt.tight_layout()
plt.show()

## 4. Create Model

In [None]:
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

In [None]:
# Create model
model = WikiArtClassifier(
    num_classes=len(WIKIART_STYLES),
    embedding_dim=512,
    pretrained=True
).to(device)

num_params = sum(p.numel() for p in model.parameters())
num_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nâœ“ Created WikiArt Classifier")
print(f"  Total parameters: {num_params:,}")
print(f"  Trainable parameters: {num_trainable:,}")
print(f"  Number of classes: {model.num_classes}")
print(f"  Embedding dimension: {model.embedding_dim}")

## 5. Train Classifier

In [None]:
# Training setup
CHECKPOINTS_DIR.mkdir(parents=True, exist_ok=True)
save_path = get_wikiart_classifier_checkpoint_path()

print(f"Checkpoint will be saved to: {save_path}")

In [None]:
# Train the model
print(f"\n{'='*70}")
print(f"Starting WikiArt Classifier Training")
print(f"{'='*70}")
print(f"Epochs: {TRAINING_CONFIG['num_epochs']}")
print(f"Batch size: {TRAINING_CONFIG['batch_size']}")
print(f"Learning rate: {TRAINING_CONFIG['learning_rate']}")
print(f"{'='*70}\n")

model, history = train_wikiart_classifier(
    model=model,
    train_loader=train_loader,
    test_loader=test_loader,
    device=device,
    num_epochs=TRAINING_CONFIG["num_epochs"],
    lr=TRAINING_CONFIG["learning_rate"],
    save_path=save_path,
)

print(f"\n{'='*70}")
print(f"Training Complete!")
print(f"{'='*70}")
print(f"Best test accuracy: {max(history['test_acc']):.2f}%")

## 6. Visualize Training Progress

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss
axes[0].plot(range(1, len(history['train_loss']) + 1), history['train_loss'], 'b-', linewidth=2)
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss')
axes[0].grid(True, alpha=0.3)

# Accuracy
axes[1].plot(range(1, len(history['train_acc']) + 1), history['train_acc'], 'b-', linewidth=2, label='Train')
axes[1].plot(range(1, len(history['test_acc']) + 1), history['test_acc'], 'r-', linewidth=2, label='Test')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy (%)')
axes[1].set_title('Training and Test Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

plt.suptitle('WikiArt Classifier Training Progress', fontsize=14)
plt.tight_layout()
plt.show()

## 7. Evaluate Per-Class Accuracy

In [None]:
# Compute per-class accuracy
per_class_acc = compute_per_class_accuracy(model, test_loader, device)

print("Per-class accuracy:")
for style_name, acc in per_class_acc.items():
    print(f"  {style_name}: {acc:.1f}%")

In [None]:
# Visualize per-class accuracy
plt.figure(figsize=(14, 6))
style_names = [s.replace('_', ' ')[:15] for s in WIKIART_STYLES]
accuracies = [per_class_acc[s] for s in WIKIART_STYLES]

colors = ['green' if acc >= 50 else 'orange' if acc >= 30 else 'red' for acc in accuracies]

plt.bar(range(len(WIKIART_STYLES)), accuracies, color=colors)
plt.xticks(range(len(WIKIART_STYLES)), style_names, rotation=45, ha='right', fontsize=8)
plt.axhline(y=50, color='gray', linestyle='--', alpha=0.5, label='50%')
plt.xlabel('Art Style')
plt.ylabel('Accuracy (%)')
plt.title('WikiArt Classifier: Per-Class Accuracy')
plt.legend()
plt.tight_layout()
plt.show()

print(f"\nAverage accuracy: {np.mean(accuracies):.1f}%")
print(f"Best style: {WIKIART_STYLES[np.argmax(accuracies)]} ({max(accuracies):.1f}%)")
print(f"Worst style: {WIKIART_STYLES[np.argmin(accuracies)]} ({min(accuracies):.1f}%)")

## 8. Confusion Matrix

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns

# Collect all predictions
all_preds = []
all_labels = []

model.eval()
with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Computing confusion matrix"):
        images = images.to(device)
        outputs = model(images)
        _, preds = outputs.max(1)
        
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.numpy())

# Compute confusion matrix
cm = confusion_matrix(all_labels, all_preds)

# Normalize
cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]

# Plot
plt.figure(figsize=(16, 14))
style_labels = [s.replace('_', ' ')[:12] for s in WIKIART_STYLES]
sns.heatmap(cm_normalized, annot=False, cmap='Blues', 
            xticklabels=style_labels, yticklabels=style_labels)
plt.xlabel('Predicted Style')
plt.ylabel('True Style')
plt.title('WikiArt Classifier Confusion Matrix (Normalized)')
plt.tight_layout()
plt.show()

## 9. Test on Sample Images

In [None]:
# Test model on sample images from test set
model.eval()

fig, axes = plt.subplots(3, 4, figsize=(16, 12))

# Get random samples
sample_indices = np.random.choice(len(test_dataset), 12, replace=False)

for i, idx in enumerate(sample_indices):
    image, label = test_dataset[idx]
    
    # Predict
    with torch.no_grad():
        img_tensor = image.unsqueeze(0).to(device)
        output = model(img_tensor)
        probs = torch.softmax(output, dim=1)
        confidence, pred = probs.max(1)
    
    # Denormalize for display
    mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
    std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
    img_display = (image * std + mean).permute(1, 2, 0).numpy().clip(0, 1)
    
    # Plot
    ax = axes[i // 4, i % 4]
    ax.imshow(img_display)
    
    true_style = WIKIART_STYLES[label][:15]
    pred_style = WIKIART_STYLES[pred.item()][:15]
    conf = confidence.item() * 100
    
    color = 'green' if pred.item() == label else 'red'
    ax.set_title(f'True: {true_style}\nPred: {pred_style} ({conf:.0f}%)', 
                 fontsize=8, color=color)
    ax.axis('off')

plt.suptitle('WikiArt Classifier Predictions', fontsize=14)
plt.tight_layout()
plt.show()

## Summary

This notebook trained a WikiArt art style classifier:

**Model:**
- ResNet-18 backbone (pretrained on ImageNet)
- Custom classification head for 27 art styles
- 512-dim embedding layer for feature extraction

**Training:**
- HuggingFace WikiArt dataset
- 80/20 train/test split
- Data augmentation (flip, rotation, color jitter)
- Saved best model checkpoint

**Usage:**
The classifier will be used in `metrics1_evaluate_wikiart.ipynb` to compute:
1. Classification accuracy on generated images (prompt adherence)
2. Feature extraction for additional metrics

**Next steps:**
- `metrics1_evaluate_wikiart.ipynb` - Compute FID and classification accuracy