In [None]:
import os

if os.getcwd().endswith('notebooks'): os.chdir('..')

import torch
import matplotlib.pyplot as plt
import torch.nn.functional as F
import numpy as np

from src.model import get_model
from src.data import get_data, get_ood_data
from src.ood_scores import OODEvaluator
from src.neural_collapse import compute_nc_metrics

from config import MODELS_DIR, BATCH_SIZE

from matplotlib.colors import TwoSlopeNorm
from sklearn.decomposition import PCA
from torchvision import transforms

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
train_loader, test_loader = get_data(BATCH_SIZE)
ood_loader = get_ood_data(BATCH_SIZE)

model_state = torch.load(os.path.join(MODELS_DIR, "resnet18_cifar100.pth"))
model = get_model().to(device)
model.load_state_dict(model_state["model_state_dict"])

In [None]:
evaluator = OODEvaluator(model)
features, logits = evaluator.get_features_and_logits(test_loader, device)
test_labels = torch.cat([l for _, l in test_loader])
weights = model.base_model.fc.weight.detach().cpu()

metrics = compute_nc_metrics(features.cpu(), test_labels.cpu(), weights)


In [None]:
def plot_nc1(features, labels, means):
    """
    Visualizes NC1 Analysis: Intra-class variability collapse.

    This function quantifies the 'tightness' of class clusters by measuring the 
    ratio of within-class covariance to between-class covariance. In a perfect 
    Neural Collapse regime, this ratio converges toward zero.

    Formula (Trace Ratio):
        NC1 = Tr(Sigma_W * pinv(Sigma_B)) / K
        Where:
        Sigma_W = 1/N * sum_{c=1}^K sum_{i=1}^{N_c} (f_{i,c} - mu_c)(f_{i,c} - mu_c)^T
        Sigma_B = 1/K * sum_{c=1}^K (mu_c - mu_G)(mu_c - mu_G)^T

    Args:
        features (torch.Tensor): Latent representations (embeddings), shape [N, d].
        labels (torch.Tensor): Ground truth class indices, shape [N].
        means (torch.Tensor): Computed class centroids, shape [K, d].
    """

    intra_distances = []
    num_classes = means.shape[0]
    
    for i in range(num_classes):
        mask = (labels == i)
        if mask.sum() > 0:
            diff = features[mask] - means[i]
            intra_distances.append(torch.norm(diff, dim=1))
    
    all_intra = torch.cat(intra_distances)
    
    mu_g = means.mean(dim=0)
    inter_dist = torch.norm(means - mu_g, dim=1).mean()
    
    # Ratio < 1 = cluster is tighter than the global spread
    collapsed_distances = (all_intra / inter_dist).cpu().numpy()
    avg_nc1 = collapsed_distances.mean()
    
    plt.figure(figsize=(10, 7), facecolor='white')
    
    n, bins, patches = plt.hist(
        collapsed_distances, bins=80, 
        color='#4a69bd', alpha=0.7, 
        edgecolor='#2f3542', linewidth=0.5
    )

    plt.axvline(avg_nc1, color='#e74c3c', linestyle='--', linewidth=2, label=f'Mean NC1 Ratio: {avg_nc1:.4f}')

    plt.title("NC1: Intra-class Variability Collapse", fontsize=16, fontweight='normal', pad=20)
    plt.xlabel("Relative Distance (Intra-class spread / Inter-class separation)", fontsize=12)
    plt.ylabel("Number of Samples", fontsize=12)
    
    plt.xlim(0, min(max(collapsed_distances), avg_nc1 * 3))
    
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    plt.gca().yaxis.grid(True, linestyle='--', alpha=0.3)

    plt.text(avg_nc1 * 1.1, max(n) * 0.8, 
             f"Ideal NC1 $\\rightarrow$ 0\nMeasured: {avg_nc1:.4f}", 
             fontsize=11, color='#c0392b', bbox=dict(facecolor='white', alpha=0.5))

    plt.legend(frameon=True)
    plt.tight_layout()
    plt.show()

    return avg_nc1

