<a href="https://colab.research.google.com/github/Tikquuss/grokking_beyong_l2_norm/blob/main/algorithmic_dataset_MLP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# ...

In [None]:
# !git clone https://github.com/Tikquuss/grokking_beyong_l2_norm
# %cd grokking_beyong_l2_norm
# # #! ls
# ! pip install -r requirements.txt
LOG_DIR="/content/LOGS"

In [None]:
import os
os.makedirs(LOG_DIR, exist_ok=True)

In [None]:
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
import os

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f'using device: {device}')

In [None]:
from neural_nets.mlp import MLP, Encoder_Decoder
from neural_nets.data import split_and_create_data_loader
from neural_nets.trainer import get_loss, eval_model_classification, run_experiments
from neural_nets.checkpointing import get_all_checkpoints
from utils.norms import l0_norm_model, l_p_norm_model, nuclear_norm_model
from plotters.utils import plot_loss_accs
from sparse_recovery.utils import find_memorization_generalization_steps, find_stable_step_final_value, plot_t1_t2
from plotters.utils import get_twin_axis, FIGSIZE, LINEWIDTH, FIGSIZE_SMALL, FIGSIZE_LARGE, FIGSIZE_MEDIUM, FONTSIZE, LABEL_FONTSIZE, TICK_LABEL_FONTSIZE, MARKERSIZE

# Data

In [None]:
import itertools
import torch.nn.functional as F

In [None]:
p=97
operator = "+" # "+", "-", "*", etc
r_train=0.4
batch_size=2**11
eval_batch_size=2**13

data = list(itertools.product(range(p), range(p)))
X, Y = [], []
for x1, x2 in data:
    #x = torch.cat([F.one_hot(torch.tensor(x1), num_classes=p), F.one_hot(torch.tensor(x2), num_classes=p)]).float() # (2*p,)
    x = torch.stack([F.one_hot(torch.tensor(x1), num_classes=p), F.one_hot(torch.tensor(x2), num_classes=p)]).float() # (2, p)
    X.append(x)
    Y.append(eval(f"({x1} {operator} {x2}) % {p}"))
X, Y = torch.stack(X), torch.tensor(Y) # (p^2, 2, p), (p^2,)

train_loader, train_loader_for_eval, test_loader = split_and_create_data_loader(
    X, Y, r_train=r_train, batch_size=batch_size, eval_batch_size=eval_batch_size, random_state=0, balance=False)

In [None]:
rows, cols = 1, 3
figsize=(6, 4)
figsize=(8, 6)
# figsize=(15, 10)
figsize=(cols*figsize[0], rows*figsize[1])
fig = plt.figure(figsize=figsize)

ax = fig.add_subplot(rows, cols, 1)
ax.bar(range(p), [(Y==k).sum() for k in range(p)])
ax.set_title(f'Class distribution in data', fontsize=26)

ax = fig.add_subplot(rows, cols, 2)
y_train = train_loader_for_eval.dataset.tensors[1].cpu().numpy() # (N,)
ax.bar(range(p), [(y_train==k).sum() for k in range(p)])
ax.set_title(f'Class distribution in train data', fontsize=22)

ax = fig.add_subplot(rows, cols, 3)
y_test = test_loader.dataset.tensors[1].cpu().numpy() # (N,)
ax.bar(range(p), [(y_test==k).sum() for k in range(p)])
ax.set_title(f'Class distribution in test data', fontsize=22)

plt.show()

# $\beta  h(\theta)$

In [None]:
aggregation_mode = 'matrix_product' # 'sum', 'concat', 'matrix_product', 'hadamard_product'
embedding_dim = 2**6
num_hidden_layers_mlp = 1 # int : number of hidden layer for the mlp (0 for linear model, ...)
width_multiplier_mlp = 1. # float : the embedding dimension is multiplied by this number to have the hidden dimension
widths_encoder = [p, embedding_dim] # embedding layer
widths_decoder = [embedding_dim] + [int(embedding_dim*width_multiplier_mlp)]*num_hidden_layers_mlp + [p] # widths of each hidden layer

In [None]:
args = {}
args['fileName'] = "mlp_algorithmic_dataset"
args['exp_dir'] = f"{LOG_DIR}/{args['fileName']}"
os.makedirs(args['exp_dir'], exist_ok=True)

args = {}
args['fileName'] = "2layers_nn"
args['exp_dir'] = f"{LOG_DIR}/{args['fileName']}"
os.makedirs(args['exp_dir'], exist_ok=True)

########################################################################################
########################################################################################

args["device"] = device
args['train_loader'], args['train_loader_for_eval'], args['test_loader'] = train_loader, train_loader_for_eval, test_loader
args['verbose'] = True

########################################################################################
########################################################################################

model = Encoder_Decoder(
        aggregation_mode,
        widths_encoder,
        widths_decoder,
        activation_class_encoder=None,
        activation_class_decoder=nn.ReLU,
        bias_encoder=False,
        bias_decoder=False,
        bias_classifier=False,
        init_params=True,
        type_init='normal',
        seed=None)

args['model'] = model
print(model)


learning_rate = 5e-3
args["optimizer"] = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
args["criterion"] = nn.CrossEntropyLoss()
#args['get_loss'] = get_loss
args["eval_model"] = eval_model_classification
args['r_train'] = r_train
args["get_exp_name_function"] = lambda args : f"id={args['exp_id']}-p={p}-rtrain={args['r_train']}"

########################################################################################
########################################################################################

