In [2]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [5]:
data = pd.read_csv('data/BO_3w.csv')

Q_11 = data['Q11']
Q_22 = data['Q22']
Q_33 = data['Q33']

alphas_read = data['alpha']

In [9]:
from numba import njit
from tqdm import tqdm
from scipy.linalg import sqrtm


@njit
def softmax(x, beta=1):
    max_x = np.array([np.max(row) for row in x])

    x = x - max_x.reshape(-1,1)
    P = np.exp(beta*x)    
    return (P.T / np.sum(P,1)).T


@njit
def A(z1):
    return softmax(np.outer(z1,z1)) + np.eye(len(z1))


@njit
def B(z1, z2):
    A1 = A(z1)
    A2 = A(A1 @ z2)
    return A2 @ A1


@njit
def g(z1, z2, z3):
    B1 = B(z1, z2) @ z3
    return softmax(np.outer(B1, B1))


samples = 100000
gen_error_list = np.zeros(len(Q_11))

for i in tqdm(range(len(Q_11))):
    Q = np.array([[Q_11[i], 0, 0], [0, Q_22[i], 0], [0, 0, Q_33[i]]])
    gen_error = 0
    for _ in range(samples):

        Z = np.random.normal(0,1, (3,2))
        U = np.random.normal(0,1, (3,2))

        sqrt_Q = sqrtm(Q)
        sqrt_one_minus_Q = sqrtm(np.eye(Q.shape[0]) - Q)
        omega = sqrt_Q @ Z

        Z_true = sqrt_Q@Z + sqrt_one_minus_Q@U
        y_teacher = g(Z_true[0], Z_true[1], Z_true[2])
        y_student = g(omega[0], omega[1], omega[2])

        loss = np.linalg.norm(y_teacher - y_student)**2
        gen_error += loss

    gen_error_list[i] = gen_error / samples

    print(f"alpha = {alphas_read[i]}: {gen_error_list[i]}")


  0%|          | 0/49 [00:00<?, ?it/s]

  2%|▏         | 1/49 [00:28<22:34, 28.21s/it]

alpha = 0.0: 0.5849756569798746


  4%|▍         | 2/49 [00:50<19:29, 24.89s/it]

alpha = 0.0625: 0.5843271788812764


  6%|▌         | 3/49 [01:13<18:19, 23.90s/it]

alpha = 0.125: 0.5834411704371758


  8%|▊         | 4/49 [01:36<17:31, 23.36s/it]

alpha = 0.1875: 0.5841458338923836


 10%|█         | 5/49 [01:59<17:08, 23.37s/it]

alpha = 0.25: 0.5784702115464376


 12%|█▏        | 6/49 [02:22<16:36, 23.17s/it]

alpha = 0.3125: 0.5874128951216233


 14%|█▍        | 7/49 [02:44<16:03, 22.95s/it]

alpha = 0.375: 0.5908169279305303


 16%|█▋        | 8/49 [03:07<15:35, 22.83s/it]

alpha = 0.4375: 0.5549054593102759


 18%|█▊        | 9/49 [03:29<15:09, 22.73s/it]

alpha = 0.5: 0.4725599407910743


 20%|██        | 10/49 [03:52<14:49, 22.81s/it]

alpha = 0.5625: 0.3770265772449534


 22%|██▏       | 11/49 [04:15<14:25, 22.76s/it]

alpha = 0.625: 0.3001588075886262


 24%|██▍       | 12/49 [04:38<14:02, 22.76s/it]

alpha = 0.6875: 0.2404559458460032


 27%|██▋       | 13/49 [05:01<13:40, 22.79s/it]

alpha = 0.75: 0.2054596889922188


 29%|██▊       | 14/49 [05:23<13:16, 22.77s/it]

alpha = 0.8125: 0.1840042022135637


 31%|███       | 15/49 [05:46<12:54, 22.77s/it]

alpha = 0.875: 0.1628542563306462


 33%|███▎      | 16/49 [06:10<12:46, 23.21s/it]

alpha = 0.9375: 0.15204252475218893


 35%|███▍      | 17/49 [06:33<12:19, 23.11s/it]

alpha = 1.0: 0.14370831268090817


 37%|███▋      | 18/49 [06:56<11:53, 23.03s/it]

alpha = 1.0625: 0.1361410805308254


 39%|███▉      | 19/49 [07:19<11:28, 22.96s/it]

alpha = 1.125: 0.12999479499456443


 41%|████      | 20/49 [07:42<11:06, 22.97s/it]

alpha = 1.1875: 0.12128936048124529


 43%|████▎     | 21/49 [08:05<10:42, 22.94s/it]

alpha = 1.25: 0.11667110363062329


 45%|████▍     | 22/49 [08:27<10:17, 22.88s/it]

alpha = 1.3125: 0.11161167314817247


 47%|████▋     | 23/49 [08:50<09:55, 22.91s/it]

