In [1]:
import torch,inspect
import argparse
from pathlib import Path
from functools import partial
import numpy as np
from argparse import Namespace
from dev import namespace_tools
# nested namespace arguement containing all elements associated to the training setup

notebook_run = Namespace(
    simple_hp = Namespace(
        batch_size= 32,
        d_model = 64,
        early_stop_thresh = np.inf, # default to np.inf
        nb_epochs = 500,
        warm_up_epochs = 20,
    ),
    # parameters to limit the size of the dataset
    dset_truncation = Namespace(
        limit_length= 1,
        use_splitting = False,
        max_length_from_file = False,
    ),
    # parameters for the optimization algorithm
    opt_params = Namespace(
        optimizer = partial(torch.optim.NAdam,lr=0.001),
        scheduler = partial(torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,15,T_mult=2,eta_min=10**(-6))
    ),
    # parameters to reload the model
    train_state_control = Namespace(             
        load_from_backup = False,
        restore_optimizer = False
    ),
    #paths from root
    paths = Namespace(
        path_dataset = "data/french_english_dataset/fra.txt",
        path_language_info = "models/language_info.pth",
        path_dataset_splitting = "dataset_splitting",
        path_model_and_dependencies = "models/sequence_translator_transformer_new.pth"
    )

)

notebook_run = namespace_tools.NameSpaceAggregation(notebook_run)

In [2]:
request_class = torch.optim.NAdam

In [25]:
from dev import module_io

In [None]:
serialization = module_io.serialize(request_class)
class_name,module_name = serialization
loaded_class = module_io.get_callable(class_name,module_name)
loaded_class == request_class

True

In [24]:
import importlib
module = importlib.import_module('torch.optim.nadam')
module

<module 'torch.optim.nadam' from '/root/miniconda/lib/python3.9/site-packages/torch/optim/nadam.py'>

In [20]:
the_class = torch.optim.lr_scheduler.CosineAnnealingLR
serialization = module_io.serialize(the_class)
loaded_class = module_io.get_module_from_name(*serialization)
loaded_class,serialization
assert loaded_class == the_class

In [None]:
get_module(request_class.__name__,get_name_module(request_class))

In [None]:
notebook_run_new == notebook_run

In [None]:
set(state_dict["opt_params"]["optimizer"].keys()) == {'func','args','keywords'}

In [None]:
from ploomber_engine.ipython import PloomberClient
from ploomber import DAG
from pathlib import Path
from ploomber.products import File

# initialize client
client = PloomberClient.from_path(Path("./training_setup.ipynb"))#,cwd=Path("../../"))
from argparse import Namespace

from translation_machine.models import transformer_mod
from translation_machine import sentence_mod

initial_namespace_as_dict = {key:globals()[key] for key in ["simple_hyp_params","dset_truncation",
                                                                         "optimizer_creator","scheduler_creator",
                                                                         "train_state_control","paths"]}
train_setup = client.get_namespace(initial_namespace_as_dict)
for key,val in train_setup.items():
        globals()[key] = val

In [None]:
help(PloomberClient.from_path)

In [None]:
# revert to train mode
model.train()
model.training

In [None]:
from translation_machine import model_trainer_mod

model_trainer = model_trainer_mod.ModelTrainer(model,optimizer,scheduler,train_data_loader,val_data_loader,baseline_loss)

In [None]:
from dataclasses import dataclass
import argparse

@dataclass
class NotebookRun:
    simple_hyp_params : argparse.Namespace
    optimization_control : argparse.Namespace
    dataset_control : argparse.Namespace
    state_train_control : argparse.Namespace
    paths_from_training : argparse.Namespace
    
    def __hash__(self):

        hash_value = hash(self.__to_dict__)
        return hash_value
    
    def state_dict(self):
        as_dict = {key:self.__dict__[key] for key in ["simple_hyp_params","optimization_control",            
            "dataset_control","state_train_control",            
            "paths_from_training"]
                  }
        return as_dict
    def load_state_dict(self,state_dict):
        self.__dict__.update(**state_dict)
    
    def __eq__(self,other):
        for el in ["simple_hyp_params","optimization_control",            
            "dataset_control","state_train_control",            
            "paths_from_training"]:
            if self.__dict__[el] != other.__dict__[el]:
                return False
        return True
        
opt_params = argparse.Namespace(optimizer=optimizer,
                               scheduler=scheduler)

notebook_run = NotebookRun(simple_hyp_params,opt_params,
                           dset_truncation,train_state_control,
                           paths)


In [None]:
## import matplotlib.pyplot as plt,numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
best_loss_val_mean = np.inf
best_epoch = scheduler.last_epoch

for epoch in tqdm(range(simple_hyp_params.nb_epochs)):
    #import time
    #start = time.time()
    print(f"training for epoch {epoch}")
    print(f"for epoch {epoch} learning rate is {optimizer.param_groups[0]['lr']}" )
    print("training_step")
    loss_train,nb_words_per_batch_train,metric_train = model_trainer.train_on_epoch()
    print("validation_step")
    loss_val,nb_words_per_batch_val,metric_val = model_trainer.validate_on_epoch()
    
    sum_loss_train = torch.tensor(loss_train).sum()
    sum_loss_val = torch.tensor(loss_val).sum()
    mean_train_loss = sum_loss_train/sum(nb_words_per_batch_train)
    mean_val_loss = sum_loss_val/sum(nb_words_per_batch_val)
    
    print(f"for epoch {epoch} mean loss on train {mean_train_loss}")
    print(f"for epoch {epoch} mean loss on val {mean_val_loss}")
        
    losses["train"].append(mean_train_loss)
    losses["val"].append(mean_val_loss)
    metrics["train"].append(metric_train)
    metrics["val"].append(metric_val)
    
    if (mean_val_loss < best_loss_val_mean):
        best_epoch = model_trainer.scheduler.last_epoch
        best_loss_val_mean = mean_val_loss

        model_training_state = {"model_params":model_trainer.model.state_dict(),
                               "model_inputs":model_inputs,
                              "optimizer":optimizer.state_dict(),
                              "scheduler":scheduler.state_dict(),
                              }
        results = { "losses":losses,
                   "metrics":metrics}
        new_back_up = dict()
        if "back_up" in globals():
            new_back_up["notebook_runs"] = back_up["notebook_runs"] + tuple([notebook_run.state_dict()])
        else:
            new_back_up["notebook_runs"] = tuple([notebook_run.state_dict()])

        new_back_up["results"] = results
        new_back_up["model_training_state"] = model_training_state
        
        back_up = new_back_up
        torch.save(back_up,path_model_and_dependencies)
        print(f"saving for epoch {epoch}")
        
        plt.plot(losses["train"],"b*")
        plt.plot(losses["val"],"g*")
        plt.title("losses")
        plt.savefig("loss_curve")
        #import pdb;pdb.set_trace()
    elif epoch - best_epoch > simple_hyp_params.early_stop_thresh  and epoch > simple_hyp_params.warm_up_epochs:
        print("Early stopped training at epoch %d" % epoch)
        break  # terminate the training loop

    del loss_train,nb_words_per_batch_train,metric_train

    del loss_val,nb_words_per_batch_val,metric_val


In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
plt.plot(losses["train"],"b*")
plt.plot(losses["val"],"g*")
plt.title("losses")
plt.savefig(f'test.png', bbox_inches='tight')