In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def binKL(p, q):
    return p * np.log(p/q) + (1-p) * np.log((1-p)/(1-q))

In [None]:
def classicLB(x, b):
    return 2*(b - x)**2

def quadLB(x, b):
    return ((b - x)**2)  / (2*b)

def newLB(x, b):
    #return  1.5*((b-x)**2 - b**2*x**2)/(2*b*(1-x))
    return  ((b-x)**2)/(2*b*(1-x))
#(1-x)*(b + b**2/2 + b**3/3 - x - x**2/2 - x**3/3) + x*(np.log(x/b))

def reversedTS(x, b): #tolstikhin-seldin
    #return  (np.sqrt(b/2 - 3*x/8) - np.sqrt(x/8))**2
    return  b/2 - x/4 - np.sqrt(b*x/4 - 3 * x**2 / 16)

def reversedRTS(x, b): #tolstikhin-seldin
    #return  (np.sqrt(b/2 - 3*x/8) - np.sqrt(x/8))**2
    return  b - np.sqrt(2*b*x - x**2)

In [None]:
# Set matplotlib parameters for better plots

# Use a colorblind-friendly palette suitable for ML publications
colors = ['#0072B2', '#D55E00', '#009E73', '#CC79A7', '#F0E442', '#56B4E9']

# Matplotlib default colors:
colors = ['#1f77b4', '#ff7f0e', '#2ca02c', '#d62728', '#9467bd', "#1C400B", '#e377c2', '#7f7f7f', '#bcbd22', '#17becf']


# Set color cycle
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=colors)

# Make lines thicker
plt.rcParams.update({'lines.linewidth': 3})

# Make text larger
plt.rcParams.update({'font.size': 12})

# Make axis label larger
plt.rcParams.update({'axes.labelsize': 16})
plt.rcParams.update({'axes.titlesize': 16})

# Make tick labels larger
plt.rcParams.update({'xtick.labelsize': 14})
plt.rcParams.update({'ytick.labelsize': 14})

# Add grid lines to plots
plt.rcParams.update({'axes.grid': True, 'grid.alpha': 0.5, 'grid.linestyle': '--'})



In [None]:
x = 0.02
bs = np.arange(x,0.04,0.00001)

plt.plot(bs, [classicLB(x,b) for b in bs],label="Classic Pinsker (3)")
plt.plot(bs, [reversedTS(x,b) for b in bs],label="Inverted TS (5)")
plt.plot(bs, [newLB(x,b) for b in bs],label="Tighter R. Pinsker (6)")
plt.plot(bs, [reversedRTS(x,b) for b in bs],label="Refined TS bound (7)")
plt.plot(bs, [quadLB(x,b) for b in bs],label="Refined Pinsker (4)")
plt.plot(bs, [binKL(x,b) for b in bs],label="Binary KL")

plt.legend()
plt.title('$\hat{L}_S(Q)=0.02$')
plt.xlabel('$L(Q)$')
plt.ylabel('Lower bound on kl$(\hat{L}_S(Q)||L(Q))$')
#plt.title('Values for x=0.02, 0.02<b<0.03')
plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
plt.ticklabel_format(style='sci', axis='x', scilimits=(0,0))
plt.show()

In [None]:
x = 0.2
bs = np.arange(x,0.4,0.00001)

plt.plot(bs, [classicLB(x,b) for b in bs],label="Classic Pinsker (3)")
plt.plot(bs, [reversedTS(x,b) for b in bs],label="Inverted TS (5)")
plt.plot(bs, [newLB(x,b) for b in bs],label="Tighter R. Pinsker (6)")
plt.plot(bs, [reversedRTS(x,b) for b in bs],label="Refined TS bound (7)")
plt.plot(bs, [quadLB(x,b) for b in bs],label="Refined Pinsker (4)")
plt.plot(bs, [binKL(x,b) for b in bs],label="Binary KL")


plt.legend()
plt.title('$\hat{L}_S(Q)=0.2$')
plt.xlabel('$L(Q)$')
plt.ylabel('Lower bound on kl$(\hat{L}_S(Q)||L(Q))$')
#plt.title('Values for x=0.02, 0.02<b<0.03')
plt.ticklabel_format(style='sci', axis='y', scilimits=(0,0))
plt.show()

In [None]:
def KLinv(x, k, d=1e-10):
    # sup{b in [x, 1] | kl(x||b)<k}
    
    #assert x<1
    b0 = x
    b1 = 1
    while b1 - b0 > d:
        tmp = (b0 + b1)/2
        if binKL(x, tmp) < k:
            b0 = tmp
        else:
            b1 = tmp
            
    return b0

