In [None]:
from mg import MixtureGaussian,XORGaussian
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
mg = MixtureGaussian([3,0],[10,3],5* np.eye(2),pi_0=1/4)


In [None]:
data = mg.generate(1e4)
n_grid=1000
line = np.linspace(-12,12,n_grid)
X,Y = np.meshgrid(line,line)
Z=mg._prob(np.vstack([X.flatten(),Y.flatten()]).T).reshape(n_grid,n_grid)
plt.contourf(X,Y,Z,cmap='plasma')
plt.colorbar()
sns.scatterplot(data=data, x=0, y=1, hue='target', style='group',palette='bright',alpha=0.3)

In [None]:
from models import LogisticRegression as gdLR
lr = gdLR(loss='square',device=None).fit(data.drop(['target','group'],axis=1).values, data['target'].values,lr=1e-6, batch_size=512)

In [None]:
def MSE(y_pred: np.ndarray, y_true:np.ndarray) -> float:
    return ((y_pred - y_true)**2).mean()

In [None]:
sigma_results = pd.DataFrame(columns=['mse','delta_A','sigma','seed','mmse','lambda'])
for seed in range(10):
    mg.rng = np.random.default_rng(seed)
    for sigma in np.linspace(1,10, 50):
        l=None
        print(sigma)
        mg.sigma = sigma* np.eye(2)
        data = mg.generate(1e4)
        LR = gdLR(loss='square',device='cpu').fit(data.drop(['target','group'],axis=1).values, data['target'].values,lr=1e-6,batch_size=256)
        #w_ERM = LR.coef_
        #b_ERM = LR.intercept_
        #w_ERM, b_ERM = train_MMSE(data, l)
        #lin_mse = MSE(pred(data.drop(['target','group'],axis=1),w_ERM , b_ERM), data['target'])
        lin_mse = MSE(LR.predict_proba(data.drop(['target','group'],axis=1).values), data['target'].values)
        test = mg.generate(1e6)
        delta_a = np.power(LR.predict_proba(test.drop(['target','group'],axis=1).values) - mg._prob(test.drop(['target','group'],axis=1).values),2).mean()
        sigma_results.loc[len(sigma_results)] = ( lin_mse,delta_a, sigma, seed, mg.mmse_estimate(1e6),l)

In [None]:
subset = sigma_results
sns.lineplot(data=subset, x='sigma',y=subset['mse'] - subset['delta_A'] - np.sqrt(np.log(1/0.05)/2e4), label='MSE - $\epsilon$')
sns.lineplot(data=subset, x='sigma',y='mse', label='MSE')
sns.lineplot(data=subset, x='sigma',y='delta_A', label='$\Delta_A$')
sns.lineplot(data=sigma_results, x='sigma',y=sigma_results['mmse'] , label='MMSE')
plt.ylabel('MSE')
plt.legend()