In [None]:
!pip install brokenaxes
import matplotlib.pyplot as plt
from scipy.optimize import newton
from math import sqrt, log, exp
from functools import partial
import numpy as np
from brokenaxes import brokenaxes

In [None]:
def f(x, delta):
  first = -2*log(1/delta)
  second = (2*log((2*sqrt(x))/delta) -1) * exp((-1/x)*log(2*sqrt(x)))

  return first + second

deltas = [delta*1e-5 for delta in range(1,int(1e4))]

soluce = []
for delta in deltas:
  func_to_optimize = partial(f, delta=delta)
  soluce.append(newton(func_to_optimize, 5))

In [None]:
plt.rcParams.update({'font.size': 15})

fig, ax = plt.subplots(figsize=(8,5),dpi=400)

ax.plot(deltas, soluce)

# delta = 0.1
func_to_optimize = partial(f, delta=0.1)
sol = newton(func_to_optimize, 5)
ax.annotate(f"$m^*=${sol:.2f}", size=15,
                xy=(0.1, sol), xytext=(5, 5), textcoords='offset points')
ax.plot(0.1, sol, 'ro')

# delta = 0.01
func_to_optimize = partial(f, delta=0.01)
sol = newton(func_to_optimize, 5)
ax.annotate(f"$m^*=${sol:.2f}", size=15,
                xy=(0.01, sol), xytext=(5, 5), textcoords='offset points')
ax.plot(0.01, sol, 'ro')

# delta = 0.05
func_to_optimize = partial(f, delta=0.05)
sol = newton(func_to_optimize, 5)
ax.annotate(f"$m^*=${sol:.2f}", size=15,
                xy=(0.05, sol), xytext=(5, 5), textcoords='offset points')
ax.plot(0.05, sol, 'ro')

#delta = 0.001
func_to_optimize = partial(f, delta=0.001)
sol = newton(func_to_optimize, 5)
ax.annotate(f"$m^*=${sol:.2f}", size=15,
                xy=(0.001, sol), xytext=(5, 5), textcoords='offset points')
ax.plot(0.001, sol, 'ro')

#delta = 0.0001
func_to_optimize = partial(f, delta=0.0001)
sol = newton(func_to_optimize, 5)
ax.annotate(f"$m^*=${sol:.2f}", size=15,
                xy=(0.0001, sol), xytext=(5, 5), textcoords='offset points')
ax.plot(0.0001, sol, 'ro')

#delta = 0.00001
func_to_optimize = partial(f, delta=0.00001)
sol = newton(func_to_optimize, 5)
ax.annotate(f"$m^*=${sol:.2f}", size=15,
                xy=(0.00001, sol), xytext=(5, 5), textcoords='offset points')
ax.plot(0.00001, sol, 'ro')

plt.semilogx()
plt.xlabel("$\delta$")
plt.ylabel("$m^*$")
plt.xlim(1e-5-1e-6, 0.11)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(True)
ax.spines['left'].set_visible(True)
plt.savefig("fig_mstar.jpg", bbox_inches='tight')

In [None]:
def k(m, delta):
  return exp((-1/m)*log(1/delta)) - exp((-1/m)*log(2*sqrt(m)/delta))

ks = []
for i in range(len(deltas)):
  ks.append(k(soluce[i], deltas[i]))

fig, ax = plt.subplots(figsize=(8,5),dpi=400)

ax.plot(deltas, ks)

# delta = 0.1
func_to_optimize = partial(f, delta=0.1)
sol = newton(func_to_optimize, 5)
k_sol = k(sol, 0.1)
ax.annotate(f"$K(m^*,\delta)=${k_sol:.2f}", size=15,
                xy=(0.1, k_sol), xytext=(5, 0), textcoords='offset points')
ax.plot(0.1, k_sol, 'ro')

# delta = 0.01
func_to_optimize = partial(f, delta=0.01)
sol = newton(func_to_optimize, 5)
k_sol = k(sol, 0.01)
ax.annotate(f"$K(m^*,\delta)=${k_sol:.2f}", size=15,
                xy=(0.01, k_sol), xytext=(5, -6), textcoords='offset points')
ax.plot(0.01, k_sol, 'ro')

# delta = 0.05
func_to_optimize = partial(f, delta=0.05)
sol = newton(func_to_optimize, 5)
k_sol = k(sol, 0.05)
ax.annotate(f"$K(m^*,\delta)=${k_sol:.2f}", size=15,
                xy=(0.05, k_sol), xytext=(5, -5), textcoords='offset points')
ax.plot(0.05, k_sol, 'ro')

