In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import autograd.numpy as np
import matplotlib.pyplot as plt

from autograd import grad
from pymanopt.manifolds import Oblique
from pymanopt.solvers import SteepestDescent
from oblique import *

# Oblique Manifold

If a cost function is a strict-saddle one, I think Riemannian gradient descent (RGD) + noise can escape strict saddles. I'll try RGD with / without noise. 

Add noise:
$$ 
\mathbf x^+ = \text{R}_{\mathbf x}(\mathbf x + \mathbf n) 
$$

It seems noise is not necessary as in Prof.'s lecture on hybrid GD.

In [3]:
# dir(Oblique)

In [4]:
grad_sine = grad(np.sin)
grad_sine(np.pi)

-1.0

## Objective 1 of "Neural Collapse with CE Loss"

In [5]:
M, N = 10, 5
solver = SteepestDescent()
manifold = Oblique(M, N)
prob = Problem(manifold=manifold, cost=lr_cost)
Xopt = solver.solve(prob)
# Xopt

Compiling cost function...
Computing gradient of cost function...
 iter		   cost val	    grad. norm
    1	+4.4580955492631968e+00	1.13171913e+00
    2	+3.8611187230605779e+00	1.98613002e-01
    3	+3.8249338031403362e+00	1.22442919e-01
    4	+3.8198016331282023e+00	6.68374376e-02
    5	+3.8185372225081076e+00	3.51035773e-02
    6	+3.8181362928263725e+00	1.29747562e-02
    7	+3.8180947559961096e+00	7.63267539e-03
    8	+3.8180907335211680e+00	6.89478057e-03
    9	+3.8180782779780991e+00	3.77997803e-03
   10	+3.8180730693266369e+00	6.18661153e-04
   11	+3.8180729286013628e+00	8.38039636e-05
   12	+3.8180729271572322e+00	5.62941063e-05
   13	+3.8180729260259962e+00	1.21977251e-05
   14	+3.8180729259742918e+00	3.27795972e-06
   15	+3.8180729259717445e+00	1.98611098e-06
   16	+3.8180729259710415e+00	1.43816566e-06
   17	+3.8180729259707809e+00	1.17198730e-06
   18	+3.8180729259702670e+00	1.57193607e-08
Terminated - min grad norm reached after 18 iterations, 0.12 seconds.



In [6]:
check_etf(Xopt)

Tests passed!


## Objective 2

In [7]:
M, N = 10, 6
solver = SteepestDescent()
manifold = Oblique(M, N)
prob = Problem(manifold=manifold, cost=lr_cost_weights)
Xopt = solver.solve(prob)

Compiling cost function...
Computing gradient of cost function...
 iter		   cost val	    grad. norm
    1	+5.2138611617459629e+00	5.08365835e+00
    2	+2.4256250596624112e+00	1.66656465e+00
    3	+1.8610699045683621e+00	1.36857888e+00
    4	+1.6101910546932048e+00	9.48082754e-01
    5	+1.4444183686721681e+00	5.66931485e-01
    6	+1.3695332246320087e+00	3.20336002e-01
    7	+1.3419833418081359e+00	1.12967273e-01
    8	+1.3388212329980105e+00	1.08933846e-02
    9	+1.3387936881207549e+00	6.70808813e-03
   10	+1.3387822450288183e+00	1.79998003e-03
   11	+1.3387813901705097e+00	1.20373131e-03
   12	+1.3387810087494350e+00	3.08086096e-04
   13	+1.3387809760345331e+00	2.25929846e-04
   14	+1.3387809630212066e+00	7.34294084e-05
   15	+1.3387809628503593e+00	8.80671405e-05
   16	+1.3387809622271318e+00	7.25667254e-05
   17	+1.3387809609115355e+00	6.04277376e-06
   18	+1.3387809608966819e+00	4.64974711e-06
   19	+1.3387809608916179e+00	1.82480787e-06
   20	+1.3387809608908561e+00	9.53532138e-07


