# Fine-tuning B-cos PIP-Net with Scoring-Sheet Classification

This notebook fine-tunes the pre-trained B-cos PIP-Net model for classification tasks using a scoring-sheet approach.

## Key Features:
- **Scoring-Sheet Classification**: Linear classifier where weights indicate prototype-class relevance
- **Negative Log-Likelihood Loss**: Standard classification loss with softmax activation
- **Prototype Purity Evaluation**: Measures consistency of prototype activations
- **Interpretable Results**: Class scores computed as weighted sum of prototype presences

## Mathematical Foundation:
- Class score for class j: `score_j = Σ(p_i * w_ij)` where p_i is prototype presence and w_ij is relevance weight
- Final predictions: `softmax(scores)` to get class confidence scores

## 1. Setup and Installation

In [None]:
# Check GPU availability
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

In [None]:
# Install required packages
!pip install torch torchvision torchaudio
!pip install tqdm tensorboard
!pip install matplotlib seaborn
!pip install scikit-learn
!pip install pillow pandas

In [None]:
# Clone the repository (if not already done)
import os
if not os.path.exists('improved-Bcos-PIPNet'):
    !git clone https://github.com/your-username/improved-Bcos-PIPNet.git
%cd improved-Bcos-PIPNet

# Initialize submodules
!git submodule update --init --recursive

In [None]:
import sys
import os

# Add source directories to Python path
sys.path.append('src')
sys.path.append('B-cos')
sys.path.append('PIPNet')

# Verify paths
print("Current directory:", os.getcwd())
print("Fine-tuning script exists:", os.path.exists('src/train_finetune.py'))
print("Classifier module exists:", os.path.exists('src/finetune_classifier.py'))

## 2. Upload Pre-trained Model

Upload your pre-trained B-cos PIP-Net checkpoint from the previous pre-training step.

In [None]:
from google.colab import files
import os

# Create checkpoints directory
os.makedirs('./checkpoints/pretrained', exist_ok=True)

print("Please upload your pre-trained B-cos PIP-Net checkpoint (.pth file):")
uploaded = files.upload()

# Move uploaded file to checkpoints directory
for filename in uploaded.keys():
    if filename.endswith('.pth'):
        os.rename(filename, f'./checkpoints/pretrained/{filename}')
        pretrained_path = f'./checkpoints/pretrained/{filename}'
        print(f"Pre-trained model saved to: {pretrained_path}")
        break
else:
    # Use a default path if no file uploaded (for testing)
    pretrained_path = './checkpoints/pretrained/final_model.pth'
    print(f"No pre-trained model uploaded. Will use: {pretrained_path}")
    print("Note: Make sure to upload your pre-trained checkpoint for actual fine-tuning.")

## 3. Import Modules and Setup Model

In [None]:
# Import fine-tuning modules
from src.finetune_classifier import (
    create_scoring_sheet_classifier,
    FineTuningLoss,
    evaluate_model,
    compute_prototype_purity
)
from src.datasets import SixChannelDataset

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
import numpy as np
import json

print("All modules imported successfully!")

In [None]:
# Test loading pre-trained model (if checkpoint exists)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

if os.path.exists(pretrained_path):
    print(f"Loading pre-trained model from: {pretrained_path}")
    
    # Load checkpoint to check contents
    checkpoint = torch.load(pretrained_path, map_location='cpu')
    print(f"Checkpoint keys: {list(checkpoint.keys())}")
    
    if 'args' in checkpoint:
        pretrained_args = checkpoint['args']
        print(f"Pre-trained model config:")
        print(f"  Backbone: {pretrained_args.get('backbone', 'unknown')}")
        print(f"  Prototypes: {pretrained_args.get('num_prototypes', 'unknown')}")
        print(f"  Dataset: {pretrained_args.get('dataset', 'unknown')}")
    
    # Test model creation for CIFAR-10 (10 classes)
    try:
        test_model = create_scoring_sheet_classifier(
            pretrained_path=pretrained_path,
            num_classes=10,  # CIFAR-10
            freeze_prototypes=True
        )
        print(f"✓ Successfully created classifier model")
        print(f"  Total parameters: {sum(p.numel() for p in test_model.parameters()):,}")
        print(f"  Trainable parameters: {sum(p.numel() for p in test_model.parameters() if p.requires_grad):,}")
        del test_model  # Free memory
    except Exception as e:
        print(f"✗ Error creating model: {e}")