args.update({
    "n_epochs" : 10**4,
    "eval_first": 10**2 * 1,
    "eval_period": 10**1 * 1,
    "print_step": 10**2 * 1,
    "save_model_step":10**3 * 1,
    "save_statistic_step":10**3 * 1,
    "verbose": True,
})

########################################################################################
########################################################################################

# l1, l2, l*
args['beta_dic'] = {1 : 0.0, 2 : 1e-6, "nuc" : 0.0} # {p : beta_p}

#args['get_loss'] = get_loss
def get_get_loss(beta_dic):
    def get_loss_func(model, batch_x, batch_y, criterion) :
        loss, scores = get_loss(model, batch_x, batch_y, criterion)
        #loss = torch.norm(scores.squeeze() - batch_y.squeeze())**2

        # sum of beta * h(Theta)
        for name, param in model.named_parameters():
            if 'weight' in name and param.requires_grad:  # Target weight tensors only
                for p, beta_p in beta_dic.items():
                    if beta_p!=0: loss = loss + beta_p * torch.norm(param, p=p)

        return loss, scores

    return get_loss_func

args['get_loss'] = get_get_loss(beta_dic=args['beta_dic'])

########################################################################################
########################################################################################

args['get_other_metrics']=None
def get_other_metrics(model, X, Y, Y_hat, loss):
    r = {}
    with torch.no_grad():
        r["l0_norm"] = l0_norm_model(model, threshold=1e-4, proportion=False, only_weights=True, requires_grad=False)
        r["l1_norm"] = l_p_norm_model(model, p=1, only_weights=True, requires_grad=False, concat_first=True).item()
        r["l2_norm"] = l_p_norm_model(model, p=2, only_weights=True, requires_grad=False, concat_first=True).item()
        r["l*_norm"] = nuclear_norm_model(model, only_weights=True, requires_grad=False).item()
    return r
args['get_other_metrics'] = get_other_metrics


########################################################################################
########################################################################################

args['exp_id'] = None
args['seed'] = 42

args["n_epochs"] = 10**4 * 1 + 1 #
args["verbose"] = True

args, model, all_metrics = run_experiments(args)

In [None]:
# all_models, all_metrics = get_all_checkpoints(checkpoint_path=args['checkpoint_path'], exp_name=args['fileName'], just_files=False)

In [None]:
plot_loss_accs(
    all_metrics,
    train_test_metrics_names = ["accuracy", "loss"],
    other_metrics_names = ["l0_norm", "l1_norm", "l2_norm", "l*_norm"],
    multiple_runs=False, log_x=True, log_y=False,
    figsize=FIGSIZE, linewidth=LINEWIDTH, fontsize=FONTSIZE,
    fileName=None, filePath=None, show=True)

# Scaling wrt $\alpha \beta$

In [None]:
all_beta_dic = {
    1 : [5e-9, 1e-8, 1e-7, 1e-6, 5e-6, 1e-5],
    2 : [1e-7, 5e-7, 1e-6, 5e-6, 1e-5, 5e-5],
    'nuc' : [5e-8, 1e-7, 1e-6, 1e-5, 5e-5, 1e-4]
}

all_alpha = [1e-3, 1e-2, 1e-1]
all_alpha = sorted(all_alpha)

In [None]:
p=1 # l_p regularization : 1, 2, 'nuc'
all_beta = all_beta_dic[p]
n_epochs = 10**4 * 1 + 1

In [None]:
to_plot_and_label_dic = {"loss":"Loss", "error":"Error", "accuracy":"Accuracy"}
to_plot="accuracy"
to_plot_label = to_plot_and_label_dic[to_plot]

**Train**

In [None]:
for i, alpha in enumerate(all_alpha):

    args['fileName'] = f"mlp_algorithmic_dataset_l{p}_alpha={alpha}"
    args['exp_dir'] = f"{LOG_DIR}/{args['fileName']}"
    os.makedirs(args['exp_dir'], exist_ok=True)

    for j, beta in enumerate(all_beta) :
        print(f"alpha = {alpha}, {(i+1)}/{len(all_alpha)}, beta_{p}={beta}, {(j+1)}/{len(all_beta)}")

        args['beta_dic'] = {1 : 0.0, 2 : 0.0, "nuc" : 0.0} # {p : beta_p}
        args['beta_dic'][p] = beta
        args['get_loss'] = get_get_loss(beta_dic=args['beta_dic'])

        args['exp_id'] = j
        args['seed'] = 42

        args["n_epochs"] = n_epochs
        args["verbose"] = False

        model = Encoder_Decoder(
            aggregation_mode,
            widths_encoder,
            widths_decoder,
            activation_class_encoder=None,
            activation_class_decoder=nn.ReLU,
            bias_encoder=False,
            bias_decoder=False,
            bias_classifier=False,
            init_params=True,
            type_init='normal',
            seed=None
        )

        args['model'] = model
        learning_rate = alpha
        args["optimizer"] = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=0)
        args["criterion"] = nn.CrossEntropyLoss()
        args, model, all_metrics = run_experiments(args)

**Load stats**

In [None]:
metrics_names = ['train_loss', 'test_loss', 'train_accuracy', 'test_accuracy', 'all_models', 'all_steps', 'l0_norm', 'l1_norm', 'l2_norm', 'l*_norm']
all_statistics = {key : {} for key in metrics_names  }

