[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/WilliamLockeIV/Spline-Theory-and-Neural-Collapse/blob/main/Theory/Background_on_Neural_Collapse.ipynb)

In [1]:
import torch

# Neural Collapse

Neural collapse is an inductive bias observed across multiple model types and datasets, first described in Papyan et al.'s "Prevalence of neural collapse during the terminal phase of deep learning training" (2020) and expanded upon in subsequent papers. This notebook describes and implements the four metrics used to measure neural collapse, summarized by Hong & Ling (2024) as:


> *   NC1 [Variability Collapse]: the feature of samples from the same class converge to a unique mean feature vector;
> *   NC2 [Convergence to simplex ETF]: these feature vectors (after centering by their global mean) form a simplex equiangular tight frame (ETF), i.e., they share the same pairwise angles and length and have max pairwise distance;
> *   NC3 [Self-duality]: the weight of the linear classifier converges to the corresponding feature mean (up to scalar product);
> *   NC4 [Nearest Class Center]: the trained DNN classifies the sample by finding the closest mean feature vectors to the sample feature.

(Hong & Ling, 2024)

## NC1: Variability Collapse

Papyan et al. (2020) define within-class covariance as

$\sum_{W} = Ave_{i,c}\{(\vec{h}_{i,c}-\vec{\mu}_{c})(\vec{h}_{i,c}-\vec{\mu}_{c})^{T}\}$

and between-class covariance as

$\sum_{B} = Ave_{c}\{(\vec{\mu}_{c}-\vec{\mu}_{G})(\vec{\mu}_{c}-\vec{\mu}_{G})^{T}\}$

where $\vec{h}_{i,c}$ is the feature representation of the $i^{th}$ sample of class $c$, $\vec{\mu}_{c}$ is the mean feature representation of class $c$, and $\vec{\mu}_{G}$ is the mean feature representation of the entire dataset.

They then formalize NC1 as within-class covariance going to zero during the terminal phase of training (TPT), i.e.

$\sum_{W} \to 0$

They also point out that the sum of within-class covariance and between-class covariance is equal to the total covariance of the dataset:

$\sum_{W} + \sum_{B} = \sum_{T}$

We use this last observation to check our calculation of within-class covariance and between-class covariance for both balanced and imbalanced classes.

In [8]:
# Covariance Matrix with balanced classes

# C is an m x n data matrix, where m is the number of data points and n is the number of variables / features.
# This is transposed from the notation used in Papyan et al. (2020), but it matches my code implementation elsewhere.
C = torch.rand((6,10))

# Split the dataset C into equal-sized classes A and B
A = C[:3]
B = C[3:]

# Calculate the global mean and class means
C_mean = C.mean(dim=0, keepdim=True)
A_mean = A.mean(dim=0, keepdim=True)
B_mean = B.mean(dim=0, keepdim=True)
class_means = torch.cat([A_mean, B_mean])

# Calculate within-class covariance, between-class covariance, and total covariance
within_class_cov = torch.stack([A.T.cov(correction=0), B.T.cov(correction=0)]).mean(dim=0)
between_class_cov = class_means.T.cov(correction=0)
total_cov = C.T.cov(correction=0)

# Show that the sum of within-class covariance + between-class covariance = total covariance
print('Within-Class + Between-Class = Total Covariance')
print(torch.allclose(within_class_cov + between_class_cov, total_cov))

Within-Class + Between-Class = Total Covariance
True


In [9]:
# Covariance Matrix with imbalanced classes

# C = m x n data matrix
C = torch.rand((6,10))

# Split the dataset C into unequal classes A and B
A = C[:2]
B = C[2:]

# Balance the within-class covariances by weighting each class' covariance by the number of samples in that class
class_weights = torch.tensor([A.shape[0], B.shape[0]])
class_weights = class_weights / class_weights.sum()

# Balance the between-class covariance by repeating each class' mean by the number of samples in that class
C_mean = C.mean(dim=0, keepdim=True)
A_mean = A.mean(dim=0, keepdim=True).repeat((A.shape[0],1))
B_mean = B.mean(dim=0, keepdim=True).repeat((B.shape[0],1))
class_means = torch.cat([A_mean, B_mean])

