In [1]:
"""
Fashion-MNIST CNN Classifier - Improved Version
===============================================
This notebook implements an improved CNN model for Fashion-MNIST classification
with better code organization, data augmentation, learning rate scheduling,
early stopping, and comprehensive visualizations.
"""

# ============================================================================
# 1. IMPORTS AND CONFIGURATION
# ============================================================================
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import pandas as pd
from torchvision import transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau, CosineAnnealingLR
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from pathlib import Path
from typing import Optional, Tuple, Dict, List
from sklearn.metrics import classification_report, confusion_matrix
import warnings
warnings.filterwarnings('ignore')

# Configuration
class Config:
    """Configuration class for hyperparameters and paths"""
    # Data paths (configurable for different environments)
    DATA_DIR = Path("./data")
    TRAIN_CSV = DATA_DIR / "fashion-mnist_train.csv"
    TEST_CSV = DATA_DIR / "fashion-mnist_test.csv"
    
    # Alternative: Kaggle paths (uncomment if running on Kaggle)
    # TRAIN_CSV = Path("/kaggle/input/fashion-mnist/fashion-mnist_train.csv")
    # TEST_CSV = Path("/kaggle/input/fashion-mnist/fashion-mnist_test.csv")
    
    # Hyperparameters
    BATCH_SIZE = 128
    LEARNING_RATE = 0.001
    EPOCHS = 50
    VALIDATION_SPLIT = 0.1
    DROPOUT_RATE = 0.5
    WEIGHT_DECAY = 1e-4
    
    # Training settings
    EARLY_STOPPING_PATIENCE = 10
    MIN_DELTA = 0.001
    NUM_WORKERS = 0  # Set to 0 on Windows to avoid multiprocessing issues
    PIN_MEMORY = True if torch.cuda.is_available() else False
    
    # Model settings
    SAVE_BEST_MODEL = True
    MODEL_SAVE_PATH = Path("./models")
    MODEL_SAVE_PATH.mkdir(exist_ok=True)
    
    # Class names for Fashion-MNIST
    CLASS_NAMES = [
        'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
    ]

# Set random seeds for reproducibility
def set_seed(seed: int = 42):
    """Set random seeds for reproducibility"""
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")

config = Config()


Using device: cpu


In [2]:
# ============================================================================
# 2. DATASET LOADER WITH PROPER NORMALIZATION
# ============================================================================
class FashionMNISTDataset(Dataset):
    """
    Improved Fashion-MNIST dataset loader with proper normalization.
    
    Args:
        csv_file: Path to CSV file
        transform: Optional transform to be applied on images
        normalize: Whether to normalize pixel values (default: True)
    """
    def __init__(
        self, 
        csv_file: str, 
        transform: Optional[transforms.Compose] = None,
        normalize: bool = True
    ):
        try:
            data = pd.read_csv(csv_file)
        except FileNotFoundError:
            # Fallback for Kaggle environment
            if "kaggle" in str(csv_file) or not Path(csv_file).exists():
                # Try alternative path
                alt_path = Path("/kaggle/input/fashion-mnist") / Path(csv_file).name
                if alt_path.exists():
                    csv_file = str(alt_path)
                    data = pd.read_csv(csv_file)
                else:
                    raise FileNotFoundError(f"Could not find CSV file: {csv_file}")
            else:
                raise
        
        self.labels = torch.tensor(data.iloc[:, 0].values, dtype=torch.long)
        # Convert pixel values from [0, 255] to [0, 1] and reshape
        images = data.iloc[:, 1:].values.astype("float32") / 255.0
        self.images = images.reshape(-1, 28, 28)
        self.transform = transform
        self.normalize = normalize
    
    def __len__(self) -> int:
        return len(self.labels)
    
    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        image = self.images[idx].copy()  # Shape: (28, 28), dtype: float32, range: [0, 1]
        
        # Apply transforms (transforms will handle conversion to tensor and normalization)
        if self.transform:
            image = self.transform(image)
        else:
            # Default: convert to tensor and normalize if no transform provided
            if not isinstance(image, torch.Tensor):
                image = torch.from_numpy(image).float()
            # Add channel dimension: (28, 28) -> (1, 28, 28)
            if image.dim() == 2:
                image = image.unsqueeze(0)
            # Normalize if requested
            if self.normalize:
                image = transforms.Normalize(mean=[0.2860], std=[0.3530])(image)
        
        label = self.labels[idx]
        return image, label


