# CLIP with MLP Head and Dropout Probabilistic Inference

This notebook demonstrates how to:
1. Add a trainable MLP layer to a frozen CLIP model
2. Train the MLP on a classification task
3. Use dropout for probabilistic inference and uncertainty estimation

In [1]:
import sys
sys.path.append('..') # add bayesvlm to path
sys.path.append('../models') # add models to path

In [17]:
# Alternative import order to avoid circular imports
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
from tabulate import tabulate

# Import torch modules
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader

# Import torchvision first to avoid circular imports
import torchvision
import torchvision.transforms as transforms

# Import torchmetrics after torchvision is fully loaded
try:
    from torchmetrics.classification import MulticlassCalibrationError, MulticlassAccuracy
except ImportError as e:
    print(f"Error importing torchmetrics: {e}")
    print("Try: pip install torchmetrics")

# Import custom modules
from clip_mlp import CLIPWithMLP
from bayesvlm.data.factory import DataModuleFactory
from bayesvlm.utils import get_transform

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

## Configuration

In [13]:
# Configuration
import torch

dataset = 'cifar10'  # or 'food101', 'cifar100', etc.
clip_model_name = "clip-base"  # Use bayesvlm naming: 'clip-base', 'clip-large', 'clip-huge'
batch_size = 32
num_workers = 4
learning_rate = 1e-3
num_epochs = 10
mlp_hidden_dim = 512
mlp_dropout_rate = 0.3
mlp_num_layers = 3
n_uncertainty_samples = 100

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

Using device: cpu



## Load Dataset

In [5]:
# Create data module
transform = get_transform('clip', 224)  # CLIP input size

f = DataModuleFactory(
    batch_size=batch_size,
    num_workers=num_workers,
    train_transform=transform,
    test_transform=transform,
    shuffle_train=True,
)
dm = f.create(dataset)
dm.setup()

# Get number of classes
num_classes = len(dm.class_prompts)
print(f"Dataset: {dataset}")
print(f"Number of classes: {num_classes}")
print(f"Train samples: {len(dm.train_ds)}")
print(f"Test samples: {len(dm.test_ds)}")

Dataset: cifar10
Number of classes: 10
Train samples: 40000
Test samples: 10000

Number of classes: 10
Train samples: 40000
Test samples: 10000


## Initialize Model

In [15]:
# Initialize CLIP with MLP model
model = CLIPWithMLP(
    clip_model_name=clip_model_name,
    num_classes=num_classes,
    mlp_hidden_dim=mlp_hidden_dim,
    mlp_dropout_rate=mlp_dropout_rate,
    mlp_num_layers=mlp_num_layers,
    freeze_clip=True,  # Only train the MLP head
    device=device
).to(device)

# Count trainable parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen parameters: {total_params - trainable_params:,}")



Total parameters: 151,807,756
Trainable parameters: 530,442
Frozen parameters: 151,277,314


## Training

In [26]:
# Initialize optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

# Training loop
model.train()
train_losses = []
train_accuracies = []

for epoch in range(num_epochs):
    epoch_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(dm.train_dataloader(), desc=f'Epoch {epoch+1}/{num_epochs}')
    
    for batch_idx, batch in enumerate(pbar):
        # Extract images and labels from the batch dictionary
        images = batch['image'].to(device)
        labels = batch['class_id'].to(device)
        
        # Forward pass
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Statistics
        epoch_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        # Update progress bar
        pbar.set_postfix({
            'Loss': f'{loss.item():.4f}',
            'Acc': f'{100.*correct/total:.2f}%'
        })
    
    avg_loss = epoch_loss / len(dm.train_dataloader())
    accuracy = 100. * correct / total
    
    train_losses.append(avg_loss)
    train_accuracies.append(accuracy)
    
    print(f'Epoch {epoch+1}: Loss = {avg_loss:.4f}, Accuracy = {accuracy:.2f}%')

Epoch 1/10:   0%|          | 0/1250 [00:00<?, ?it/s]
Epoch 1/10:   0%|          | 0/1250 [00:00<?, ?it/s]


IndexError: too many indices for tensor of dimension 4

## Plot Training Progress

In [None]:
# Plot training curves
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

ax1.plot(train_losses)
ax1.set_title('Training Loss')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.grid(True)

ax2.plot(train_accuracies)
ax2.set_title('Training Accuracy')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.grid(True)

plt.tight_layout()
plt.show()