In [8]:
check_etf(Xopt[:, :N // 2], verbose=True)

[1. 1. 1.]
[[ 1.  -0.5 -0.5]
 [-0.5  1.  -0.5]
 [-0.5 -0.5  1. ]]
Tests passed!


In [9]:
check_etf(Xopt[:, N // 2:], verbose=True)

[1. 1. 1.]
[[ 1.  -0.5 -0.5]
 [-0.5  1.  -0.5]
 [-0.5 -0.5  1. ]]
Tests passed!


## Weight Decay

In [10]:
num_cls = 10
N = 3
np.kron(np.arange(num_cls), np.ones((N,)))

array([0., 0., 0., 1., 1., 1., 2., 2., 2., 3., 3., 3., 4., 4., 4., 5., 5.,
       5., 6., 6., 6., 7., 7., 7., 8., 8., 8., 9., 9., 9.])

In [11]:
num_cls = 10
sample_per_cls = 10
N = sample_per_cls * num_cls
M = 15
solver = SteepestDescent()
manifold = Product((Oblique(M, num_cls), Oblique(M, N), Euclidean(num_cls)))
prob = Problem(manifold=manifold, cost=make_lr_weight_decay())
Xopt = solver.solve(prob)

Compiling cost function...
Computing gradient of cost function...
 iter		   cost val	    grad. norm
    1	+2.2824155917492575e+01	1.93013281e+00
    2	+2.1128124909185441e+01	1.45496375e+00
    3	+1.9241096341396510e+01	5.61531267e-01
    4	+1.8757381793514341e+01	1.23864538e-01
    5	+1.8624784241787577e+01	3.77220665e-01
    6	+1.8440315366324434e+01	1.76954171e-01
    7	+1.8423933122530446e+01	2.20533891e-01
    8	+1.8372708089508279e+01	1.29607874e-01
    9	+1.8345122680670080e+01	1.64812805e-01
   10	+1.8305494889709443e+01	1.19008554e-01
   11	+1.8278008477513918e+01	1.34206003e-01
   12	+1.8262545417378259e+01	1.72662954e-01
   13	+1.8221126604726329e+01	7.67869308e-02
   14	+1.8182812245277365e+01	1.64246756e-01
   15	+1.8179688516504758e+01	1.97470741e-01
   16	+1.8167865983068062e+01	1.76318956e-01
   17	+1.8132553065702574e+01	9.22554810e-02
   18	+1.8130500877200554e+01	1.75329053e-01
   19	+1.8122711692415987e+01	1.57544457e-01
   20	+1.8098719811156620e+01	8.96615205e-02


  231	+1.7876957729874025e+01	3.74228196e-04
  232	+1.7876946741912530e+01	2.50586211e-03
  233	+1.7876939983138076e+01	9.88741213e-04
  234	+1.7876938869126601e+01	3.95801060e-04
  235	+1.7876938647436450e+01	2.13470443e-04
  236	+1.7876938634994012e+01	4.56376228e-04
  237	+1.7876938587345165e+01	4.17684594e-04
  238	+1.7876938433070453e+01	2.65160467e-04
  239	+1.7876938312629807e+01	2.18771595e-04
  240	+1.7876938209041842e+01	2.17322962e-04
  241	+1.7876938106158956e+01	1.85877295e-04
  242	+1.7876938007049354e+01	2.16433050e-04
  243	+1.7876937908917721e+01	1.77120222e-04
  244	+1.7876937814873841e+01	2.15372908e-04
  245	+1.7876937721048254e+01	1.67227986e-04
  246	+1.7876937630628614e+01	2.15269800e-04
  247	+1.7876937540390450e+01	1.58669063e-04
  248	+1.7876937453543686e+01	2.15268052e-04
  249	+1.7876937452677357e+01	3.68525417e-04
  250	+1.7876937449227722e+01	3.65192024e-04
  251	+1.7876937435681814e+01	3.51850707e-04
  252	+1.7876937385655093e+01	2.98511230e-04
  253	+1.7

  478	+1.7876934920565947e+01	1.42935087e-06
  479	+1.7876934920565873e+01	2.49468191e-06
  480	+1.7876934920565571e+01	2.45232371e-06
  481	+1.7876934920564409e+01	2.27928821e-06
  482	+1.7876934920560480e+01	1.59686307e-06
  483	+1.7876934920558973e+01	2.14249204e-06
  484	+1.7876934920554305e+01	1.24265355e-06
  485	+1.7876934920551552e+01	1.54699455e-06
  486	+1.7876934920548031e+01	1.34887736e-06
  487	+1.7876934920544628e+01	9.68790603e-07
Terminated - min grad norm reached after 487 iterations, 1.24 seconds.



In [12]:
W, H, b = Xopt

In [13]:
check_etf(W, verbose=True, atol=1e-4)

[1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]
[[ 1.         -0.11109537 -0.11113914 -0.11108789 -0.1110932  -0.11113914
  -0.11112241 -0.11111815 -0.11108726 -0.1111174 ]
 [-0.11109537  1.         -0.11109997 -0.11112964 -0.11110966 -0.11111899
  -0.11110289 -0.11109897 -0.11113087 -0.11111364]
 [-0.11113914 -0.11109997  1.         -0.11110381 -0.1111119  -0.11111644
  -0.1111302  -0.11111289 -0.11108665 -0.11109896]
 [-0.11108789 -0.11112964 -0.11110381  1.         -0.11112793 -0.11109949
  -0.11111074 -0.11110653 -0.1111268  -0.11110717]
 [-0.1110932  -0.11110966 -0.1111119  -0.11112793  1.         -0.11109432
  -0.11111453 -0.11112281 -0.11112239 -0.11110327]
 [-0.11113914 -0.11111899 -0.11111644 -0.11109949 -0.11109432  1.
  -0.11111466 -0.11109751 -0.11110284 -0.11111657]
 [-0.11112241 -0.11110289 -0.1111302  -0.11111074 -0.11111453 -0.11111466
   1.         -0.11111008 -0.11109528 -0.1110992 ]
 [-0.11111815 -0.11109897 -0.11111289 -0.11110653 -0.11112281 -0.11109751
  -0.11111008  1.         -

In [14]:
b

array([-8.42298980e-07, -3.64552279e-08, -9.27569626e-07, -8.32376497e-08,
        3.69042552e-10, -6.17086435e-07, -4.84012306e-07,  3.64018385e-07,
        2.34573121e-06,  2.80541586e-07])

In [15]:
np.set_printoptions(precision=4, linewidth=100)

In [16]:
H

array([[ 0.3432,  0.3432,  0.3432, ...,  0.4595,  0.4595,  0.4595],
       [ 0.1834,  0.1834,  0.1834, ...,  0.2312,  0.2312,  0.2312],
       [-0.0712, -0.0712, -0.0712, ..., -0.2954, -0.2954, -0.2954],
       ...,
       [-0.1813, -0.1813, -0.1813, ..., -0.189 , -0.189 , -0.189 ],
       [-0.3821, -0.3821, -0.3821, ...,  0.0253,  0.0253,  0.0253],
       [ 0.1351,  0.1351,  0.1351, ...,  0.0973,  0.0973,  0.0973]])

In [17]:
H[:, :sample_per_cls]

array([[ 0.3432,  0.3432,  0.3432,  0.3432,  0.3432,  0.3432,  0.3432,  0.3432,  0.3432,  0.3432],
       [ 0.1834,  0.1834,  0.1834,  0.1834,  0.1834,  0.1834,  0.1834,  0.1834,  0.1834,  0.1834],
       [-0.0712, -0.0712, -0.0712, -0.0712, -0.0712, -0.0712, -0.0712, -0.0712, -0.0712, -0.0712],
       [-0.3534, -0.3534, -0.3534, -0.3534, -0.3534, -0.3534, -0.3534, -0.3534, -0.3534, -0.3534],
       [ 0.1671,  0.1671,  0.1671,  0.1671,  0.1671,  0.1671,  0.1671,  0.1671,  0.1671,  0.1671],
       [ 0.1386,  0.1386,  0.1386,  0.1386,  0.1386,  0.1386,  0.1386,  0.1386,  0.1386,  0.1386],
       [ 0.038 ,  0.038 ,  0.038 ,  0.038 ,  0.038 ,  0.038 ,  0.038 ,  0.038 ,  0.038 ,  0.038 ],
       [-0.5118, -0.5118, -0.5118, -0.5118, -0.5118, -0.5118, -0.5118, -0.5118, -0.5118, -0.5118],
       [ 0.2747,  0.2747,  0.2747,  0.2747,  0.2747,  0.2747,  0.2747,  0.2747,  0.2747,  0.2747],
       [ 0.2023,  0.2023,  0.2023,  0.2023,  0.2023,  0.2023,  0.2023,  0.2023,  0.2023,  0.2023],
       [-0

### Observances
1. If I run the problem solving cell multiple times the solution will change due to orthogonal transform, but **Thm. 3.1**'s results hold.
1. $\mathbf W$, $\mathbf H$ must both be oblique manifold. My understanding is since they are dual, they must be in the same structure.
1. $\mathbf b$ vanishes just as expected.

In [26]:
check_duality(W, H)

ETF Tests passed!
All tests passed!