def TSbound(x, k):
    return x + np.sqrt(2*x*k) + 2*k

# b - 3*x/8 - np.sqrt(b*x/2 - x**2/4) > k
def RefinedTS(x, k):
    return x + np.sqrt(2*x*k) + k

def ClassicInv(x, k):
    return x + np.sqrt(k/2)
    
def QuadInv(x, k):
    return (np.sqrt(x + k/2) + np.sqrt(k/2))**2

def RefinedQuad(x, k):
    return QuadInv(x, k) - k/2

def NewInv(x, k):
    #x + (1-x)k + np.sqrt(2x(1-x)k + k**2 * (1-x)**2)
    return (np.sqrt(x + k*(1-x)/2) + np.sqrt(k*(1-x)/2))**2

# def RefinedNew(x, k):
#     return NewInv(x, k) - k/2

In [None]:
#small loss
x = 0.02
ks = np.arange(0,0.4,0.00001)

plt.plot(ks, [ClassicInv(x,k) for k in ks],label="Inv. Classic Pinsker (3)")
plt.plot(ks, [TSbound(x,k) for k in ks],label="Tolstikhin-Seldin (5)")
plt.plot(ks, [NewInv(x,k) for k in ks],label="Inv. Tighter R. Pinsker (6)")
plt.plot(ks, [RefinedTS(x,k) for k in ks],label="Refined TS bound (7)")
plt.plot(ks, [QuadInv(x,k) for k in ks],label="Inv. Refined Pinsker (4)")
plt.plot(ks, [KLinv(x,k) for k in ks],label="kl$^{-1}$")

plt.legend()
plt.title('$\hat{L}_S(Q)=0.02$')
plt.ylabel('Upper bound on $L(Q)$')
plt.xlabel('$K$')
plt.show()

In [None]:
#larger loss 
x = 0.2
ks = np.arange(0,0.4,0.00001)

plt.plot(ks, [ClassicInv(x,k) for k in ks],label="Inv. Classic Pinsker (3)")
plt.plot(ks, [TSbound(x,k) for k in ks],label="Tolstikhin-Seldin (5)")
plt.plot(ks, [NewInv(x,k) for k in ks],label="Inv. Tighter R. Pinsker (6)")
plt.plot(ks, [RefinedTS(x,k) for k in ks],label="Refined TS bound (7)")
plt.plot(ks, [QuadInv(x,k) for k in ks],label="Inv. Refined Pinsker (4)")
plt.plot(ks, [KLinv(x,k) for k in ks],label="kl$^{-1}$")


# Make axis label larger
plt.xlabel('$K$', fontsize=18)
plt.ylabel('Upper bound on $L(Q)$', fontsize=16)

plt.legend()
plt.title('$\hat{L}_S(Q)=0.2$')
plt.show()

In [None]:
# Region comparison

xs = []
ks = []
cs = []
dct={
    0: 'Classic Pinsker (3)',
    1: 'TS (5)',
    2: 'Tighter R. Pinsker (6)',
    3: 'Refined TS bound (7)',
    4: 'R. Pinsker (4)',
}
for x in np.linspace(0.01,1,1000):
    for k in np.linspace(0.01,1,1000):
        classic = ClassicInv(x,k)
        new = NewInv(x,k)
        ts = TSbound(x,k)
        rts = RefinedTS(x,k)
        quad = QuadInv(x,k)
        xs.append(x)
        ks.append(k)
        cs.append(np.argmin([classic, ts, new, rts, quad]))
        
for c in np.unique(np.array(cs)):
    ix = np.where(np.array(cs) == c)
    plt.scatter(np.array(xs)[ix], np.array(ks)[ix], color=colors[c], s = 0.05)
    plt.scatter([], [], color=colors[c], label=dct[c], marker='s', s=100)
    
leg = plt.legend()
# for h in leg.legendHandles:
#     h._sizes = [30]

plt.xlabel('$\hat{L}_S(Q)$')
plt.ylabel('$K$')

plt.show()

In [None]:
xs = []
ks = []
cs = []
dct={
    0:'Classic Pinsker (3)',
    1:'Tolstikhin-Seldin (5)'
}
for x in np.linspace(0.01,1,1000):
    for k in np.linspace(0.01,1,1000):
        classic = ClassicInv(x,k)
        ts = TSbound(x,k)
        xs.append(x)
        ks.append(k)
        cs.append(np.argmin([classic, ts]))
        