# ============================================================================
# 3. DATA AUGMENTATION AND TRANSFORMS
# ============================================================================
class ToPILImageWrapper:
    """Wrapper to convert numpy array to PIL Image"""
    def __call__(self, img):
        from PIL import Image
        if isinstance(img, np.ndarray):
            # Convert to uint8 and then to PIL Image
            img = (img * 255).astype(np.uint8)
            if img.ndim == 2:
                return Image.fromarray(img, mode='L')
            elif img.ndim == 3 and img.shape[0] == 1:
                return Image.fromarray(img.squeeze(0), mode='L')
        return img

def get_transforms(augment: bool = False):
    """
    Get data transformation pipeline.
    
    Args:
        augment: Whether to apply data augmentation (for training)
    
    Returns:
        Transform function or Compose object
    """
    if augment:
        # Training transforms with augmentation
        # Note: Horizontal flip removed as it may not be appropriate for clothing items
        return transforms.Compose([
            ToPILImageWrapper(),
            transforms.RandomRotation(degrees=10),
            transforms.RandomAffine(degrees=0, translate=(0.1, 0.1), scale=(0.9, 1.1)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.2860], std=[0.3530])
        ])
    else:
        # Validation/test transforms (no augmentation) - simple normalization
        def transform_func(img):
            if isinstance(img, np.ndarray):
                img = torch.from_numpy(img.copy()).float()
            elif isinstance(img, torch.Tensor):
                img = img.clone().float()
            # Ensure channel dimension exists
            if img.dim() == 2:
                img = img.unsqueeze(0)
            elif img.dim() == 3 and img.shape[0] != 1:
                # If shape is (H, W, C), convert to (C, H, W)
                img = img.permute(2, 0, 1)
            # Normalize
            return transforms.Normalize(mean=[0.2860], std=[0.3530])(img)
        return transform_func


# ============================================================================
# 4. LOAD AND PREPARE DATA
# ============================================================================
def load_data(config: Config) -> Tuple[DataLoader, DataLoader, DataLoader]:
    """
    Load and prepare datasets with train/validation/test splits.
    
    Returns:
        Tuple of (train_loader, val_loader, test_loader)
    """
    # Try to load train data
    try:
        if config.TRAIN_CSV.exists():
            train_path = config.TRAIN_CSV
        else:
            # Fallback: try Kaggle path
            train_path = "/kaggle/input/fashion-mnist/fashion-mnist_train.csv"
            if not Path(train_path).exists():
                # If train file doesn't exist, use test file for both (backward compatibility)
                print("Warning: Train file not found. Using test file for training.")
                train_path = config.TEST_CSV if config.TEST_CSV.exists() else "/kaggle/input/fashion-mnist/fashion-mnist_test.csv"
    except:
        train_path = "/kaggle/input/fashion-mnist/fashion-mnist_train.csv"
    
    # Try to load test data
    try:
        if config.TEST_CSV.exists():
            test_path = config.TEST_CSV
        else:
            test_path = "/kaggle/input/fashion-mnist/fashion-mnist_test.csv"
    except:
        test_path = "/kaggle/input/fashion-mnist/fashion-mnist_test.csv"
    
    # Load datasets
    train_dataset = FashionMNISTDataset(
        train_path, 
        transform=get_transforms(augment=True),
        normalize=False  # Normalization handled in transforms
    )
    
    test_dataset = FashionMNISTDataset(
        test_path,
        transform=get_transforms(augment=False),
        normalize=False
    )
    
    # Split training data into train and validation
    val_size = int(len(train_dataset) * config.VALIDATION_SPLIT)
    train_size = len(train_dataset) - val_size
    train_dataset, val_dataset = random_split(
        train_dataset, 
        [train_size, val_size],
        generator=torch.Generator().manual_seed(42)
    )
    
    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=True,
        num_workers=config.NUM_WORKERS,
        pin_memory=config.PIN_MEMORY
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        num_workers=config.NUM_WORKERS,
        pin_memory=config.PIN_MEMORY
    )
    
    test_loader = DataLoader(
        test_dataset,
        batch_size=config.BATCH_SIZE,
        shuffle=False,
        num_workers=config.NUM_WORKERS,
        pin_memory=config.PIN_MEMORY
    )
    
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    
    return train_loader, val_loader, test_loader

