In [None]:
#! git clone https://github.com/Tikquuss/sag_torch

In [None]:
#%cd sag_torch

In [1]:
#! pip install -r requirements.txt

In [2]:
import os

from src.modeling import Model

In [3]:
%load_ext tensorboard

# Train

In [None]:
#! wandb login $som_key

## cmd

##### On run

In [None]:
! chmod +x train.sh 
! ./train.sh 

In [None]:
%tensorboard --logdir /content/log_files/0/classification_tdf=80-wd=0.0-r_lr=0.001-d_lr=0.001-r_d=0.0-d_d=0.0-opt=adam/lightning_logs

In [None]:
pretrained_filename = "/content/log_files/0/classification_tdf=80-wd=0.0-r_lr=0.001-d_lr=0.001-r_d=0.0-d_d=0.0-opt=adam/epoch=1-val_loss=5.2339.ckpt"
model = Model.load_from_checkpoint(pretrained_filename)

In [None]:
#! rm -r /content/log_files/0

##### Multiple run (for phase diagram)

In [None]:
! chmod +x train_loop.sh
! ./train_loop.sh 

## Without cmd (see multiple_runs.py) : Allows to visualize directly the embedding evolution in the notebook output

In [None]:
from src.utils import AttrDict
from src.dataset import get_dataloader
from src.trainer import train

In [None]:
train_pct=80
weight_decay=0.0
representation_lr=0.001
decoder_lr=0.001
representation_dropout=0.0
decoder_dropout=0.0
opt="adam"

group_name=f"tdf={train_pct}-wd={weight_decay}-r_lr={representation_lr}-d_lr={decoder_lr}-r_d={representation_dropout}-d_d={decoder_dropout}-opt={opt}"

random_seed=0
operator="+"
modular=False

log_dir="../log_files"

p = 100
task = "classification"

params = AttrDict({
    ### Main parameters
    "task" : task,
    "exp_id" : f"{task}_{group_name}",
    "log_dir" : f"{log_dir}/{random_seed}",

    ### Model
    "emb_dim" : 256, 
    "hidden_dim" : 512,  
    "n_layers" : 1,
	"representation_dropout" : representation_dropout,
	"decoder_dropout" : decoder_dropout,
    "pad_index" : None, 
    "p" : p, 

    ### Dataset
    "operator" : operator, 
    "modular" : modular,
    "train_pct" : train_pct,
    "batch_size" : 512,

    ### Optimizer
    "optimizer" : f"{opt},weight_decay={weight_decay},beta1=0.9,beta2=0.99,eps=0.00000001",
    "representation_lr" : representation_lr,
    "decoder_lr" : decoder_lr,

    ### LR Scheduler
    "lr_scheduler" : None,
    #"lr_scheduler" : "reduce_lr_on_plateau,factor=0.2,patience=20,min_lr=0.00005,mode=min,monitor=val_loss",
    
    ### Training
    "max_epochs" : 2, 
    "validation_metrics" : "val_loss",
    "checkpoint_path" : None, 
    "model_name": "", 
    "every_n_epochs":1, 
    "every_n_epochs_show":1, 
    "early_stopping_patience":1e9, 
    "save_top_k":-1,

    # Wandb 
    "use_wandb" : False,
	"wandb_entity" : "grokking_ppsp",
	"wandb_project" : f"toy_model_grokking_op={operator}-p={p}-task={task}-mod={modular}",
    "group_name" : group_name,

    "group_vars" : None,

    ### Intrinsic Dimension Estimation
    #"ID_params" : {},
    #"ID_params": {"method" : "mle", "k":2},
    "ID_params": {"method" : "twonn"},
    
    # Devices & Seed
    "accelerator" : "auto",
    "devices" : "auto",
    "random_seed": random_seed,

    ### Early_stopping (for grokking) : Stop the training `patience` epochs after the `metric` has reached the value `metric_threshold` 
    #"early_stopping_grokking" : None,
    "early_stopping_grokking" : "patience=int(1000),metric=str(val_acc),metric_threshold=float(90.0)",

})
params["weight_decay"] = weight_decay
params["regression"] = task == "regression"
train_loader, val_loader, dataloader, data_infos = get_dataloader(
    p, train_pct, regression = params.regression, operator=params.operator, 
    modular = params.modular, batch_size=params.batch_size, num_workers=2
)
print(data_infos)
params["data_infos"] = data_infos

##### On run

In [None]:
model, result = train(params, train_loader, val_loader)

In [None]:
%tensorboard --logdir /content/log_files/0/classification_tdf=80-wd=0.0-r_lr=0.001-d_lr=0.001-r_d=0.0-d_d=0.0-opt=adam/lightning_logs

In [None]:
#! rm -r /content/log_files/0

##### Multiple run (for phase diagram) : see multiple_runs.py or train_parallel.py

In [None]:
#! python multiple_runs.py
#! python train_parallel.py --parallel False

In [None]:
from multiple_runs import plot_results, itertools
from src.utils import get_group_name

In [None]:
decoder_lrs = representation_lrs = [1e-2, 1e-3] 
#decoder_lrs = representation_lrs = np.linspace(start=1e-1, stop=1e-5, num=10)

