In [None]:
import torch,inspect
import argparse
from pathlib import Path
from functools import partial
import numpy as np
from argparse import Namespace
import torch
from dev import namespace_tools


In [None]:
# nested namespace arguement containing all elements associated to the training setup

notebook_run = Namespace(
    simple_hp = Namespace(
        batch_size= 16,
        d_model = 64,
        early_stop_thresh = np.inf,
        nb_epochs = 200,
        warm_up_epochs = 20,
    ),
    # parameters to limit the size of the dataset
    dset_truncation = Namespace(
        limit_length= 64,
        use_splitting = False,
 #set to False,if you want to overfit the model on the training set 
        max_length_from_file = False,
        recompute_vocabulary = True,
    ),
    # parameters for the optimization algorithm
    opt_params = Namespace(
        unlinked_optimizer = partial(torch.optim.NAdam,lr=0.0001),
        unlinked_scheduler = partial(torch.optim.lr_scheduler.ReduceLROnPlateau, mode='min', 
                                     factor=0.8, patience=5,min_lr=10**(-6))
    ),
    # parameters to reload the model
    train_state_control = Namespace(             
        load_from_backup =True,
        restore_optimizer = True
    ),
    #paths from root
    paths = namespace_tools.Paths(
        path_dataset = "data/french_english_dataset/fra.txt",
        path_language_info = "models/language_info.pth",
        path_dataset_splitting = "dataset_splitting",
        path_model_and_dependencies = f"models/sequence_translator_transformer_over_fitted_next.pth",
        root = "../.."
    )
    
)

In [None]:
notebook_run = namespace_tools.NameSpaceAggregation(notebook_run)
notebook_run.diffuse(globals())

In [None]:
from ploomber_engine.ipython import PloomberClient
from pathlib import Path
from argparse import Namespace

from translation_machine.models import transformer_mod
from translation_machine import sentence_mod


# initialize client
client = PloomberClient.from_path(Path("./training_setup.ipynb"),cwd=Path("./"))
train_setup = client.get_namespace(notebook_run.diffuse())
for key,val in train_setup.items():
        globals()[key] = val

In [None]:
# revert to train mode
model.train()
model.training

In [None]:
from translation_machine import model_trainer_mod
model_trainer = model_trainer_mod.ModelTrainer(model,optimizer,train_data_loader,val_data_loader,baseline_loss,device)

In [None]:
token_ids,counts = np.unique(np.vstack([el[1] for el in train_data_loader]),return_counts=True)
token_id_to_count = dict(zip(token_ids,counts))
token_id_to_count = {key:val for (key,val) in token_id_to_count.items() if key !=0}
nb_tokens = sum(token_id_to_count.values())
token_id_to_freq = {key:val/nb_tokens  for (key,val) in token_id_to_count.items()}
token_to_freq = {sentence_mod.FrenchSentence.vocab.itos_[key]:val  for (key,val) in token_id_to_freq.items()}

In [None]:
dict(sorted(token_to_freq.items(), key=lambda item: item[1]))

In [None]:
# import matplotlib.pyplot as plt,numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
best_epoch = scheduler.last_epoch

for epoch in tqdm(range(simple_hp.nb_epochs)):
    #import time
    #start = time.time()
    print(f"training for epoch {epoch}")
    print(f"for epoch {epoch} learning rate is {optimizer.param_groups[0]['lr']}" )
    print("training_step")
    loss_train,nb_words_per_batch_train,metric_train = model_trainer.train_on_epoch()
    print("validation_step")
    loss_val,nb_words_per_batch_val,metric_val = model_trainer.validate_on_epoch()

    sum_loss_train = torch.tensor(loss_train).sum()
    sum_loss_val = torch.tensor(loss_val).sum()
    mean_train_loss = sum_loss_train/sum(nb_words_per_batch_train)
    mean_val_loss = sum_loss_val/sum(nb_words_per_batch_val)

    scheduler.step(mean_val_loss)

        
    print(f"for epoch {epoch} mean loss on train {mean_train_loss}")
    print(f"for epoch {epoch} mean loss on val {mean_val_loss}")
        
    losses["train"].append(mean_train_loss)
    losses["val"].append(mean_val_loss)
    metrics["train"].append(metric_train)
    metrics["val"].append(metric_val)
    
    if (mean_val_loss < best_loss_val_mean):
        best_epoch = scheduler.last_epoch
        best_loss_val_mean = mean_val_loss

        model_training_state = {"model_params":model_trainer.model.state_dict(),
                               "model_inputs":model_inputs,
                              "optimizer":optimizer.state_dict(),
                              "scheduler":scheduler.state_dict(),
                              }
        results = { "losses":losses,
                   "metrics":metrics}
        new_back_up = dict()
        if "back_up" in globals():
            new_back_up["notebook_runs"] = back_up["notebook_runs"] + tuple([notebook_run.state_dict()])
        else:
            new_back_up["notebook_runs"] = tuple([notebook_run.state_dict()])

        new_back_up["results"] = results
        new_back_up["model_training_state"] = model_training_state
        
        back_up = new_back_up
        torch.save(back_up,paths.path_model_and_dependencies)
        print(f"saving for epoch {epoch}")
        
        plt.plot(losses["train"],"b*")
        plt.plot(losses["val"],"g*")
        plt.title("losses")
        plt.savefig("loss_curve")
        #import pdb;pdb.set_trace()
    elif epoch - best_epoch > simple_hp.early_stop_thresh  and epoch > simple_hp.warm_up_epochs:
        print("Early stopped training at epoch %d" % epoch)
        break  # terminate the training loop

    del loss_train,nb_words_per_batch_train,metric_train

    del loss_val,nb_words_per_batch_val,metric_val


In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
plt.plot(results["losses"]["train"],"b*")
plt.plot(results["losses"]["val"],"g*")
plt.title("losses")
plt.savefig(f'test.png', bbox_inches='tight')