In [2]:
import torch
import torch.nn as nn
from model import Transformer
from config import get_config, get_weights_path_path
from train import get_model, get_ds, greedy_decode
import altair as plt
import pandas as pd
import numpy as np
import warnings
warnings.filterwarnings("ignore")

In [4]:
# define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device {device}")

Using device cuda


In [None]:
config = get_config()
train_dataloader, val_dataloader, vocab_src, vocab_tgt = get_ds(config)
model = get_model(config, vocab_src.get_vocab_size(), vocab_tgt.get_vocab_size()).to(device)

# load the pretrained weights
model_filename = get_weights_path_path(config, f"29")
state = torch.load(model_filename)
model.load_state_dict(state["model_state_dict"])

In [None]:
def load_next_batch():
    batch = next(iter(val_dataloader))
    encoder_input = batch["encoder_input"].to(device)
    encoder_mask = batch["encoder_mask"].to(device)
    decoder_input = batch["decoder_input"].to(device)
    decoder_mask = batch["decoder_mask"].to(device)
    
    encoder_input_tokens = [vocab_src.id_to_token(idx) for idx in encoder_input[0].cpu().numpy()]
    decoder_input_tokens = [vocab_src.id_to_token(idx) for idx in decoder_input[0].cpu().numpy()]

    model_out = greedy_decode(model, encoder_input, encoder_mask, vocab_src, vocab_tgt, config["seq_len"], device)

    return batch, encoder_input, decoder_input_tokens