for c in np.unique(np.array(cs)):
    ix = np.where(np.array(cs) == c)
    plt.scatter(np.array(xs)[ix], np.array(ks)[ix], color=colors[c], s = 0.05)
    plt.scatter([], [], color=colors[c], label=dct[c], marker='s', s=100)

#plt.scatter(xs, ks, c=cs, s=0.4)

leg = plt.legend()


plt.xlabel('$\hat{L}_S(Q)$')
plt.ylabel('$K$')
#plt.yscale('log')
plt.show()

### MNIST RESULTS

In [None]:
#df = pd.concat([pd.read_csv('experiment_output.csv')]).reset_index(drop=True)
df = pd.concat([pd.read_csv('experiment_output_mnist-multiseed.csv')]).reset_index(drop=True)

df = df.loc[:, (df != df.iloc[0]).any()]  #quitamos columnas ctes.
#df = df[df['mc_samples']==15000]
#df = df[df['kl_penalty']==1]
display(df.sort_values(['Risk_01']))#[['objective', 'sigma_prior', 'Risk_CE', 'Risk_01', 'KL']])

In [None]:
plt.scatter(df[df['objective']=='fquad']['Stch loss'], df[df['objective']=='fquad']['Risk_CE'], s=30, alpha = 0.5, color=colors[2])
#plt.scatter(df[df['objective']=='fnew']['Stch loss'], df[df['objective']=='fnew']['Risk_CE'], s=30, alpha = 0.5)
plt.scatter(df[df['objective']=='f_rts']['Stch loss'], df[df['objective']=='f_rts']['Risk_CE'], s=30, alpha = 0.5, color=colors[3])
plt.scatter(df[df['objective']=='fgrad']['Stch loss'], df[df['objective']=='fgrad']['Risk_CE'], s=30, alpha = 0.5, color=colors[5])

plt.xlabel('$\hat{L}_S^{xe}(Q)$')
plt.ylabel('Risk certificate on $L^{xe}(Q)$')
plt.legend(['Refined Pinsker', 'Refined TS bound', 'Maurer bound'])  #, 'Tighter R. Pinsker'
#plt.title('01 error and certificate for some experiments')
plt.show()

In [None]:
plt.scatter(df[df['objective']=='fquad']['Stch 01 error'], df[df['objective']=='fquad']['Risk_01'], s=30, alpha = 0.5, color=colors[2])
#plt.scatter(df[df['objective']=='fnew']['Stch 01 error'], df[df['objective']=='fnew']['Risk_01'], s=30, alpha = 0.5)
plt.scatter(df[df['objective']=='f_rts']['Stch 01 error'], df[df['objective']=='f_rts']['Risk_01'], s=30, alpha = 0.5, color=colors[3])
plt.scatter(df[df['objective']=='fgrad']['Stch 01 error'], df[df['objective']=='fgrad']['Risk_01'], s=30, alpha = 0.5, color=colors[5])
plt.xlabel('$\hat{L}_S^{01}(Q)$')
plt.ylabel('Risk certificate on $L^{01}(Q)$')
plt.legend(['Refined Pinsker', 'Refined TS bound', 'Maurer bound'])  #, 'Tighter R. Pinsker'
#plt.title('01 error and certificate for some experiments')
plt.show()

In [None]:
plt.scatter(df[df['objective']=='fquad_acc']['Stch 01 error'], df[df['objective']=='fquad_acc']['Risk_01'], s=30, alpha = 0.5, color=colors[2])
#plt.scatter(df[df['objective']=='fnew_acc']['Stch 01 error'], df[df['objective']=='fnew_acc']['Risk_01'], s=30, alpha = 0.5)
plt.scatter(df[df['objective']=='f_rts_acc']['Stch 01 error'], df[df['objective']=='f_rts_acc']['Risk_01'], s=30, alpha = 0.5, color=colors[3])
plt.scatter(df[df['objective']=='fgrad_acc']['Stch 01 error'], df[df['objective']=='fgrad_acc']['Risk_01'], s=30, alpha = 0.5, color=colors[5])
plt.xlabel('$\hat{L}_S^{01}(Q)$')
plt.ylabel('Risk certificate on $L^{01}(Q)$')
plt.legend(['Refined Pinsker', 'Refined TS bound', 'Maurer bound'])  #, 'Tighter R. Pinsker'
#plt.title('01 error and certificate for some experiments')
plt.show()