else:
    print(f"Pre-trained model not found at: {pretrained_path}")
    print("Please upload your pre-trained checkpoint first.")

## 4. Dataset Setup for Fine-tuning

In [None]:
# Configuration for fine-tuning
class FineTuningConfig:
    # Dataset
    dataset = 'cifar10'  # Change to 'cub' for CUB-200-2011
    data_dir = './data'
    batch_size = 64  # Reduced for Colab
    num_workers = 2
    img_size = 224  # For CUB dataset
    
    # Model
    freeze_prototypes = True  # Freeze prototype learning components
    
    # Training
    epochs = 30  # Reduced for demo
    lr = 1e-4  # Learning rate for classifier
    lr_backbone = 1e-5  # Learning rate for backbone (if not frozen)
    weight_decay = 1e-4
    warmup_epochs = 3
    
    # Loss weights
    nll_weight = 1.0
    l1_weight = 0.0001  # L1 regularization on classifier weights
    orthogonal_weight = 0.0  # Orthogonal regularization
    
    # Device and logging
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    log_interval = 20
    eval_interval = 5
    
    # Directories
    log_dir = './logs/colab_finetune'
    save_dir = './checkpoints/colab_finetune'
    
    # Other
    seed = 42

args = FineTuningConfig()
print(f"Fine-tuning configuration:")
print(f"  Dataset: {args.dataset}")
print(f"  Device: {args.device}")
print(f"  Epochs: {args.epochs}")
print(f"  Batch size: {args.batch_size}")
print(f"  Freeze prototypes: {args.freeze_prototypes}")
print(f"  Learning rate: {args.lr}")

In [None]:
# Create dataloaders for CIFAR-10
print("Setting up CIFAR-10 dataset for fine-tuning...")

# Data transforms with augmentation for training
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomCrop(32, padding=4),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], 
                       std=[0.2023, 0.1994, 0.2010])
])

test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], 
                       std=[0.2023, 0.1994, 0.2010])
])

# Load CIFAR-10 datasets
train_dataset_base = torchvision.datasets.CIFAR10(
    root=args.data_dir, train=True, download=True, transform=train_transform
)
test_dataset_base = torchvision.datasets.CIFAR10(
    root=args.data_dir, train=False, download=True, transform=test_transform
)

# Wrap with 6-channel transformation
train_dataset = SixChannelDataset(train_dataset_base)
test_dataset = SixChannelDataset(test_dataset_base)

# Create dataloaders
train_loader = DataLoader(
    train_dataset, batch_size=args.batch_size, shuffle=True,
    num_workers=args.num_workers, pin_memory=True
)
test_loader = DataLoader(
    test_dataset, batch_size=args.batch_size, shuffle=False,
    num_workers=args.num_workers, pin_memory=True
)

num_classes = 10  # CIFAR-10
class_names = ['airplane', 'automobile', 'bird', 'cat', 'deer', 
               'dog', 'frog', 'horse', 'ship', 'truck']

print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Number of classes: {num_classes}")
print(f"Classes: {class_names}")

# Test data loading
sample_input, sample_label = next(iter(train_loader))
print(f"Sample input shape: {sample_input.shape} (6-channel)")
print(f"Sample label shape: {sample_label.shape}")

In [None]:
# Create directories
os.makedirs(args.log_dir, exist_ok=True)
os.makedirs(args.save_dir, exist_ok=True)
print(f"Created directories: {args.log_dir}, {args.save_dir}")

## 5. Model Setup and Architecture

In [None]:
# Set seed for reproducibility
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
np.random.seed(args.seed)