for i, alpha in enumerate(all_alpha):

    args['fileName'] = f"mlp_algorithmic_dataset_l{p}_alpha={alpha}"
    args['exp_dir'] = f"{LOG_DIR}/{args['fileName']}"

    for key in metrics_names :
        all_statistics[key][alpha] = {}

    for j, beta in enumerate(all_beta) :
        print(f"alpha = {alpha}, {(i+1)}/{len(all_alpha)}, beta_{p}={beta}, {(j+1)}/{len(all_beta)}")

        args['exp_id'] = j
        exp_name = args['get_exp_name_function'](args)
        args['checkpoint_path'] = os.path.join(args['exp_dir'], exp_name)

        all_models, statistics = get_all_checkpoints(checkpoint_path=args['checkpoint_path'], exp_name=args['fileName'], just_files=True)

        all_statistics['train_loss'][alpha][beta] = statistics['train']['loss']
        all_statistics['test_loss'][alpha][beta] = statistics['test']['loss']

        all_statistics['train_accuracy'][alpha][beta] = statistics['train']['accuracy']
        all_statistics['test_accuracy'][alpha][beta] = statistics['test']['accuracy']

        all_statistics['all_models'][alpha][beta] = all_models
        for key in ['all_steps', 'l0_norm', 'l1_norm', 'l2_norm', 'l*_norm']:
            all_statistics[key][alpha][beta] = statistics[key]

## Figure 30

In [None]:
L=len(all_alpha)
cols = min(3, L)
rows = L // cols + 1 * (L % cols != 0)

figsize=FIGSIZE_SMALL
figsize=(cols*figsize[0], rows*figsize[1])
fig = plt.figure(figsize=figsize)

log_x=False
log_y=False

color_indices = np.linspace(0, 1, len(all_beta)+1*0)
colors = plt.cm.viridis(color_indices)

for i, alpha in enumerate(all_alpha):

    ax = fig.add_subplot(rows, cols, i+1)
    _, ax, _ = get_twin_axis(ax=ax, no_twin=True)
    #_, ax, ax1 = get_twin_axis(ax=ax, no_twin=False)

    ax.set_title(f'$\\alpha={alpha}$', fontsize=LABEL_FONTSIZE)

    for j, beta in enumerate(all_beta) :

        all_steps = all_statistics['all_steps'][alpha][beta]
        if log_x : all_steps = np.array(all_steps) + 1

        if to_plot == "loss" :
            test_errors, train_errors = all_statistics['test_loss'][alpha][beta], all_statistics['train_loss'][alpha][beta]
        elif to_plot == "error" :
            test_errors, train_errors = 1-np.array(all_statistics['test_accuracy'][alpha][beta]), 1-np.array(all_statistics['train_accuracy'][alpha][beta])
        elif to_plot == "accuracy" :
            test_errors, train_errors = all_statistics['test_accuracy'][alpha][beta], all_statistics['train_accuracy'][alpha][beta]

        ax.plot(all_steps, test_errors, '--', color=colors[j], linewidth=LINEWIDTH)
        ax.plot(all_steps, train_errors, '-', label=f'$\\beta={beta}$', color=colors[j], linewidth=LINEWIDTH)

        # Plot times
        if to_plot == "loss" :
            test_errors, train_errors = all_statistics['test_loss'][alpha][beta], all_statistics['train_loss'][alpha][beta]
        elif to_plot == "error" or to_plot == "accuracy" :
            test_errors, train_errors = 1-np.array(all_statistics['test_accuracy'][alpha][beta]), 1-np.array(all_statistics['train_accuracy'][alpha][beta])
        # elif to_plot == "accuracy" :
        #     test_errors, train_errors = all_statistics['test_accuracy'][alpha][beta], all_statistics['train_accuracy'][alpha][beta]
        # t_2, t_2_index = find_stable_step_final_value(all_steps, test_errors, K=3, tolerance_fraction=0.05, M=2)
        t_1, t_2 = find_memorization_generalization_steps(train_errors, test_errors, all_steps, train_threshold=min(train_errors), test_threshold=min(test_errors))
        #plot_t1_t2(ax, t_1, t_2, log_x, log_y, plot_Delta=True)
        t = t_2
        if t is not None :
            ax.axvline(x=t, ymin=0.01, ymax=1., color=colors[j], linestyle='--', lw=1.)
            ax.plot([t, t], [0, 0], 'o', color='b')

    if (rows-1)*cols <= i < rows*cols : ax.set_xlabel('Steps (t)', fontsize=LABEL_FONTSIZE)
    if i%cols==0 : ax.set_ylabel(to_plot_label, fontsize=LABEL_FONTSIZE)
    ax.tick_params(axis='both', labelsize=TICK_LABEL_FONTSIZE)

    ########### Color bar
    sm = plt.cm.ScalarMappable(cmap='viridis', norm=plt.Normalize(vmin=min(all_beta), vmax=max(all_beta)))
    import matplotlib.colors as mcolors
    sm = plt.cm.ScalarMappable(cmap='viridis', norm=mcolors.LogNorm(vmin=min(all_beta), vmax=max(all_beta)))
    sm.set_array([])  # We only need the colormap here, no actual data
    cbar = plt.colorbar(sm, ax=ax)
    if i==cols-1: cbar.set_label(f'$\\beta$', fontsize=LABEL_FONTSIZE)
    # # Set the ticks to correspond to the values in `all_beta_1`
    cbar.set_ticks(all_beta)  # Sets tick positions based on `all_beta`
    # cbar.set_ticklabels([str(beta) for beta in all_beta])  # Sets tick labels to match `all_beta`

    if log_x : ax.set_xscale('log')
    if log_y : ax.set_yscale('log')

    legend_elements = [
        Line2D([0], [0], color='k', linestyle='-', label='Train'),
        Line2D([0], [0], color='k', linestyle='--', label='Test')
        ]
    ax.legend(handles=legend_elements, fontsize=LABEL_FONTSIZE*0.8)


