In [2]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [3]:
data = pd.read_csv('data/BO_2w_smooth.csv')

Q_11 = data['Q11']
Q_22 = data['Q22']

alphas_read = data['alpha']

In [None]:
from numba import njit
from tqdm import tqdm
from scipy.linalg import sqrtm

@njit
def softmax(x, beta=1):
    max_x = np.array([np.max(row) for row in x])

    x = x - max_x.reshape(-1,1)
    P = np.exp(beta*x)    
    return (P.T / np.sum(P,1)).T


@njit
def A(z1):
    return softmax(np.outer(z1,z1)) + np.eye(len(z1))


@njit
def g(z1, z2):
    A1 = A(z1) @ z2
    return softmax(np.outer(A1, A1))


samples = 100000
gen_error_list = np.zeros(len(Q_11))

for i in tqdm(range(len(Q_11))):
    Q = np.array([[Q_11[i], 0], [0, Q_22[i]]])
    gen_error = 0
    for _ in range(samples):

        Z = np.random.normal(0,1, (2,2))
        U = np.random.normal(0,1, (2,2))

        sqrt_Q = sqrtm(Q)
        sqrt_one_minus_Q = sqrtm(np.eye(Q.shape[0]) - Q)
        omega = sqrt_Q @ Z

        Z_true = sqrt_Q@Z + sqrt_one_minus_Q@U
        y_teacher = g(Z_true[0], Z_true[1])
        y_student = g(omega[0], omega[1])

        loss = np.linalg.norm(y_teacher - y_student)**2
        gen_error += loss

    gen_error_list[i] = gen_error / samples

    print(f"alpha = {alphas_read[i]}: {gen_error_list[i]}")


  1%|          | 1/181 [00:27<1:22:22, 27.46s/it]

alpha = 0.0: 0.3848642790595911


  1%|          | 2/181 [00:50<1:14:57, 25.13s/it]

alpha = 0.1435844654940499: 0.3863464685630178


  2%|▏         | 3/181 [01:19<1:18:53, 26.59s/it]

alpha = 0.1666666665: 0.38243430065185313


  2%|▏         | 4/181 [01:42<1:14:15, 25.17s/it]

alpha = 0.177952755727559: 0.38494889337698923


  3%|▎         | 5/181 [02:05<1:11:56, 24.53s/it]

alpha = 0.1892388449551181: 0.38216482116516787


  3%|▎         | 6/181 [02:29<1:10:58, 24.34s/it]

alpha = 0.2005249341826771: 0.3796747693071608


  4%|▍         | 7/181 [02:54<1:11:00, 24.49s/it]

alpha = 0.2118110234102362: 0.37788133215908853


  4%|▍         | 8/181 [03:14<1:06:56, 23.22s/it]

alpha = 0.2230971126377952: 0.3738615613394652


  5%|▍         | 9/181 [03:35<1:04:07, 22.37s/it]

alpha = 0.2343832018653543: 0.3705097337104113


  6%|▌         | 10/181 [03:56<1:02:22, 21.88s/it]

alpha = 0.2456692910929133: 0.36598942325673023


  6%|▌         | 11/181 [04:16<1:00:17, 21.28s/it]

alpha = 0.2569553803204724: 0.36271541147853115


  7%|▋         | 12/181 [04:37<59:42, 21.20s/it]  

alpha = 0.2682414695480315: 0.3587412942983997


  7%|▋         | 13/181 [04:58<59:33, 21.27s/it]

alpha = 0.2795275587755905: 0.3550496477939876


  8%|▊         | 14/181 [05:19<58:44, 21.11s/it]

alpha = 0.2908136480031495: 0.351641710260913


  8%|▊         | 15/181 [05:42<1:00:22, 21.82s/it]

alpha = 0.3020997372307086: 0.3486670492408768


  9%|▉         | 16/181 [06:04<1:00:11, 21.89s/it]

alpha = 0.3133858264582677: 0.3434888771737394


  9%|▉         | 17/181 [06:27<1:00:13, 22.04s/it]