# Load data
train_loader, val_loader, test_loader = load_data(config)




FileNotFoundError: Could not find CSV file: /kaggle/input/fashion-mnist/fashion-mnist_test.csv

In [None]:
# ============================================================================
# 5. IMPROVED CNN MODEL ARCHITECTURE
# ============================================================================
class ImprovedCNN(nn.Module):
    """
    Improved CNN architecture for Fashion-MNIST classification.
    Features:
    - Batch normalization after each conv layer
    - Dropout for regularization
    - Residual connections for better gradient flow
    - Optimized architecture for 28x28 images
    """
    def __init__(self, dropout_rate: float = 0.5):
        super(ImprovedCNN, self).__init__()
        
        # First conv block: 28x28 -> 14x14
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)
        self.conv2 = nn.Conv2d(32, 32, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(32)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout2d(0.25)
        
        # Second conv block: 14x14 -> 7x7
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(64)
        self.conv4 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(64)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.dropout2 = nn.Dropout2d(0.25)
        
        # Third conv block: 7x7 -> 3x3
        self.conv5 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn5 = nn.BatchNorm2d(128)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.dropout3 = nn.Dropout2d(0.25)
        
        # Fully connected layers
        self.fc1 = nn.Linear(128 * 3 * 3, 256)
        self.bn_fc = nn.BatchNorm1d(256)
        self.dropout_fc = nn.Dropout(dropout_rate)
        self.fc2 = nn.Linear(256, 10)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # First block
        x = F.relu(self.bn1(self.conv1(x)))
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool1(x)
        x = self.dropout1(x)
        
        # Second block
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.pool2(x)
        x = self.dropout2(x)
        
        # Third block
        x = F.relu(self.bn5(self.conv5(x)))
        x = self.pool3(x)
        x = self.dropout3(x)
        
        # Flatten
        x = x.view(-1, 128 * 3 * 3)
        
        # Fully connected layers
        x = F.relu(self.bn_fc(self.fc1(x)))
        x = self.dropout_fc(x)
        x = self.fc2(x)
        
        return x
    
    def get_model_size(self) -> int:
        """Return the number of trainable parameters"""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)


# ============================================================================
# 6. VISUALIZATION FUNCTIONS
# ============================================================================
def show_sample_images(
    dataloader: DataLoader, 
    class_names: List[str], 
    count: int = 9
) -> None:
    """Display sample images from the dataset"""
    dataloader_iter = iter(dataloader)
    images, labels = next(dataloader_iter)
    
    fig, axes = plt.subplots(3, 3, figsize=(10, 10))
    axes = axes.ravel()
    
    for i in range(min(count, len(images))):
        img = images[i].squeeze().cpu().numpy()
        # Denormalize for display
        img = img * 0.3530 + 0.2860
        img = np.clip(img, 0, 1)
        
        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(f"Label: {labels[i].item()} - {class_names[labels[i].item()]}")
        axes[i].axis('off')
    
    plt.suptitle("Sample Images from Dataset", fontsize=16)
    plt.tight_layout()
    plt.show()


def plot_class_distribution(dataloader: DataLoader, class_names: List[str]) -> None:
    """Plot class distribution"""
    all_labels = []
    for _, labels in dataloader:
        all_labels.extend(labels.numpy())
    
    label_counts = np.bincount(all_labels)
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))
    
    # Pie chart
    ax1.pie(label_counts, labels=class_names, autopct='%1.1f%%', startangle=90)
    ax1.set_title("Class Distribution (Pie Chart)")
    
    # Bar chart
    ax2.bar(range(len(class_names)), label_counts)
    ax2.set_xlabel("Class")
    ax2.set_ylabel("Count")
    ax2.set_title("Class Distribution (Bar Chart)")
    ax2.set_xticks(range(len(class_names)))
    ax2.set_xticklabels(class_names, rotation=45, ha='right')
    
    plt.tight_layout()
    plt.show()


