### I) define hyper_parameters, datasets, and IO operations

In [13]:
# convention : all new inputs parameters for the notebbok through papermil,
# that are progressivelt added have a default value, equal to the one tha would be used in
# order obtain the same results of scripts

# two types of inputs: 
# inputs that influences the training (e.g hyper-paramteres, dataset set and splitting)
# inputs that controls state of training (reset it or load from)
import torch

import argparse
from pathlib import Path



from dataclasses import dataclass
import argparse
@dataclass
class NotebookRun:
    simple_hyp_params : argparse.Namespace
    optimization_control : argparse.Namespace
    dataset_control : argparse.Namespace
    state_train_control : argparse.Namespace
    paths_from_training : argparse.Namespace
    
    def __hash__(self):
        tmp = tuple((
            tuple(self.simple_hyp_params.__dict__.items()),
            tuple(self.opt_params.__dict__.items()),
            tuple(self.dset_truncation.__dict__.items()),
            tuple(self.state_train_control.__dict__.items()),
            tuple(self.paths_from_training.__dict__.items())
        ))
        hash_value = hash(tmp)
        return hash_value
    
notebook_run = NotebookRun(simple_hyp_params,opt_params,
                           dset_truncation,train_state_control,
                           paths)

In [14]:
from translation_machine.models import transformer_mod
from translation_machine import sentence_mod

model_inputs = {
    "d_model":simple_hyp_params.d_model,
    "vocab_src":sentence_mod.EnglishSentence.vocab,
    "vocab_tgt":sentence_mod.FrenchSentence.vocab,
}

model = transformer_mod.TransformerForSeq2Seq(**model_inputs)


In [15]:
simple_hyp_params.early_stop_thresh = opt_params.half_period_cycle*simple_hyp_params.early_stop_steps_per_half_clr_cycle,


optimizer = torch.optim.NAdam(model.parameters(), lr=opt_params.base_lr)
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,step_size_up= opt_params.half_period_cycle , 
                                              base_lr=opt_params.base_lr, max_lr=opt_params.max_lr,
                                              cycle_momentum=False)


# II) load the vocabulary

In [4]:
from translation_machine import dataset_mod,sentence_mod

import torch,numpy as np

language_info = torch.load(paths.path_language_info)

vocab_french = language_info["french"]["vocab"]
vocab_english = language_info["english"]["vocab"]


len(vocab_french),len(vocab_english)

(722, 694)

# III) Load the dataset

In [5]:
if dset_truncation.limit_length is None:
    dset_truncation.limit_length = language_info["limit_length"]
else:
    dset_truncation.limit_length = min(language_info["limit_length"],dset_truncation.limit_length)
    
whole_dataset = dataset_mod.DatasetFromTxt(paths.path_dataset)

idxs_whole = np.arange(dset_truncation.limit_length)
dataset = torch.utils.data.Subset(whole_dataset,idxs_whole)

dataset = list(dataset_mod.SentenceDataSet(dataset,sentence_type_src=sentence_mod.EnglishSentence,sentence_type_dst=sentence_mod.FrenchSentence))
len(dataset)

20000

### IV) preparation of notebook params for serialization (for the purpose of associating the run to results of training), 

In [6]:
max_length_from_file = False
if max_length_from_file:
    max_length_french = language_info["french"]["max_sentence_train_val"]
    max_length_english = language_info["english"]["max_sentence_train_val"]
else:# get max length from current dataset, which is prefered
    import itertools
    tmp = [(len(el[0]),len(el[1])) for el in dataset]
    a,b = zip(*tmp)
    max_length_english  = max(a)
    max_length_french = max(b)
    
max_length_english,max_length_french

(7, 20)

### V) dataset splitting

In [7]:
# Remark : the responsability to split the dataset is done outside of this notebook
from pathlib import Path

if dset_truncation.use_splitting:
    path_dataset_splitting = paths.path_dataset_splitting
    path_idxs_train = str(Path(path_dataset_splitting).joinpath("idxs_train.npy"))
    path_idxs_val = str(Path(path_dataset_splitting).joinpath("idxs_val.npy"))
    path_idxs_test = str(Path(path_dataset_splitting).joinpath("idxs_test.npy"))

    idxs_train = np.load(path_idxs_train)
    idxs_val = np.load(path_idxs_val)
    idxs_test = np.load(path_idxs_test)

    idxs_train,idxs_val,idxs_test = [[idx for idx in idxs if idx<len(whole_dataset)] for idxs in [idxs_train,idxs_val,idxs_test]]
    idxs_train = list(set(idxs_whole).intersection(set(idxs_train)))
    idxs_val = list(set(idxs_whole).intersection(set(idxs_val)))
    idxs_test = list(set(idxs_whole).intersection(set(idxs_test)))
    
else:
    idxs_train = idxs_whole
    idxs_val = idxs_whole
    idxs_test = idxs_whole

train_dataset = torch.utils.data.Subset(dataset,idxs_train)
val_dataset = torch.utils.data.Subset(dataset,idxs_val)
test_dataset = torch.utils.data.Subset(dataset,idxs_test)
len(train_dataset),len(val_dataset),len(test_dataset)

(15556, 2222, 2222)

### VI) dataloader construction

In [8]:
from translation_machine import collate_fn_mod
from torch.utils.data import DataLoader

import torch
import numpy as np

collate_fn = collate_fn_mod.get_collate_fn(max_length_english,max_length_french)

train_data_loader = DataLoader(train_dataset,batch_size=simple_hyp_params.batch_size,
                               shuffle=True,collate_fn=collate_fn)
val_data_loader = DataLoader(val_dataset,batch_size=simple_hyp_params.batch_size,
                             shuffle=True,collate_fn=collate_fn)

In [9]:
len(vocab_french.vocab.itos_),len(vocab_english.vocab.itos_)

(722, 694)

In [10]:
from torch import optim,nn

from translation_machine import model_trainer

baseline_loss = nn.CrossEntropyLoss(reduction="sum")


#scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,step_size_up=half_period_cycle , base_lr=base_lr, max_lr=max_lr)
model_trainer = model_trainer.ModelTrainer(model,optimizer,scheduler,train_data_loader,val_data_loader,baseline_loss)


In [11]:
from tqdm import tqdm
model.train()
losses = {"train":[],"val":[]}
metrics = {"train":[],"val":[]}


In [12]:
from pathlib import Path
import torch

path_model_and_dependencies = paths.path_model_and_dependencies
if train_state_control.load_from_backup and Path(path_model_and_dependencies).exists():
    back_up = torch.load(path_model_and_dependencies)
    for el1,el2 in zip([model,scheduler,optimizer,scheduler,losses,metrics],
                      ["model_params","scheduler","optimizer","losses","metrics"]):
        if el2 in train_state_control.restore_from_backup:
            if el2 == "losses":
                losses = back_up[el2]
            elif el2 == "metrics":
                metrics = back_up[el2]
            else:
                el1.load_state_dict(back_up[el2])
                
    print("model loaded")