In [None]:
def plot_nc2(means, grid_step=10):
    """
    Visualizes NC2 Analysis: Convergence of class means to a Simplex ETF.

    This function plots the cosine similarity matrix between centered class means.
    In a Neural Collapse regime, the off-diagonal entries should converge to 
    the ETF constant, reflecting maximal mutual repulsion.

    Formula (Inter-class Cosine Similarity):
        S_{i,j} = <mu_i - mu_G, mu_j - mu_G> / (||mu_i - mu_G|| * ||mu_j - mu_G||)
        Ideal Target (NC2) = -1 / (K - 1)

    Args:
        means (torch.Tensor): Matrix of class centroids, shape [K, d].
        grid_step (int): Interval for drawing visual grid lines on the heatmap 
            (default: 10 for CIFAR-100).
    """
    
    mu_g = means.mean(dim=0)
    norm_means = F.normalize(means - mu_g, p=2, dim=1)
    
    cosine_sim = torch.mm(norm_means, norm_means.T).cpu().numpy()
    num_classes = means.shape[0]
    etf_target = -1 / (num_classes - 1)
    
    plt.figure(figsize=(11, 9), facecolor='white')
    
    # Normalization focused on the ETF target
    norm = TwoSlopeNorm(vcenter=etf_target, vmin=-1.0, vmax=1.0)
    img = plt.imshow(cosine_sim, cmap='RdBu_r', norm=norm, interpolation='nearest')

    mask = ~torch.eye(num_classes, dtype=bool)
    off_diag_avg = cosine_sim[mask].mean().item()
    print(f"Experimental Mean Similarity: {off_diag_avg:.6f}")
    
    for i in range(0, num_classes, grid_step):
        plt.axhline(i - 0.5, color='black', linewidth=0.5, alpha=0.3)
        plt.axvline(i - 0.5, color='black', linewidth=0.5, alpha=0.3)

    cbar = plt.colorbar(img)
    cbar.set_label("Cosine Similarity", fontsize=12)
    
    plt.title(f"NC2: Simplex ETF Correlation Matrix", 
              fontsize=16, fontweight='normal', pad=20)
    plt.xlabel("Class Index", fontsize=12)
    plt.ylabel("Class Index", fontsize=12)
    
    plt.xlim(-0.5, num_classes - 0.5)
    plt.ylim(num_classes - 0.5, -0.5)
    
    plt.tight_layout()
    plt.show()

In [None]:
def plot_nc3(model, means):
    """
    Visualizes NC3 Analysis: Self-Duality (Weight-Centroid Alignment).

    This function measures the alignment between the class centroids (means) 
    and the corresponding row vectors of the classifier's weight matrix. 
    In a collapsed state, the network exhibits self-duality, meaning the 
    classifier weights become equal to the centered class means (up to a scalar).

    Formula (Alignment Metric):
        NC3 = || W_c / ||W_c||_2 - (mu_c - mu_G) / ||mu_c - mu_G||_2 ||_2
        As Neural Collapse progresses, NC3 -> 0 for all classes c=1...K.

    Args:
        model (torch.nn.Module): The trained neural network containing 
            the final linear classification layer.
        means (torch.Tensor): Computed class centroids in the feature 
            space, shape [K, d].
    """

    weights = model.base_model.fc.weight.detach().cpu()
    w_norm = F.normalize(weights, p=2, dim=1)
    
    mu_g = means.cpu().mean(dim=0)
    m_centered = means.cpu() - mu_g
    m_norm = F.normalize(m_centered, p=2, dim=1)
    
    cross_sim = torch.mm(w_norm, m_norm.T)
    
    best_alignment, _ = cross_sim.max(dim=1)
    avg_nc3 = best_alignment.mean().item()

    plt.figure(figsize=(10, 7), facecolor='white')
    plt.hist(best_alignment.numpy(), bins=30, color='#27ae60', 
             alpha=0.6, edgecolor='#1e8449', linewidth=1.2)
    
    plt.axvline(avg_nc3, color='#c0392b', linestyle='--', linewidth=2, 
                label=f'Mean Alignment: {avg_nc3:.4f}')
    
    plt.title("NC3: Classifier Weight and Class Mean Alignment (Self-Duality)", 
              fontsize=15, fontweight='normal', pad=20)
    plt.xlabel("Cosine Similarity", fontsize=12)
    plt.ylabel("Number of Classes", fontsize=12)
    
    plt.xlim(0, 1.1)
    plt.gca().yaxis.grid(True, linestyle='--', alpha=0.3)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    
    plt.legend(frameon=True, fontsize=11)
    plt.tight_layout()
    plt.show()
    
    return avg_nc3

