In [1]:
import os
from os.path import exists
import torch
import torch.nn as nn
from torch.nn.functional import log_softmax, pad
import math
import copy
import time
from torch.optim.lr_scheduler import LambdaLR
import pandas as pd
import altair as alt
from torchtext.data.functional import to_map_style_dataset
from torch.utils.data import DataLoader
from torchtext.vocab import build_vocab_from_iterator
import torchtext.datasets as datasets
import spacy
import GPUtil
import warnings
from torch.utils.data.distributed import DistributedSampler
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

# Set to False to skip notebook execution (e.g. for debugging)
warnings.filterwarnings("ignore")

%run Model_structure.ipynb
%run Dataset.ipynb


Finished.
Vocabulary sizes:
8315
6384


In [2]:
###Bleu
from nltk.translate.bleu_score import sentence_bleu

class BleuScore:
    def __init__(self):
        self.total_bleu=0
        self.total_num=0
    
    @staticmethod
    def cal_single_bleu(reference, candidate):
        return sentence_bleu(reference, candidate)
    
    def add(self,reference, candidate):
        self.total_bleu+=sentence_bleu(reference, candidate)
        self.total_num+=1
    
    def get_score(self):
        return (self.total_bleu/self.total_num)*100
bleu = BleuScore()

In [3]:
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    memory = model.encode(src, src_mask)
    ys = torch.zeros(1, 1).fill_(start_symbol).type_as(src.data)
    for i in range(max_len - 1):
        out = model.decode(
            memory, src_mask, ys, subsequent_mask(ys.size(1)).type_as(src.data)
        )
        prob = model.generator(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.data[0]
        ys = torch.cat(
            [ys, torch.zeros(1, 1).type_as(src.data).fill_(next_word)], dim=1
        )
    return ys

In [4]:
bleu = BleuScore()


def check_outputs(
    valid_dataloader,
    model,
    vocab_src,
    vocab_tgt,
    n_examples=15,
    pad_idx=2,
    eos_string="</s>",
):
    results = [()] * n_examples
    for idx in range(n_examples):
        print("\nExample %d ========\n" % idx)
        b = next(iter(valid_dataloader))
        rb = Batch(b[0], b[1], pad_idx)
        greedy_decode(model, rb.src, rb.src_mask, 64, 0)[0]

        src_tokens = [
            vocab_src.get_itos()[x] for x in rb.src[0] if x != pad_idx
        ]
        tgt_tokens = [
            vocab_tgt.get_itos()[x] for x in rb.tgt[0] if x != pad_idx
        ]
        
        print(
            "Source Text (Input)        : "
            + " ".join(src_tokens).replace("\n", "")
        )
        print(
            "Target Text (Ground Truth) : "
            + " ".join(tgt_tokens).replace("\n", "")
        )
        model_out = greedy_decode(model, rb.src, rb.src_mask, 72, 0)[0]
        model_txt = (
            " ".join(
                [vocab_tgt.get_itos()[x] for x in model_out if x != pad_idx]
            ).split(eos_string, 1)[0]
            + eos_string
        )
        print("Model Output               : " + model_txt.replace("\n", ""))
        results[idx] = (rb, src_tokens, tgt_tokens, model_out, model_txt)
        
        print('SINGLE BLEU SCORE : '+ str(BleuScore.cal_single_bleu([tgt_tokens[1:-1]], model_txt.split(' ')[1:-1])))
        bleu.add([tgt_tokens[1:-1]], model_txt.split(' ')[1:-1])
    print(bleu.get_score())
    return results


def run_model_example(n_examples=1):
    global vocab_src, vocab_tgt, spacy_de, spacy_en

    print("Preparing Data ...")
    _, valid_dataloader = create_dataloaders(
        torch.device("cpu"),
        vocab_src,
        vocab_tgt,
        spacy_de,
        spacy_en,
        batch_size=1,
    )

    print("Loading Trained Model ...")

    model = make_model(len(vocab_src), len(vocab_tgt), N=6)
    model.load_state_dict(
        torch.load("./base/multi30k_best.pt", map_location=torch.device("cpu"))
    )

    print("Checking Model Outputs:")
    example_data = check_outputs(
        valid_dataloader, model, vocab_src, vocab_tgt, n_examples=n_examples
    )



run_model_example()

Preparing Data ...
Loading Trained Model ...
Checking Model Outputs:


Source Text (Input)        : <s> Eine Gruppe von Männern lädt Baumwolle auf einen Lastwagen </s>
Target Text (Ground Truth) : <s> A group of men are loading cotton onto a truck </s>
Model Output               : <s> A group of men are putting cotton into a truck . </s>
SINGLE BLEU SCORE : 0.4172261448611506
41.72261448611506
