In [22]:
import torch
import numpy as np
import matplotlib.pyplot as plt

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")

d      = 200
k      = 5
U, V   = k ** (-1/4) * torch.randn(2, d, k, device=device)
A_star = U @ V.T
E      = torch.randn(d, d, device=device)

def run(gamma_sigma_sq, gamma_w):
    w        = d ** gamma_w
    sigma_sq = d ** gamma_sigma_sq

    def eta():
        if gamma_w + gamma_sigma_sq > 1:
            return 1 / (50 * w * sigma_sq)
        else:
            return 1 / (50 * torch.linalg.norm(A_star, ord=2).cpu())

    w = int(w)
    W1 = sigma_sq ** (1/2) * torch.randn(w, d, device=device)
    W2 = sigma_sq ** (1/2) * torch.randn(d, w, device=device)

    losses, gen_losses, svs, ratios, ranks, stable_ranks  = [], [], [], [], [], []    
    
    for i in range(750):
        A = W2 @ W1
        
        loss = torch.norm(A - (A_star + E)) ** 2 / d ** 2
        gen_loss = torch.norm(A - A_star) ** 2 / d ** 2
        losses.append(loss.cpu())
        gen_losses.append(gen_loss.cpu())
        svs.append(torch.svd(A, compute_uv=False)[1].cpu().numpy())
        ratios.append(svs[-1][k-1]/svs[-1][k])
        ranks.append(torch.linalg.matrix_rank(A, atol = svs[-1][0]/5).cpu().numpy())
        stable_ranks.append((torch.linalg.norm(A)/torch.linalg.norm(A, ord=2)).cpu().numpy() ** 2)
        
        W1 = W1 - eta() * W2.T @ (A - (A_star + E))
        W2 = W2 - eta() * (A - (A_star + E)) @ W1.T
        
        pp = lambda: print(i, loss.cpu().numpy(), gen_loss.cpu().numpy())
        if i % 50 == 0: pp()
            
        if loss < 1e-3:
            pp(); break      
            
        if i > 3:
            if gen_losses[-3] < gen_losses[-2] < gen_losses[-1]:
                pp(); break
                
            if np.abs(losses[-1] - losses[-2]) < 2e-4 and losses[0] - losses[-1] > 1e-2:
                pp(); break
    
    active = torch.linalg.norm(W2 @ W1, ord=2).cpu().numpy() / w * sigma_sq
    return losses, gen_losses, svs, ratios, ranks, stable_ranks, active, i

In [23]:
xs, ys, zs = [], [], []
gss = np.linspace(-2.2, -0.8, 30)
gws = np.linspace(1.75, 2.8, 30)

for gs in gss:
    for gw in gws:
        print('\n', gs, gw)
        g = run(gs, gw)
        xs.append(gs)
        ys.append(gw)
        zs.append(g)


 -2.2 1.75
0 2.0022235 1.0121952
50 1.992288 1.0060841
100 1.953832 0.9732312
150 1.7649351 0.7982297
200 1.2789859 0.3492563
250 0.96560115 0.081441864
300 0.90654516 0.052496336
304 0.9045082 0.0524916

 -2.2 1.7862068965517242
0 2.002222 1.0122068
50 1.990295 1.0049024
100 1.9446328 0.96598643
150 1.7280236 0.76577306
200 1.2268336 0.30528098
250 0.94789433 0.07263268
298 0.89773184 0.052704472

 -2.2 1.8224137931034483
0 2.0022428 1.0122154
50 1.9877261 1.0033085
100 1.9324113 0.95618886
150 1.6812738 0.7246251
200 1.1723814 0.26046708
250 0.93119013 0.066068456
292 0.88994277 0.053017315

 -2.2 1.8586206896551725
0 2.0022876 1.0122361
50 1.984778 1.0015043
100 1.9187618 0.9453399
150 1.6331139 0.6826535
200 1.124037 0.22212665
250 0.9155194 0.06138955
285 0.88195527 0.05328789

 -2.2 1.8948275862068966
0 2.0021925 1.0121818
50 1.9808285 0.99905366
100 1.9006847 0.9308911
150 1.5743095 0.6317728
200 1.0747076 0.1848894
250 0.8990263 0.0581647
278 0.8723619 0.0537323

 -2.2 1.93103

In [24]:
import pickle

with open('mse_exp.pickle', 'wb') as handle:
    pickle.dump([xs, ys, zs, gss, gws], handle, protocol=pickle.HIGHEST_PROTOCOL)