In [None]:
def extract_layer_features(model, layer, loader, device):
    """
    Extracts latent features from a specific model layer and their associated labels.

    This utility uses a forward hook to capture internal activations during 
    inference, which are essential for analyzing the geometric properties of 
    the latent space (Neural Collapse) and computing OOD scores.

    Args:
        model (torch.nn.Module): The neural network to extract features from.
        layer (torch.nn.Module): The specific layer to hook (e.g., model.base_model.avgpool).
        loader (torch.utils.data.DataLoader): The data source to process.
        device (torch.device): The computing device (cuda or cpu).

    Returns:
        tuple (torch.Tensor, torch.Tensor): A pair of tensors (features, labels)
            - features: The extracted activations, shape [N, d].
            - labels: The corresponding ground truth class indices, shape [N].
    """

    model.eval()
    labels_list = []
    tmp_storage = []

    def hook_fn(module, input, output):
        # output.shape : [B, C, H, W]
        pooled_output = F.adaptive_avg_pool2d(output, (1, 1)).flatten(1) # [B, C]
        tmp_storage.append(pooled_output.detach().cpu())

    handle = layer.register_forward_hook(hook_fn)

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            model(images)
            labels_list.append(labels.cpu())
            
    handle.remove()

    features = torch.cat(tmp_storage, dim=0)
    all_labels = torch.cat(labels_list, dim=0)

    return features, all_labels

In [None]:
def plot_nc4(model, test_loader, class_means, device):
    """
    Visualizes NC4 Analysis: Equivalence to Nearest Class Center (NCC) Rule.

    This function evaluates the convergence of the learned Softmax classifier 
    to a simplified Nearest Class Center (NCC) decision rule. In the NC regime, 
    the complex decision boundaries of the neural network simplify into Voronoi 
    cells centered around the class centroids.

    Formula (Decision Equivalence):
        y_softmax = argmax_c (W^T * f + b)
        y_ncc = argmin_c || f - mu_c ||_2
        Metric: | Accuracy(y_softmax) - Accuracy(y_ncc) | -> 0

    Args:
        model (torch.nn.Module): The trained neural network.
        test_loader (torch.utils.data.DataLoader): Data to evaluate accuracy.
        class_means (torch.Tensor): Computed class centroids, shape [K, d].
        device (torch.device): Computing device (cuda or cpu).
    """

    model.eval()
    all_logits = []
    all_features = []
    all_labels = []

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            output = model(inputs)
            features = output[0] if isinstance(output, tuple) else output
            logits = output[1] if isinstance(output, tuple) else None
            
            all_features.append(features.cpu())
            if logits is not None:
                all_logits.append(logits.cpu())
            all_labels.append(labels.cpu())

    test_features = torch.cat(all_features)
    test_labels = torch.cat(all_labels)
    
    # Softmax acc
    if len(all_logits) > 0:
        test_logits = torch.cat(all_logits)
        acc_softmax = (test_logits.argmax(dim=1) == test_labels).float().mean().item() * 100
    else: acc_softmax = 0.0 

    # NCC acc
    mu_g = class_means.mean(dim=0).cpu()
    feat_norm = F.normalize(test_features - mu_g, p=2, dim=1)
    means_norm = F.normalize(class_means.cpu() - mu_g, p=2, dim=1)
    
    cos_sim = torch.mm(feat_norm, means_norm.T)
    preds_ncc = cos_sim.argmax(dim=1)
    acc_ncc = (preds_ncc == test_labels).float().mean().item() * 100

    plt.figure(figsize=(9, 7), facecolor='white')
    labels = ['Softmax (Network)', 'NCC (Geometric)']
    accuracies = [acc_softmax, acc_ncc]
    
    colors = ['#4a69bd', '#7f8c8d']
    bars = plt.bar(labels, accuracies, color=colors, alpha=0.8, width=0.4)
    
    for bar in bars:
        height = bar.get_height()
        plt.text(bar.get_x() + bar.get_width()/2., height + 1.5,
                 f'{height:.2f}%', ha='center', va='bottom', fontsize=12)

    plt.ylim(0, 110)
    gap = abs(acc_softmax - acc_ncc)
    plt.title(f"NC4: Decision Rule Convergence (Gap: {gap:.4f}%)", 
              fontsize=15, fontweight='normal', pad=20)
    plt.ylabel("Test Accuracy (%)", fontsize=12)
    
    plt.gca().yaxis.grid(True, linestyle='--', alpha=0.3)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    
    plt.tight_layout()
    plt.show()

    return acc_softmax, acc_ncc