## Evaluation with Uncertainty Quantification

In [23]:
def evaluate_with_uncertainty(model, dataloader, device, n_samples=100):
    """Evaluate model with uncertainty quantification"""
    model.eval()
    
    all_predictions = []
    all_uncertainties_epistemic = []
    all_uncertainties_aleatoric = []
    all_labels = []
    all_deterministic_preds = []
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc='Evaluating'):
            # Extract images and labels from the batch dictionary
            images = batch['image'].to(device)
            labels = batch['class_id'].to(device)
            
            # Deterministic prediction (no dropout)
            model.eval()
            det_logits = model(images)
            det_probs = F.softmax(det_logits, dim=-1)
            
            # Probabilistic prediction with uncertainty
            mean_probs, epistemic_unc, aleatoric_unc = model.predict_with_uncertainty(
                images, n_samples=n_samples
            )
            
            all_predictions.append(mean_probs.cpu())
            all_uncertainties_epistemic.append(epistemic_unc.cpu())
            all_uncertainties_aleatoric.append(aleatoric_unc.cpu())
            all_labels.append(labels.cpu())
            all_deterministic_preds.append(det_probs.cpu())
    
    # Concatenate all results
    predictions = torch.cat(all_predictions, dim=0)
    uncertainties_epistemic = torch.cat(all_uncertainties_epistemic, dim=0)
    uncertainties_aleatoric = torch.cat(all_uncertainties_aleatoric, dim=0)
    labels = torch.cat(all_labels, dim=0)
    det_predictions = torch.cat(all_deterministic_preds, dim=0)
    
    return predictions, uncertainties_epistemic, uncertainties_aleatoric, labels, det_predictions

In [None]:
# Evaluate on test set
predictions, epistemic_unc, aleatoric_unc, labels, det_predictions = evaluate_with_uncertainty(
    model, dm.test_dataloader(), device, n_samples=n_uncertainty_samples
)

## Compute Metrics

In [None]:
# Compute metrics
def compute_metrics(predictions, labels, num_classes):
    """Compute accuracy, ECE, and NLPD"""
    # Accuracy
    predicted_classes = predictions.argmax(dim=-1)
    accuracy = (predicted_classes == labels).float().mean().item()
    
    # ECE
    ece_metric = MulticlassCalibrationError(num_classes=num_classes, n_bins=20, norm='l1')
    ece = ece_metric(predictions, labels).item()
    
    # NLPD (Negative Log Predictive Density)
    log_probs = torch.log(predictions + 1e-8)
    nlpd = -log_probs.gather(1, labels.unsqueeze(1)).squeeze().mean().item()
    
    return accuracy, ece, nlpd

# Compute metrics for both approaches
acc_dropout, ece_dropout, nlpd_dropout = compute_metrics(predictions, labels, num_classes)
acc_det, ece_det, nlpd_det = compute_metrics(det_predictions, labels, num_classes)

print("\n=== Results ===")
data = [
    ["Accuracy (↑)", f"{acc_dropout:.4f}", f"{acc_det:.4f}"],
    ["ECE (↓)", f"{ece_dropout:.4f}", f"{ece_det:.4f}"],
    ["NLPD (↓)", f"{nlpd_dropout:.4f}", f"{nlpd_det:.4f}"]
]

print(tabulate(data, headers=["Metric", "Dropout (Bayesian)", "Deterministic"], tablefmt="simple"))

## Uncertainty Analysis

In [None]:
# Analyze uncertainty
predicted_classes = predictions.argmax(dim=-1)
correct_predictions = (predicted_classes == labels)

# Uncertainty statistics
epistemic_correct = epistemic_unc[correct_predictions]
epistemic_incorrect = epistemic_unc[~correct_predictions]
aleatoric_correct = aleatoric_unc[correct_predictions]
aleatoric_incorrect = aleatoric_unc[~correct_predictions]

print(f"\n=== Uncertainty Analysis ===")
print(f"Epistemic Uncertainty:")
print(f"  Correct predictions: {epistemic_correct.mean():.4f} ± {epistemic_correct.std():.4f}")
print(f"  Incorrect predictions: {epistemic_incorrect.mean():.4f} ± {epistemic_incorrect.std():.4f}")
print(f"\nAleatoric Uncertainty:")
print(f"  Correct predictions: {aleatoric_correct.mean():.4f} ± {aleatoric_correct.std():.4f}")
print(f"  Incorrect predictions: {aleatoric_incorrect.mean():.4f} ± {aleatoric_incorrect.std():.4f}")

