In [1]:
import numpy as np,torch
from torch.utils.data import DataLoader
from pathlib import Path
from translation_machine import collate_fn_mod
from translation_machine import dataset_mod,sentence_mod


In [2]:
path_dataset = "../data/french_english_dataset/fra.txt"
path_model_and_dependencies = "../models/sequence_translator_transformer_over_fitted_adamw.pth"
batch_size = 10
limit_length = 10


assert Path(path_dataset).exists()
assert Path(path_model_and_dependencies).exists()

In [3]:
language_info = torch.load("../models/language_info.pth")

vocab_french = language_info["french"]["vocab"]
vocab_english = language_info["english"]["vocab"]

max_length_french = language_info["french"]["max_sentence_train_val"]
max_length_english = language_info["english"]["max_sentence_train_val"]

whole_dataset = dataset_mod.DatasetFromTxt("../data/french_english_dataset/fra.txt")
dataset = torch.utils.data.Subset(whole_dataset,np.arange(limit_length))


In [4]:
back_up = torch.load(path_model_and_dependencies)


dataset = dataset_mod.SentenceDataSet(dataset,sentence_mod.EnglishSentence,sentence_mod.FrenchSentence)

collate_fn=collate_fn_mod.get_collate_fn(max_length_english,max_length_french)

dataloader = DataLoader(dataset,batch_size=batch_size,shuffle=True,collate_fn=collate_fn)

back_up.keys()

dict_keys(['model_params', 'model_inputs', 'optimizer', 'scheduler', 'losses', 'metrics'])

In [5]:
from translation_machine.models import transformer_mod

model_inputs = back_up["model_inputs"]

model = transformer_mod.TransformerForSeq2Seq(**model_inputs)

model.load_state_dict(back_up["model_params"])
model = model.eval()
model = model.to("cuda")
model.training

False

In [6]:
from translation_machine import translator_mod

translator = translator_mod.Translator(model)

In [7]:
translator.model = translator.model.eval()

In [8]:
for english_sentence,french_sentence in  dataset:
    translation = translator(english_sentence,limit_sentence=10)
    print(f"input sentence :  {english_sentence}")
    print(f"target sentence {french_sentence}")
    print(f"predicted sentence {translation}")
    print("\n")

  return torch._native_multi_head_attention(


input sentence :  Go .
target sentence <sos> Va ! <eos>
predicted sentence <sos> Qui ? <eos>


input sentence :  Hi .
target sentence <sos> Salut ! <eos>
predicted sentence <sos> Salut ! <eos>


input sentence :  Hi .
target sentence <sos> Salut . <eos>
predicted sentence <sos> Salut ! <eos>


input sentence :  Run !
target sentence <sos> Cours ! <eos>
predicted sentence <sos> ! <eos>


input sentence :  Run !
target sentence <sos> Courez ! <eos>
predicted sentence <sos> ! <eos>


input sentence :  Who ?
target sentence <sos> Qui ? <eos>
predicted sentence <sos> Qui ? <eos>


input sentence :  Wow !
target sentence <sos> Ça alors ! <eos>
predicted sentence <sos> ! <eos>


input sentence :  Fire !
target sentence <sos> Au feu ! <eos>
predicted sentence <sos> ! <eos>


input sentence :  Help !
target sentence <sos> À l' aide ! <eos>
predicted sentence <sos> Va ! <eos>


input sentence :  Jump .
target sentence <sos> Saute . <eos>
predicted sentence <sos> Salut ! <eos>