In [None]:
_df = pd.concat([pd.read_csv('experiment_output_mnist-multiseed.csv'),
                pd.read_csv('experiment_output_cifar10-multiseed.csv'),
                ]).reset_index(drop=True)

# compare empirical loss and empirical accuracy 
# _df = df[df['Stch loss'] < 0.15]
# _df = _df[_df['kl_penalty'] == 1]
# _df = _df[_df['Sigma'] == 0.04] 

_df_mnist = pd.read_csv('experiment_output_mnist-multiseed.csv')
_df_cifar = pd.read_csv('experiment_output_cifar10-multiseed.csv')



# Plot linear regression of _df
m, b = np.polyfit(_df['Stch loss'], _df['Stch 01 error'], 1)
plt.scatter(_df_mnist['Stch loss'], _df_mnist['Stch 01 error'], s=15, alpha = 0.5)
plt.scatter(_df_cifar['Stch loss'], _df_cifar['Stch 01 error'], s=15, alpha = 0.5)

plt.plot(_df['Stch loss'], m*_df['Stch loss'] + b, color='black', label='Linear regression', alpha=0.5)

plt.xlabel('$\hat{L}_S^{xe}(Q)$')
plt.ylabel('$\hat{L}_S^{01}(Q)$')
plt.legend(['MNIST', 'CIFAR10'])
plt.plot()
plt.show()


### CIFAR10


In [None]:
# df = pd.concat([pd.read_csv('experiment_output_cifar10-4l-150k.csv'),
#                 pd.read_csv('experiment_output_cifar10-4l-small-150k.csv'),
#                 pd.read_csv('experiment_output_cifar10-5l-150k.csv'),
#                 pd.read_csv('experiment_output_cifar10-5l-small-150k.csv'),
#                 ]).reset_index(drop=True)


# print(df.columns)

# df = df.loc[:, (df != df.iloc[0]).any()]  #quitamos columnas ctes.
# display(df.sort_values(['Risk_01']))#[['objective', 'sigma_prior', 'Risk_CE', 'Risk_01', 'KL']])

In [None]:
# plt.scatter(df[df['objective']=='fquad_acc']['Stch 01 error'], df[df['objective']=='fquad_acc']['Risk_01'], s=30, alpha = 0.5)
# #plt.scatter(df[df['objective']=='fnew_acc']['Stch 01 error'], df[df['objective']=='fnew_acc']['Risk_01'], s=30, alpha = 0.5)
# plt.scatter(df[df['objective']=='f_rts_acc']['Stch 01 error'], df[df['objective']=='f_rts_acc']['Risk_01'], s=30, alpha = 0.5)
# plt.scatter(df[df['objective']=='fgrad_acc']['Stch 01 error'], df[df['objective']=='fgrad_acc']['Risk_01'], s=30, alpha = 0.5)
# plt.xlabel('$\hat{L}_S^{01}(Q)$')
# plt.ylabel('Risk certificate on $L^{01}(Q)$')
# plt.legend(['Refined Pinsker', 'Refined TS bound', 'Maurer bound'])  #, 'Tighter R. Pinsker'
# #plt.title('01 error and certificate for some experiments')
# plt.show()

# _df = df[df['Stch 01 error'] < 0.5] 
# plt.scatter(_df[_df['objective']=='fquad_acc']['Stch 01 error'], _df[_df['objective']=='fquad_acc']['Risk_01'], s=30, alpha = 0.5)
# #plt.scatter(_df[_df['objective']=='fnew_acc']['Stch 01 error'], _df[_df['objective']=='fnew_acc']['Risk_01'], s=30, alpha = 0.5)
# plt.scatter(_df[_df['objective']=='f_rts_acc']['Stch 01 error'], _df[_df['objective']=='f_rts_acc']['Risk_01'], s=30, alpha = 0.5)
# plt.scatter(_df[_df['objective']=='fgrad_acc']['Stch 01 error'], _df[_df['objective']=='fgrad_acc']['Risk_01'], s=30, alpha = 0.5)
# plt.xlabel('$\hat{L}_S^{01}(Q)$')
# plt.ylabel('Risk certificate on $L^{01}(Q)$')
# plt.legend(['$\\tilde{f}_{rp}$: Refined Pinsker', '$\\tilde{f}_{rts}$: Refined TS bound', '$\\tilde{f}_{mb}$: Maurer bound'])  #, 'Tighter R. Pinsker'
# #plt.title('01 error and certificate for some experiments - zoom in')
# plt.show()