## Visualize Uncertainty Distributions

In [None]:
# Plot uncertainty distributions
fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, figsize=(12, 8))

# Epistemic uncertainty
ax1.hist(epistemic_correct.numpy(), bins=50, alpha=0.7, label='Correct', density=True)
ax1.hist(epistemic_incorrect.numpy(), bins=50, alpha=0.7, label='Incorrect', density=True)
ax1.set_title('Epistemic Uncertainty Distribution')
ax1.set_xlabel('Epistemic Uncertainty')
ax1.set_ylabel('Density')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Aleatoric uncertainty
ax2.hist(aleatoric_correct.numpy(), bins=50, alpha=0.7, label='Correct', density=True)
ax2.hist(aleatoric_incorrect.numpy(), bins=50, alpha=0.7, label='Incorrect', density=True)
ax2.set_title('Aleatoric Uncertainty Distribution')
ax2.set_xlabel('Aleatoric Uncertainty')
ax2.set_ylabel('Density')
ax2.legend()
ax2.grid(True, alpha=0.3)

# Uncertainty scatter plot
colors = ['green' if c else 'red' for c in correct_predictions]
ax3.scatter(epistemic_unc.numpy(), aleatoric_unc.numpy(), c=colors, alpha=0.5, s=1)
ax3.set_xlabel('Epistemic Uncertainty')
ax3.set_ylabel('Aleatoric Uncertainty')
ax3.set_title('Epistemic vs Aleatoric Uncertainty')
ax3.grid(True, alpha=0.3)

# Confidence vs accuracy
confidence = predictions.max(dim=-1)[0]
ax4.scatter(confidence.numpy(), correct_predictions.float().numpy(), alpha=0.5, s=1)
ax4.set_xlabel('Prediction Confidence')
ax4.set_ylabel('Correct (1) / Incorrect (0)')
ax4.set_title('Confidence vs Correctness')
ax4.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

## Sample Predictions with Uncertainty

In [None]:
# Show some examples with high/low uncertainty
# Get a small subset for visualization
n_samples = 8
batch = next(iter(dm.test_dataloader()))
test_images = batch['image'][:n_samples].to(device)
test_labels = batch['class_id'][:n_samples]

# Get predictions and uncertainties
with torch.no_grad():
    sample_preds, sample_epistemic, sample_aleatoric = model.predict_with_uncertainty(
        test_images, n_samples=n_uncertainty_samples
    )

# Convert to numpy for plotting
test_images_np = test_images.cpu().numpy()
sample_preds_np = sample_preds.cpu().numpy()
sample_epistemic_np = sample_epistemic.cpu().numpy()
sample_aleatoric_np = sample_aleatoric.cpu().numpy()
test_labels_np = test_labels.cpu().numpy()

# Plot examples
fig, axes = plt.subplots(2, 4, figsize=(16, 8))
axes = axes.flatten()

for i in range(n_samples):
    # Normalize image for display
    img = test_images_np[i].transpose(1, 2, 0)
    img = (img - img.min()) / (img.max() - img.min())
    
    predicted_class = sample_preds_np[i].argmax()
    confidence = sample_preds_np[i].max()
    true_class = test_labels_np[i]
    
    axes[i].imshow(img)
    axes[i].set_title(
        f'True: {true_class}, Pred: {predicted_class}\n'
        f'Conf: {confidence:.3f}\n'
        f'Epist: {sample_epistemic_np[i]:.3f}\n'
        f'Aleat: {sample_aleatoric_np[i]:.3f}'
    )
    axes[i].axis('off')

plt.tight_layout()
plt.show()

## Summary

This notebook demonstrated:

1. **Model Architecture**: Added a trainable MLP head to a frozen CLIP model
2. **Training**: Trained only the MLP parameters while keeping CLIP frozen
3. **Uncertainty Quantification**: Used dropout to estimate both epistemic and aleatoric uncertainty
4. **Evaluation**: Compared deterministic vs probabilistic predictions
5. **Analysis**: Visualized uncertainty distributions and their correlation with prediction correctness

Key observations:
- The Bayesian (dropout) approach typically provides better calibration (lower ECE)
- Higher uncertainty is generally associated with incorrect predictions
- Epistemic uncertainty captures model uncertainty, while aleatoric captures data uncertainty
- The approach allows for uncertainty-aware decision making in downstream applications