alpha = 0.3246719156858267: 0.340623158861121


 10%|▉         | 18/181 [06:48<59:37, 21.95s/it]  

alpha = 0.3359580049133858: 0.33242607227107396


 10%|█         | 19/181 [07:10<58:54, 21.82s/it]

alpha = 0.3472440941409448: 0.32727638797375097


 11%|█         | 20/181 [07:31<57:48, 21.55s/it]

alpha = 0.3585301833685039: 0.32341601776432427


 12%|█▏        | 21/181 [07:52<57:16, 21.48s/it]

alpha = 0.369816272596063: 0.3144559890412132


 12%|█▏        | 22/181 [08:12<55:46, 21.05s/it]

alpha = 0.381102361823622: 0.30738526326072113


 13%|█▎        | 23/181 [08:33<55:08, 20.94s/it]

alpha = 0.3923884510511811: 0.2987083693504203


 13%|█▎        | 24/181 [08:55<55:38, 21.26s/it]

alpha = 0.4036745402787401: 0.2910213203526811


 14%|█▍        | 25/181 [09:20<58:28, 22.49s/it]

alpha = 0.4149606295062992: 0.27786065341361693


In [8]:
[print(f"{alphas_read[i]},{gen_error_list[i]}") for i in range(len(alphas_read))]

1

0.0,0.38790104430716293
0.0333333333,0.3863906533072164
0.0666666666,0.3873364436237081
0.0999999999,0.386496188506728
0.1333333332,0.387630446654175
0.1666666665,0.3853618503105733
0.1999999998,0.37943140255793223
0.2333333331,0.3714551620391603
0.2666666664,0.36006952922175506
0.2999999997,0.34930170951738093
0.333333333,0.3367345044354671
0.3666666663,0.31950025703925283
0.3999999996,0.29281596862582837
0.4333333329,0.26521979324409845
0.4666666662,0.2290801256030667
0.4999999995,0.18290987814475676
0.5333333328,0.14144536786699358
0.5666666661,0.10963932437731191
0.5999999994,0.0936929839179832
0.6333333327,0.07715788296286483
0.666666666,0.06711189035625276
0.6999999993,0.06360388166891703
0.7333333326,0.05551124290970098
0.7666666659,0.05330829745085989
0.7999999992,0.04656697225578734
0.8333333325,0.0356888556706998
0.8666666658,0.0340239109828602
0.8999999991,0.029007817411543086
0.9333333324,0.025145068887988502
0.9666666657,0.021862992937828875
0.999999999,0.02222946822097011

1

In [9]:
from numba import njit
from tqdm import tqdm
from scipy.linalg import sqrtm

@njit
def softmax(x, beta=1):
    max_x = np.array([np.max(row) for row in x])

    x = x - max_x.reshape(-1,1)
    P = np.exp(beta*x)    
    return (P.T / np.sum(P,1)).T


@njit
def A(z1):
    return softmax(np.outer(z1,z1)) + np.eye(len(z1))


@njit
def g(z1, z2):
    A1 = A(z1) @ z2
    return softmax(np.outer(A1, A1))


samples = 100000
gen_error_list = np.zeros(len(Q_11))

for i in tqdm(range(len(Q_11))):
    Q = np.array([[0, 0], [0, Q_22[i]]])
    gen_error = 0
    for _ in range(samples):

        Z = np.random.normal(0,1, (2,2))
        U = np.random.normal(0,1, (2,2))

        sqrt_Q = sqrtm(Q)
        sqrt_one_minus_Q = sqrtm(np.eye(Q.shape[0]) - Q)
        omega = sqrt_Q @ Z

        Z_true = sqrt_Q@Z + sqrt_one_minus_Q@U
        y_teacher = g(Z_true[0], Z_true[1])
        y_student = g(omega[0], omega[1])

        loss = np.linalg.norm(y_teacher - y_student)**2
        gen_error += loss

    gen_error_list[i] = gen_error / samples

    print(f"alpha = {alphas_read[i]}: {gen_error_list[i]}")


  2%|▏         | 1/48 [00:25<19:40, 25.12s/it]