# Create the fine-tuning model
print("Creating scoring-sheet classifier...")
model = create_scoring_sheet_classifier(
    pretrained_path=pretrained_path,
    num_classes=num_classes,
    freeze_prototypes=args.freeze_prototypes
).to(args.device)

print(f"Model created successfully!")
print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Trainable parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")
print(f"Number of prototypes: {model.num_prototypes}")
print(f"Number of classes: {model.num_classes}")

# Test forward pass
with torch.no_grad():
    test_input = torch.randn(2, 6, 32, 32).to(args.device)
    proto_features, pooled_features, logits, class_scores = model(test_input)
    
    print(f"\nForward pass test:")
    print(f"  Prototype features: {proto_features.shape}")
    print(f"  Pooled features: {pooled_features.shape}")
    print(f"  Logits: {logits.shape}")
    print(f"  Class scores: {class_scores.shape}")
    print(f"  Class scores sum: {class_scores.sum(dim=1)}")

In [None]:
# Visualize the scoring sheet concept
print("Scoring-Sheet Classification Explanation:")
print("==========================================")
print("1. Input image → 6-channel encoding [r,g,b,1-r,1-g,1-b]")
print("2. B-cos backbone → Extract interpretable features")
print("3. Prototype layer → Generate prototype activations p_i")
print("4. Global pooling → Get prototype presence scores")
print("5. Linear classifier → Compute class scores: score_j = Σ(p_i * w_ij)")
print("6. Softmax → Convert to class probabilities")
print("\nWhere:")
print("  - p_i: presence score of prototype i")
print("  - w_ij: learned relevance weight of prototype i to class j")
print("  - The classifier weights form a 'scoring sheet' showing prototype-class relationships")

# Show current classifier weights (before training)
classifier_weights = model.get_prototype_class_relevance()
print(f"\nClassifier weight matrix shape: {classifier_weights.shape}")
print(f"(rows = prototypes, columns = classes)")
print(f"Weight statistics:")
print(f"  Mean: {classifier_weights.mean().item():.4f}")
print(f"  Std: {classifier_weights.std().item():.4f}")
print(f"  Min: {classifier_weights.min().item():.4f}")
print(f"  Max: {classifier_weights.max().item():.4f}")

## 6. Training Setup

In [None]:
# Create loss function
criterion = FineTuningLoss(
    nll_weight=args.nll_weight,
    l1_weight=args.l1_weight,
    orthogonal_weight=args.orthogonal_weight
)

# Create optimizer
if args.freeze_prototypes:
    # Only optimize classifier
    optimizer = optim.AdamW(model.classifier.parameters(), 
                          lr=args.lr, weight_decay=args.weight_decay)
    print("Optimizer: Only training classifier (prototypes frozen)")
else:
    # Different learning rates for classifier and backbone
    optimizer = optim.AdamW([
        {'params': model.classifier.parameters(), 'lr': args.lr},
        {'params': list(model.backbone.parameters()) + list(model.prototype_layer.parameters()), 
         'lr': args.lr_backbone}
    ], weight_decay=args.weight_decay)
    print("Optimizer: Training classifier + backbone with different learning rates")

print(f"Loss function: NLL + L1({args.l1_weight}) + Orthogonal({args.orthogonal_weight})")
print(f"Learning rate: {args.lr}")
print(f"Weight decay: {args.weight_decay}")

