In [None]:
! git clone https://github.com/Tikquuss/sag_torch

In [None]:
%cd sag_torch

In [None]:
# %cd /content
# !rm -r -f sag_torch

In [None]:
! pip install -r requirements.txt

In [None]:
import os
import torch
from src.modeling import Model
from src.dataset import TORCH_SET

In [None]:
%load_ext tensorboard

# Train

In [None]:
#! wandb login 27a83be2529992fa4451956a0536d35825426b45

## cmd

##### On run

In [None]:
! chmod +x train.sh 
! ./train.sh 

In [None]:
! rm -r /content/log_files

In [None]:
# %tensorboard --logdir /content/log_files/0/classification_wd=0.1-lr=0.001-d=0.1-opt=adam
%tensorboard --logdir /content/log_files/0/iris/wd=0.0-lr=0.001-d=0.5-opt=adam/lightning_logs

In [None]:
logdir = "/content/log_files/0"
id = "/iris/wd=0.0-lr=0.001-d=0.5-opt=adam"
logdir += "/" + id

In [None]:
##
import re
import os 

def sorted_nicely(l): 
    """ Sort the given iterable in the way that humans expect.
    https://stackoverflow.com/a/2669120/11814682
    """ 
    convert = lambda text: int(text) if text.isdigit() else text 
    alphanum_key = lambda key: [convert(c) for c in re.split('([0-9]+)', key) ] 
    return sorted(l, key = alphanum_key)

pretrained_folder = logdir

#pattern = '^epoch_[0-9]+.ckpt$'
pattern = '^epoch=[0-9]+-val_loss=[0-9]+\.[0-9]+.ckpt$'

model_files = os.listdir(pretrained_folder)
model_files = [f for f in model_files if re.match(pattern, f)]
model_files = sorted_nicely(model_files)
#model_files = ["init.ckpt"] + model_files
model_files = [pretrained_folder + "/" + f for f in model_files]

L = len(model_files)
print(L)

model_files[-10:]

In [None]:
model = Model.load_from_checkpoint(model_files[-1])

In [None]:
params = torch.load(logdir + "/params.pt")
data_module = torch.load(logdir+"/data.pt")

In [None]:
examples = enumerate(data_module.test_dataloader())
batch_idx, (example_data, example_targets, indexes) = next(examples)
batch_idx, example_data.shape

In [None]:
if params.dataset_name in TORCH_SET :
    import matplotlib.pyplot as plt

    fig = plt.figure()
    for i in range(6):
      plt.subplot(2,3,i+1)
      plt.tight_layout()
      plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
      plt.title("Ground Truth: {}".format(example_targets[i]))
      plt.xticks([])
      plt.yticks([])
    fig

In [None]:
with torch.no_grad():
  output = model(example_data.to(model.device))

In [None]:
if params.dataset_name in TORCH_SET :
    fig = plt.figure()
    for i in range(6):
      plt.subplot(2,3,i+1)
      plt.tight_layout()
      plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
      plt.title("Prediction: {}".format(
        output.data.max(1, keepdim=True)[1][i].item()))
      plt.xticks([])
      plt.yticks([])
    fig

##### Multiple run 

In [None]:
! chmod +x train_loop.sh
! ./train_loop.sh 

## Without cmd (see multiple_runs.py) : Allows to visualize directly the embedding evolution in the notebook output

In [None]:
import pytorch_lightning as pl
import torch

from src.utils import AttrDict
from src.dataset import LMLightningDataModule
from src.trainer import train

In [None]:
weight_decay=0.0
lr=0.001
dropout=0.5
opt="adam"
group_name=f"wd={weight_decay}-lr={lr}-d={dropout}-opt={opt}"

random_seed=0
log_dir="../log_files"

dataset_name="iris"
train_pct=80

#val_metric="val_acc"
val_metric="val_loss"

opt=f"{opt},weight_decay={weight_decay},beta1=0.9,beta2=0.99,eps=0.00000001"
opt="sag"
opt=f"sgd,weight_decay={weight_decay}"
opt=f"sag,weight_decay={weight_decay},batch_mode=False,init_y_i=True"