#delta = 0.001
func_to_optimize = partial(f, delta=0.001)
sol = newton(func_to_optimize, 5)
k_sol = k(sol, 0.001)
ax.annotate(f"$K(m^*,\delta)=${k_sol:.2f}", size=15,
                xy=(0.001, k_sol), xytext=(5, -10), textcoords='offset points')
ax.plot(0.001, k_sol, 'ro')

#delta = 0.0001
func_to_optimize = partial(f, delta=0.0001)
sol = newton(func_to_optimize, 5)
k_sol = k(sol, 0.0001)
ax.annotate(f"$K(m^*,\delta)=${k_sol:.2f}", size=15,
                xy=(0.0001, k_sol), xytext=(5, -10), textcoords='offset points')
ax.plot(0.0001, k_sol, 'ro')

#delta = 0.00001
func_to_optimize = partial(f, delta=0.00001)
sol = newton(func_to_optimize, 5)
k_sol = k(sol, 0.00001)
ax.annotate(f"$K(m^*,\delta)=${k_sol:.2f}", size=15,
                xy=(0.00001, k_sol), xytext=(5, -12), textcoords='offset points')
ax.plot(0.00001, k_sol, 'ro')

plt.semilogx()
plt.xlabel("$\delta$")
plt.ylabel("$K(m^*,\delta)$")
plt.xlim(1e-5-1e-6, 0.11)
plt.ylim(0.05,0.17)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_visible(True)
ax.spines['left'].set_visible(True)
plt.savefig("fig_k_n_delta.jpg", bbox_inches='tight')

In [None]:
plt.figure(figsize=(9,5),dpi=400)
# fig, ax = plt.subplots(figsize=(8,5),dpi=400)

datasets = ["MNIST56", "MNIST08", "MNIST49", "MNIST23", "MNIST17", "MNIST (50%)", "MNIST (20%)", "MNIST (10%)", "AMAZON"]
ns = [10206, 10597, 10612, 10881, 11707, 24000, 42000, 48000, 144000]

k_datasets = []
for i in range(len(datasets)):
  k_datasets.append(k(ns[i], 0.01))
k_datasets

ax = brokenaxes(xlims=((1000,60000), (120000, 150000),), ylims=((0.0, max(k_datasets)+0.00005),), width_ratios=[2, 1])
ax.scatter(ns, k_datasets, color="red")

k_all = []
for i in range(1, 160000, 100):
  k_all.append(k(i, 0.01))
ax.plot(range(1, 160000, 100), k_all)

for dataset in datasets:
  rotation = 0
  textcoords = "data"
  xtext = 20000
  if dataset == "MNIST56":
    xytext = (xtext, 0.00055)
    arrowprops = dict(arrowstyle="->", connectionstyle="arc3,rad=0.2")
  elif dataset == "MNIST08":
    xytext = (xtext, 0.00051)
    arrowprops = dict(arrowstyle="->", connectionstyle="arc3,rad=0.2")
    rotation = 0
  elif dataset == "MNIST49":
    xytext = (xtext, 0.00047)
    arrowprops=dict(arrowstyle="->", connectionstyle="arc3")
    rotation = 0
  elif dataset == "MNIST23":
    xytext = (xtext, 0.00043)
    arrowprops = dict(arrowstyle="->", connectionstyle="arc3,rad=-0.2")
    rotation = 0
  elif dataset == "MNIST17":
    xytext = (xtext, 0.00039)
    arrowprops = dict(arrowstyle="->", connectionstyle="arc3,rad=-0.2")
    rotation = 0
  else:
    xytext = (5,5)
    arrowprops = None 
    rotation = 45
    textcoords = 'offset points'
  xy = (ns[datasets.index(dataset)], k_datasets[datasets.index(dataset)])
  ax.annotate(dataset, size=15,
                xy=xy, xytext=xytext, textcoords=textcoords, rotation=rotation, arrowprops=arrowprops)
fig.canvas.draw()

# ax.axs[0].set_xticks([10206, 10597, 10612, 10881, 11707])
labels = [item.get_text() for item in ax.axs[0].get_xticklabels()]
ax.axs[0].set_xticks(ax.axs[0].get_xticks(), labels, rotation=45, ha='right')
labels = [item.get_text() for item in ax.axs[1].get_xticklabels()]
ax.axs[1].set_xticks(ax.axs[1].get_xticks(), labels, rotation=45, ha='right')
ax.set_xlabel("Size of the dataset (m)",labelpad=60)
ax.set_ylabel("$K(m,\delta)$", labelpad=60)
plt.savefig("fig_kn_datasets.jpg", bbox_inches='tight')