# Calculate within-class covariance, between-class covariance, and total covariance
within_class_cov = torch.stack([A.T.cov(correction=0), B.T.cov(correction=0)])
within_class_cov = (within_class_cov * class_weights[:,None,None]).sum(dim=0)
between_class_cov = class_means.T.cov(correction=0)
total_cov = C.T.cov(correction=0)

# Show that the sum of within-class covariance + between-class covariance = total covariance
print('Within-Class + Between-Class = Total Covariance')
print(torch.allclose(within_class_cov + between_class_cov, total_cov))

Within-Class + Between-Class = Total Covariance
True


In their paper, rather than report the entire within-class covariance matrix, Papyan et al. (2020) show the trace of the within-class covariance matrix multiplied by the Penrose-Moore pseudoinverse of the between-class covariance, divided by the number of classes, i.e.

$Tr(\sum_{W}\sum_{B}^{\dagger})/C$

They define this as the inverse signal-to-noise ratio for classification problems, and explain that it scales and rotates the within-class covariance (noise) by the pseudoinverse of the between-class covariance (signal).

However, I have a hard time relating this metric to either of the more straightforward metrics $Tr(\sum_{W})$ or $Tr(\sum_{B})$, and sometimes it is close to zero even while $Tr(\sum_{W}) > Tr(\sum_{B})$, so I am uncertain of its use to prove that within-class covariance goes to zero.

In [17]:
# Inverse Signal-to-Noise Ratio

# C = m x n data matrix
C = torch.rand((6,10))

# Split the dataset C into unequal classes A and B
A = C[:2]
B = C[2:]

# Balance the within-class covariances by weighting each class' covariance by the number of samples in that class
class_weights = torch.tensor([A.shape[0], B.shape[0]])
class_weights = class_weights / class_weights.sum()

# Balance the between-class covariance by repeating each class' mean by the number of samples in that class
C_mean = C.mean(dim=0, keepdim=True)
A_mean = A.mean(dim=0, keepdim=True).repeat((A.shape[0],1))
B_mean = B.mean(dim=0, keepdim=True).repeat((B.shape[0],1))
class_means = torch.cat([A_mean, B_mean])

# Calculate within-class covariance, between-class covariance, and noise-to-signal ratio
within_class_cov = torch.stack([A.T.cov(correction=0), B.T.cov(correction=0)])
within_class_cov = (within_class_cov * class_weights[:,None,None]).sum(dim=0)
between_class_cov = class_means.T.cov(correction=0)
noise_to_signal = torch.matmul(within_class_cov, torch.linalg.pinv(between_class_cov))

# Print the trace of each covariance matrix divided by the number of classes
within_class_trace = torch.trace(within_class_cov) / 2
between_class_trace = torch.trace(between_class_cov) / 2
noise_to_signal_trace = torch.trace(noise_to_signal) / 2

print(f'Within-class:    {within_class_trace:.2f}')
print(f'Between-class:   {between_class_trace:.2f}')
print(f'Noise-to-Signal: {noise_to_signal_trace:.2f}')

Within-class:    0.25
Between-class:   0.08
Noise-to-Signal: 0.06


## NC2: Convergence to Simplex ETF

The paper defines a simplex ETF as a collection of points in $R^{c}$ specified by the columns of

$M^{\star} = \sqrt{\frac{C}{C-1}}(I-\frac{1}{c}\mathbf{1})$

where $C$ is the number of classes, $I \in R^{CxC}$ is the Identity Matrix, and $\mathbf{1} \in R^{CxC}$ is a matrix of all ones. They then loosen this condition to allow for other poses and rescaling, so that $M = \alpha UM^{\star}$ where $\alpha \in R_{+}$ is a scalar and $U \in R^{PxC}$ is a partial orthogonal matrix s.t. $U^{T}U=I$.

NC2 states that, in the terminal phase of training (TPT), the centered class means form a simplex ETF. This is measured by three metrics: 1) convergence of centered class means to the same norm; 2) convergence of centered class means to the same cosine similarity with each other; and 3) convergence of all centered class means to a cosine similarity of -1/(C-1), where C is the number of classes.