alpha = 1.375: 0.1061489926757768


 49%|████▉     | 24/49 [09:13<09:32, 22.89s/it]

alpha = 1.4375: 0.10057491332648912


 51%|█████     | 25/49 [09:38<09:21, 23.38s/it]

alpha = 1.5: 0.0977731757365553


 53%|█████▎    | 26/49 [10:00<08:51, 23.13s/it]

alpha = 1.5625: 0.09551362671395895


 55%|█████▌    | 27/49 [10:23<08:23, 22.87s/it]

alpha = 1.625: 0.09371748702223848


 57%|█████▋    | 28/49 [10:45<07:57, 22.72s/it]

alpha = 1.6875: 0.090564976108122


 59%|█████▉    | 29/49 [11:08<07:34, 22.71s/it]

alpha = 1.75: 0.0891716449742239


 61%|██████    | 30/49 [11:30<07:09, 22.61s/it]

alpha = 1.8125: 0.08665048212259818


 63%|██████▎   | 31/49 [11:52<06:46, 22.56s/it]

alpha = 1.875: 0.08643279862713567


 65%|██████▌   | 32/49 [12:16<06:27, 22.79s/it]

alpha = 1.9375: 0.08432369870237369


 67%|██████▋   | 33/49 [12:42<06:20, 23.77s/it]

alpha = 2.0: 0.08330255198251221


 69%|██████▉   | 34/49 [13:07<06:02, 24.13s/it]

alpha = 2.0625: 0.08331761630331404


 71%|███████▏  | 35/49 [13:32<05:40, 24.35s/it]

alpha = 2.125: 0.08226766152639862


 73%|███████▎  | 36/49 [13:54<05:09, 23.83s/it]

alpha = 2.1875: 0.08216093166618896


 76%|███████▌  | 37/49 [14:17<04:41, 23.45s/it]

alpha = 2.25: 0.08055844264457905


 78%|███████▊  | 38/49 [14:39<04:14, 23.15s/it]

alpha = 2.3125: 0.07949326612151697


 80%|███████▉  | 39/49 [15:04<03:55, 23.54s/it]

alpha = 2.375: 0.07809164174904266


 82%|████████▏ | 40/49 [15:30<03:40, 24.49s/it]

alpha = 2.4375: 0.0782956170581082


 84%|████████▎ | 41/49 [15:57<03:20, 25.07s/it]

alpha = 2.5: 0.07890963145747046


 86%|████████▌ | 42/49 [16:21<02:54, 24.92s/it]

alpha = 2.5625: 0.07772346115684943


 88%|████████▊ | 43/49 [16:44<02:25, 24.28s/it]

alpha = 2.625: 0.07600163678117194


 90%|████████▉ | 44/49 [17:07<01:58, 23.78s/it]

alpha = 2.6875: 0.07585326177863456


 92%|█████████▏| 45/49 [17:30<01:34, 23.59s/it]

alpha = 2.75: 0.076657021455768


 94%|█████████▍| 46/49 [17:53<01:10, 23.34s/it]

alpha = 2.8125: 0.07612236335551027


 96%|█████████▌| 47/49 [18:15<00:46, 23.09s/it]

alpha = 2.875: 0.07511244319613518


 98%|█████████▊| 48/49 [18:38<00:22, 22.97s/it]

alpha = 2.9375: 0.074726974485367


100%|██████████| 49/49 [19:01<00:00, 23.29s/it]

alpha = 3.0: 0.07403265889743416





In [10]:
[print(f"{alphas_read[i]},{gen_error_list[i]}") for i in range(len(alphas_read))]

1

0.0,0.5849756569798746
0.0625,0.5843271788812764
0.125,0.5834411704371758
0.1875,0.5841458338923836
0.25,0.5784702115464376
0.3125,0.5874128951216233
0.375,0.5908169279305303
0.4375,0.5549054593102759
0.5,0.4725599407910743
0.5625,0.3770265772449534
0.625,0.3001588075886262
0.6875,0.2404559458460032
0.75,0.2054596889922188
0.8125,0.1840042022135637
0.875,0.1628542563306462
0.9375,0.15204252475218893
1.0,0.14370831268090817
1.0625,0.1361410805308254
1.125,0.12999479499456443
1.1875,0.12128936048124529
1.25,0.11667110363062329
1.3125,0.11161167314817247
1.375,0.1061489926757768
1.4375,0.10057491332648912
1.5,0.0977731757365553
1.5625,0.09551362671395895
1.625,0.09371748702223848
1.6875,0.090564976108122
1.75,0.0891716449742239
1.8125,0.08665048212259818
1.875,0.08643279862713567
1.9375,0.08432369870237369
2.0,0.08330255198251221
2.0625,0.08331761630331404
2.125,0.08226766152639862
2.1875,0.08216093166618896
2.25,0.08055844264457905
2.3125,0.07949326612151697
2.375,0.07809164174904266
2.4