In [None]:
def adjust_learning_rate(optimizer, epoch, args):
    """Cosine learning rate schedule with warmup"""
    if epoch < args.warmup_epochs:
        lr_mult = epoch / args.warmup_epochs
    else:
        lr_mult = 0.5 * (1. + np.cos(np.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
    
    for i, param_group in enumerate(optimizer.param_groups):
        if i == 0:  # Classifier parameters
            param_group['lr'] = args.lr * lr_mult
        else:  # Backbone parameters (if not frozen)
            param_group['lr'] = args.lr_backbone * lr_mult
    
    return args.lr * lr_mult

## 7. Fine-tuning Training Loop

In [None]:
# Training loop
print("Starting fine-tuning...")

# Lists to store metrics
train_losses = []
train_accuracies = []
test_accuracies = []
prototype_purities = []
learning_rates = []

best_accuracy = 0.0

for epoch in range(args.epochs):
    # Adjust learning rate
    lr = adjust_learning_rate(optimizer, epoch, args)
    learning_rates.append(lr)
    
    # Training phase
    model.train()
    
    # Freeze prototype components if specified
    if args.freeze_prototypes:
        model.backbone.eval()
        model.prototype_layer.eval()
    
    epoch_loss = 0.0
    epoch_correct = 0
    epoch_total = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{args.epochs}')
    
    for batch_idx, (inputs, targets) in enumerate(pbar):
        inputs, targets = inputs.to(args.device), targets.to(args.device)
        
        optimizer.zero_grad()
        
        # Forward pass
        proto_features, pooled_features, logits, class_scores = model(inputs)
        
        # Compute loss
        loss_dict = criterion(logits, targets, model.classifier.weight)
        loss = loss_dict['total_loss']
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Statistics
        epoch_loss += loss.item()
        predicted = torch.argmax(class_scores, dim=1)
        epoch_total += targets.size(0)
        epoch_correct += (predicted == targets).sum().item()
        
        # Update progress bar
        current_acc = 100. * epoch_correct / epoch_total
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Acc': f'{current_acc:.2f}%',
            'LR': f'{lr:.6f}'
        })
    
    # Calculate epoch metrics
    avg_loss = epoch_loss / len(train_loader)
    train_accuracy = 100. * epoch_correct / epoch_total
    
    train_losses.append(avg_loss)
    train_accuracies.append(train_accuracy)
    
    # Evaluation phase
    if epoch % args.eval_interval == 0 or epoch == args.epochs - 1:
        print("\nEvaluating model...")
        eval_metrics = evaluate_model(model, test_loader, args.device, num_classes)
        purity_metrics = compute_prototype_purity(model, test_loader, args.device, num_classes)
        
        test_accuracy = eval_metrics['accuracy']
        mean_purity = purity_metrics['mean_purity']
        
        test_accuracies.append(test_accuracy)
        prototype_purities.append(mean_purity)
        
        # Check if best model
        if test_accuracy > best_accuracy:
            best_accuracy = test_accuracy
            print(f"★ New best accuracy: {best_accuracy:.2f}%")
        
        print(f"Epoch {epoch+1}: Train Loss={avg_loss:.4f}, Train Acc={train_accuracy:.2f}%, "
              f"Test Acc={test_accuracy:.2f}%, Purity={mean_purity:.3f}")
        print(f"  Active prototypes/sample: {eval_metrics['active_prototypes_per_sample']:.1f}")
        print(f"  High purity prototypes: {purity_metrics['high_purity_prototypes']}")
        print(f"  Unused prototypes: {eval_metrics['num_unused_prototypes']}")
    else:
        print(f"Epoch {epoch+1}: Train Loss={avg_loss:.4f}, Train Acc={train_accuracy:.2f}%, LR={lr:.6f}")

print(f"\nFine-tuning completed!")
print(f"Best test accuracy: {best_accuracy:.2f}%")

## 8. Results Analysis and Visualization

In [None]:
# Plot training curves
fig, axes = plt.subplots(2, 2, figsize=(15, 10))

# Training loss
axes[0,0].plot(train_losses)
axes[0,0].set_title('Training Loss')
axes[0,0].set_xlabel('Epoch')
axes[0,0].set_ylabel('Loss')
axes[0,0].grid(True)

# Training accuracy
axes[0,1].plot(train_accuracies, label='Train', color='blue')
if test_accuracies:
    eval_epochs = [i * args.eval_interval for i in range(len(test_accuracies))]
    if len(eval_epochs) != len(test_accuracies):
        eval_epochs = list(range(0, len(train_accuracies), args.eval_interval))[:len(test_accuracies)]
    axes[0,1].plot(eval_epochs, test_accuracies, label='Test', color='red', marker='o')
