In [1]:

from data.data_loader import load_data, collate_fn, data_process_pipeline
from functools import partial
from torch.utils.data import DataLoader
import torch
from trainer import Trainer
from model.transformer_encoder import Encoder
from model.transformer_decoder import Decoder
from model.transformer import Transformer
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence


In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"

# hyperparamers
num_epochs=2
learning_rate=0.001
batch_size = 512
encoder_embedding_size = 512
decoder_embedding_size = 512
hidden_size = 1024
encoder_n_layers = 2
decoder_n_layers = 2
encoder_dropout = 0.5
decoder_dropout = 0.5
teacher_forcing_ratio = 0.5
num_heads = 8
forward_expansion = 4
num_encoders = 6


In [3]:
raw_train_data, raw_val_data = load_data(data_path="fra.txt", train_percent=0.8)
train_data, eng_vocab, fra_vocab = data_process_pipeline(raw_train_data)
val_data, _, _ = data_process_pipeline(
    raw_val_data, eng_vocab=eng_vocab, fra_vocab=fra_vocab
)

# data loader
train_loader = DataLoader(
    train_data,
    batch_size=batch_size,
    collate_fn=partial(
        collate_fn, src_pad_val=eng_vocab["<pad>"], tgt_pad_val=fra_vocab["<pad>"]
    ),
)
val_loader = DataLoader(
    val_data,
    batch_size=batch_size,
    collate_fn=partial(
        collate_fn, src_pad_val=eng_vocab["<pad>"], tgt_pad_val=fra_vocab["<pad>"]
    ),
)

In [4]:
encoder = Encoder(
    embed_size=encoder_embedding_size,
    num_heads=num_heads,
    batch_size=batch_size,
    forward_expansion=forward_expansion,
    num_encoders=num_encoders,
    vocab_size=len(eng_vocab),
)

decoder = Decoder(
    embed_size=decoder_embedding_size,
    batch_size=batch_size,
    num_heads=num_heads,
    forward_expansion=forward_expansion,
    num_decoders=num_encoders,
    output_size=len(fra_vocab),
    teacher_force_ratio=teacher_forcing_ratio,
    vocab_size=len(fra_vocab),
)
transformer=Transformer(encoder=encoder,decoder=decoder,src_pad_idx=eng_vocab["<pad>"],tgt_pad_idx=fra_vocab["<pad>"],device=device)


In [5]:
criterion = nn.CrossEntropyLoss(ignore_index=fra_vocab['<pad>'])
optimizer=torch.optim.Adam(transformer.parameters(),lr=learning_rate)

In [6]:
def tokenizer(text):
    return [token for token in f"<sos> {text} <eos>".split(" ") if token]

In [7]:
def predict_pipeline(txt):
    if isinstance(txt,str):
        txt=[txt]
    sent_tokens=[tokenizer(tokens) for tokens in txt]
    int_tokens=[eng_vocab.forward(tokens) for tokens in sent_tokens]
    src_tensor=[torch.LongTensor(token_list) for token_list in int_tokens]
    src=pad_sequence(src_tensor,padding_value=eng_vocab['<pad>'])
    return src

In [8]:
trainer=Trainer(model=transformer,
            num_epochs=num_epochs,
            batch_size=batch_size,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            print_stats=True,
            tgt_vocab=fra_vocab,
            text_pipeline=predict_pipeline
        )

In [9]:
def transcribe(model,inputs,pipeline,max_tokens,start_token,tgt_vocab,end_token,device):
        with torch.no_grad():
            inputs=pipeline(inputs)
            inputs=inputs.to(device)
            
            batch_size=inputs.shape[1]
            x_mask=model.src_mask(inputs)
            encoder_states=model.encoder(inputs,x_mask)
            output_tokens=[]
            i=0
            for i in range(inputs.shape[1]):
                y=torch.LongTensor([start_token]).reshape(-1,1).to(device)
                
                current_output=[]
                k=0
                while True:
                    
                    predictions=model.decoder(y,encoder_states,None)
                    
                    predictions=predictions[-1,:,:].argmax(-1).unsqueeze(0)
                   
                    pred_tokens=tgt_vocab.lookup_token(predictions[-1].item())
                    
                    current_output.append(pred_tokens)
                    y=torch.cat((y,predictions),dim=0)
                    if end_token==predictions or len(current_output)>=max_tokens:
                        break
                    k+=1
                output_tokens.append(" ".join(current_output))
        return output_tokens

In [10]:
test_txt="In this story an old man sets out to ask an Indian king to dig some well in his village when their water runs dry"
expected_translation="Dans cette histoire, un vieil homme entreprend de demander à un roi indien de creuser un puits dans son village lorsque leur eau sera à sec."
predicted_translation=transcribe(
    model=transformer,
    inputs=test_txt,
    pipeline=predict_pipeline,
    max_tokens=50,
    start_token=fra_vocab['<sos>'],
    tgt_vocab=fra_vocab,
    end_token=fra_vocab['<eos>'],
    device=device
)
               

encoder embeds shape :  torch.Size([27, 1, 512])
['L\'alligator décidiez grasse touchait tourne d\'entraînement amende avides prioritaire éclairer l\'infirmier trousse l\'éventail Relâche-le n\'attendra parlent-ils d\'historien Juste connûmes provoque mîmes diapo embarquer douche renseignes papillon agressif insistant "Tais-toi" connaissais-tu l\'envoyer ouvrage m\'énerves apportez-moi bleu l\'Inde surannée Monte ressaisir potentiel limpide câlin trouvâmes séparés pointer négocie furent-elles amputé Résides-tu escroc']


In [11]:
l=torch.rand((27,1,50))

In [12]:
seq,n,em=l.shape

In [13]:
l.reshape(n,seq,em)

tensor([[[0.7594, 0.2821, 0.0959,  ..., 0.9308, 0.7382, 0.4747],
         [0.5889, 0.0044, 0.6634,  ..., 0.2076, 0.5865, 0.8355],
         [0.4105, 0.1199, 0.9600,  ..., 0.0649, 0.1990, 0.4409],
         ...,
         [0.7169, 0.5462, 0.5273,  ..., 0.3649, 0.0948, 0.8626],
         [0.3013, 0.6027, 0.5353,  ..., 0.5057, 0.2280, 0.9400],
         [0.1132, 0.8083, 0.5205,  ..., 0.2114, 0.0280, 0.5455]]])