In [None]:
def plot_progressive_pca(model, loader, device):
    """
    Visualizes the latent space organization using a 2x2 PCA grid.

    This function provides a high-contrast qualitative assessment of the 
    feature space. It projects the high-dimensional embeddings into a 
    2D principal subspace to observe class clustering and the emergence 
    of the Simplex ETF geometry.

    Formula (PCA Projection):
        f_centered = f - E[f]
        Components = Eigenvectors of (f_centered^T * f_centered)
        f_2D = f_centered @ V_{1:2}

    Args:
        model (torch.nn.Module): The trained model to extract features from.
        loader (torch.utils.data.DataLoader): The dataset used for visualization.
        device (torch.device): Computing device (cuda or cpu).
    """
    
    layers = [
        model.base_model.layer1, 
        model.base_model.layer2, 
        model.base_model.layer3, 
        model.base_model.layer4
    ]
    
    fig, axes = plt.subplots(2, 2, figsize=(12, 12), facecolor='white')
    axes = axes.flatten()
    
    target_classes = [0, 10, 20, 30, 40] 
    colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd']

    for idx, layer in enumerate(layers):
        feats, lbls = extract_layer_features(model, layer, loader, device)
        
        mask = torch.tensor([l in target_classes for l in lbls])
        f_sub = feats[mask].numpy()
        l_sub = lbls[mask].numpy()
        
        pca = PCA(n_components=2)
        projected = pca.fit_transform(f_sub)
        
        ax = axes[idx]
        for i, c in enumerate(target_classes):
            c_mask = l_sub == c
            ax.scatter(
                projected[c_mask, 0], 
                projected[c_mask, 1], 
                color=colors[i],
                label=f'Class {c}', 
                s=25, 
                alpha=0.7, 
                edgecolors='w', 
                linewidths=0.3
            )
        
        ax.set_box_aspect(1) 
        
        ax.set_title(f"Layer {idx+1}", fontsize=16, fontweight='normal', pad=15)
        
        ax.set_xticks([])
        ax.set_yticks([])
        
        for spine in ax.spines.values():
            spine.set_visible(True)
            spine.set_color('#888888')
            spine.set_linewidth(1.0)

    plt.suptitle("Geometric Evolution and Neural Collapse Across Layers", 
                 fontsize=22, fontweight='normal', y=0.96)
    
    handles, labels = axes[0].get_legend_handles_labels()
    fig.legend(handles, labels, loc='lower center', bbox_to_anchor=(0.5, 0.05),
               ncol=5, fontsize=13, frameon=True)

    plt.tight_layout(rect=[0, 0.08, 1, 0.93])
    plt.show()

In [None]:
features_final, labels_final = extract_layer_features(model, model.base_model.layer4, test_loader, device)

class_means = []
for i in range(100):
    mask = (labels_final == i)
    class_means.append(features_final[mask].mean(dim=0))
class_means = torch.stack(class_means)

fc_weights = model.base_model.fc.weight.detach().cpu()

## NC1: Intra-class Variability Collapse

This plot illustrates the distance between individual sample features and their respective class centroids. 

* **Interpretation**: A distribution that is both narrow and centered near zero indicates that the network has successfully "collapsed" intra-class variations. 
* **Theoretic Goal**: This process effectively erases non-discriminative information, mapping all samples of the same category onto a single point (the centroid) in the latent space.

In [None]:
plot_nc1(features_final, labels_final, class_means)

## NC2: Convergence to Simplex ETF

This heatmap displays the cosine similarity matrix between centered class means. 

* **Observation**: We look for a uniform off-diagonal structure.
* **Theoretical Target**: In a state of Neural Collapse, class centers reach a state of maximal mutual repulsion. We expect the off-diagonal entries to converge toward the theoretical constant of $-1/(K-1)$ (approx. $-0.01$ for CIFAR-100), forming a perfectly symmetrical **Simplex Equiangular Tight Frame (ETF)**.

