Practical Work: Out-of-Distribution Detection, OOD Scoring Methods, and Neural Collapse.

# 1. Training a ResNet18 classifier on CIFAR-100 with PyTorch

In [32]:
import torch
import torchvision
from torchvision.models import resnet18
import torchvision.transforms as transforms
import torch.nn.functional as F
import matplotlib.pyplot as plt


In [None]:
# Hyperparameters
batch_size = 8
num_workers = 8
lr = 1e-3
momentum = 0.95
weight_decay = 5e-4
epochs = 200

# Skip training and load model?
skip_training = False  # True to skip

# Random seed
torch.manual_seed(42)

<torch._C.Generator at 0x79ec681c5df0>

In [34]:

# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Data transformations
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
    transforms.RandomErasing(p=0.25),  # 增加RandomErasing概率
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])


# Load CIFAR-100
train_set = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(train_set, batch_size=batch_size, shuffle=True, num_workers=num_workers)

test_set = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=num_workers)

# Model configuration
model = resnet18(False)
model.fc = torch.nn.Sequential(
    torch.nn.Dropout(p=0.5),  # 增加Dropout从0.3到0.5以减少过拟合
    torch.nn.Linear(model.fc.in_features, 100)
)
model = model.cuda()

# Optim, Loss, Scheduler
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay)
criterion = torch.nn.CrossEntropyLoss()
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[60, 120, 160], gamma=0.2)

# Training loop
train_losses = []
train_accs = []
test_losses = []
test_accs = []

