In [1]:
import torch
import whitening as wh

In [2]:
n = 512
v = 100
k = 2
d = 5

In [3]:
#generate a random symmetric positive definite matrix
A = torch.randn(v, k, d, d)
Corr = torch.einsum('vkde,vkfe->vkdf', A, A) + 1

#generate random points correlated with Corr
X = torch.randn(n, v, k, d)
X = X - X.mean(0)

In [4]:
U, S, V = torch.svd(Corr)
S = torch.diag_embed(torch.sqrt(S))
W = torch.einsum('vkde,vkef->vkdf', U, S)
W = torch.einsum('vkdf,vkgf->vkdg', W, U)
Y = torch.einsum('nvkd,vkdg->nvkg', X, W)

Y += X.mean(0)

In [13]:
Z = wh.whitening(Y[:, 0, 0])

torch.Size([512, 5]) torch.Size([5, 5])
tensor([ 5.1223e-09, -5.1223e-09,  1.8626e-09,  4.4238e-09, -1.8626e-09])
tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


In [6]:
for method in ['zca', 'zca_cor', 'pca', 'pca_cor', 'cholesky']:
    whitening = wh.WhitenNorm(method=method)
    Z = whitening(Y)
    cov = torch.einsum("nvkd, nvkD -> vkdD", Z, Z) / (n - 1)
    cov = cov.view(-1, d, d)
    print(method, 'max abs difference from I:', torch.max(torch.abs(cov - torch.eye(5))))

zca max abs difference from I: tensor(0.0002)
zca_cor max abs difference from I: tensor(9.7871e-05)
pca max abs difference from I: tensor(0.0002)
pca_cor max abs difference from I: tensor(0.0002)
cholesky max abs difference from I: tensor(0.0003)


  warn("Cholesky whitening does not output a symmetric matrix. If you want to train a probe\


In [7]:
random_tensor = torch.randn(1000,100,2,64)

In [8]:
v1 = random_tensor - random_tensor.mean(0)

v2 = random_tensor.view(1000, 100*2, 64)
v2 = v2 - v2.mean(0)

print(torch.max(torch.abs(v1 - v2.view(1000, 100, 2, 64))))

tensor(0.)


In [9]:
covs = torch.einsum('nvd, nvD -> vdD', v2, v2)

In [10]:
U, S, V = torch.svd(covs)

In [11]:
print(U.shape, S.shape, V.shape)

Ss = 1 / torch.sqrt(S)
# Ss is 200 x 64, and we need to transform it to a batch of diagonal matrices of size 200 x 64 x 64
print(Ss.shape)
print(Ss[0])

Ss = torch.diag_embed(Ss)
print(Ss.shape)
print(Ss[0])


torch.Size([200, 64, 64]) torch.Size([200, 64]) torch.Size([200, 64, 64])
torch.Size([200, 64])
tensor([0.0254, 0.0255, 0.0258, 0.0264, 0.0265, 0.0267, 0.0269, 0.0274, 0.0276,
        0.0278, 0.0280, 0.0281, 0.0283, 0.0285, 0.0288, 0.0290, 0.0291, 0.0292,
        0.0296, 0.0296, 0.0298, 0.0299, 0.0301, 0.0304, 0.0306, 0.0308, 0.0310,
        0.0312, 0.0314, 0.0314, 0.0317, 0.0317, 0.0320, 0.0323, 0.0325, 0.0328,
        0.0329, 0.0330, 0.0334, 0.0335, 0.0337, 0.0338, 0.0341, 0.0344, 0.0348,
        0.0349, 0.0352, 0.0354, 0.0357, 0.0362, 0.0366, 0.0367, 0.0370, 0.0372,
        0.0377, 0.0381, 0.0382, 0.0386, 0.0391, 0.0395, 0.0400, 0.0405, 0.0414,
        0.0418])
torch.Size([200, 64, 64])
tensor([[0.0254, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0255, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0258,  ..., 0.0000, 0.0000, 0.0000],
        ...,
        [0.0000, 0.0000, 0.0000,  ..., 0.0405, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000,

In [12]:
x_normalized = torch.randn(1000, 200, 64)

torch.std(x_normalized, dim=0).shape

torch.Size([200, 64])