## Adjust layout and add padding
fig.tight_layout(pad=2)  # Adjust padding between plots
plt.subplots_adjust(right=0.85)  # Adjust right boundary of the plot to fit color bar

##
#plt.savefig(f"{LOG_DIR}/mlp_algorithmic_dataset_scaling_alpha_and_beta_{p}"  + '.pdf', dpi=300, bbox_inches='tight', format='pdf')

# plt.show()

In [None]:
cols = 5
rows = len(all_alpha)

figsize=FIGSIZE_SMALL
figsize=(cols*figsize[0], rows*figsize[1])
fig = plt.figure(figsize=figsize)

fig, axes = plt.subplots(rows, cols, figsize=figsize)
if rows!=1 and cols!=1 :
    # flatten
    axes = [axes[i][j] for i in range(len(axes)) for j in range(len(axes[0]))]

log_x=False
log_y=False

color_indices = np.linspace(0, 1, len(all_beta)+1*0)
colors = plt.cm.viridis(color_indices)

k=0
for i, alpha in enumerate(all_alpha):

    #ax = fig.add_subplot(rows, cols, k+1)
    ax = axes[k]
    k+=1
    _, ax, _ = get_twin_axis(ax=ax, no_twin=True)
    #_, ax, ax1 = get_twin_axis(ax=ax, no_twin=False)

    ax.set_title(f'$\\alpha={alpha} \ ({to_plot_label})$', fontsize=LABEL_FONTSIZE)

    for j, beta in enumerate(all_beta) :

        all_steps = all_statistics['all_steps'][alpha][beta]
        if log_x : all_steps = np.array(all_steps) + 1

        if to_plot == "loss" :
            test_errors, train_errors = all_statistics['test_loss'][alpha][beta], all_statistics['train_loss'][alpha][beta]
        elif to_plot == "error" :
            test_errors, train_errors = 1-np.array(all_statistics['test_accuracy'][alpha][beta]), 1-np.array(all_statistics['train_accuracy'][alpha][beta])
        elif to_plot == "accuracy" :
            test_errors, train_errors = all_statistics['test_accuracy'][alpha][beta], all_statistics['train_accuracy'][alpha][beta]

        ax.plot(all_steps, test_errors, '--', color=colors[j], linewidth=LINEWIDTH)
        ax.plot(all_steps, train_errors, '-', label=f'$\\beta={beta}$', color=colors[j], linewidth=LINEWIDTH)

        # Plot times
        # t_2, t_2_index = find_stable_step_final_value(all_steps, test_errors, K=3, tolerance_fraction=0.05, M=2)
        t_1, t_2 = find_memorization_generalization_steps(train_errors, test_errors, all_steps, train_threshold=min(train_errors), test_threshold=min(test_errors))
        #plot_t1_t2(ax, t_1, t_2, log_x, log_y, plot_Delta=True)
        t = t_2
        if t is not None :
            ax.axvline(x=t, ymin=0.01, ymax=1., color=colors[j], linestyle='--', lw=1.)
            ax.plot([t, t], [0, 0], 'o', color='b')


    #if (rows-1)*cols <= i < rows*cols : ax.set_xlabel('Steps (t)', fontsize=LABEL_FONTSIZE)
    #if i%cols==0 : ax.set_ylabel(to_plot_label, fontsize=LABEL_FONTSIZE)
    ax.tick_params(axis='both', labelsize=TICK_LABEL_FONTSIZE)

    ########### Color bar
    # sm = plt.cm.ScalarMappable(cmap='viridis', norm=plt.Normalize(vmin=min(all_beta), vmax=max(all_beta)))
    # import matplotlib.colors as mcolors
    # sm = plt.cm.ScalarMappable(cmap='viridis', norm=mcolors.LogNorm(vmin=min(all_beta), vmax=max(all_beta)))
    # sm.set_array([])  # We only need the colormap here, no actual data
    # cbar = plt.colorbar(sm, ax=ax, location='left', pad=0.2, fraction=0.1, shrink=0.9)
    # plt.tight_layout()  # Automatically adjusts layout
    # cbar.set_label(f'$\\beta$', fontsize=LABEL_FONTSIZE)
    # # # Set the ticks to correspond to the values in `all_beta_1`
    # cbar.set_ticks(all_beta)  # Sets tick positions based on `all_beta`
    # # cbar.set_ticklabels([str(beta) for beta in all_beta])  # Sets tick labels to match `all_beta`

    if log_x : ax.set_xscale('log')
    if log_y : ax.set_yscale('log')

    legend_elements = [
        Line2D([0], [0], color='k', linestyle='-', label='Train'),
        Line2D([0], [0], color='k', linestyle='--', label='Test')
        ]
    ax.legend(handles=legend_elements, fontsize=LABEL_FONTSIZE*0.8)

    for norm_name, p_label in zip(['l0_norm', 'l1_norm', 'l2_norm', 'l*_norm'], [0, 1, 2, '*']):
        #ax = fig.add_subplot(rows, cols, k+1)
        ax = axes[k]
        k+=1
        _, ax, _ = get_twin_axis(ax=ax, no_twin=True)

        ax.set_title(f'$\\alpha={alpha} \ (\ell_{p_label})$', fontsize=LABEL_FONTSIZE)

        for j, beta in enumerate(all_beta) :

            all_steps = all_statistics['all_steps'][alpha][beta]
            if log_x : all_steps = np.array(all_steps) + 1

            ax.plot(all_steps, all_statistics[norm_name][alpha][beta], "-", color=colors[j], label=f'$\\beta={beta}$', linewidth=LINEWIDTH)

             # Plot times
            if to_plot == "loss" :
                test_errors, train_errors = all_statistics['test_loss'][alpha][beta], all_statistics['train_loss'][alpha][beta]
            elif to_plot == "error" or to_plot=="accuracy" :
                test_errors, train_errors = 1-np.array(all_statistics['test_accuracy'][alpha][beta]), 1-np.array(all_statistics['train_accuracy'][alpha][beta])
            # elif to_plot == "accuracy" :
            #     test_errors, train_errors = all_statistics['test_accuracy'][alpha][beta], all_statistics['train_accuracy'][alpha][beta]

            # t_2, t_2_index = find_stable_step_final_value(all_steps, test_errors, K=3, tolerance_fraction=0.05, M=2)
            t_1, t_2 = find_memorization_generalization_steps(train_errors, test_errors, all_steps, train_threshold=min(test_errors), test_threshold=min(test_errors))
            #plot_t1_t2(ax, t_1, t_2, log_x, log_y, plot_Delta=True)
            t = t_2
            if t is not None :
                ax.axvline(x=t, ymin=0.01, ymax=1., color=colors[j], linestyle='--', lw=1.)
                ax.plot([t, t], [0, 0], 'o', color='b')

        # if (rows-1)*cols <= i < rows*cols : ax.set_xlabel('Steps (t)', fontsize=LABEL_FONTSIZE)
        # if i%cols==0 : ax.set_ylabel(f'$\ell_{p_label}$', fontsize=LABEL_FONTSIZE)
        ax.tick_params(axis='both', labelsize=TICK_LABEL_FONTSIZE)

        if log_x : ax.set_xscale('log')
        if log_y : ax.set_yscale('log')

        if (k-2)%cols==0:
            ax.legend(fontsize=LABEL_FONTSIZE*0.8)