1

In [11]:
from numba import njit
from tqdm import tqdm
from scipy.linalg import sqrtm

@njit
def softmax(x, beta=1):
    max_x = np.array([np.max(row) for row in x])

    x = x - max_x.reshape(-1,1)
    P = np.exp(beta*x)    
    return (P.T / np.sum(P,1)).T


@njit
def A(z1):
    return softmax(np.outer(z1,z1)) + np.eye(len(z1))


@njit
def B(z1, z2):
    A1 = A(z1)
    A2 = A(A1 @ z2)
    return A2 @ A1


@njit
def g(z1, z2, z3):
    B1 = B(z1, z2) @ z3
    return softmax(np.outer(B1, B1))


samples = 100000
gen_error_list = np.zeros(len(Q_11))

for i in tqdm(range(len(Q_11))):
    Q = np.array([[0, 0, 0], [0, 0, 0], [0, 0, Q_33[i]]])
    gen_error = 0
    for _ in range(samples):

        Z = np.random.normal(0,1, (3,2))
        U = np.random.normal(0,1, (3,2))

        sqrt_Q = sqrtm(Q)
        sqrt_one_minus_Q = sqrtm(np.eye(Q.shape[0]) - Q)
        omega = sqrt_Q @ Z

        Z_true = sqrt_Q@Z + sqrt_one_minus_Q@U
        y_teacher = g(Z_true[0], Z_true[1], Z_true[2])
        y_student = g(omega[0], omega[1], omega[2])

        loss = np.linalg.norm(y_teacher - y_student)**2
        gen_error += loss

    gen_error_list[i] = gen_error / samples

    print(f"alpha = {alphas_read[i]}: {gen_error_list[i]}")

  2%|▏         | 1/49 [00:28<23:00, 28.77s/it]

alpha = 0.0: 0.5841452984069229


  4%|▍         | 2/49 [00:52<20:07, 25.70s/it]

alpha = 0.0625: 0.5840106244987391


  6%|▌         | 3/49 [01:14<18:33, 24.20s/it]

alpha = 0.125: 0.5833305157070101


  8%|▊         | 4/49 [01:36<17:33, 23.40s/it]

alpha = 0.1875: 0.5851799457682222


 10%|█         | 5/49 [01:59<16:52, 23.00s/it]

alpha = 0.25: 0.5794285750184182


 12%|█▏        | 6/49 [02:21<16:20, 22.80s/it]

alpha = 0.3125: 0.5860315865001978


 14%|█▍        | 7/49 [02:44<15:52, 22.67s/it]

alpha = 0.375: 0.5854370582668691


 16%|█▋        | 8/49 [03:06<15:25, 22.58s/it]

alpha = 0.4375: 0.5488822019991927


 18%|█▊        | 9/49 [03:29<15:03, 22.60s/it]

alpha = 0.5: 0.47451091815640634


 20%|██        | 10/49 [03:51<14:36, 22.48s/it]

alpha = 0.5625: 0.38075774675651775


 22%|██▏       | 11/49 [04:13<14:14, 22.48s/it]

alpha = 0.625: 0.30137274190244573


 24%|██▍       | 12/49 [04:36<13:53, 22.53s/it]

alpha = 0.6875: 0.2411445893582055


 27%|██▋       | 13/49 [04:59<13:32, 22.56s/it]

alpha = 0.75: 0.20552106278741253


 29%|██▊       | 14/49 [05:26<14:06, 24.19s/it]

alpha = 0.8125: 0.18110879807634797


 31%|███       | 15/49 [05:52<13:55, 24.58s/it]

alpha = 0.875: 0.16365309393769106


 33%|███▎      | 16/49 [06:17<13:34, 24.69s/it]

alpha = 0.9375: 0.15398749486824034


 35%|███▍      | 17/49 [06:39<12:45, 23.93s/it]

alpha = 1.0: 0.14440068405413048


 37%|███▋      | 18/49 [07:00<11:55, 23.09s/it]

alpha = 1.0625: 0.13874995950267413


 39%|███▉      | 19/49 [07:21<11:14, 22.47s/it]

alpha = 1.125: 0.13389340602128216


 41%|████      | 20/49 [07:42<10:36, 21.96s/it]

alpha = 1.1875: 0.1317563043301489


 43%|████▎     | 21/49 [08:04<10:15, 21.97s/it]

alpha = 1.25: 0.12906308710810135


 45%|████▍     | 22/49 [08:27<10:00, 22.23s/it]

alpha = 1.3125: 0.12575084139374487


 47%|████▋     | 23/49 [08:52<09:56, 22.96s/it]

alpha = 1.375: 0.12323934912571877


 49%|████▉     | 24/49 [09:15<09:39, 23.19s/it]