# Show sample images and distribution
print("Sample Images from Training Set:")
show_sample_images(train_loader, config.CLASS_NAMES, count=9)

print("\nClass Distribution:")
plot_class_distribution(train_loader, config.CLASS_NAMES)

Model saved to fashion_mnist_cnn.pth


In [None]:
# ============================================================================
# 7. TRAINING AND EVALUATION FUNCTIONS
# ============================================================================
def evaluate_model(
    model: nn.Module, 
    dataloader: DataLoader, 
    device: torch.device,
    return_predictions: bool = False
) -> Tuple[float, Optional[np.ndarray], Optional[np.ndarray]]:
    """
    Evaluate model on a dataloader.
    
    Returns:
        Tuple of (accuracy, predictions, true_labels) if return_predictions=True,
        else (accuracy, None, None)
    """
    model.eval()
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            
            correct += (preds == labels).sum().item()
            total += labels.size(0)
            
            if return_predictions:
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
    
    accuracy = 100.0 * correct / total
    
    if return_predictions:
        return accuracy, np.array(all_preds), np.array(all_labels)
    return accuracy, None, None


class EarlyStopping:
    """Early stopping utility to stop training when validation loss stops improving"""
    def __init__(self, patience: int = 10, min_delta: float = 0.001, mode: str = 'max'):
        self.patience = patience
        self.min_delta = min_delta
        self.mode = mode
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        
    def __call__(self, score: float) -> bool:
        if self.best_score is None:
            self.best_score = score
        elif self.mode == 'max':
            if score < self.best_score + self.min_delta:
                self.counter += 1
                if self.counter >= self.patience:
                    self.early_stop = True
            else:
                self.best_score = score
                self.counter = 0
        else:  # mode == 'min'
            if score > self.best_score - self.min_delta:
                self.counter += 1
                if self.counter >= self.patience:
                    self.early_stop = True
            else:
                self.best_score = score
                self.counter = 0
        
        return self.early_stop