# # Create the ScalarMappable for the color bar
# sm = plt.cm.ScalarMappable(cmap='viridis', norm=mcolors.LogNorm(vmin=min(all_beta), vmax=max(all_beta)))
# sm.set_array([])
# # Add a single horizontal color bar on top of the figure
# cbar = fig.colorbar(sm, ax=axes, location='top', orientation='horizontal', pad=0.1, aspect=50, fraction=0.01, shrink=0.7)
# cbar.set_label('$\\beta$', fontsize=LABEL_FONTSIZE)
# cbar.set_ticks(all_beta)

## Adjust layout and add padding
fig.tight_layout(pad=2)  # Adjust padding between plots
plt.subplots_adjust(right=0.85)  # Adjust right boundary of the plot to fit color bar

##
#plt.savefig(f"{LOG_DIR}/mlp_algorithmic_dataset_scaling_alpha_and_beta_{p}_with_norms"  + '.pdf', dpi=300, bbox_inches='tight', format='pdf')

#plt.show()

## Figure 2

In [None]:
all_T_max_dic = [None, None, None]
#all_T_max_dic = [None, 500, 310]
kappa=1.5

In [None]:
L=len(all_alpha)
cols = min(3, L)
rows = L // cols + 1 * (L % cols != 0)

figsize=FIGSIZE_SMALL
figsize=(cols*figsize[0], rows*figsize[1])
fig = plt.figure(figsize=figsize)

log_x=False
log_y=False

color_indices = np.linspace(0, 1, len(all_beta)+1*0)
colors = plt.cm.viridis(color_indices)