axes[0,1].set_title('Accuracy')
axes[0,1].set_xlabel('Epoch')
axes[0,1].set_ylabel('Accuracy (%)')
axes[0,1].legend()
axes[0,1].grid(True)

# Prototype purity
if prototype_purities:
    axes[1,0].plot(eval_epochs[:len(prototype_purities)], prototype_purities, color='green', marker='s')
    axes[1,0].set_title('Prototype Purity')
    axes[1,0].set_xlabel('Epoch')
    axes[1,0].set_ylabel('Mean Purity')
    axes[1,0].grid(True)

# Learning rate
axes[1,1].plot(learning_rates, color='purple')
axes[1,1].set_title('Learning Rate')
axes[1,1].set_xlabel('Epoch')
axes[1,1].set_ylabel('Learning Rate')
axes[1,1].grid(True)

plt.tight_layout()
plt.show()

print(f"Final metrics:")
print(f"  Training accuracy: {train_accuracies[-1]:.2f}%")
if test_accuracies:
    print(f"  Test accuracy: {test_accuracies[-1]:.2f}%")
if prototype_purities:
    print(f"  Prototype purity: {prototype_purities[-1]:.3f}")
print(f"  Best test accuracy: {best_accuracy:.2f}%")

In [None]:
# Analyze the learned scoring sheet
print("Analyzing Learned Scoring Sheet")
print("==============================")

model.eval()
classifier_weights = model.get_prototype_class_relevance()  # Shape: (prototypes, classes)

print(f"Classifier weight matrix: {classifier_weights.shape}")
print(f"Weight statistics:")
print(f"  Mean: {classifier_weights.mean().item():.4f}")
print(f"  Std: {classifier_weights.std().item():.4f}")
print(f"  Min: {classifier_weights.min().item():.4f}")
print(f"  Max: {classifier_weights.max().item():.4f}")

# Find most important prototypes for each class
print(f"\nTop 3 prototypes for each class:")
for class_idx, class_name in enumerate(class_names):
    class_weights = classifier_weights[:, class_idx]
    top_prototypes = torch.topk(class_weights, k=3)
    print(f"  {class_name}:")
    for i, (weight, proto_idx) in enumerate(zip(top_prototypes.values, top_prototypes.indices)):
        print(f"    {i+1}. Prototype {proto_idx.item()}: {weight.item():.4f}")

# Visualize scoring sheet as heatmap
plt.figure(figsize=(12, 8))
weights_np = classifier_weights.detach().cpu().numpy()

# Show only top prototypes for clarity
n_top_prototypes = min(50, model.num_prototypes)
prototype_importance = weights_np.max(axis=1)  # Max weight across all classes
top_prototype_indices = np.argsort(prototype_importance)[-n_top_prototypes:]

plt.imshow(weights_np[top_prototype_indices].T, aspect='auto', cmap='viridis')
plt.colorbar(label='Relevance Weight')
plt.xlabel('Top Prototypes')
plt.ylabel('Classes')
plt.title(f'Scoring Sheet: Prototype-Class Relevance Weights (Top {n_top_prototypes} Prototypes)')
plt.yticks(range(num_classes), class_names)
plt.tight_layout()
plt.show()

# Prototype usage statistics
active_prototypes = (weights_np.max(axis=1) > 0.1).sum()
print(f"\nPrototype usage:")
print(f"  Active prototypes (weight > 0.1): {active_prototypes}/{model.num_prototypes}")
print(f"  Utilization rate: {100 * active_prototypes / model.num_prototypes:.1f}%")

In [None]:
# Show example predictions with explanations
print("Example Predictions with Scoring-Sheet Explanations")
print("==================================================")