alpha = 1.4375: 0.12121165810981263


 51%|█████     | 25/49 [09:41<09:32, 23.84s/it]

alpha = 1.5: 0.12074590135326058


 53%|█████▎    | 26/49 [10:07<09:22, 24.47s/it]

alpha = 1.5625: 0.11696317449049205


 55%|█████▌    | 27/49 [10:33<09:09, 24.99s/it]

alpha = 1.625: 0.1174693109362527


 57%|█████▋    | 28/49 [10:59<08:55, 25.50s/it]

alpha = 1.6875: 0.11790824450203112


 59%|█████▉    | 29/49 [11:26<08:35, 25.75s/it]

alpha = 1.75: 0.1164068270837184


 61%|██████    | 30/49 [11:52<08:10, 25.80s/it]

alpha = 1.8125: 0.11684210555907215


 63%|██████▎   | 31/49 [12:17<07:44, 25.78s/it]

alpha = 1.875: 0.1156201537241975


 65%|██████▌   | 32/49 [12:40<07:03, 24.92s/it]

alpha = 1.9375: 0.11538462882535137


 67%|██████▋   | 33/49 [13:03<06:27, 24.25s/it]

alpha = 2.0: 0.11638639669674067


 69%|██████▉   | 34/49 [13:26<05:59, 23.95s/it]

alpha = 2.0625: 0.11410085331050623


 71%|███████▏  | 35/49 [13:53<05:48, 24.91s/it]

alpha = 2.125: 0.11496516691860562


 73%|███████▎  | 36/49 [14:17<05:17, 24.40s/it]

alpha = 2.1875: 0.11354679413501043


 76%|███████▌  | 37/49 [14:39<04:44, 23.69s/it]

alpha = 2.25: 0.11500764532288894


 78%|███████▊  | 38/49 [15:04<04:25, 24.10s/it]

alpha = 2.3125: 0.11302030542398575


 80%|███████▉  | 39/49 [15:28<04:02, 24.26s/it]

alpha = 2.375: 0.11180073747927649


 82%|████████▏ | 40/49 [15:52<03:37, 24.12s/it]

alpha = 2.4375: 0.11446839035113202


 84%|████████▎ | 41/49 [16:16<03:12, 24.00s/it]

alpha = 2.5: 0.11399143546466184


 86%|████████▌ | 42/49 [16:38<02:43, 23.32s/it]

alpha = 2.5625: 0.11284892104378647


 88%|████████▊ | 43/49 [17:02<02:21, 23.56s/it]

alpha = 2.625: 0.11272537640512019


 90%|████████▉ | 44/49 [17:26<01:58, 23.77s/it]

alpha = 2.6875: 0.11339353716092643


 92%|█████████▏| 45/49 [17:49<01:33, 23.50s/it]

alpha = 2.75: 0.11221060004104903


 94%|█████████▍| 46/49 [18:13<01:11, 23.75s/it]

alpha = 2.8125: 0.11250614061369821


 96%|█████████▌| 47/49 [18:38<00:48, 24.22s/it]

alpha = 2.875: 0.11146171027407747


 98%|█████████▊| 48/49 [19:01<00:23, 23.64s/it]

alpha = 2.9375: 0.11253263114575122


100%|██████████| 49/49 [19:23<00:00, 23.74s/it]

alpha = 3.0: 0.11247131164153605





In [12]:
[print(f"{alphas_read[i]},{gen_error_list[i]}") for i in range(len(alphas_read))]

1

0.0,0.5841452984069229
0.0625,0.5840106244987391
0.125,0.5833305157070101
0.1875,0.5851799457682222
0.25,0.5794285750184182
0.3125,0.5860315865001978
0.375,0.5854370582668691
0.4375,0.5488822019991927
0.5,0.47451091815640634
0.5625,0.38075774675651775
0.625,0.30137274190244573
0.6875,0.2411445893582055
0.75,0.20552106278741253
0.8125,0.18110879807634797
0.875,0.16365309393769106
0.9375,0.15398749486824034
1.0,0.14440068405413048
1.0625,0.13874995950267413
1.125,0.13389340602128216
1.1875,0.1317563043301489
1.25,0.12906308710810135
1.3125,0.12575084139374487
1.375,0.12323934912571877
1.4375,0.12121165810981263
1.5,0.12074590135326058
1.5625,0.11696317449049205
1.625,0.1174693109362527
1.6875,0.11790824450203112
1.75,0.1164068270837184
1.8125,0.11684210555907215
1.875,0.1156201537241975
1.9375,0.11538462882535137
2.0,0.11638639669674067
2.0625,0.11410085331050623
2.125,0.11496516691860562
2.1875,0.11354679413501043
2.25,0.11500764532288894
2.3125,0.11302030542398575
2.375,0.1118007374792

1