for i, alpha in enumerate(all_alpha):

    ax = fig.add_subplot(rows, cols, i+1)
    _, ax, _ = get_twin_axis(ax=ax, no_twin=True)
    #_, ax, ax1 = get_twin_axis(ax=ax, no_twin=False)

    ax.set_title(f'$\\alpha={alpha}$', fontsize=LABEL_FONTSIZE*(3*kappa/4))

    for j, beta in enumerate(all_beta) :

        all_steps = all_statistics['all_steps'][alpha][beta]
        if log_x : all_steps = np.array(all_steps) + 1

        if to_plot == "loss" :
            test_errors, train_errors = all_statistics['test_loss'][alpha][beta], all_statistics['train_loss'][alpha][beta]
        elif to_plot == "error" :
            test_errors, train_errors = 1-np.array(all_statistics['test_accuracy'][alpha][beta]), 1-np.array(all_statistics['train_accuracy'][alpha][beta])
        elif to_plot == "accuracy" :
            test_errors, train_errors = all_statistics['test_accuracy'][alpha][beta], all_statistics['train_accuracy'][alpha][beta]

        T_max = all_T_max_dic[i]
        all_steps, test_errors, train_errors = all_steps[:T_max], test_errors[:T_max], train_errors[:T_max]

        ax.plot(all_steps, test_errors, '--', color=colors[j], linewidth=LINEWIDTH*kappa)
        ax.plot(all_steps, train_errors, '-', label=f'$\\beta={beta}$', color=colors[j], linewidth=LINEWIDTH*kappa)


        # Plot times
        if to_plot == "loss" :
            test_errors, train_errors = all_statistics['test_loss'][alpha][beta], all_statistics['train_loss'][alpha][beta]
        elif to_plot == "error" or to_plot == "accuracy" :
            test_errors, train_errors = 1-np.array(all_statistics['test_accuracy'][alpha][beta]), 1-np.array(all_statistics['train_accuracy'][alpha][beta])
        all_steps, test_errors, train_errors = all_steps[:T_max], test_errors[:T_max], train_errors[:T_max]
        # t_2, t_2_index = find_stable_step_final_value(all_steps, test_errors, K=3, tolerance_fraction=0.05, M=2)
        t_1, t_2 = find_memorization_generalization_steps(train_errors, test_errors, all_steps, train_threshold=min(train_errors), test_threshold=min(test_errors))
        #plot_t1_t2(ax, t_1, t_2, log_x, log_y, plot_Delta=True)
        t = t_2
        if t is not None :
            ax.axvline(x=t, ymin=0.01, ymax=1., color=colors[j], linestyle='--', lw=1.)
            ax.plot([t, t], [0, 0], 'o', color='b')

    if (rows-1)*cols <= i < rows*cols : ax.set_xlabel('Steps (t)', fontsize=LABEL_FONTSIZE*(3*kappa/4))
    if i%cols==0 : ax.set_ylabel(to_plot_label, fontsize=LABEL_FONTSIZE*(3*kappa/4))
    ax.tick_params(axis='both', labelsize=TICK_LABEL_FONTSIZE*(3*kappa/4))

    ########### Color bar
    sm = plt.cm.ScalarMappable(cmap='viridis', norm=plt.Normalize(vmin=min(all_beta), vmax=max(all_beta)))
    import matplotlib.colors as mcolors
    sm = plt.cm.ScalarMappable(cmap='viridis', norm=mcolors.LogNorm(vmin=min(all_beta), vmax=max(all_beta)))
    sm.set_array([])  # We only need the colormap here, no actual data
    cbar = plt.colorbar(sm, ax=ax)
    if i==cols-1: cbar.set_label(f'$\\beta$', fontsize=LABEL_FONTSIZE*(3*kappa/4))
    # # Set the ticks to correspond to the values in `all_beta_1`
    cbar.set_ticks(all_beta)  # Sets tick positions based on `all_beta`
    # cbar.set_ticklabels([str(beta) for beta in all_beta])  # Sets tick labels to match `all_beta`
    cbar.ax.tick_params(labelsize=TICK_LABEL_FONTSIZE*(kappa/2))  #

    if log_x : ax.set_xscale('log')
    if log_y : ax.set_yscale('log')

    legend_elements = [
        Line2D([0], [0], color='k', linestyle='-', label='Train'),
        Line2D([0], [0], color='k', linestyle='--', label='Test')
        ]
    if i==0 :ax.legend(handles=legend_elements, fontsize=LABEL_FONTSIZE*(3*kappa/4))


## Adjust layout and add padding
fig.tight_layout(pad=2)  # Adjust padding between plots
plt.subplots_adjust(right=0.85)  # Adjust right boundary of the plot to fit color bar

##
#plt.savefig(f"{LOG_DIR}/mlp_algorithmic_dataset_scaling_alpha_and_beta_{p}_small_plot"  + '.pdf', dpi=300, bbox_inches='tight', format='pdf')

# plt.show()

## Figure 1

In [None]:
# args, LOG_DIR
def get_all_statistics_mlp(p, all_alpha, all_beta):
    metrics_names = ['train_loss', 'test_loss', 'train_accuracy', 'test_accuracy', 'all_models', 'all_steps', 'l0_norm', 'l1_norm', 'l2_norm', 'l*_norm']
    all_statistics = {key : {} for key in metrics_names  }

    for i, alpha in enumerate(all_alpha):

        args['fileName'] = f"mlp_algorithmic_dataset_l{p}_alpha={alpha}"
        args['exp_dir'] = f"{LOG_DIR}/{args['fileName']}"

        for key in metrics_names :
            all_statistics[key][alpha] = {}

        for j, beta in enumerate(all_beta) :
            print(f"alpha = {alpha}, {(i+1)}/{len(all_alpha)}, beta_{p}={beta}, {(j+1)}/{len(all_beta)}")

            args['exp_id'] = j
            exp_name = args['get_exp_name_function'](args)
            args['checkpoint_path'] = os.path.join(args['exp_dir'], exp_name)

            all_models, statistics = get_all_checkpoints(checkpoint_path=args['checkpoint_path'], exp_name=args['fileName'], just_files=True)

            all_statistics['train_loss'][alpha][beta] = statistics['train']['loss']
            all_statistics['test_loss'][alpha][beta] = statistics['test']['loss']

            all_statistics['train_accuracy'][alpha][beta] = statistics['train']['accuracy']
            all_statistics['test_accuracy'][alpha][beta] = statistics['test']['accuracy']

            all_statistics['all_models'][alpha][beta] = all_models
            for key in ['all_steps', 'l0_norm', 'l1_norm', 'l2_norm', 'l*_norm']:
                all_statistics[key][alpha][beta] = statistics[key]

    return all_statistics

In [None]:
selected_runs = [(1, 0.01), (2, 0.01),  ('nuc', 0.001)] # (p, alpha)
selected_runs = [(1, 0.01), (2, 0.01),  ('nuc', 0.01)] # (p, alpha)
selected_statistics = {}
for p, alpha in selected_runs :
    print(f"p={p}, alpha={alpha}")
    selected_statistics[p] = get_all_statistics_mlp(p, all_alpha=[alpha], all_beta=all_beta_dic[p])