model.eval()
with torch.no_grad():
    # Get a batch of test samples
    test_inputs, test_targets = next(iter(test_loader))
    test_inputs, test_targets = test_inputs.to(args.device), test_targets.to(args.device)
    
    # Get predictions and explanations
    proto_features, pooled_features, logits, class_scores = model(test_inputs, inference=True)
    
    # Show explanations for first few samples
    n_examples = min(3, test_inputs.size(0))
    
    for i in range(n_examples):
        sample_input = test_inputs[i:i+1]
        true_class = test_targets[i].item()
        
        # Get explanation
        explanation = model.get_scoring_sheet_explanation(sample_input)
        
        predicted_class = explanation['predicted_class'][0].item()
        class_scores_sample = explanation['class_scores'][0]
        contributions = explanation['contributions'][0]
        presences = explanation['prototype_presences'][0]
        
        print(f"\nExample {i+1}:")
        print(f"  True class: {class_names[true_class]}")
        print(f"  Predicted class: {class_names[predicted_class]}")
        print(f"  Confidence: {class_scores_sample[predicted_class]:.3f}")
        print(f"  Correct: {'✓' if predicted_class == true_class else '✗'}")
        
        # Show top contributing prototypes
        top_contributions = torch.topk(contributions, k=5)
        print(f"  Top contributing prototypes:")
        for j, (contrib, proto_idx) in enumerate(zip(top_contributions.values, top_contributions.indices)):
            proto_presence = presences[proto_idx].item()
            class_weight = classifier_weights[proto_idx, predicted_class].item()
            print(f"    {j+1}. P{proto_idx.item()}: presence={proto_presence:.3f} × weight={class_weight:.3f} = {contrib.item():.3f}")
        
        # Show class scores breakdown
        print(f"  Class scores:")
        sorted_scores, sorted_indices = torch.sort(class_scores_sample, descending=True)
        for j in range(min(3, num_classes)):
            class_idx = sorted_indices[j].item()
            score = sorted_scores[j].item()
            print(f"    {class_names[class_idx]}: {score:.3f}")

## 9. Save Results and Model

In [None]:
# Run final comprehensive evaluation
print("Running final evaluation...")

final_eval_metrics = evaluate_model(model, test_loader, args.device, num_classes)
final_purity_metrics = compute_prototype_purity(model, test_loader, args.device, num_classes)

print(f"\n" + "="*50)
print(f"FINAL RESULTS")
print(f"="*50)
print(f"Test Accuracy: {final_eval_metrics['accuracy']:.2f}%")
print(f"Per-class accuracies:")
for i, (class_name, acc) in enumerate(zip(class_names, final_eval_metrics['class_accuracies'])):
    print(f"  {class_name}: {acc:.1f}%")

print(f"\nPrototype Analysis:")
print(f"  Mean prototype purity: {final_purity_metrics['mean_purity']:.3f}")
print(f"  High purity prototypes (>0.5): {final_purity_metrics['high_purity_prototypes']}")
print(f"  Active prototypes per sample: {final_eval_metrics['active_prototypes_per_sample']:.1f}")
print(f"  Unused prototypes: {final_eval_metrics['num_unused_prototypes']}")
print(f"  Mean prototype activation: {final_eval_metrics['mean_prototype_activation']:.3f}")

print(f"\nModel Interpretability:")
active_prototypes = (classifier_weights.max(dim=1)[0] > 0.1).sum().item()
print(f"  Scoring sheet utilization: {100 * active_prototypes / model.num_prototypes:.1f}%")
print(f"  Average prototypes per class: {classifier_weights.gt(0.1).sum().item() / num_classes:.1f}")

print(f"\nTraining Summary:")
print(f"  Training epochs: {args.epochs}")
print(f"  Best test accuracy: {best_accuracy:.2f}%")
print(f"  Final training accuracy: {train_accuracies[-1]:.2f}%")
print(f"  Prototypes frozen: {args.freeze_prototypes}")

In [None]:
# Save the fine-tuned model
final_checkpoint = {
    'epoch': args.epochs,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'best_accuracy': best_accuracy,
    'final_eval_metrics': final_eval_metrics,
    'final_purity_metrics': final_purity_metrics,
    'training_args': vars(args),
    'class_names': class_names,
    'num_classes': num_classes,
    'num_prototypes': model.num_prototypes
}

