### 1) loading the dataset

In [2]:
from translation_machine import dataset_mod,sentence_mod

import numpy as np
import torch

language_info = torch.load("../../models/language_info.pth")

vocab_french = language_info["french"]["vocab"]
vocab_english = language_info["english"]["vocab"]

max_length_french = language_info["french"]["max_sentence_train_val"]
max_length_english = language_info["english"]["max_sentence_train_val"]


whole_dataset = dataset_mod.DatasetFromTxt("../../data/french_english_dataset/fra.txt")

whole_dataset = list(dataset_mod.SentenceDataSet(whole_dataset,sentence_type_src=sentence_mod.EnglishSentence,sentence_type_dst=sentence_mod.FrenchSentence))

# Remark : the responsability to split the dataset is done outside of this notebook

idxs_train = np.load("../../dataset_splitting/idx_train.npy")
idxs_val = np.load("../../dataset_splitting/idx_val.npy")
idxs_test = np.load("../../dataset_splitting/idx_test.npy")

train_dataset = torch.utils.data.Subset(whole_dataset,idxs_train)
val_dataset = torch.utils.data.Subset(whole_dataset,idxs_val)
test_dataset = torch.utils.data.Subset(whole_dataset,idxs_test)

### 2) creation the vocabulary

In [3]:
from translation_machine import collate_fn_mod

import torch
from torch.utils.data import DataLoader
import numpy as np

collate_fn = collate_fn_mod.get_collate_fn(max_length_english,max_length_french)

batch_size= 128

train_data_loader = DataLoader(train_dataset,batch_size=batch_size,shuffle=True,collate_fn=collate_fn)
val_data_loader = DataLoader(val_dataset,batch_size=batch_size,shuffle=True,collate_fn=collate_fn)
test_data_loader = DataLoader(test_dataset,batch_size=batch_size,shuffle=True,collate_fn=collate_fn)


In [4]:
len(vocab_french.vocab.itos_),len(vocab_english.vocab.itos_)

(5407, 4076)

In [5]:
from translation_machine import modelsdel_mod,model_bidirectionnal_mod



bidirectional_encoder = True
en_embeddings_size = 128
fr_embeddings_size = 128

hidden_size_encoder = 256

nb_directions = (2 if bidirectional_encoder else 1)

hidden_size_decoder = 256*nb_directions


model_inputs = {
    "embeddings_src_size":en_embeddings_size,
    "embeddings_tgt_size":fr_embeddings_size,
    "hidden_size_encoder":hidden_size_encoder,
    "hidden_size_decoder":hidden_size_decoder,
    "vocab_src":vocab_english,
    "vocab_tgt":vocab_french,
    "length_src_sentence":max_length_english,
    "length_tgt_sentence":max_length_french,
    "bidirectional_encoder":bidirectional_encoder
}

sequence_translator = model_bidirectionnal_mod.SequenceTranslator(**model_inputs)


ImportError: cannot import name 'model_mod' from 'translation_machine' (/root/nmt/src/translation_machine/__init__.py)

In [None]:
from torch import optim
from translation_machine import model_trainer

from torch import nn
from torch import nn
baseline_loss = nn.CrossEntropyLoss(reduction="sum")

optimizer = optim.NAdam(params=sequence_translator.parameters(),lr=0.5)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)
model_trainer = model_trainer.ModelTrainer(sequence_translator,optimizer,scheduler,train_data_loader,val_data_loader,baseline_loss)


In [None]:
from tqdm import tqdm
sequence_translator.train()
losses_on_train = []
losses_on_val = []
metrics_on_train = []
metrics_on_val = []

In [None]:
from pathlib import Path
import torch
load_from_backup = True
path_model_and_dependencies = "../models/sequence_translator_bidirectionnal.pth"

if load_from_backup and Path(path_model_and_dependencies).exists():
    back_up = torch.load(path_model_and_dependencies)
    sequence_translator.load_state_dict(back_up["model_params"])
    scheduler.load_state_dict(back_up["scheduler"])
    optimizer.load_state_dict(back_up["optimizer"])
    losses_train_and_weights = back_up["losses_with_weights"]["train"]
    losses_val_and_weights = back_up["losses_with_weights"]["val"]
    metrics_on_train = back_up["metrics"]["train"]
    metrics_on_val = back_up["metrics"]["val"]
    print("model loaded")

In [None]:
import matplotlib.pyplot as plt
early_stop_thresh = 3
nb_epochs = 20

best_loss_val_mean = np.inf
best_epoch = scheduler.last_epoch

for epoch in tqdm(range(nb_epochs)):
    import time
    start = time.time()
    print(f"optimizing for epoch {epoch}")
    print("training_step")
    loss_train,nb_words_per_batch_train,metric_train = model_trainer.train_on_epoch()
    print("validation_step")
    loss_val,nb_words_per_batch_val,metric_val = model_trainer.validate_on_epoch()


    loss_train = np.array([float(el) for el in loss_train])
    loss_val = np.array([float(el) for el in loss_val])
    train_weights = 1/sum(nb_words_per_batch_train)
    val_weights = 1/sum(nb_words_per_batch_val)
    
    
        
    losses_on_train.append(np.sum(loss_train)/sum(nb_words_per_batch_train))
    losses_on_val.append(np.sum(loss_val)/sum(nb_words_per_batch_val))
    metrics_on_train.append(metric_train)
    metrics_on_val.append(metric_val)
    
    current_loss_val_mean = np.mean(loss_val)
    
    if (current_loss_val_mean < best_loss_val_mean) and (epoch-best_epoch<5):
        best_epoch = scheduler.last_epoch
        best_loss_val_mean = current_loss_val_mean

        state_dict_extended = {"model_params":model_trainer.model.state_dict(),
                               "model_inputs":model_inputs,
                              "optimizer":optimizer.state_dict(),
                              "scheduler":scheduler.state_dict(),
                              "losses_with_weights":{"train":losses_on_train,"val":losses_on_val},
                              "metrics":{"train":metrics_on_train,"val":metrics_on_val}
                              }
        
        torch.save(state_dict_extended,path_model_and_dependencies)
        print(f"saving for epoch {epoch}")

        plt.plot(losses_on_train,"b*")
        plt.plot(losses_on_val,"g*")
        plt.title("losses")
        plt.show()        
        plt.figure()
        plt.plot(metrics_on_train,"b*")
        plt.plot(metrics_on_val,"g*")
        plt.title("bleu score")
        plt.show()
    elif epoch - best_epoch > early_stop_thresh:
        print("Early stopped training at epoch %d" % epoch)
        break  # terminate the training loop
    stop = time.time()
    print(stop-start)