In [None]:
! git clone https://github.com/Tikquuss/lwd_grokking

In [None]:
%cd lwd_grokking

In [None]:
! pip install -r requirements.txt

In [None]:
import os

from src.modeling import Model

In [None]:
%load_ext tensorboard

# Train

In [None]:
#! wandb login 27a83be2529992fa4451956a0536d35825426b45

## cmd

##### On run

In [None]:
! chmod +x train.sh 
! ./train.sh Styblinski-Tang

In [None]:
%tensorboard --logdir /content/log_files/0/Styblinski-Tang:ndim=2-tdf=80-wd=0.0-lr=0.001-d=0.0-opt=adam-alpha=1.0-beta=1.0/lightning_logs

In [None]:
pretrained_filename = "/content/log_files/0/Styblinski-Tang:ndim=2-tdf=80-wd=0.0-lr=0.001-d=0.0-opt=adam-alpha=1.0-beta=1.0/epoch=0-val_loss_y=1997.3793.ckpt"
model = Model.load_from_checkpoint(pretrained_filename)

In [None]:
#! rm -r /content/log_files/0

##### Multiple run (for phase diagram)

In [None]:
! chmod +x train_loop.sh 
! ./train_loop.sh Styblinski-Tang

In [None]:
#! rm -r /content/log_files

## Without cmd (see multiple_runs.py) : Allows to visualize directly the embedding evolution in the notebook output

In [None]:
from src.utils import AttrDict
from src.dataset import get_dataloader
from src.trainer import train
from src.functions import get_function

In [None]:
# Styblinski-Tang, Ackley, Beale, Booth, Bukin, McCormick, Rosenbrock, Sum, Prod
f_name="Styblinski-Tang"
train_pct=80
weight_decay=0.0
lr=0.001
dropout=0.0
opt="adam"
ndim=2
group_name=f"{f_name}:ndim={ndim}-tdf={train_pct}-wd={weight_decay}-lr={lr}-d={dropout}-opt={opt}"
random_seed=0
log_dir="../log_files"
alpha=1.0
beta=1.0
params = AttrDict({
    ### Main parameters
    "exp_id" : f"{group_name}",
    "log_dir" : f"{log_dir}/{random_seed}",

    ### Model
    "hidden_dim" : 512,  
    "n_layers" : 1,
    "dropout" : dropout,

    ### Dataset
    "func_params": AttrDict({"f_name" : f_name}),
    #"func_params" : AttrDict({"f_name" : f_name, "min_x": -5, "max_x" : 5, "min_y" : -5, "max_y" : 5, "step_x" : 0.25, "step_y" : 0.25}),
    "ndim" : 2,
    "num_samples" : 1000,
    "noise_params" : None,
    #"noise_params" : {"distribution" : "normal", "loc" : 0.0, "scale" : 1.0},
    "normalize" : False,
    "train_pct" : train_pct,
    "batch_size" : 512,
       
    ### Optimizer
    "optimizer" : f"{opt},lr={lr},weight_decay={weight_decay},beta1=0.9,beta2=0.99,eps=0.00000001",
    #"alpha_beta" : None, 
    "alpha_beta" : {"alpha" : alpha, "beta" : beta}, 
 
    ### LR Scheduler
    "lr_scheduler" : None,
    #"lr_scheduler" : "reduce_lr_on_plateau,factor=0.2,patience=20,min_lr=0.00005,mode=min,monitor=val_loss_y",
        
    ### Training
    "max_epochs" : 10000, 
    "validation_metrics" : "val_loss_y",
    "checkpoint_path" : None, 
    "model_name": "None", 
    "every_n_epochs":100, 
    "every_n_epochs_show":200, 
    "early_stopping_patience":1e9, 
    "save_top_k":-1,

    # Wandb 
    "use_wandb" : False,
    "wandb_entity" : "grokking_ppsp",
    "wandb_project" : f"lerning_with_derivative",
    "group_name" : group_name,

    "group_vars" : None,

    ### Intrinsic Dimension Estimation
    #"ID_params" : {},
    #"ID_params": {"method" : "mle", "k":2},
    "ID_params": {"method" : "twonn"},
        
    ### Devices & Seed
    "accelerator" : "auto",
    "devices" : "auto",
    "random_seed": random_seed,

    ### Early_stopping (for grokking) : Stop the training `patience` epochs after the `metric` has reached the value `metric_threshold` 
    #"early_stopping_grokking" : None,
    "early_stopping_grokking" : "patience=int(1000),metric=str(val_loss_y),metric_threshold=float(0.0)"
})
if params.alpha_beta is not None : 
    params.group_name=f"{params.group_name}-alpha={alpha}-beta={beta}"
    params.exp_id=params.group_name
params["weight_decay"] = weight_decay
params["f_name"] = f_name
func_params = get_function(params.func_params)
params.func_params = func_params
train_loader, val_loader, dataloader, data_infos, data_config = get_dataloader(
    func_params.callable_function, params.ndim, func_params.min_x, func_params.max_x, params.num_samples, params.train_pct, 
    deriv_function = getattr(func_params, "callable_function_deriv", None), noise_params=params.noise_params, 
    batch_size=params.batch_size, num_workers=2, normalize=params.normalize
)
params["data_infos"] = data_infos
params["data_config"] = data_config

