### ETF Prototypes (K <= d)

https://github.com/NeuralCollapseApplications/ImbalancedLearning/blob/main/models/resnet.py#L326

$K$ = number of classes

$d$ = representation space dimension

In [2]:
import numpy as np
import torch
import random

seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

def generate_random_orthogonal_matrix(feat_dim, num_classes):
    a = np.random.random(size=(feat_dim, num_classes))
    P, _ = np.linalg.qr(a)
    P = torch.tensor(P).float()
    assert torch.allclose(torch.matmul(P.T, P), torch.eye(num_classes), atol=1e-07), torch.max(torch.abs(torch.matmul(P.T, P) - torch.eye(num_classes)))
    return P

d = 16
K = 4

P = generate_random_orthogonal_matrix(feat_dim=d, num_classes=K)
I = torch.eye(K)
one = torch.ones(K, K)
M_star = np.sqrt(K / (K-1)) * torch.matmul(P, I-((1/K) * one))
M_star = M_star.cuda()

print(M_star)
print(M_star.shape)

def test_ETF_prototype(M_star):
    # print norm of each vector in M_star
    for i, m in enumerate(M_star):
        print(f"m_{i} l2 norm: {torch.norm(m, dim=None, p=2)}.")

    for i, m in enumerate(M_star):
        for j, n in enumerate(M_star):
            if i != j:
                dot_prod = torch.dot(m,n)
                m_norm = torch.norm(m, dim=None, p=2)
                n_norm = torch.norm(n, dim=None, p=2)
                cosine_sim = dot_prod / (m_norm * n_norm)
                cosine_sim = torch.clamp(cosine_sim, -1.0, 1.0)
                pairwise_angle = torch.acos(cosine_sim)
                print(f"Pairwise angle between m_{i} and m_{j} is: {pairwise_angle}.")

test_ETF_prototype(M_star=M_star.T)

tensor([[-0.2537,  0.4420,  0.1176, -0.3060],
        [-0.2382, -0.1224, -0.2090,  0.5695],
        [-0.2723,  0.2513, -0.3602,  0.3812],
        [-0.1654,  0.0027,  0.0830,  0.0798],
        [-0.1469,  0.2387,  0.1004, -0.1922],
        [-0.2090, -0.0974,  0.1086,  0.1978],
        [-0.1633,  0.4138, -0.1584, -0.0920],
        [-0.2177, -0.1740,  0.3836,  0.0081],
        [-0.3088,  0.3803,  0.1920, -0.2635],
        [-0.2600, -0.1840,  0.3300,  0.1140],
        [-0.2156,  0.1455, -0.3207,  0.3908],
        [-0.1686,  0.3161, -0.0727, -0.0748],
        [-0.4177, -0.2744,  0.4265,  0.2656],
        [-0.3872,  0.2776,  0.0057,  0.1040],
        [-0.0870,  0.0567, -0.0946,  0.1249],
        [-0.2610, -0.0577,  0.4087, -0.0899]], device='cuda:0')
torch.Size([16, 4])
m_0 l2 norm: 1.0.
m_1 l2 norm: 0.9999999403953552.
m_2 l2 norm: 1.0.
m_3 l2 norm: 1.0.
Pairwise angle between m_0 and m_1 is: 1.9106332063674927.
Pairwise angle between m_0 and m_2 is: 1.9106332063674927.
Pairwise angle betwee

### Prototype generation (HPN 2019)

https://github.com/psmmettes/hpn/blob/master/prototypes.py

$K$ = number of classes

$d$ = representation space dimension

In [3]:
from torch import nn
import torch.nn.functional as F
import torch.optim as optim

d = 16
K = 100

def prototype_loss(prototypes):
    product = torch.matmul(prototypes, prototypes.t()) + 1
    product -= 2. * torch.diag(torch.diag(product))
    loss = product.max(dim=1)[0]
    return loss.mean(), product.max()

prototypes = torch.randn(K, d)
prototypes = nn.Parameter(F.normalize(prototypes, p=2, dim=1))
optimizer = optim.SGD([prototypes], lr=0.1, momentum=0.9)
epochs = 1_000_000
for i in range(epochs):
    loss1, sep = prototype_loss(prototypes)
    loss = loss1
    loss.backward()
    optimizer.step()

    prototypes = nn.Parameter(F.normalize(prototypes, p=2, dim=1))
    optimizer = optim.SGD([prototypes], lr=0.1, momentum=0.9)
    # print("%03d/%d: %.4f\r" %(i, epochs, sep))

In [4]:
prototypes.shape

torch.Size([100, 16])

In [5]:
test_ETF_prototype(M_star=prototypes)

m_0 l2 norm: 1.0.
m_1 l2 norm: 1.0.
m_2 l2 norm: 0.9999999403953552.
m_3 l2 norm: 1.0.
m_4 l2 norm: 1.0.
m_5 l2 norm: 1.0.
m_6 l2 norm: 0.9999999403953552.
m_7 l2 norm: 1.0.
m_8 l2 norm: 0.9999999403953552.
m_9 l2 norm: 0.9999999403953552.
m_10 l2 norm: 1.0.
m_11 l2 norm: 0.9999999403953552.
m_12 l2 norm: 1.0.
m_13 l2 norm: 1.0.
m_14 l2 norm: 0.9999999403953552.
m_15 l2 norm: 0.9999999403953552.
m_16 l2 norm: 0.9999999403953552.
m_17 l2 norm: 0.9999999403953552.
m_18 l2 norm: 0.9999999403953552.
m_19 l2 norm: 0.9999999403953552.
m_20 l2 norm: 1.0.
m_21 l2 norm: 1.0.
m_22 l2 norm: 0.9999999403953552.
m_23 l2 norm: 1.0000001192092896.
m_24 l2 norm: 1.0.
m_25 l2 norm: 0.9999999403953552.
m_26 l2 norm: 1.0.
m_27 l2 norm: 1.0.
m_28 l2 norm: 1.0.
m_29 l2 norm: 0.9999999403953552.
m_30 l2 norm: 0.9999999403953552.
m_31 l2 norm: 1.0.
m_32 l2 norm: 0.9999999403953552.
m_33 l2 norm: 1.0.
m_34 l2 norm: 0.9999999403953552.
m_35 l2 norm: 1.0.
m_36 l2 norm: 1.0.
m_37 l2 norm: 0.9999999403953552.
m_3