In [None]:
# Imports
import sys
sys.path.append("../")
from src import get_worst_instance
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
import pickle

# Constants
L = 1
mus = np.linspace(0.1, 0.8, 50)
GAMMAS = np.linspace(0.1, 1.9, 50)

# Shorthands for the functions
get = lambda gamma, mu, N: get_worst_instance(gamma, N, mu, L, objective="e_sum", solver="MOSEK")

In [None]:
instances1 = [[get(gamma, mu, 1) for gamma in GAMMAS] for mu in tqdm(mus)]
instances2 = [[get(gamma, mu, 2) for gamma in GAMMAS] for mu in tqdm(mus)]
instances3 = [[get(gamma, mu, 3) for gamma in GAMMAS] for mu in tqdm(mus)]
instances4 = [[get(gamma, mu, 4) for gamma in GAMMAS] for mu in tqdm(mus)]
instances_all = [instances1, instances2, instances3, instances4]

# Theoretical rate
phi = lambda gamma, mu: 1 - gamma * mu if gamma * (L + mu) <= 2 else gamma * L - 1
variance_theory = lambda gamma, mu, N: (1 - phi(gamma, mu) ** (2 * N)) / (1 - phi(gamma, mu) ** 2) * 2 * gamma ** 2 * phi(gamma, mu) / max(2 * phi(gamma, mu) - (L - mu) * gamma, 1e-3)

rel_error = lambda inst: abs(inst.param.variance.value - variance_theory(inst.gamma, inst.mu, inst.NB_ITS)) / variance_theory(inst.gamma, inst.mu, inst.NB_ITS)

errors = np.array([[[rel_error(inst) for inst in insts_mu] for insts_mu in instances] for instances in instances_all])

In [None]:
def plot_heatmap(ax, N, errors):
    extent = [GAMMAS[0], GAMMAS[-1], mus[0], mus[-1]]
    heatmap = ax.imshow(errors, cmap='Blues', interpolation='nearest', extent=extent, aspect='auto', origin='lower')
    ax.set_title(f"T={N} Iteration{"s" if N > 1 else ""}")
    ax.set_xlabel(r"Normalized Step-Size ($\gamma L$)")
    ax.set_ylabel(r"Strong Convexity Parameter ($\mu$)")
    plt.colorbar(heatmap, ax=ax)
    ax.plot(GAMMAS, [2 / gamma - L if mus[0] <= 2 / gamma - L <= mus[-1] else None for gamma in GAMMAS], "r--", label=r"Optimal Step-Size $\gamma=\frac{2}{L+\mu}$")

fig, axs = plt.subplots(2, 2, figsize=(10, 8))

plot_heatmap(axs[0,0], 1, errors[0])
plot_heatmap(axs[0,1], 2, errors[1])
plot_heatmap(axs[1,0], 3, errors[2])
plot_heatmap(axs[1,1], 4, errors[3])

handles, labels = axs.flatten()[0].get_legend_handles_labels()
plt.tight_layout()
fig.legend(handles, labels, loc='lower center', ncol=2, bbox_to_anchor=(0.5, -0.05))

plt.tight_layout()
plt.show()