alpha = 0.0: 0.38576924742253743


  4%|▍         | 2/48 [00:45<16:54, 22.06s/it]

alpha = 0.0333333333: 0.38983397654849256


  6%|▋         | 3/48 [01:10<17:38, 23.51s/it]

alpha = 0.0666666666: 0.3850014758635755


  8%|▊         | 4/48 [01:34<17:30, 23.87s/it]

alpha = 0.0999999999: 0.385858019658417


 10%|█         | 5/48 [01:59<17:23, 24.26s/it]

alpha = 0.1333333332: 0.38620451975527176


 12%|█▎        | 6/48 [02:22<16:38, 23.77s/it]

alpha = 0.1666666665: 0.3847197827928925


 15%|█▍        | 7/48 [02:46<16:14, 23.76s/it]

alpha = 0.1999999998: 0.3801219578078816


 17%|█▋        | 8/48 [03:11<16:14, 24.37s/it]

alpha = 0.2333333331: 0.3699448507361799


 19%|█▉        | 9/48 [03:36<15:53, 24.44s/it]

alpha = 0.2666666664: 0.3605752863131721


 21%|██        | 10/48 [03:57<14:48, 23.37s/it]

alpha = 0.2999999997: 0.3489031507809933


 23%|██▎       | 11/48 [04:18<14:00, 22.71s/it]

alpha = 0.333333333: 0.33666168477253305


 25%|██▌       | 12/48 [04:43<14:00, 23.36s/it]

alpha = 0.3666666663: 0.3210228414262696


 27%|██▋       | 13/48 [05:07<13:42, 23.49s/it]

alpha = 0.3999999996: 0.2910472694203277


 29%|██▉       | 14/48 [05:30<13:16, 23.43s/it]

alpha = 0.4333333329: 0.26272098324974125


 31%|███▏      | 15/48 [05:52<12:39, 23.02s/it]

alpha = 0.4666666662: 0.2299044128310558


 33%|███▎      | 16/48 [06:14<12:07, 22.74s/it]

alpha = 0.4999999995: 0.18411141652858354


 35%|███▌      | 17/48 [06:35<11:28, 22.20s/it]

alpha = 0.5333333328: 0.14222092977268705


 38%|███▊      | 18/48 [06:56<10:56, 21.90s/it]

alpha = 0.5666666661: 0.10946260442095497


 40%|███▉      | 19/48 [07:16<10:18, 21.33s/it]

alpha = 0.5999999994: 0.09405243524869136


 42%|████▏     | 20/48 [07:36<09:46, 20.95s/it]

alpha = 0.6333333327: 0.07827819193503689


 44%|████▍     | 21/48 [07:58<09:28, 21.06s/it]

alpha = 0.666666666: 0.06684159133233106


 46%|████▌     | 22/48 [08:19<09:10, 21.15s/it]

alpha = 0.6999999993: 0.06331542854006347


 48%|████▊     | 23/48 [08:39<08:38, 20.73s/it]

alpha = 0.7333333326: 0.05512545651053185


 50%|█████     | 24/48 [08:59<08:14, 20.59s/it]

alpha = 0.7666666659: 0.05268983636418027


 52%|█████▏    | 25/48 [09:20<07:58, 20.80s/it]

alpha = 0.7999999992: 0.04667273740044811


 54%|█████▍    | 26/48 [09:41<07:38, 20.85s/it]

alpha = 0.8333333325: 0.041005449980894865


 56%|█████▋    | 27/48 [10:03<07:20, 20.97s/it]

alpha = 0.8666666658: 0.042359338227994085


 58%|█████▊    | 28/48 [10:23<06:57, 20.87s/it]

alpha = 0.8999999991: 0.04016062642240004


 60%|██████    | 29/48 [10:44<06:35, 20.84s/it]