In [None]:
all_T_max_dic = {1:700, 2:500, 'nuc':500}
kappa=1.5

In [None]:
rows, cols = 2, len(selected_runs)

figsize=FIGSIZE_SMALL
figsize=(cols*figsize[0], rows*figsize[1])
fig = plt.figure(figsize=figsize)

log_x=False
log_y=False

color_indices = np.linspace(0, 1, len(all_beta)+1*0)
colors = plt.cm.viridis(color_indices)


for i, (p, alpha) in enumerate(selected_runs) :
    all_statistics = selected_statistics[p]
    all_beta = all_beta_dic[p]
    T_max = all_T_max_dic[p]

    ax = fig.add_subplot(rows, cols, i+1)
    _, ax, _ = get_twin_axis(ax=ax, no_twin=True)
    #_, ax, ax1 = get_twin_axis(ax=ax, no_twin=False)

    #ax.set_title(f'$\\alpha={alpha}$', fontsize=LABEL_FONTSIZE)

    ax.set_title(f"$\ell_{{{'*' if p=='nuc' else p}}}$ reg.", fontsize=LABEL_FONTSIZE*(3*kappa/4))

    for j, beta in enumerate(all_beta) :

        all_steps = all_statistics['all_steps'][alpha][beta]
        if log_x : all_steps = np.array(all_steps) + 1

        if to_plot == "loss" :
            test_errors, train_errors = all_statistics['test_loss'][alpha][beta], all_statistics['train_loss'][alpha][beta]
        elif to_plot == "error" :
            test_errors, train_errors = 1-np.array(all_statistics['test_accuracy'][alpha][beta]), 1-np.array(all_statistics['train_accuracy'][alpha][beta])
        elif to_plot == "accuracy"  :
            test_errors, train_errors = all_statistics['test_accuracy'][alpha][beta], all_statistics['train_accuracy'][alpha][beta]

        all_steps, test_errors, train_errors = all_steps[:T_max], test_errors[:T_max], train_errors[:T_max]
        ax.plot(all_steps, test_errors, '--', color=colors[j], linewidth=LINEWIDTH*kappa)
        ax.plot(all_steps, train_errors, '-', label=f'$\\beta={beta}$', color=colors[j], linewidth=LINEWIDTH*kappa)

        # Plot times

        if to_plot == "loss" :
            test_errors, train_errors = all_statistics['test_loss'][alpha][beta], all_statistics['train_loss'][alpha][beta]
        elif to_plot == "error" or to_plot == "accuracy" :
            test_errors, train_errors = 1-np.array(all_statistics['test_accuracy'][alpha][beta]), 1-np.array(all_statistics['train_accuracy'][alpha][beta])
        all_steps, test_errors, train_errors = all_steps[:T_max], test_errors[:T_max], train_errors[:T_max]
        # elif to_plot == "accuracy"  :
        #     test_errors, train_errors = all_statistics['test_accuracy'][alpha][beta], all_statistics['train_accuracy'][alpha][beta]
        # t_2, t_2_index = find_stable_step_final_value(all_steps, test_errors, K=3, tolerance_fraction=0.05, M=2)
        t_1, t_2 = find_memorization_generalization_steps(train_errors, test_errors, all_steps, train_threshold=min(train_errors), test_threshold=min(test_errors))
        #plot_t1_t2(ax, t_1, t_2, log_x, log_y, plot_Delta=True)
        t = t_2
        if t is not None :
            ax.axvline(x=t, ymin=0.01, ymax=1., color=colors[j], linestyle='--', lw=1.)
            ax.plot([t, t], [0, 0], 'o', color='b')

    #if (rows-1)*cols <= i < rows*cols : ax.set_xlabel('Steps (t)', fontsize=LABEL_FONTSIZE*(3*kappa/4))
    if i%cols==0 : ax.set_ylabel(to_plot_label, fontsize=LABEL_FONTSIZE*(3*kappa/4))
    ax.tick_params(axis='both', labelsize=TICK_LABEL_FONTSIZE*(3*kappa/4))

    ########### Color bar
    sm = plt.cm.ScalarMappable(cmap='viridis', norm=plt.Normalize(vmin=min(all_beta), vmax=max(all_beta)))
    import matplotlib.colors as mcolors
    sm = plt.cm.ScalarMappable(cmap='viridis', norm=mcolors.LogNorm(vmin=min(all_beta), vmax=max(all_beta)))
    sm.set_array([])  # We only need the colormap here, no actual data
    cbar = plt.colorbar(sm, ax=ax)
    if i==cols-1: cbar.set_label(f'$\\beta$', fontsize=LABEL_FONTSIZE*(3*kappa/4))
    # # Set the ticks to correspond to the values in `all_beta_1`
    cbar.set_ticks(all_beta)  # Sets tick positions based on `all_beta`
    # cbar.set_ticklabels([str(beta) for beta in all_beta])  # Sets tick labels to match `all_beta`
    cbar.ax.tick_params(labelsize=TICK_LABEL_FONTSIZE*(kappa/2))  #

    if log_x : ax.set_xscale('log')
    if log_y : ax.set_yscale('log')

    legend_elements = [
        Line2D([0], [0], color='k', linestyle='-', label='Train'),
        Line2D([0], [0], color='k', linestyle='--', label='Test')
        ]
    if i==0 :ax.legend(handles=legend_elements, fontsize=LABEL_FONTSIZE*(3*kappa/4))

    #if (k-2)%cols==0:
    #ax.legend(fontsize=LABEL_FONTSIZE*0.8)