weight_decays = list(range(20))
#weight_decays =  np.linspace(start=0, stop=20, num=21)

flag = True # if True, decoder_lrs if True, else weight_decays
if flag : s = "decoder_lr"
else : s = "weight_decay"
print(representation_lrs, decoder_lrs if flag else weight_decays)

In [None]:
model_dict = {}
i = 0
for a, b in itertools.product(representation_lrs, decoder_lrs if flag else weight_decays) :

    params["representation_lr"] = a 
    if flag : params[s] = b
    else : params["optimizer"] = params["optimizer"].replace(f"{s}={weight_decay}", f"{s}={b}")
  
    name = f"representation_lr={a}, {s}={b}"
    params.exp_id = name
    
    #group_vars = GROUP_VARS + ["representation_lr", s]
    group_vars = ["representation_lr", s]
    group_vars = list(set(group_vars))
    params["group_name"] = get_group_name(params, group_vars = group_vars)
    
    print("*"*10, i, name, "*"*10)
    i+=1

    model, result = train(params, train_loader, val_loader)
    
    model_dict[name] = {"model": model, "result": result}

In [None]:
print(model_dict.keys())

In [None]:
val_loss = [model_dict[k]["result"]["val"]["val_loss"] for k in model_dict]
val_acc = [model_dict[k]["result"]["val"].get("val_acc", 0) for k in model_dict]
print(val_loss, val_acc)

In [None]:
plot_results(params, model_dict, 
    hparms_1 = representation_lrs, hparms_2 = decoder_lrs if flag else weight_decays,
    s1 = 'representation_lr', s2 = s
)

In [None]:
# for a, b in itertools.product(representation_lrs, decoder_lrs if flag else weight_decays) :
#     name = f"representation_lr={a}, {s}={b}"
#     model = model_dict[name]["model"]

## Visualize embedding 2&3D with plotly (This and the following sections only need `model`)

In [None]:
#_ = display_pca_scatterplot(model.hparams.p, model, dim=3)
_ = display_pca_scatterplot(model.hparams.p, model, dim=2, title="Embeddings")

In [None]:
word_vectors = model.mlp[-1].weight
_ = display_pca_scatterplot(word_vectors.size(0), model, word_vectors = word_vectors, dim=2, title = f"last_layer_weight")

## Video animation (visualize the evolution of embedding during training)

In [None]:
root_dir = os.path.join(model.hparams.log_dir, model.hparams.exp_id, "images")
for dirname in os.listdir(root_dir) :
    image_folder = os.path.join(root_dir, dirname)
    if os.path.isdir(image_folder):
        print(image_folder)
        try :
            video_path = os.path.join(model.hparams.log_dir, model.hparams.exp_id, f'{dirname}.avi')
            images_to_vidoe(image_folder, video_path, format="png")
            print(video_path)
        except IndexError: #list index out of range
            print("Error")

In [None]:
root_dir = os.path.join(model.hparams.log_dir, model.hparams.exp_id, "images")
video_path = os.path.join(model.hparams.log_dir, model.hparams.exp_id, f'grid.avi')
print(video_path)
images_to_vidoe(root_dir, video_path, format="png")

## Visualize the learned set of embeddings (if embed_dim=2)

In [None]:
if model.hparams.emb_dim == 2 :
    img = visualize_embeddings_good(model, A = None, B = None, N = 500, 
                                    interpolation=None,
                                    #interpolation='hermite',
                                    figsize=(5,5), title = "learned_set_of_embeddings",
                                    save_to='/content/learned_set_of_embeddings.png'
                                    ) 

## Visualize embedding 2D (good)

In [None]:
# !pip install folium==0.2.1
# !pip install pdflatex
# !sudo apt-get install texlive-latex-recommended 
# !sudo apt install texlive-latex-extra
# !sudo apt install dvipng

In [None]:
# ! python src/analyze_embedding.py

In [None]:
from src.analyze_embedding import display_pca_scatterplot_simple
import numpy as np
import torch

In [None]:
save_path = os.path.join(model.hparams.log_dir, model.hparams.exp_id, "Visualize_embedding")
os.makedirs(save_path, exist_ok=True)

### Embedding

In [None]:
data_dim, words = display_pca_scatterplot(model.hparams.p, model, dim=2, return_data = True)
#word_vectors = model.embeddings.weight
#data_dim, words = display_pca_scatterplot(model.hparams.p, model, word_vectors = word_vectors, dim=2, return_data = True)

N=data_dim.shape[0] # 
filename = f"{model.hparams.p}_structured_embedding"
display_pca_scatterplot_simple(data_dim, words, N, save_path, filename, 
                               #plot_line = True,
                               plot_line = False, 
                               cmap='viridis', eps = 0.01)

### Embedding + Prediction (before PCA)

In [None]:
data_predict = np.zeros_like(data_dim) # (p, 2)

if False :
    data_predict[:2] = data_dim[:2] + 0

    tmp1 = data_dim[1] - data_dim[0]
    tmp1 = tmp1[None].repeat(model_ld.hparams.p - 2, axis=0) # (p-2, 2)

    tmp2 = data_dim[0][None].repeat(model_ld.hparams.p - 2, axis=0) # (p-2, 2)

    data_predict[2:] = tmp2 + np.arange(2, model_ld.hparams.p)[..., None] * tmp1 # (p,2)