# Training loop
def train(epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for i, data in enumerate(train_loader, 0):
        inputs, labels = data
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

    epoch_loss = running_loss / len(train_loader)
    epoch_acc = 100. * correct / total
    print(f"Epoch {epoch}: Loss: {epoch_loss:.4f} | Acc: {epoch_acc:.2f}%")
    return epoch_loss, epoch_acc

# Simple Test Loop
def test(epoch):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    with torch.no_grad():
        for data in test_loader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    epoch_loss = running_loss / len(test_loader)
    epoch_acc = 100. * correct / total
    print(f"Test Accuracy: {epoch_acc:.2f}%")
    return epoch_loss, epoch_acc

Using device: cuda


In [35]:

# Load model
if skip_training:  # Change to True to load the model
    try:
        model.load_state_dict(torch.load('resnet18_cifar100.pth', map_location=device))
        # plot_path = "training_curves.png"
        # img = plt.imread(plot_path)
        # plt.figure(figsize=(12, 5))
        # plt.imshow(img)
        # plt.axis("off")
        # plt.show()
    except FileNotFoundError:
        assert False, "Model or training curves not found. Please train the model first."
else:
    # Main training and testing loop
    for epoch in range(epochs):
        train_loss, train_acc = train(epoch)
        test_loss, test_acc = test(epoch)
        
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        test_losses.append(test_loss)
        test_accs.append(test_acc)
        
        scheduler.step()

    # Save model
    torch.save(model.state_dict(), 'resnet18_cifar100.pth')

    # Plotting
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(test_losses, label='Test Loss')
    plt.title('Loss over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Acc')
    plt.plot(test_accs, label='Test Acc')
    plt.title('Accuracy over Epochs')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.tight_layout()
    plt.savefig("training_curves.png", dpi=200, bbox_inches="tight")
    plt.show()





# 2. Implement and compare OOD scores
Implement and compare the following OOD scores:

· Max Softmax Probability (MSP)

· Maximum Logit Score

· Mahalanobis

· Energy Score

· ViM

# 2.1 Max Softmax Probability (MSP). 
If the classifier assigns low maximum probability, it is unsure or the input may be OOD.

Score: the maximum softmax probability
$$s_{max-prob}(x)=\mathop{max}\limits_{c} p(y=c|x)=\mathop{max}\limits_{c} softmax(z_c(x))$$

In [36]:
def msp_score(model, loader):
    model.eval()
    all_scores = []

    with torch.no_grad():
        for data in loader:
            inputs, _ = data
            inputs = inputs.to(device)
            logits = model(inputs)
            probs = F.softmax(logits, dim=1)

            max_probs, _ = torch.max(probs, dim=1)
            
            all_scores.append(max_probs.cpu())
            
    return torch.cat(all_scores)

msp_scores = msp_score(model, test_loader)

print(f"MSP scores: {msp_scores}")

MSP scores: tensor([0.2913, 0.3016, 0.4502,  ..., 1.0000, 0.4367, 0.3209])


# 3. Max Logit technique. 
Logits reflect raw model evidence before the softmax normalization; using logits avoids saturating effects of softmax.
$$s_{max-logit}(x)= \mathop{max}\limits_{c}\ z_c(x)$$

In [37]:
def max_logit_score(model, loader):
    model.eval()
    all_scores = []
    with torch.no_grad():
        for data in loader:
            inputs, _ = data
            inputs = inputs.to(device)
            logits = model(inputs)
            max_logits, _ = torch.max(logits, dim=1)
            all_scores.append(max_logits.cpu())
    return torch.cat(all_scores)

max_logit_scores = max_logit_score(model, test_loader)
print(f"Max logit scores: {max_logit_scores}")

Max logit scores: tensor([ 6.6087,  5.7924,  8.6517,  ..., 18.1743,  7.1058,  6.6317])


# 4. Energy-based OOD score.

Energy provides a scalar that correlates with the model’s
total evidence across classes. Lower energy (more negative) implies
stronger evidence; higher energy (less negative or positive) can
indicate OOD. Energy score derived from the logits: a common definition:

$$  E(x) = -log(\sum_{c}{ }e^{z_c(x)}) = -LSE(z(x))$$

With temperature $T>0$ one can use:
$$ E_T(x) = -T*log(\sum_{c}{ }e^{z_c(x)/T})$$

In [38]:
def energy_score(model, loader, Temperature=1):
    model.eval()
    all_scores = []
    with torch.no_grad():
        for data in loader:
            inputs, _ = data
            inputs = inputs.to(device)
            logits = model(inputs)
            # Though in the formula, E is negative, it is kept positive here for easier comparison.
            # Note that Temperature is greater than 0.
            if Temperature <= 0:
                raise ValueError("Temperature must be greater than 0.")
            energy = Temperature * torch.logsumexp(logits / Temperature, dim=1)
            all_scores.append(energy.cpu())
    return torch.cat(all_scores)

energy_scores = energy_score(model, test_loader, Temperature=1)
print(f"Energy scores: {energy_scores}")

Energy scores: tensor([ 7.8421,  6.9911,  9.4498,  ..., 18.1744,  7.9342,  7.7682])


# 5. Mahalanobis Distance-based OOD Detection

For each class $c$, compute the class mean in feature space:
$$ \mu _c = 1/N_c \sum_{i:y_i=c}{}{f(x_i)} $$

Then, estimate a shared covariance matrix $\Sigma$ (or per-class covariance), typically eh empirical covariance of features across the training set.

Therefore, the Mahalanobis score calculated as:
$$ d_{ M a h a } ( x ) = \operatorname* { m i n } _ { c } \, ( f ( x ) - \mu _ { c } ) ^ { \top } \Sigma ^ { - 1 } ( f ( x ) - \mu _ { c } ) $$

For each test sample, calculate the distance to all 100 class centers and take the minimum distance as the OOD score (the smaller the distance, the more similar to an in-distribution sample).

In [39]:
features_buffer = []

def hook_fn(module, input, output):
    # output shape is [batch, 512, 1, 1], flatten it to [batch, 512]
    features_buffer.append(output.view(output.size(0), -1))

# General hook registration to the top avgpool layer (works for ResNet18/34/50/101)
handle = model.avgpool.register_forward_hook(hook_fn)

def extract_features(model, inputs):
    features_buffer.clear() 
    _ = model(inputs)       
    return features_buffer[0] 

def mahalanobis_score(model, train_loader, test_loader, num_classes=100):
    """
    Correct implementation of Mahalanobis distance:
    1. Compute the mean μ_c for each class (100 classes)
    2. Compute the covariance matrix Σ
    3. For each test sample, compute the distance to all classes and take the minimum
    """
    model.eval()
    
    # Step 1: Collect training features and labels
    print("Collecting training features...")
    train_features = []
    train_labels = []
    with torch.no_grad():
        for inputs, labels in train_loader:
            inputs = inputs.to(device)
            features = extract_features(model, inputs)
            train_features.append(features.cpu())
            train_labels.append(labels)
    
    train_features = torch.cat(train_features)
    train_labels = torch.cat(train_labels)
    
    # Step 2: Compute the mean for each class
    print("Computing class means...")
    class_means = []
    for c in range(num_classes):
        class_features = train_features[train_labels == c]
        class_mean = torch.mean(class_features, dim=0)
        class_means.append(class_mean)
    class_means = torch.stack(class_means)  # shape: [100, feature_dim]
    
    # Step 3: Compute the covariance matrix (using global covariance)
    print("Computing covariance matrix...")
    cov = torch.cov(train_features.T) + 0.01 * torch.eye(train_features.size(1))
    inv_cov = torch.inverse(cov)
    
    # Step 4: Compute the minimum Mahalanobis distance to each class for the test set
    print("Computing Mahalanobis distances for the test set...")
    all_scores = []
    with torch.no_grad():
        for inputs, _ in test_loader:
            inputs = inputs.to(device)
            features = extract_features(model, inputs).cpu()
            
            # For each sample in the batch
            batch_scores = []
            for feat in features:
                # Compute Mahalanobis distance to each class
                dists = []
                for class_mean in class_means:
                    diff = feat - class_mean
                    # Mahalanobis distance: sqrt(diff^T * Σ^-1 * diff). But according to the formula given in the slides,
                    # we use the squared distance.
                    dist = diff @ inv_cov @ diff
                    dists.append(dist.item())
                
                # Take the minimum distance (smaller distance means more likely to be an ID sample)
                min_dist = min(dists)
                batch_scores.append(-min_dist)  # Negative sign: higher score means more likely to be ID
            
            all_scores.append(torch.tensor(batch_scores))

    return torch.cat(all_scores)

mahalanobis_scores = mahalanobis_score(model, train_loader, test_loader)
print(f"Mahalanobis scores (first 10): {mahalanobis_scores[:10]}")

# Cleanup hook
if 'handle' in globals():
    handle.remove()
    features_buffer.clear()

Collecting training features...
Computing class means...
Computing covariance matrix...
Computing Mahalanobis distances for the test set...
Mahalanobis scores (first 10): tensor([-467.2338, -358.8292, -417.4225, -505.1868, -386.7666, -281.6877,
        -348.1957, -450.8566, -178.3548, -193.7439])


# 6. ViM (Virtual Matching) Score

ViM detects OOD samples by analyzing the principal component directions in the feature space:
- ID samples' features primarily reside in a low-dimensional subspace (formed by principal components)
- OOD samples exhibit larger projections beyond the principal component directions

$$s_{ViM}(x) = -\alpha \cdot \|P_{principal}(f(x) - \mu)\|_2 + \|P_{residual}(f(x) - \mu)\|_2$$

where:
- $P_{principal}$: projection onto the principal component subspace
- $P_{residual}$: projection onto the residual subspace (orthogonal complement space)
- $\alpha$: weighting parameter

In [None]:
features_buffer = []

def hook_fn(module, input, output):
    # output shape is [batch, 512, 1, 1], flatten it to [batch, 512]
    features_buffer.append(output.view(output.size(0), -1))

# General hook registration to the top avgpool layer (works for ResNet18/34/50/101)
handle = model.avgpool.register_forward_hook(hook_fn)

def extract_features(model, inputs):
    features_buffer.clear() 
    _ = model(inputs)       
    return features_buffer[0] 


def vim_score(model, train_loader, test_loader, num_principal_components=100, alpha=1.0):
    """
    ViM OOD detection method.
    
    Parameters:
    - num_principal_components: Number of principal components to retain
    - alpha: Weight for the principal component direction
    """
    model.eval()
    
    # Step 1: Collect training features
    print("Collecting training features for PCA...")
    train_features = []
    with torch.no_grad():
        for inputs, _ in train_loader:
            inputs = inputs.to(device)
            features = extract_features(model, inputs)
            train_features.append(features.cpu())
    
    train_features = torch.cat(train_features)
    
    # Step 2: Compute mean and center features
    print("Computing feature mean...")
    mean = torch.mean(train_features, dim=0)
    centered_features = train_features - mean
    
    # Step 3: PCA - Compute eigenvalues and eigenvectors of covariance matrix
    print("Performing PCA...")
    cov = torch.cov(centered_features.T)
    eigenvalues, eigenvectors = torch.linalg.eigh(cov)
    
    # Sort by eigenvalue in descending order
    idx = torch.argsort(eigenvalues, descending=True)
    eigenvalues = eigenvalues[idx]
    eigenvectors = eigenvectors[:, idx]
    
    # Step 4: Select principal components and residual subspace
    principal_components = eigenvectors[:, :num_principal_components]  # Principal component directions
    residual_components = eigenvectors[:, num_principal_components:]   # Residual directions
    
    print(f"Retaining top {num_principal_components} principal components")
    print(f"Explained variance ratio of principal components: {eigenvalues[:num_principal_components].sum() / eigenvalues.sum():.4f}")
    
    # Step 5: Compute ViM scores for the test set
    print("Computing ViM scores for the test set...")
    all_scores = []
    with torch.no_grad():
        for inputs, _ in test_loader:
            inputs = inputs.to(device)
            features = extract_features(model, inputs).cpu()
            
            # Center the features
            centered = features - mean
            
            # Project onto principal component subspace
            principal_proj = centered @ principal_components
            principal_norm = torch.norm(principal_proj, dim=1)
            
            # Project onto residual subspace
            residual_proj = centered @ residual_components
            residual_norm = torch.norm(residual_proj, dim=1)
            
            # ViM score: -alpha * ||principal component projection|| + ||residual projection||
            # OOD samples have larger components in the residual direction
            vim_scores = -alpha * principal_norm + residual_norm
            
            all_scores.append(vim_scores)
    
    return torch.cat(all_scores)

# Compute ViM scores
vim_scores = vim_score(model, train_loader, test_loader, num_principal_components=100, alpha=1.0)
print(f"ViM scores (first 10): {vim_scores[:10]}")

# Cleanup hook
if 'handle' in globals():
    handle.remove()
    features_buffer.clear()

Collecting training features for PCA...
Computing feature mean...
Performing PCA...
Retaining top 100 principal components
Explained variance ratio of principal components: 0.9004
Computing ViM scores for the test set...
ViM scores (first 10): tensor([-6.6161, -5.4149, -7.4246, -6.5089, -9.2754, -5.5998, -5.0616, -7.3689,
        -7.2696, -6.6848])


# 3. Study the Neural Collapse phenomenon at the end of training NC1 to NC4. 

Assume that the model is overparameterized.

NC1: variability collapse: at the end of training, the in-class variation collapses to very low, due to the feature of each class converges to the mean of the class.

NC2: convergence to Simplex Equiangular Tight Frame(ETF): the mean centers of different classes form an ETF structure, meaning they are equidistant from each other in the feature space while maximizing angular separation.

NC3: convergence to self-duaality: the class means and linear classifiers converges to each other, up to rescaling.

NC4: simplification to nearest-class center: the behavior of the classifier ultimately simplifies to classification based on the nearest class centers in the feature space.

# 4. Study the Neural Collapse phenomenon at the end of training NC5. 

NC5: ID/OOD orthogonality: as training progresses, the clusters of OOD data become increasingly orthogonal to the configuration adopted by ID data.

# 5. Implementation of the NECO method (Neural Collapse Inspired OOD Detection).

$$ NECO(x)=\frac{||P~h_{\omega}(x)||}{||h_{\omega}(x)||} \\
=\frac{\sqrt{h_{\omega}(x)^\top PP^\top h_{\omega}(x)}}{\sqrt{h_{\omega}(x)^\top h_{\omega}(x)}}$$

With $h_{\omega}(x)$ the penultimate layer representation and $P$ the projection matrix on the biggest $d$-eigenvectors.

$P$ is fitted using a PCA on the in-distribution training features.

NECO is rescaled by the maximum-logit value. This has the effect of injecting class-based information.