In [1]:
from pathlib import Path
import torch
import torch.nn as nn
from config import get_config, latest_weights_file_path
from train import get_model, get_ds, causal_mask

In [2]:
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)
config = get_config()
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)

# Load the pretrained weights
model_filename = latest_weights_file_path(config)
state = torch.load(model_filename, map_location=torch.device(f'{device}'))
model.load_state_dict(state['model_state_dict'])

Using device: cpu
Max length of source sentence: 40
Max length of target sentence: 35


<All keys matched successfully>

In [3]:
def beam_search_decode(model, beam_size, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    encoder_output = model.encode(source, source_mask)
    decoder_initial_input = torch.tensor([[sos_idx]], dtype=torch.long, device=device)

    candidates = [(decoder_initial_input, 0)]  # (sequence, score)

    while True:
        if any([cand.size(1) == max_len for cand, _ in candidates]):
            break

        new_candidates = []

        for candidate, score in candidates:
            if candidate[0, -1].item() == eos_idx:
                new_candidates.append((candidate, score))
                continue

            candidate_mask = causal_mask(candidate.size(1)).type_as(source_mask).to(device)
            out = model.decode(encoder_output, source_mask, candidate, candidate_mask)
            prob = model.project(out[:, -1])
            topk_prob, topk_idx = torch.topk(prob, beam_size, dim=1)

            for i in range(beam_size):
                token = topk_idx[0][i].unsqueeze(0).unsqueeze(0)
                token_prob = topk_prob[0][i].item()
                new_candidate = torch.cat([candidate, token], dim=1)
                new_candidates.append((new_candidate, score + token_prob))

        candidates = sorted(new_candidates, key=lambda x: x[1], reverse=True)[:beam_size]

        if all([cand[0, -1].item() == eos_idx for cand, _ in candidates]):
            break

    return candidates[0][0].squeeze()

In [4]:
def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source, source_mask)
    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break

        # build mask for target
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)

        # calculate output
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        # get next token
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)

In [5]:
def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, num_examples=5):
    model.eval()
    count = 0
    console_width = 80

    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch["encoder_input"].to(device)
            encoder_mask = batch["encoder_mask"].to(device)

            if encoder_input.size(0) != 1:
                raise ValueError("Batch size must be 1 for validation.")

            greedy_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)
            beam_out = beam_search_decode(model, 3, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)

            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]
            greedy_out_text = tokenizer_tgt.decode(greedy_out.detach().cpu().numpy())
            beam_out_text = tokenizer_tgt.decode(beam_out.detach().cpu().numpy())

            print_msg('-' * console_width)
            print_msg(f"{f'SOURCE: ':>20}{source_text}")
            print_msg(f"{f'TARGET: ':>20}{target_text}")
            print_msg(f"{f'PREDICTED GREEDY: ':>20}{greedy_out_text}")
            print_msg(f"{f'PREDICTED BEAM: ':>20}{beam_out_text}")

            if count == num_examples:
                print_msg('-' * console_width)
                break


In [9]:
run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, 50, device, print_msg=print, num_examples=5)

--------------------------------------------------------------------------------
            SOURCE: Then he took from his knapsack a bottle of wine, and drank some.

            TARGET: ثم أخذ من حقيبة ظهره زجاجة من النبيذ واحتسى القليل.

  PREDICTED GREEDY: ثم أخذ من حقيبة ظهره زجاجة من النبيذ القليل .
    PREDICTED BEAM: ثم أخذ من حقيبة ظهره زجاجة من النبيذ القليل من الطعام . القليل من الطعام . الخمر . .
--------------------------------------------------------------------------------
            SOURCE: Members from the Wimbledon area have refused to consider such a plan.

            TARGET: وقد رفض أعضاء من منطقة ويمبلدون النظر في مثل هذه الخطة.

  PREDICTED GREEDY: وقد رفض أعضاء من منطقة ويمبلدون النظر في مثل هذه الخطة .
    PREDICTED BEAM: وقد رفض أعضاء من منطقة ويمبلدون النظر في مثل هذه الخطة .
--------------------------------------------------------------------------------
            SOURCE: He gave no reason, but his motive was obvious enough.

            TARGET: لم يذكر أي