In [None]:
plot_nc2(class_means)

## NC3 & NC4: Structural Alignment and Decision Rule

These visualizations compare the learned classifier weights with the feature centroids and the equivalence of the decision boundaries.

* **NC3 (Self-Duality)**: As the latent space reaches its optimal geometric configuration, the classifier weights align perfectly with the class means.
* **NC4 (NCC Equivalence)**: The complex Softmax decision boundaries simplify into **Voronoi cells**. At this stage, a simple **Nearest Class Center (NCC)** rule performs as well as the fully trained linear head, confirming the geometric hollowing of the space.

In [None]:
plot_nc3(model, class_means)

In [None]:
plot_nc4(model, test_loader, class_means, device)

## Visual Interpretation (PCA)

The 2D projection of the high-dimensional feature space provides a qualitative confirmation of the NC phenomenon.

> **Key takeaway**: We observe the transition from a disorganized cloud of points to highly concentrated, equidistant clusters. This structural prior is what we leverage to improve Out-of-Distribution detection: by forcing In-Distribution data into these rigid "anchors," OOD samples are more likely to fall into the "geometric void" between them.

In [None]:
plot_progressive_pca(model, test_loader, device)

## NC5: Feature Invariance under Augmentation

NC5 measures the stability of the latent representations when the input undergoes stochastic transformations (rotations, crops, jittering).

* **Observation**: In the "Star Plot", each cluster represents a single source image and its variations. The smaller the "star" radius, the higher the invariance.
* **Interpretation**: A successful Neural Collapse implies that the network has learned to ignore the "manifold of variations" (noise, pose, lighting). It collapses all possible versions of an image into a single, robust semantic point.
* **Link to OOD**: High NC5 invariance ensures that the In-Distribution (ID) anchors are extremely stable. This makes any Out-of-Distribution (OOD) sample—which won't map to these stable points—much easier to distinguish as it falls outside these tightly collapsed zones.

In [None]:
def analyze_nc5_invariance(model, image, transform_pipeline, n_augmentations=50):
    """
    Analyzes NC5 Invariance: Feature representation stability under augmentation.

    This function quantifies the degree to which the model's latent features are 
    invariant to data augmentations. In a Neural Collapse regime, the network 
    learns to map all augmented versions of an image to its class centroid, 
    effectively collapsing the augmentation manifold.

    Formula (Invariance Metric):
        Let f_0 be the embedding of the original image and f_j the j-th augmentation.
        NC5_score = (1 / n) * sum_{j=1}^n || f_j - f_0 ||_2

    Args:
        model (torch.nn.Module): The trained neural network.
        image (torch.Tensor): A single seed image, shape [C, H, W].
        transform_pipeline (callable): A stochastic augmentation pipeline.
        n_augmentations (int): Number of augmented samples to generate.

    Returns:
        float: The mean Euclidean distance between the augmented embeddings 
               and the original image's centroid (NC5 score).
    """

    model.eval()
    device = next(model.parameters()).device
    
    augmented_images = torch.stack([transform_pipeline(image) for _ in range(n_augmentations)]).to(device)
    
    with torch.no_grad():
        output = model(augmented_images)
        
        if isinstance(output, tuple):
            features = output[0] if output[0].dim() == 2 else output[1]
        else: features = output
            
        features = F.normalize(features, p=2, dim=1)
        
    sim_matrix = torch.matmul(features, features.T).cpu().numpy()
    
    centroid = features.mean(dim=0, keepdim=True)
    centroid = F.normalize(centroid, p=2, dim=1)
    dist_to_centroid = torch.norm(features - centroid, dim=1).cpu().numpy()
    
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
    im = ax1.imshow(sim_matrix, cmap='viridis', vmin=0.8, vmax=1.0)
    ax1.set_title("Similarité Cosinus entre Augmentations\n(Proche de 1 = NC5 fort)")
    plt.colorbar(im, ax=ax1)
    
    ax2.hist(dist_to_centroid, bins=20, color='salmon', edgecolor='black', alpha=0.7)
    ax2.axvline(dist_to_centroid.mean(), color='red', linestyle='--', 
        label=f'Dispersion moyenne: {dist_to_centroid.mean():.4f}')
    ax2.set_title("Dispersion Intra-échantillon (NC5)")
    ax2.set_xlabel("Distance Euclidienne au centroïde de l'image")
    ax2.set_ylabel("Nombre d'augmentations")
    ax2.legend()
    
    plt.tight_layout()
    plt.show()
    
    return dist_to_centroid.mean()