model_path = os.path.join(args.save_dir, 'finetuned_model_final.pth')
torch.save(final_checkpoint, model_path)
print(f"Final model saved to: {model_path}")

# Save training history and results
results = {
    'train_losses': train_losses,
    'train_accuracies': train_accuracies,
    'test_accuracies': test_accuracies,
    'prototype_purities': prototype_purities,
    'learning_rates': learning_rates,
    'best_accuracy': best_accuracy,
    'final_accuracy': final_eval_metrics['accuracy'],
    'final_purity': final_purity_metrics['mean_purity'],
    'training_config': vars(args),
    'class_names': class_names
}

with open(os.path.join(args.save_dir, 'training_results.json'), 'w') as f:
    # Convert numpy arrays to lists
    results_serializable = {}
    for k, v in results.items():
        if hasattr(v, 'tolist'):
            results_serializable[k] = v.tolist()
        else:
            results_serializable[k] = v
    json.dump(results_serializable, f, indent=2)

print("Training results saved!")

In [None]:
# Create downloadable package
import zipfile

zip_path = 'bcos_pipnet_finetuned_results.zip'
with zipfile.ZipFile(zip_path, 'w') as zipf:
    # Add model and results
    for file in os.listdir(args.save_dir):
        if file.endswith(('.pth', '.json')):
            zipf.write(os.path.join(args.save_dir, file), f'finetuned/{file}')
    
    # Add this notebook
    if os.path.exists('BcosPIPNet_FineTuning.ipynb'):
        zipf.write('BcosPIPNet_FineTuning.ipynb', 'BcosPIPNet_FineTuning.ipynb')

print(f"Results packaged in: {zip_path}")
print(f"Download this file to save your fine-tuning results!")

# Show file sizes
if os.path.exists(model_path):
    size_mb = os.path.getsize(model_path) / (1024*1024)
    print(f"Fine-tuned model size: {size_mb:.1f} MB")

if os.path.exists(zip_path):
    zip_size_mb = os.path.getsize(zip_path) / (1024*1024)
    print(f"Results archive size: {zip_size_mb:.1f} MB")

## 10. Summary and Next Steps

### What we accomplished:

1. **✅ Scoring-Sheet Classification**: Implemented linear classifier where weights represent prototype-class relevance
2. **✅ Standard NLL Loss**: Used negative log-likelihood with softmax for classification training
3. **✅ Interpretable Weights**: Learned weights show which prototypes are relevant to each class
4. **✅ Weighted Sum Computation**: Class scores = Σ(prototype_presence × class_weight)
5. **✅ Softmax Activation**: Applied during training to convert logits to class probabilities
6. **✅ Purity Evaluation**: Implemented prototype purity metrics inspired by PIP-Net

### Key Results:
- The model learns interpretable prototype-class relationships through the scoring sheet
- Each prediction can be explained as a weighted combination of prototype activations
- Prototype purity measures show how consistently prototypes activate for specific classes
- The approach maintains both accuracy and interpretability

### How to use the fine-tuned model:

```python
# Load the fine-tuned model
checkpoint = torch.load('finetuned_model_final.pth')
model = create_scoring_sheet_classifier(
    pretrained_path='path_to_pretrained.pth',
    num_classes=checkpoint['num_classes']
)
model.load_state_dict(checkpoint['model_state_dict'])

# Get predictions with explanations
explanation = model.get_scoring_sheet_explanation(input_image)
print(f"Predicted class: {explanation['predicted_class']}")
print(f"Top contributing prototypes: {explanation['contributions']}")
```

### Next Steps:
1. **Prototype Visualization**: Create visualizations of what each prototype represents
2. **Attention Maps**: Generate heatmaps showing where prototypes activate in images
3. **Cross-Dataset Transfer**: Test the model on other datasets
4. **Human Evaluation**: Study how interpretable the prototype explanations are to humans

The fine-tuned model now provides both accurate predictions and interpretable explanations through the scoring-sheet mechanism!