### 1) loading the dataset

In [13]:
limit_length= 10
batch_size= 10
d_model = 32
load_from_backup = True
use_splitting = False
early_stopping_activated = False
half_period_cycle = 5
early_stop_thresh = 3*half_period_cycle
nb_epochs = 300

optimizer_option = "AdamW"
base_lr = 0.00001
max_lr = 0.001
momentum=0.9

restore_from_backup = ["model_params","scheduler","optimizer","losses","metrics"]

In [14]:
from translation_machine import dataset_mod,sentence_mod

import numpy as np
import torch

language_info = torch.load("../../models/language_info.pth")

vocab_french = language_info["french"]["vocab"]
vocab_english = language_info["english"]["vocab"]

max_length_french = language_info["french"]["max_sentence_train_val"]
max_length_english = language_info["english"]["max_sentence_train_val"]


whole_dataset = dataset_mod.DatasetFromTxt("../../data/french_english_dataset/fra.txt")

idxs_whole = np.arange(limit_length)
dataset = torch.utils.data.Subset(whole_dataset,idxs_whole)

In [15]:
dataset = list(dataset_mod.SentenceDataSet(dataset,sentence_type_src=sentence_mod.EnglishSentence,sentence_type_dst=sentence_mod.FrenchSentence))
len(dataset)

10

In [16]:
# Remark : the responsability to split the dataset is done outside of this notebook
if use_splitting:
    idxs_train = np.load("../../dataset_splitting/idx_train.npy")
    idxs_val = np.load("../../dataset_splitting/idx_val.npy")
    idxs_test = np.load("../../dataset_splitting/idx_test.npy")

    idxs_train,idxs_val,idxs_test = [[idx for idx in idxs if idx<len(whole_dataset)] for idxs in [idxs_train,idxs_val,idxs_test]]
    idxs_train = set(idxs_whole).intersection(set(idxs_train))
    idxs_val = set(idxs_whole).intersection(set(idxs_val))
    idxs_test = set(idxs_whole).intersection(set(idxs_test))
    
else:
    idxs_train = idxs_whole
    idxs_val = idxs_whole
    idxs_test = idxs_whole

train_dataset = torch.utils.data.Subset(dataset,idxs_train)
val_dataset = torch.utils.data.Subset(dataset,idxs_val)
test_dataset = torch.utils.data.Subset(dataset,idxs_test)

### 2) creation the vocabulary

In [17]:
from translation_machine import collate_fn_mod

import torch
from torch.utils.data import DataLoader
import numpy as np

collate_fn = collate_fn_mod.get_collate_fn(max_length_english,max_length_french)


train_data_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,collate_fn=collate_fn)
val_data_loader = DataLoader(val_dataset,batch_size=batch_size,shuffle=True,collate_fn=collate_fn)

In [18]:
len(vocab_french.vocab.itos_),len(vocab_english.vocab.itos_)

(5407, 4076)

In [19]:
from translation_machine.models import transformer_mod



model_inputs = {
    "d_model":d_model,
    "vocab_src":sentence_mod.EnglishSentence.vocab,
    "vocab_tgt":sentence_mod.FrenchSentence.vocab,
}

model = transformer_mod.TransformerForSeq2Seq(**model_inputs)


In [33]:
optimizer = torch.optim.NAdam(model.parameters(), lr=base_lr)

In [27]:
optimizer = torch.optim.NAdam(model.parameters(), lr=base_lr)
torch.optim.lr_scheduler.CyclicLR(optimizer,base_lr,100*base_lr)

ValueError: optimizer must support momentum with `cycle_momentum` option enabled

In [24]:
from torch import optim
from translation_machine import model_trainer

from torch import nn
baseline_loss = nn.CrossEntropyLoss(reduction="sum")


if optimizer_option == "AdamW":
    optimizer = torch.optim.NAdam(model.parameters(), lr=base_lr)
elif optimizer_option == "SGD":
    optimizer = torch.optim.SGD(model.parameters(), lr=base_lr, momentum=momentum)
else:
    raise ValueError
scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer,step_size_up=half_period_cycle , base_lr=base_lr, max_lr=max_lr)
model_trainer = model_trainer.ModelTrainer(model,optimizer,scheduler,train_data_loader,val_data_loader,baseline_loss)


ValueError: optimizer must support momentum with `cycle_momentum` option enabled

In [None]:
from tqdm import tqdm
model.train()
losses = {"train":[],"val":[]}
metrics = {"train":[],"val":[]}


In [None]:
from pathlib import Path
import torch
path_model_and_dependencies = "../../models/sequence_translator_transformer_over_fitted_adamw.pth"

if load_from_backup and Path(path_model_and_dependencies).exists():
    back_up = torch.load(path_model_and_dependencies)
    for el1,el2 in zip([model,scheduler,optimizer,scheduler,losses,metrics],
                      ["model_params","scheduler","optimizer","losses","metrics"]):
        if el2 in restore_from_backup:
            if el2 == "losses":
                losses = back_up[el2]
            elif el2 == "metrics":
                metrics = back_up[el2]
            else:
                el1.load_state_dict(back_up[el2])
                
    print("model loaded")

In [None]:
import matplotlib.pyplot as plt

best_loss_val_mean = np.inf
best_epoch = scheduler.last_epoch

for epoch in tqdm(range(nb_epochs)):
    #import time
    #start = time.time()
    print(f"optimizing for epoch {epoch}")
    print("training_step")
    loss_train,nb_words_per_batch_train,metric_train = model_trainer.train_on_epoch()
    
    print("validation_step")
    loss_val,nb_words_per_batch_val,metric_val = model_trainer.validate_on_epoch()

    loss_train = np.array([float(el) for el in loss_train])
    loss_val = np.array([float(el) for el in loss_val])
    train_weights = 1/sum(nb_words_per_batch_train)
    val_weights = 1/sum(nb_words_per_batch_val)
    
    
        
    losses["train"].append(np.sum(loss_train)/sum(nb_words_per_batch_train))
    losses["val"].append(np.sum(loss_val)/sum(nb_words_per_batch_val))
    metrics["train"].append(metric_train)
    metrics["val"].append(metric_val)
    
    current_loss_val_mean = np.mean(loss_val)
    
    if (current_loss_val_mean < best_loss_val_mean):
        best_epoch = scheduler.last_epoch
        best_loss_val_mean = current_loss_val_mean

        state_dict_extended = {"model_params":model_trainer.model.state_dict(),
                               "model_inputs":model_inputs,
                              "optimizer":optimizer.state_dict(),
                              "scheduler":scheduler.state_dict(),
                              "losses":losses,
                              "metrics":metrics
                              }
        
        torch.save(state_dict_extended,path_model_and_dependencies)
        print(f"saving for epoch {epoch}")

        plt.plot(losses["train"],"b*")
        plt.plot(losses["val"],"g*")
        plt.title("losses")
        plt.show()        
        #plt.figure()
        #plt.plot(metrics["train"],"b*")
        #plt.plot(metrics["val"],"g*")
        #plt.title("bleu score")
        #plt.show()
    elif epoch - best_epoch > early_stop_thresh  and early_stopping_activated:
        print("Early stopped training at epoch %d" % epoch)
        break  # terminate the training loop
    #stop = time.time()
    #print(stop-start)

    del loss_train,nb_words_per_batch_train,metric_train

    del loss_val,nb_words_per_batch_val,metric_val


In [None]:
import matplotlib.pyplot as plt
plt.plot(losses["train"],"b*")
plt.plot(losses["val"],"g*")
plt.title("losses")
plt.show()        