test_invariance_transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.RandomResizedCrop(32, scale=(0.7, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(30),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.ToTensor(),
    transforms.Normalize((0.5071, 0.4867, 0.4408), (0.2675, 0.2565, 0.2761)),
])

sample_image = next(iter(train_loader))[0].cpu()[0]

avg_dispersion = analyze_nc5_invariance(
    model=model, 
    image=sample_image, 
    transform_pipeline=test_invariance_transform, 
    n_augmentations=100
)

print(f"Dispersion moyenne NC5 : {avg_dispersion:.4f}")

In [None]:
def plot_nc5_stars(model, dataloader, transform_pipeline, n_images=6, n_augmentations=40):
    """
    Visualizes NC5 (Feature Invariance) using a 2D projection of augmented samples.

    This function generates a "Star Plot" where each star corresponds to one 
    original image. The center of the star is the original embedding, and the 
    points are its stochastic augmentations. Shorter "rays" indicate higher 
    invariance to data transformations.

    Formula (Geometric Interpretation):
        Original: f_0 = Encoder(x)
        Augmented: f_j = Encoder(T_j(x))
        Star Radius: d_j = || f_j - f_0 ||_2

    Args:
        model (torch.nn.Module): The trained neural network.
        dataloader (torch.utils.data.DataLoader): Source for the seed images.
        transform_pipeline (callable): Stochastic augmentation pipeline (e.g., RandAugment).
        n_images (int): Number of distinct stars (source images) to plot.
        n_augmentations (int): Number of augmented samples per source image.
    """

    model.eval()
    device = next(model.parameters()).device
    
    images, labels = next(iter(dataloader))
    all_features = []
    group_ids = []
    class_names = []

    with torch.no_grad():
        for i in range(n_images):
            aug_imgs = torch.stack([transform_pipeline(images[i]) for _ in range(n_augmentations)]).to(device)
            
            output = model(aug_imgs)
            f = output[0] if isinstance(output, tuple) and output[0].dim() == 2 else output
            if isinstance(f, tuple): f = f[1]
            
            f = F.normalize(f, p=2, dim=1)
            
            all_features.append(f.cpu())
            group_ids.extend([i] * n_augmentations)
            class_names.append(labels[i].item())

    features_concat = torch.cat(all_features).numpy()
    pca = PCA(n_components=2)
    features_2d = pca.fit_transform(features_concat)

    plt.figure(figsize=(12, 8), facecolor='white')
    colors = plt.cm.get_cmap('Dark2', n_images)
    
    for i in range(n_images):
        idx = np.where(np.array(group_ids) == i)[0]
        points = features_2d[idx]
        center = points.mean(axis=0)
        
        for p in points:
            plt.plot([center[0], p[0]], [center[1], p[1]], 
                     color=colors(i), alpha=0.15, lw=1, zorder=1)
        
        plt.scatter(points[:, 0], points[:, 1], 
                    color=colors(i), s=25, alpha=0.6, edgecolors='none', zorder=2)
        
        plt.scatter(center[0], center[1], 
                    color=colors(i), s=150, marker='o', edgecolors='black', 
                    linewidth=1.5, label=f"Image {i} (Class {class_names[i]})", zorder=3)

    plt.title("NC5 Visualization: Feature Invariance to Augmentations", fontsize=16, pad=20)
    plt.xlabel("PCA Principal Component 1", fontsize=12)
    plt.ylabel("PCA Principal Component 2", fontsize=12)
    
    plt.grid(True, linestyle='--', alpha=0.4)
    plt.gca().spines['top'].set_visible(False)
    plt.gca().spines['right'].set_visible(False)
    
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', title="Samples", frameon=True)
    
    plt.tight_layout()
    plt.show()

plot_nc5_stars(
    model=model, 
    dataloader=train_loader, 
    transform_pipeline=test_invariance_transform,
    n_images=5, 
    n_augmentations=50
)