else :
    data_predict[0] = data_dim[0] + 0
    data_predict[-1] = data_dim[-1] + 0
    data_predict[1:-1] = (data_dim[:-2] + data_dim[2:]) / 2

N=data_predict.shape[0] # 
filename = f"{model.hparams.p}_structured_embedding_1"
#display_pca_scatterplot_simple(data_predict, words, N, filename, plot_line = True, cmap='viridis', eps = 0.01)

display_pca_scatterplot_simple(data_dim, words, N, save_path, filename, 
                               #plot_line = True, 
                               plot_line = False, 
                               cmap='viridis', eps = 0.01, preicted_data_dim = data_predict,
                               #legend_loc="center",
                               #legend_loc="upper center",
                               legend_loc="best"
                               )

### Embedding + Prediction (after PCA)

In [None]:
data_dim_tmp = model.embeddings.weight.detach().cpu().numpy() # (p, embed_dim)
data_predict = np.zeros_like(data_dim_tmp) # (p, embed_dim)

if False :
    data_predict[:2] = data_dim_tmp[:2]

    tmp1 = data_dim_tmp[1] - data_dim_tmp[0]
    tmp1 = tmp1[None].repeat(model_ld.hparams.p - 2, axis=0) # (p-2, embed_dim)

    tmp2 = data_dim_tmp[0][None].repeat(model_ld.hparams.p - 2, axis=0) # (p-2, embed_dim)

    data_predict[2:] = tmp2 + np.arange(2, model_ld.hparams.p)[..., None] * tmp1 # (p, embed_dim)
else :
    data_predict[0] = data_dim_tmp[0]
    data_predict[-1] = data_dim_tmp[-1]
    data_predict[1:-1] = (data_dim_tmp[:-2] + data_dim_tmp[2:]) / 2

word_vectors = torch.from_numpy(data_predict)
data_predict, words = display_pca_scatterplot(word_vectors.size(0), model, word_vectors = word_vectors, dim=2, return_data = True)

N=data_predict.shape[0] # 
filename = f"{model.hparams.p}_structured_embedding_2"
#display_pca_scatterplot_simple(data_predict, words, N, filename, plot_line = True, cmap='viridis', eps = 0.01)

display_pca_scatterplot_simple(data_dim, words, N, save_path, filename, 
                               #plot_line = True, 
                               plot_line = False, 
                               cmap='viridis', eps = 0.01, preicted_data_dim = data_predict,
                               #legend_loc="center",
                               legend_loc="upper center"
                               )

### Last layer weights

In [None]:
word_vectors = model.mlp[-1].weight
data_dim, words = display_pca_scatterplot(word_vectors.size(0), model, word_vectors = word_vectors, dim=2, return_data = True)

N=data_dim.shape[0]
filename = f"{model.hparams.p}_structured_last_layer_weights"
display_pca_scatterplot_simple(data_dim, words, N, save_path, filename, plot_line = False, cmap='viridis', 
                               #eps = 0.01,
                               eps = 0.000001,
                               )

### Last layer weights + prediction (before PCA)

In [None]:
data_predict = np.zeros_like(data_dim) # (p, 2)

if False :
    data_predict[:2] = data_dim[:2] + 0

    tmp1 = data_dim[1] - data_dim[0]
    tmp1 = tmp1[None].repeat(model_ld.hparams.p - 2, axis=0) # (p-2, 2)

    tmp2 = data_dim[0][None].repeat(model_ld.hparams.p - 2, axis=0) # (p-2, 2)

    data_predict[2:] = tmp2 + np.arange(2, model_ld.hparams.p)[..., None] * tmp1 # (p,2)
else :
    data_predict[0] = data_dim[0] + 0
    data_predict[-1] = data_dim[-1] + 0
    data_predict[1:-1] = (data_dim[:-2] + data_dim[2:]) / 2

N=data_predict.shape[0] # 
filename = f"{model.hparams.p}_structured_last_layer_weights_1"
#display_pca_scatterplot_simple(data_predict, words, N, filename, plot_line = True, cmap='viridis', eps = 0.01)

display_pca_scatterplot_simple(data_dim, words, N, save_path, filename, plot_line = False, cmap='viridis', 
                               #eps = 0.01, 
                               eps = 0.000001,
                               preicted_data_dim = data_predict,
                              #legend_loc="center",
                               #legend_loc="upper center", 
                               legend_loc="best", 
                               )

## Analyse embedding

In [None]:
from src.analyze_embedding import analyze

In [None]:
save_path = os.path.join(model.hparams.log_dir, model.hparams.exp_id, "analyse_embedding")
os.makedirs(save_path, exist_ok=True)

In [None]:
analyze(model, option = 1, save_path = save_path)

In [None]:
analyze(model, option = 2, save_path = save_path)

In [None]:
analyze(model, option = 3, save_path = save_path)

In [None]:
analyze(model, option = 4, save_path = save_path)