### I) define hyper_parameters, datasets, and IO operations

In [1]:
# convention : all new inputs parameters for the notebbok through papermil,
# that are progressivelt added have a default value, equal to the one tha would be used in
# order obtain the same results of scripts

# two types of inputs: 
# inputs that influences the training (e.g hyper-paramteres, dataset set and splitting)
# inputs that controls state of training (reset it or load from)

batch_size= 64
d_model = 512
early_stopping_activated = False
half_period_cycle = 5
early_stop_thresh = 3*half_period_cycle
nb_epochs = 300


path_dataset = "../../data/french_english_dataset/fra.txt"
limit_length= None
use_splitting = True
path_language_info = "../../models/language_info.pth"
path_dataset_splitting = "../../dataset_splitting"
max_length_from_file = False

optimizer_option = "AdamW"

base_lr = 10**(-6)
max_lr = 0.0005
momentum = 0.9

load_from_backup = True
restore_from_backup = tuple(["model_params","scheduler","optimizer","losses","metrics"])



### II) load the vocabulary

In [2]:
import torch
language_info = torch.load(path_language_info)

vocab_french = language_info["french"]["vocab"]
vocab_english = language_info["english"]["vocab"]


if limit_length is None:
    limit_length = language_info["limit_length"]
else:
    limit_length = min(language_info["limit_length"],limit_length)
len(vocab_french),len(vocab_english)

(722, 694)

### III) preparation of notebook params for serialization (for the purpose of associating the run to results of training), 

In [3]:
from dataclasses import dataclass

# by simple hyper
_local_variable =locals()
simple_hyper_parameters = {key:_local_variable[key] for key in ["batch_size","d_model",
                                                                "early_stopping_activated","half_period_cycle",
                                                                "early_stop_thresh","nb_epochs"]}

dataset_control = {key:_local_variable[key] for key in ["path_dataset","path_language_info",
                                                                "limit_length","use_splitting"]}

optimization_control = {key:_local_variable[key] for key in ["optimizer_option","base_lr",
                                                                "max_lr","momentum"]}

state_train_control = {key:_local_variable[key] for key in ["load_from_backup","restore_from_backup"]}


@dataclass
class NotebookRun:
    simple_hyper_parameters : dict
    optimization_control : dict
    dataset_control : dict
    state_train_control : dict
    
    def __hash__(self):
        tmp = tuple((
            tuple(self.simple_hyper_parameters.items()),
            tuple(self.optimization_control.items()),
            tuple(self.dataset_control.items()),
            tuple(self.state_train_control.items())         
        ))
        hash_value = hash(tmp)
        return hash_value
    
notebook_run = NotebookRun(simple_hyper_parameters,optimization_control,dataset_control,state_train_control)

In [4]:
from translation_machine import dataset_mod,sentence_mod

import numpy as np
import torch


whole_dataset = dataset_mod.DatasetFromTxt(path_dataset)
if limit_length is not None:
    idxs_whole = np.arange(limit_length)
    dataset = torch.utils.data.Subset(whole_dataset,idxs_whole)
else:
    idxs_whole = np.arange(len(whole_dataset))
    dataset = whole_dataset
    
    
dataset = list(dataset_mod.SentenceDataSet(dataset,sentence_type_src=sentence_mod.EnglishSentence,sentence_type_dst=sentence_mod.FrenchSentence))
len(dataset)

20000

In [5]:
max_length_from_file = False
if max_length_from_file:
    max_length_french = language_info["french"]["max_sentence_train_val"]
    max_length_english = language_info["english"]["max_sentence_train_val"]
else:# get max length from current dataset, which is prefered
    import itertools
    tmp = [(len(el[0]),len(el[1])) for el in dataset]
    a,b = zip(*tmp)
    max_length_english  = max(a)
    max_length_french = max(b)
    
max_length_english,max_length_french

(7, 20)

### IV) Storage of notebook params (for serialization)

In [6]:
# Remark : the responsability to split the dataset is done outside of this notebook
from pathlib import Path

if use_splitting:
    path_idxs_train = str(Path(path_dataset_splitting).joinpath("idxs_train.npy"))
    path_idxs_val = str(Path(path_dataset_splitting).joinpath("idxs_val.npy"))
    path_idxs_test = str(Path(path_dataset_splitting).joinpath("idxs_test.npy"))

    idxs_train = np.load(path_idxs_train)
    idxs_val = np.load(path_idxs_val)
    idxs_test = np.load(path_idxs_test)

    idxs_train,idxs_val,idxs_test = [[idx for idx in idxs if idx<len(whole_dataset)] for idxs in [idxs_train,idxs_val,idxs_test]]
    idxs_train = list(set(idxs_whole).intersection(set(idxs_train)))
    idxs_val = list(set(idxs_whole).intersection(set(idxs_val)))
    idxs_test = list(set(idxs_whole).intersection(set(idxs_test)))
    
else:
    idxs_train = idxs_whole
    idxs_val = idxs_whole
    idxs_test = idxs_whole

train_dataset = torch.utils.data.Subset(dataset,idxs_train)
val_dataset = torch.utils.data.Subset(dataset,idxs_val)
test_dataset = torch.utils.data.Subset(dataset,idxs_test)
len(train_dataset),len(val_dataset),len(test_dataset)

(15556, 2222, 2222)

### 2) dataloader construction

In [7]:
from translation_machine import collate_fn_mod
from torch.utils.data import DataLoader

import torch
import numpy as np

collate_fn = collate_fn_mod.get_collate_fn(max_length_english,max_length_french)

train_data_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,collate_fn=collate_fn)
val_data_loader = DataLoader(val_dataset,batch_size=batch_size,shuffle=True,collate_fn=collate_fn)

In [8]:
len(vocab_french.vocab.itos_),len(vocab_english.vocab.itos_)

(722, 694)

In [9]:
from translation_machine.models import transformer_mod

model_inputs = {
    "d_model":d_model,
    "vocab_src":sentence_mod.EnglishSentence.vocab,
    "vocab_tgt":sentence_mod.FrenchSentence.vocab,
}

model = transformer_mod.TransformerForSeq2Seq(**model_inputs)


In [10]:
from torch import optim,nn

from translation_machine import model_trainer

baseline_loss = nn.CrossEntropyLoss(reduction="sum")


if optimizer_option == "AdamW":
    optimizer = torch.optim.NAdam(model.parameters(), lr=base_lr)
elif optimizer_option == "SGD":
    optimizer = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=momentum)
else:
    raise ValueError
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,step_size_up=half_period_cycle , base_lr=base_lr, max_lr=max_lr,cycle_momentum=False)
#scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,step_size_up=half_period_cycle , base_lr=base_lr, max_lr=max_lr)
model_trainer = model_trainer.ModelTrainer(model,optimizer,scheduler,train_data_loader,val_data_loader,baseline_loss)


In [11]:
from tqdm import tqdm
model.train()
losses = {"train":[],"val":[]}
metrics = {"train":[],"val":[]}


In [14]:
from pathlib import Path
import torch
path_model_and_dependencies = "../../models/sequence_translator_transformer.pth"

if load_from_backup and Path(path_model_and_dependencies).exists():
    back_up = torch.load(path_model_and_dependencies)
    for el1,el2 in zip([model,scheduler,optimizer,scheduler,losses,metrics],
                      ["model_params","scheduler","optimizer","losses","metrics"]):
        if el2 in restore_from_backup:
            if el2 == "losses":
                losses = back_up[el2]
            elif el2 == "metrics":
                metrics = back_up[el2]
            else:
                el1.load_state_dict(back_up[el2])
                
    print("model loaded")