#################################################################################
#################################################################################


p=1
all_statistics = selected_statistics[p]
all_beta = all_beta_dic[p]
T_max = all_T_max_dic[p]
for i, (norm_name, p_label) in enumerate(zip(['l1_norm', 'l2_norm', 'l*_norm'], [1, 2, '*'])):

    ax = fig.add_subplot(rows, cols, 3+i+1)
    _, ax, _ = get_twin_axis(ax=ax, no_twin=True)
    #_, ax, ax1 = get_twin_axis(ax=ax, no_twin=False)

    #ax.set_title(f'$\\alpha={alpha}$', fontsize=LABEL_FONTSIZE)
    #ax.set_title(f'$\\alpha={alpha} \ (\ell_{p_label})$', fontsize=LABEL_FONTSIZE)
    ax.set_title(f'$\ell_{p}$ reg., $\\ell_{p_label}$ norm', fontsize=LABEL_FONTSIZE*(3*kappa/4))

    for j, beta in enumerate(all_beta) :

        all_steps = all_statistics['all_steps'][alpha][beta]
        if log_x : all_steps = np.array(all_steps) + 1


        norms = np.array(all_statistics[norm_name][alpha][beta])
        norms = norms / 1000 # np.max(norms)
        ax.plot(all_steps[:T_max], norms[:T_max], "-", color=colors[j], label=f'$\\beta={beta}$', linewidth=LINEWIDTH*kappa)


        # Plot times
        if to_plot == "loss" :
            test_errors, train_errors = all_statistics['test_loss'][alpha][beta], all_statistics['train_loss'][alpha][beta]
        elif to_plot == "error" or to_plot == "accuracy" :
            test_errors, train_errors = 1-np.array(all_statistics['test_accuracy'][alpha][beta]), 1-np.array(all_statistics['train_accuracy'][alpha][beta])
        all_steps, test_errors, train_errors = all_steps[:T_max], test_errors[:T_max], train_errors[:T_max]
        # elif to_plot == "accuracy"  :
        #     test_errors, train_errors = all_statistics['test_accuracy'][alpha][beta], all_statistics['train_accuracy'][alpha][beta]
        # t_2, t_2_index = find_stable_step_final_value(all_steps, test_errors, K=3, tolerance_fraction=0.05, M=2)
        t_1, t_2 = find_memorization_generalization_steps(train_errors, test_errors, all_steps, train_threshold=min(train_errors), test_threshold=min(test_errors))
        #plot_t1_t2(ax, t_1, t_2, log_x, log_y, plot_Delta=True)
        t = t_2
        if t is not None :
            ax.axvline(x=t, ymin=0.01, ymax=1., color=colors[j], linestyle='--', lw=1.)
            ax.plot([t, t], [0, 0], 'o', color='b')

    #if (rows-1)*cols <= i < rows*cols : ax.set_xlabel('Steps (t)', fontsize=LABEL_FONTSIZE)
    ax.set_xlabel('Steps (t)', fontsize=LABEL_FONTSIZE*(3*kappa/4))
    if i%cols==0 : ax.set_ylabel('Norm ($\\times 10^{-3}$)', fontsize=LABEL_FONTSIZE*(3*kappa/4))
    ax.tick_params(axis='both', labelsize=TICK_LABEL_FONTSIZE*(3*kappa/4))

    ########## Color bar
    sm = plt.cm.ScalarMappable(cmap='viridis', norm=plt.Normalize(vmin=min(all_beta), vmax=max(all_beta)))
    import matplotlib.colors as mcolors
    sm = plt.cm.ScalarMappable(cmap='viridis', norm=mcolors.LogNorm(vmin=min(all_beta), vmax=max(all_beta)))
    sm.set_array([])  # We only need the colormap here, no actual data
    cbar = plt.colorbar(sm, ax=ax)
    if i==cols-1: cbar.set_label(f'$\\beta$', fontsize=LABEL_FONTSIZE*(3*kappa/4))
    # # Set the ticks to correspond to the values in `all_beta_1`
    cbar.set_ticks(all_beta)  # Sets tick positions based on `all_beta`
    # cbar.set_ticklabels([str(beta) for beta in all_beta])  # Sets tick labels to match `all_beta`
    cbar.ax.tick_params(labelsize=TICK_LABEL_FONTSIZE*(kappa/2))  #

    if log_x : ax.set_xscale('log')
    #if log_y : ax.set_yscale('log')

    # legend_elements = [
    #     Line2D([0], [0], color='k', linestyle='-', label='Train'),
    #     Line2D([0], [0], color='k', linestyle='--', label='Test')
    #     ]
    # ax.legend(handles=legend_elements, fontsize=LABEL_FONTSIZE*(3*kappa/4))

#################################################################################
#################################################################################

########

## Adjust layout and add padding
fig.tight_layout(pad=2)  # Adjust padding between plots
plt.subplots_adjust(right=0.85)  # Adjust right boundary of the plot to fit color bar

##
#plt.savefig(f"{LOG_DIR}/mlp_algorithmic_dataset_scaling_beta_l1_l2_lnuc"  + '.pdf', dpi=300, bbox_inches='tight', format='pdf')

# plt.show()