In [18]:
# We show elsewhere that orthogonal class means with the same norm form an ETF when centered.
# We compare three random "class means" and three orthogonal "class means".
rand_means = torch.rand((3,10))
ortho_means, _ = torch.linalg.qr(rand_means.T, mode='reduced')
ortho_means = ortho_means.T

# Center means
rand_means = rand_means - rand_means.mean(dim=0, keepdim=True)
ortho_means = ortho_means - ortho_means.mean(dim=0, keepdim=True)

# Calculate norms
rand_norms = torch.linalg.vector_norm(rand_means, dim=1)
rand_print = [f'{norm:.2f}' for norm in rand_norms]
ortho_norms = torch.linalg.vector_norm(ortho_means, dim=1)
ortho_print = [f'{norm:.2f}' for norm in ortho_norms]
print('Norms of centered random class means:', rand_print)
print('Norms of centered orthonormal means: ', ortho_print)
print()

# Calculate angles
rand_angles = torch.stack([torch.nn.functional.normalize(mean, dim=0) for mean in rand_means])
rand_angles = torch.matmul(rand_angles, rand_angles.T)
print('Cosine similarity matrix of centered random class means:')
print(rand_angles)
print()
ortho_angles = torch.stack([torch.nn.functional.normalize(mean, dim=0) for mean in ortho_means])
ortho_angles = torch.matmul(ortho_angles, ortho_angles.T)
print('Cosine similarity matrix of centered orthogonal means:')
print(ortho_angles)
print()

Norms of centered random class means: ['0.59', '0.66', '0.89']
Norms of centered orthonormal means:  ['0.82', '0.82', '0.82']

Cosine similarity matrix of centered random class means:
tensor([[ 1.0000,  0.0254, -0.6793],
        [ 0.0254,  1.0000, -0.7509],
        [-0.6793, -0.7509,  1.0000]])

Cosine similarity matrix of centered orthogonal means:
tensor([[ 1.0000, -0.5000, -0.5000],
        [-0.5000,  1.0000, -0.5000],
        [-0.5000, -0.5000,  1.0000]])



Rather than report all class norms and angles, they report the standard deviation of the norms over their average, the standard deviation of the cosine similarities, and the average difference between the cosine similarity and the max angle of -1/(C-1).

In [19]:
# Calculate standard deviation / mean of class norms
rand_norms_metric = (torch.std(rand_norms) / torch.mean(rand_norms)).item()
ortho_norms_metric = (torch.std(ortho_norms) / torch.mean(ortho_norms)).item()
print(f'STD / Avg of random norms: {rand_norms_metric:.2f}')
print(f'STD / Avg of orthogonal norms: {ortho_norms_metric:.2f}')
print()

# Calculate standard deviation of interclass angles
rows, cols = torch.triu_indices(rand_angles.shape[0], rand_angles.shape[1], offset=1)
rand_angles_metric = torch.std(rand_angles[rows, cols]).item()
ortho_angles_metric = torch.std(ortho_angles[rows, cols]).item()
print(f'STD of random angles: {rand_angles_metric:.2f}')
print(f'STD of orthogonal angles: {ortho_angles_metric:.2f}')
print()

# Calculate average difference of interclass angles and max angles
rows, cols = torch.triu_indices(rand_angles.shape[0], rand_angles.shape[1], offset=1)
rand_shift_angles = torch.mean(torch.abs(rand_angles[rows,cols]+1/(3-1))).item()
ortho_shift_angles = torch.mean(torch.abs(ortho_angles[rows,cols]+1/(3-1))).item()
print(f'Avg difference of random angles and max angles: {rand_shift_angles:.2f}')
print(f'Avg difference of orthogonal angles and max angles: {ortho_shift_angles:.2f}')

STD / Avg of random norms: 0.22
STD / Avg of orthogonal norms: 0.00

STD of random angles: 0.43
STD of orthogonal angles: 0.00

Avg difference of random angles and max angles: 0.32
Avg difference of orthogonal angles and max angles: 0.00


## NC3: Self-Duality