##### On run

In [None]:
model, result = train(params, train_loader, val_loader)

In [None]:
%tensorboard --logdir /content/log_files/0/Styblinski-Tang:ndim=2-tdf=80-wd=0.0-lr=0.001-d=0.0-opt=adam-alpha=1.0-beta=1.0/lightning_logs

In [None]:
#! rm -r /content/log_files/0

##### Multiple run (for phase diagram) : see multiple_runs.py or train_parallel.py

In [None]:
#! python multiple_runs.py
#! python train_parallel.py --parallel False --f_name Styblinski-Tang --ndim 2

In [None]:
#! rm -r /content/log_files

In [None]:
import numpy as np

from multiple_runs import plot_results, itertools
from src.utils import get_group_name

In [None]:
lrs = [1e-3]
#lrs = [1e-2, 1e-3, 1e-4, 1e-5] 
#lrs = np.linspace(start=1e-1, stop=1e-5, num=10)

weight_decays = [0.0]
#weight_decays = list(range(20))
#weight_decays =  np.linspace(start=0, stop=20, num=21)

print(lrs, weight_decays)

In [None]:
model_dict = {}
i = 0
for a, b in itertools.product(lrs, weight_decays) :
    params["lr"] = a 
    params["optimizer"] = params["optimizer"].replace(f"weight_decay={weight_decay}", f"weight_decay={b}")
    
    name = f"lr={a},weight_decay={b}"
    params.exp_id = name
        
    #group_vars = GROUP_VARS + ["lr", s]
    group_vars = ["lr", "weight_decay"]
    group_vars = list(set(group_vars))
    params["group_name"] = get_group_name(params, group_vars = None)
        
    print("*"*10, i, name, "*"*10)
    i+=1

    model, result = train(params, train_loader, val_loader)
        
    model_dict[name] = {"model": model, "result": result}

In [None]:
print(model_dict.keys())

In [None]:
val_loss = [model_dict[k]["result"]["val"]["val_loss_y"] for k in model_dict]
val_loss_dydx = [model_dict[k]["result"]["val"]["val_loss_dydx"] for k in model_dict]
print(val_loss, val_loss_dydx)

In [None]:
plot_results(params, model_dict, 
    hparms_1 = lrs, hparms_2 = weight_decays,
    s1 = 'lr', s2 = "weight_decay",
    title = None, save_to = f"{params.log_dir}/result_multiple_run.png", show = True
)

In [None]:
import math
import torch

phases = {"grokking" : 0, "comprehension" : 1, "memorization" : 2, "confusion" : 3}

M = len(lrs)
N = len(weight_decays)
img = torch.zeros(size=(M,N), dtype=int)

for i in range(M) :
    for j in range(N) :
        a = lrs[i]
        b = weight_decays[j]
        name = f"lr={a}, weight_decay={b}"
        # model = model_dict[name]["model"]
        # if model.confusion : img[i][j] = phases["confusion"]
        # elif model.memorization : img[i][j] = phases["memorization"]
        # else : # grokking or comprehension
        #     """
        #     The difference between grokking and comprehension is the number of training steps or epochs that separate memorization and comprehension, 
        #     because the model potentially goes through the following phases : confusion > memorization > comprehension
        #     """
        #     grok = model.grok
        #     comprehension = model.comprehension
        #     # In this phase : 0 <= model.memo_epoch <= model.comp_epoch < +inf
        #     diff_epoch = model.comp_epoch - model.memo_epoch
        #     if not math.isnan(diff_epoch) : 
        #         grok = diff_epoch >= 100
        #         comprehension = not grok
        #     img[i][j] = phases["grokking"] if grok else phases["comprehension"]

        img[i][j] = np.random.randint(
            #low=0, 
            low=1, 
            high=3+1,
            #high=2+1
        )

In [None]:
from src.visualize_phases import visualize_phases

In [None]:
_ = visualize_phases(img, phases, 
                     interpolation='hermite', 
                     title = None, save_to = None, show = True, pixel_wise_text = False
)

In [None]:
%tensorboard --logdir '/content/log_files/0/lr=0.001, weight_decay=0.0/lightning_logs'

## Video animation (visualize the evolution of embedding during training)

In [None]:
from src.trainer import images_to_vidoe

In [None]:
root_dir = os.path.join(model.hparams.log_dir, model.hparams.exp_id, "images")
for dirname in os.listdir(root_dir) :
    image_folder = os.path.join(root_dir, dirname)
    if os.path.isdir(image_folder):
        print(image_folder)
        try :
            video_path = os.path.join(model.hparams.log_dir, model.hparams.exp_id, f'{dirname}.avi')
            images_to_vidoe(image_folder, video_path, format="png")
            print(video_path)
        except IndexError: #list index out of range
            print("Error")

In [None]:
root_dir = os.path.join(model.hparams.log_dir, model.hparams.exp_id, "images")
video_path = os.path.join(model.hparams.log_dir, model.hparams.exp_id, f'grid.avi')
print(video_path)
images_to_vidoe(root_dir, video_path, format="png")