In [1]:
%load_ext autoreload
%autoreload 2

## initialization

In [2]:
from beam_search import create_dataset
from data import encode_line, load_vocabulary
from model import Transformer
import torch
import torch.nn as nn
import argparse

args = argparse.Namespace()
args.src_vocab = '../corpus/vocab_en_fr.txt'
args.tgt_vocab = '../corpus/vocab_en_fr.txt'
args.ckpt = './averaged_checkpoint.pt'
args.device = 'cpu'

batch_size = 16
beam_size = 5
length_penalty = 1
max_length = 256

source_vocabulary, _ = load_vocabulary(args.src_vocab)
target_vocabulary, target_vocabulary_rev = load_vocabulary(args.tgt_vocab)

bos = target_vocabulary["<s>"]
eos = target_vocabulary["</s>"]
model = Transformer(
    len(source_vocabulary),
    len(target_vocabulary),
    share_embeddings=True,
)
checkpoint = torch.load(args.ckpt)
model.load_state_dict(checkpoint["model"])
model.to(args.device)
model.eval()
n = 0
for param in model.parameters():
    # print(param.shape)
    n += param.numel()
print(n)

source_path = '../corpus/rep_test.en.tok'
dataset = create_dataset(source_path, source_vocabulary, args.device)
ref_path = '../corpus/rep_test.fr'
with open(ref_path, 'r', encoding='utf-8') as file:
    ref = file.readlines()
ref[0]

209129472


'Cette fois-ci, la baisse est due à la chute des actions au Wall Street.\n'

## helper functions

In [3]:
def view(hypotheses):
    for hypo in hypotheses:
        tokens = hypo[1]
        if tokens and tokens[-1] == eos:
            tokens.pop(-1)
        tokens = [target_vocabulary_rev[token_id] for token_id in tokens]
        print(" ".join(tokens))

## inference analysis

In [4]:
from my_beam_search import beam_search

In [5]:
batch = next(iter(dataset))

In [6]:
from utils import *

In [None]:
# we just take small batch_size for test
# _ = beam_search(model, batch[None, 0], bos, eos)
with torch.no_grad():
    _ = beam_search(model, batch[0:2], bos, eos)

result is a batch of data, whose element is (score, sentence)

In [18]:
view(result[0])

Cette fois ￭, la chute des actions de Wall Street est responsable de la chute ￭.
Cette fois ￭, la chute des actions de Wall Street en est la cause ￭.
Cette fois ￭-￭ ci ￭, la chute des actions de Wall Street est responsable de la chute ￭.
Cette fois ￭-￭ ci ￭, la chute des actions de Wall Street en est la cause ￭.
Cette fois ￭, la chute des actions de Wall Street est responsable de la baisse ￭.
Cette fois ￭-￭ ci ￭, la chute des actions de Wall Street est responsable de la baisse ￭.
Cette fois ￭, la chute des actions de Wall Street est responsable de la chute ￭. cours ￭.


Maybe what we want is 

`Cette fois ￭, la chute des actions de Wall Street est responsable de la baisse ￭.`

However, it is not the second candidate