Self-duality refers to the phenomenon of the centered class means and the weights of the linear classifier approaching scaled versions of each other. To measure this, they take the difference between the class mean matrix and the linear weight matrix, both normalized by their Frobenius norms, and take the norm of that difference matrix, which should go to zero.

In [22]:
# Calculate the difference between class means and linear weights

# Create "class means" and corresponding "dual weights" (i.e. a scaled version of the means),
# as well as a separate "random weights" to compare. We make all of these orthogonal to fit NC2.
class_means = torch.rand((3,10))
class_means = torch.linalg.qr(class_means.T, mode='reduced')[0].T
class_means = class_means - class_means.mean(dim=0, keepdim=True)
dual_weights = class_means * torch.rand((1))
rand_weights = torch.rand((3,10))
rand_weights = torch.linalg.qr(rand_weights.T, mode='reduced')[0].T

# Normalize matrices by their Frobenius norms
class_means = class_means / torch.linalg.matrix_norm(class_means, ord='fro')
dual_weights = dual_weights / torch.linalg.matrix_norm(dual_weights, ord='fro')
rand_weights = rand_weights / torch.linalg.matrix_norm(rand_weights, ord='fro')

# Calculate difference between the class means and dual weights, and between the class means and random weights
self_dual = class_means - dual_weights
non_dual = class_means - rand_weights

# Calculate the Frobenius norm of both difference matrices
print(f'Self-Dual Matrix Norm: {torch.linalg.matrix_norm(self_dual).item():.2f}')
print(f'Non-Dual Matrix Norm:  {torch.linalg.matrix_norm(non_dual).item():.2f}')

Self-Dual Matrix Norm: 0.00
Non-Dual Matrix Norm:  1.27


## NC4: Nearest Class Center

In the terminal phase of training (TPT), given a specific input, the output of the model becomes equivalent to selecting the class mean most closely aligned with that input's final-layer activations. This follows from NC1-NC3, as shown below.

In [25]:
# Following NC1 and NC2, we assume that the features for each class have converged to
# their class means, and the centered class means form an ETF.
class_means = torch.rand((6,10))
class_means = torch.linalg.qr(class_means.T, mode='reduced')[0].T
class_means = class_means - class_means.mean(dim=0, keepdim=True)

# Following NC3, we assume the linear classifier is a scaled version of the class means
linear_weights = class_means * torch.rand((1))

# Given a new set of input features, we show that max linear output and the max class
# mean similarity (as measured by min L2 distance) agree
input_features = torch.rand(8,10)
linear_outputs = torch.argmax(torch.matmul(linear_weights, input_features.T), dim=0)
nearest_class = input_features[None,:,:] - class_means[:,None,:]
nearest_class = torch.argmin(torch.linalg.vector_norm(nearest_class, dim=2),dim=0)
print('Linear outputs:     ', linear_outputs)
print('Nearest class means:', nearest_class)

Linear outputs:      tensor([5, 2, 2, 1, 5, 5, 2, 2])
Nearest class means: tensor([5, 2, 2, 1, 5, 5, 2, 2])


One additional note is that the bias vector of the linear classifier (left out / set to zero in this example) could potentially cause disagreement between the linear outputs and nearest class means.

They measure this as the percentage of disagreements between the linear output and the nearest class center (NCC) classifications.

In [26]:
samples = 100
input_features = torch.rand(samples,10)
linear_outputs = torch.argmax(torch.matmul(linear_weights, input_features.T), dim=0)
nearest_class = input_features[None,:,:] - class_means[:,None,:]
nearest_class = torch.argmin(torch.linalg.vector_norm(nearest_class, dim=2),dim=0)
disagreement = (linear_outputs != nearest_class).sum() / samples
print(f'Disagreement between linear output and NCC: {disagreement:.1%}')

Disagreement between linear output and NCC: 0.0%


# References

Papyan, V., Han, X. Y., & Donoho, D. L. Prevalence of Neural Collapse during the Terminal Phase of Deep Learning Training. PNAS, 2020.

Hong, W., & Ling, S. Neural Collapse for Unconstrained Feature Model under Cross-entropy Loss with Imbalanced Data. JMLR, 2024. 