params = AttrDict({
    ### Main parameters
    "exp_id" : f"{dataset_name}",
    "log_dir" : f"{log_dir}",

    ### Model
	  "c_out" :  [10, 10],
	  "hidden_dim" :  [50],
	  "kernel_size" : [5],
	  "kernel_size_maxPool" : 2,
	  "dropout"  : dropout,

    ### Dataset
    "dataset_name":dataset_name,
    "train_batch_size" : 512,
    "val_batch_size" : 512,
	  "train_pct" : train_pct,
	  "val_pct" : 100,

    ### Optimizer
    "optimizer" : opt,
    "lr" : lr,

    ### LR Scheduler
    "lr_scheduler" : None,
    #"lr_scheduler" : "reduce_lr_on_plateau,factor=0.2,patience=20,min_lr=0.00005,mode=min,monitor=val_loss",
    
    ### Training
    "max_epochs" : 10, 
    "validation_metrics" : "val_loss",
    "checkpoint_path" : None, 
    "model_name": "", 
    "every_n_epochs":1, 
    "every_n_epochs_show":1, 
    "early_stopping_patience":1e9, 
    "save_top_k":-1,

    # Wandb 
    "use_wandb" : False,
    "wandb_entity" : "grokking_ppsp",
    "wandb_project" : f"dataset={dataset_name}",
    "group_name" : group_name,

    "group_vars" : None,
    
    # Devices & Seed
    "accelerator" : "auto",
    "devices" : "auto",
    "random_seed": random_seed,

    ### Early_stopping (for grokking) : Stop the training `patience` epochs after the `metric` has reached the value `metric_threshold` 
    #"early_stopping_grokking" : None,
    "early_stopping_grokking" : f"patience=int(1000),metric=str({val_metric}),metric_threshold=float(90.0)",

})

pl.seed_everything(params.random_seed, workers=True)
# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.determinstic = True
torch.backends.cudnn.benchmark = False

root_dir = os.path.join(params.log_dir, params.exp_id, params.group_name, str(params.random_seed)) 
os.makedirs(root_dir, exist_ok=True)

data_module = LMLightningDataModule(
    dataset_name = params.dataset_name,
    train_batch_size = params.train_batch_size,
    val_batch_size = params.val_batch_size,
    train_pct = params.train_pct,
    val_pct = params.val_pct,
    data_path = params.log_dir + "/data",
    #num_workers = params.num_workers,
)
setattr(params, "data_infos", data_module.data_infos)
setattr(params, "train_dataset", data_module.train_dataset)

##### On run

In [None]:
model, result = train(params, data_module, root_dir)

In [None]:
%tensorboard --logdir /content/log_files/0/classification_tdf=80-wd=0.0-r_lr=0.001-d_lr=0.001-r_d=0.0-d_d=0.0-opt=adam/lightning_logs

In [None]:
#! rm -r /content/log_files/0

##### Multiple run (for phase diagram) : see multiple_runs.py or train_parallel.py

In [None]:
! python multiple_runs.py

In [None]:
import numpy as np

from multiple_runs import plot_results, itertools
from src.utils import get_group_name

In [None]:
lrs = [1e-2, 1e-3, 1e-4, 1e-5] 
#lrs = np.linspace(start=1e-1, stop=1e-5, num=10)

weight_decays = [0, 1]
#weight_decays = list(range(20))
#weight_decays =  np.linspace(start=0, stop=20, num=21)

s = "weight_decay"
assert s in params["optimizer"]
print(lrs, weight_decays)

In [None]:
model_dict = {}
i = 0
for a, b in itertools.product(lrs, weight_decays) :
    params["lr"] = a 
    params["optimizer"] = params["optimizer"].replace(f"{s}={weight_decay}", f"{s}={b}")
    
    name = f"lr={a}, {s}={b}"
    params.exp_id = name
        
    #group_vars = GROUP_VARS + ["lr", s]
    group_vars = ["lr", s]
    group_vars = list(set(group_vars))
    setattr(params, s, b)
    params["group_name"] = get_group_name(params, group_vars = group_vars)
        
    print("*"*10, i, name, "*"*10)
    i+=1

    model, result = train(params, data_module, root_dir)
        
    model_dict[name] = {"model": model, "result": result}

In [None]:
print(model_dict.keys())

In [None]:
val_loss = [model_dict[k]["result"]["val"]["val_loss"] for k in model_dict]
val_acc = [model_dict[k]["result"]["val"].get("val_acc", 0) for k in model_dict]
print(val_loss, val_acc)

In [None]:
plot_results(params, model_dict, 
    hparms_1 = lrs, hparms_2 = weight_decays,
    s1 = 'lr', s2 = s
)