alpha = 0.9333333324: 0.03973070191805872


 62%|██████▎   | 30/48 [11:05<06:17, 20.95s/it]

alpha = 0.9666666657: 0.03802049324635839


 65%|██████▍   | 31/48 [11:27<05:59, 21.13s/it]

alpha = 0.999999999: 0.03820710554360341


 67%|██████▋   | 32/48 [11:49<05:44, 21.56s/it]

alpha = 1.0333333323: 0.03718376980857806


 69%|██████▉   | 33/48 [12:11<05:23, 21.57s/it]

alpha = 1.0666666656: 0.0361899097783006


 71%|███████   | 34/48 [12:33<05:02, 21.63s/it]

alpha = 1.0999999989: 0.03505968669631899


 73%|███████▎  | 35/48 [12:54<04:40, 21.57s/it]

alpha = 1.1333333322: 0.036134643501620406


 75%|███████▌  | 36/48 [13:14<04:12, 21.08s/it]

alpha = 1.1666666655: 0.03565080303510884


 77%|███████▋  | 37/48 [13:34<03:47, 20.68s/it]

alpha = 1.1999999988: 0.0367576683088817


 79%|███████▉  | 38/48 [13:55<03:27, 20.72s/it]

alpha = 1.2333333321: 0.03505808452832888


 81%|████████▏ | 39/48 [14:16<03:08, 20.92s/it]

alpha = 1.2666666654: 0.03681468110950428


 83%|████████▎ | 40/48 [14:37<02:46, 20.80s/it]

alpha = 1.2999999987: 0.03493682703063279


 85%|████████▌ | 41/48 [14:56<02:21, 20.25s/it]

alpha = 1.333333332: 0.035547569146528527


 88%|████████▊ | 42/48 [15:18<02:04, 20.77s/it]

alpha = 1.3666666653: 0.034885100919110845


 90%|████████▉ | 43/48 [15:39<01:44, 20.99s/it]

alpha = 1.3999999986: 0.034770648633320904


 92%|█████████▏| 44/48 [15:59<01:22, 20.75s/it]

alpha = 1.4333333319: 0.03609927534631465


 94%|█████████▍| 45/48 [16:20<01:01, 20.60s/it]

alpha = 1.4999999985: 0.034628734177195815


 96%|█████████▌| 46/48 [16:42<00:42, 21.16s/it]

alpha = 1.5333333318: 0.03564259218913689


 98%|█████████▊| 47/48 [17:04<00:21, 21.42s/it]

alpha = 1.5666666651: 0.0351472086849386


100%|██████████| 48/48 [17:25<00:00, 21.79s/it]

alpha = 1.5999999984: 0.03496864335677616





In [10]:
[print(f"{alphas_read[i]},{gen_error_list[i]}") for i in range(len(alphas_read))]

1

0.0,0.38576924742253743
0.0333333333,0.38983397654849256
0.0666666666,0.3850014758635755
0.0999999999,0.385858019658417
0.1333333332,0.38620451975527176
0.1666666665,0.3847197827928925
0.1999999998,0.3801219578078816
0.2333333331,0.3699448507361799
0.2666666664,0.3605752863131721
0.2999999997,0.3489031507809933
0.333333333,0.33666168477253305
0.3666666663,0.3210228414262696
0.3999999996,0.2910472694203277
0.4333333329,0.26272098324974125
0.4666666662,0.2299044128310558
0.4999999995,0.18411141652858354
0.5333333328,0.14222092977268705
0.5666666661,0.10946260442095497
0.5999999994,0.09405243524869136
0.6333333327,0.07827819193503689
0.666666666,0.06684159133233106
0.6999999993,0.06331542854006347
0.7333333326,0.05512545651053185
0.7666666659,0.05268983636418027
0.7999999992,0.04667273740044811
0.8333333325,0.041005449980894865
0.8666666658,0.042359338227994085
0.8999999991,0.04016062642240004
0.9333333324,0.03973070191805872
0.9666666657,0.03802049324635839
0.999999999,0.0382071055436034

1