In [17]:

from experiment_information import *
from data import *
from helpers import *
from _version import __version__
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import logging
%load_ext autoreload
%autoreload 2

import matplotlib as mpl
mpl.rcParams['lines.linewidth'] = 1
mpl.rcParams['legend.fontsize'] = 13
mpl.rcParams['axes.titlesize'] = 15
mpl.rcParams['axes.labelsize'] = 13
mpl.rcParams['xtick.labelsize'] = 10
mpl.rcParams['ytick.labelsize'] = 10

# Get a logger
logger = logging.getLogger()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [27]:
from state_evolution import adversarial_generalization_error_logistic, adversarial_generalization_error_overlaps, adversarial_generalization_error_overlaps_test

In [19]:
data_model = VanillaGaussianDataModel(1000,logger,source_pickle_path="")

In [20]:
# Let's load some ms and qs

with DatabaseHandler(logger,"experiments/experiments.db") as dbHandler:
    df_state_evolution = dbHandler.get_state_evolutions()

In [21]:
ms = df_state_evolution["m"]
qs = df_state_evolution["q"]

In [219]:
tau = 0
rho = 1
n = 1000
epsilon = 0.5

errors = []
errors_test = []

for m, q in zip(ms, qs):

    standard = adversarial_generalization_error_logistic(m,q,rho,tau,epsilon*np.sqrt(q))

    overlaps = OverlapSet()
    overlaps.m = m
    overlaps.q = q
    overlaps.N = q
    overlaps.A = q

    task = Task(None,None,None,None,None,epsilon,None,tau,None,None,None,None,None)

    data_model.rho = rho

    new = adversarial_generalization_error_overlaps(overlaps, task, data_model)    

    test = adversarial_generalization_error_overlaps_test(overlaps, task, data_model)

    print(standard, new, test)

    squared_difference = (standard - new)**2
    # print(squared_difference)
    errors.append(squared_difference)

    squared_difference_test = (standard - test)**2
    # print(squared_difference_test)
    errors_test.append(squared_difference_test)

# print the mean squared error and the standard deviation
print("Mean squared error of standard and new")
print(np.mean(errors))
print(np.std(errors))

print("Mean squared error of standard and test")
print(np.mean(errors_test))
print(np.std(errors_test))



0.5480542559636891 0.5480542559675968 0.5480542559675969
0.5370315396530518 0.5370315396571691 0.5370315396571691
0.5480542559636891 0.5480542559675968 0.5480542559675969
0.5370315396530518 0.5370315396571691 0.5370315396571691
0.548361744558192 0.5483617445620939 0.5483617445620939
0.5370710014947846 0.5370710014989006 0.5370710014989006
0.39351746466966747 0.3935174646383093 0.3935174646383093
0.42566947554123213 0.42566947555031254 0.42566947555031265
0.42858836601062267 0.4285883660194063 0.42858836601940625
0.40515749669271245 0.40515749670469187 0.4051574967046919
0.38368947377385343 0.38368947378728685 0.3836894737872868
0.4002356214725854 0.4002356214855553 0.40023562148555525
0.40359020110723026 0.4035902011195104 0.4035902011195105
0.39021396144351883 0.390213961444217 0.39021396144421716
0.3907293484616198 0.3907293484769721 0.39072934847697205
0.38297395927623884 0.3829739592765435 0.38297395927654343
0.39365767106795896 0.3936576710405707 0.3936576710405708
0.3858809646547

In [172]:
from scipy.special import erfc
from scipy.integrate import quad
def exp_erf_integral(a,b):
    def integrand(x):
        return np.exp(-x**2)*erfc(a*x)
    return quad(integrand,0,b)[0]

In [173]:
from scipy.special import owens_t
from scipy.special import erf
def solution(a,b):
    r = -4*np.pi*np.sqrt(a**2)*owens_t(np.sqrt(2)*np.sqrt(a**2)*b,1/np.sqrt(a**2))
    r += np.pi*a*erf(b)*erfc(a*b)
    r += 2*a*np.arctan(1/a)
    r /= 2*np.sqrt(np.pi)*a
    return r

In [174]:
epsilons = np.linspace(0,1,100)
activations = -ms/np.sqrt(qs) / (np.sqrt(2 * (1- ms**2/qs)))

for eps in epsilons:
    for act in activations:
        i = exp_erf_integral(act,eps)
        s = solution(act,eps)
        if np.abs(i-s) > 1e-15:
            print("Error in integral")
            print(i,s)
            print(act,eps)

In [175]:
def custom_owen(x,a):
    def integrand(t):
        return np.exp( -x**2*(1+t**2)*0.5 )/(1+t**2)
    return quad(integrand,0,a)[0]/(2*np.pi)

In [176]:
for eps in epsilons:
    for act in activations:
        i = custom_owen(act,eps)
        s = owens_t(act,eps)
        if np.abs(i-s) > 1e-15:
            print("Error in integral")
            print(i,s)
            print(act,eps)

In [177]:
quad(lambda x: x**2/2,0,1)[0]

0.16666666666666666

In [178]:
quad(lambda x: x**2*np.sqrt(2),0,1/np.sqrt(2))[0]

0.16666666666666663

In [204]:
from helpers import gaussian
def adv_term_integral(m,q,epsilon_term, rho):


    def integrand(xi):
        
        z_0 = erfc(-m*xi / np.sqrt( 2 * ( rho*q - m**2 ) )  )

        return z_0 * gaussian(xi)


    I = quad(lambda xi: integrand(xi),0,epsilon_term,limit=500)[0]

    # return exp_erf_integral(-m/np.sqrt(2*(rho*q-m**2)),epsilon_term/np.sqrt(2))/np.sqrt(2*np.pi)

    return I

def adv_term_solution(m,q,epsilon_term,rho):

    # return exp_erf_integral(-m/np.sqrt((rho*q-m**2)),epsilon_term/np.sqrt(2))/np.sqrt(np.pi)

    a = -m / np.sqrt((q* rho - m**2))
    b = epsilon_term/np.sqrt(2)
    
    return solution(a,b)/np.sqrt(np.pi)

In [205]:
for eps in epsilons:
    for m,q in zip(ms,qs):
        i = adv_term_integral(m,q,eps,rho)
        s = adv_term_solution(m,q,eps,rho)
        if np.abs(i-s) > 1e-15:
            print("Error in integral")
            print(i,s)
            print(m,q,eps)

In [224]:
for m, q in zip(ms, qs):
    angle = np.arccos(m/np.sqrt(q))/np.pi

    alternative = np.arctan( m/np.sqrt(q - m**2) )/np.pi

    print(angle, alternative)

0.3354700036587427 0.16452999634125726
0.32241213119550816 0.17758786880449184
0.3354700036587427 0.16452999634125726
0.32241213119550816 0.17758786880449184
0.3358327561412239 0.16416724385877607
0.3224590771222636 0.17754092287773643
0.11526710256302988 0.3847328974369701
0.17701585999376337 0.32298414000623665
0.18154146499135848 0.3184585350086415
0.14156438256887668 0.35843561743112334
0.0713085253956467 0.42869147460435325
0.13141623971557945 0.3685837602844206
0.13844110505474552 0.36155889494525456
0.10551373624876552 0.3944862637512346
0.10716218330137016 0.39283781669862977
0.05307986465496306 0.4469201353450369
0.11564552650814602 0.38435447349185403
0.08800887116436455 0.4119911288356356