def train_epoch(
    model: nn.Module,
    dataloader: DataLoader,
    criterion: nn.Module,
    optimizer: torch.optim.Optimizer,
    device: torch.device
) -> Tuple[float, float]:
    """Train for one epoch"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    for images, labels in dataloader:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)
    
    epoch_loss = running_loss / len(dataloader)
    epoch_acc = 100.0 * correct / total
    return epoch_loss, epoch_acc


def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    config: Config,
    device: torch.device
) -> Dict[str, List[float]]:
    """
    Train the model with early stopping and learning rate scheduling.
    
    Returns:
        Dictionary with training history
    """
    # Initialize optimizer and criterion
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.AdamW(
        model.parameters(),
        lr=config.LEARNING_RATE,
        weight_decay=config.WEIGHT_DECAY
    )
    
    # Learning rate scheduler
    scheduler = ReduceLROnPlateau(
        optimizer, 
        mode='max', 
        factor=0.5, 
        patience=5, 
        verbose=True
    )
    
    # Early stopping
    early_stopping = EarlyStopping(
        patience=config.EARLY_STOPPING_PATIENCE,
        min_delta=config.MIN_DELTA,
        mode='max'
    )
    
    # Training history
    history = {
        'train_loss': [],
        'train_acc': [],
        'val_loss': [],
        'val_acc': [],
        'learning_rate': []
    }
    
    best_val_acc = 0.0
    best_model_state = None
    
    print(f"Training model with {model.get_model_size():,} parameters")
    print(f"{'Epoch':<8} {'Train Loss':<12} {'Train Acc':<12} {'Val Loss':<12} {'Val Acc':<12} {'LR':<10}")
    print("-" * 70)
    
    for epoch in range(config.EPOCHS):
        # Train
        train_loss, train_acc = train_epoch(
            model, train_loader, criterion, optimizer, device
        )
        
        # Validate
        val_acc, _, _ = evaluate_model(model, val_loader, device)
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                val_loss += criterion(outputs, labels).item()
        val_loss /= len(val_loader)
        
        # Update learning rate
        scheduler.step(val_acc)
        current_lr = optimizer.param_groups[0]['lr']
        
        # Save history
        history['train_loss'].append(train_loss)
        history['train_acc'].append(train_acc)
        history['val_loss'].append(val_loss)
        history['val_acc'].append(val_acc)
        history['learning_rate'].append(current_lr)
        
        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_state = model.state_dict().copy()
            if config.SAVE_BEST_MODEL:
                torch.save(
                    best_model_state,
                    config.MODEL_SAVE_PATH / "best_model.pth"
                )
        
        # Print progress
        print(f"{epoch+1:<8} {train_loss:<12.4f} {train_acc:<12.2f} {val_loss:<12.4f} {val_acc:<12.2f} {current_lr:<10.6f}")
        
        # Early stopping
        if early_stopping(val_acc):
            print(f"\nEarly stopping triggered at epoch {epoch+1}")
            print(f"Best validation accuracy: {best_val_acc:.2f}%")
            break
    
    # Load best model
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
        print(f"\nLoaded best model with validation accuracy: {best_val_acc:.2f}%")
    
    return history


# ============================================================================
# 8. INITIALIZE AND TRAIN MODEL
# ============================================================================
model = ImprovedCNN(dropout_rate=config.DROPOUT_RATE).to(device)
print(f"Model initialized with {model.get_model_size():,} trainable parameters\n")

# Train the model
history = train_model(model, train_loader, val_loader, config, device)


In [None]:
# ============================================================================
# 9. VISUALIZE TRAINING HISTORY
# ============================================================================
def plot_training_history(history: Dict[str, List[float]]) -> None:
    """Plot training and validation metrics"""
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    epochs = range(1, len(history['train_loss']) + 1)
    
    # Loss plot
    axes[0, 0].plot(epochs, history['train_loss'], 'b-', label='Train Loss', linewidth=2)
    axes[0, 0].plot(epochs, history['val_loss'], 'r-', label='Validation Loss', linewidth=2)
    axes[0, 0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # Accuracy plot
    axes[0, 1].plot(epochs, history['train_acc'], 'b-', label='Train Accuracy', linewidth=2)
    axes[0, 1].plot(epochs, history['val_acc'], 'r-', label='Validation Accuracy', linewidth=2)
    axes[0, 1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Accuracy (%)')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # Learning rate plot
    axes[1, 0].plot(epochs, history['learning_rate'], 'g-', linewidth=2)
    axes[1, 0].set_title('Learning Rate Schedule', fontsize=14, fontweight='bold')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Learning Rate')
    axes[1, 0].set_yscale('log')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Accuracy difference (overfitting indicator)
    acc_diff = [val - train for train, val in zip(history['train_acc'], history['val_acc'])]
    axes[1, 1].plot(epochs, acc_diff, 'purple', linewidth=2)
    axes[1, 1].axhline(y=0, color='k', linestyle='--', alpha=0.5)
    axes[1, 1].set_title('Validation - Training Accuracy (Overfitting Indicator)', 
                         fontsize=14, fontweight='bold')
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Accuracy Difference (%)')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

plot_training_history(history)


In [None]:
# ============================================================================
# 10. EVALUATE ON TEST SET
# ============================================================================
test_acc, test_preds, test_labels = evaluate_model(
    model, test_loader, device, return_predictions=True
)

print(f"\n{'='*70}")
print(f"Test Set Results")
print(f"{'='*70}")
print(f"Test Accuracy: {test_acc:.2f}%")
print(f"{'='*70}\n")

# Classification report
print("Classification Report:")
print(classification_report(
    test_labels, 
    test_preds, 
    target_names=config.CLASS_NAMES,
    digits=4
))


In [None]:
# ============================================================================
# 11. CONFUSION MATRIX
# ============================================================================
def plot_confusion_matrix(
    y_true: np.ndarray,
    y_pred: np.ndarray,
    class_names: List[str]
) -> None:
    """Plot confusion matrix"""
    cm = confusion_matrix(y_true, y_pred)
    
    plt.figure(figsize=(12, 10))
    sns.heatmap(
        cm, 
        annot=True, 
        fmt='d', 
        cmap='Blues',
        xticklabels=class_names,
        yticklabels=class_names,
        cbar_kws={'label': 'Count'}
    )
    plt.title('Confusion Matrix', fontsize=16, fontweight='bold', pad=20)
    plt.ylabel('True Label', fontsize=12)
    plt.xlabel('Predicted Label', fontsize=12)
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()
    
    # Calculate per-class accuracy
    class_accuracies = cm.diagonal() / cm.sum(axis=1) * 100
    print("\nPer-Class Accuracy:")
    for i, (class_name, acc) in enumerate(zip(class_names, class_accuracies)):
        print(f"  {class_name:<20}: {acc:.2f}%")

plot_confusion_matrix(test_labels, test_preds, config.CLASS_NAMES)


In [None]:
# ============================================================================
# 12. VISUALIZE PREDICTIONS
# ============================================================================
def show_predictions(
    model: nn.Module,
    dataloader: DataLoader,
    class_names: List[str],
    device: torch.device,
    count: int = 16,
    show_errors_only: bool = False
) -> None:
    """Visualize model predictions"""
    model.eval()
    images_shown = 0
    fig, axes = plt.subplots(4, 4, figsize=(12, 12))
    axes = axes.ravel()
    
    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, preds = torch.max(outputs, 1)
            probabilities = F.softmax(outputs, dim=1)
            
            for i in range(images.size(0)):
                if images_shown >= count:
                    break
                
                true_label = labels[i].item()
                pred_label = preds[i].item()
                confidence = probabilities[i][pred_label].item() * 100
                
                # Skip if showing errors only and prediction is correct
                if show_errors_only and true_label == pred_label:
                    continue
                
                img = images[i].cpu().squeeze().numpy()
                # Denormalize for display
                img = img * 0.3530 + 0.2860
                img = np.clip(img, 0, 1)
                
                axes[images_shown].imshow(img, cmap='gray')
                color = 'green' if true_label == pred_label else 'red'
                title = f"True: {class_names[true_label]}\nPred: {class_names[pred_label]}\nConf: {confidence:.1f}%"
                axes[images_shown].set_title(title, color=color, fontsize=9)
                axes[images_shown].axis('off')
                
                images_shown += 1
            
            if images_shown >= count:
                break
    
    # Hide unused subplots
    for i in range(images_shown, count):
        axes[i].axis('off')
    
    plt.suptitle('Model Predictions on Test Set', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()

# Show correct predictions
print("Sample Correct Predictions:")
show_predictions(model, test_loader, config.CLASS_NAMES, device, count=16, show_errors_only=False)

# Show incorrect predictions if any
print("\nIncorrect Predictions (if any):")
show_predictions(model, test_loader, config.CLASS_NAMES, device, count=16, show_errors_only=True)


In [None]:
# ============================================================================
# 13. SAVE MODEL AND FINAL SUMMARY
# ============================================================================
# Save final model
final_model_path = config.MODEL_SAVE_PATH / "fashion_mnist_cnn_final.pth"
torch.save(model.state_dict(), final_model_path)
print(f"Final model saved to: {final_model_path}")

# Save training history
import json
history_path = config.MODEL_SAVE_PATH / "training_history.json"
with open(history_path, 'w') as f:
    # Convert numpy arrays to lists for JSON serialization
    json_history = {k: [float(v) for v in vals] for k, vals in history.items()}
    json.dump(json_history, f, indent=2)
print(f"Training history saved to: {history_path}")

# Print final summary
print(f"\n{'='*70}")
print("TRAINING SUMMARY")
print(f"{'='*70}")
print(f"Model Parameters: {model.get_model_size():,}")
print(f"Best Validation Accuracy: {max(history['val_acc']):.2f}%")
print(f"Final Test Accuracy: {test_acc:.2f}%")
print(f"Total Epochs Trained: {len(history['train_loss'])}")
print(f"Final Learning Rate: {history['learning_rate'][-1]:.6f}")
print(f"{'='*70}\n")
