In [25]:
%load_ext autoreload
%autoreload 2
# TO DO: import transformer functions
from utilsTokenComparison import *
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


##### The datatset download links (German, English) in the original Harvard NLP code*  no longer work; replace download links.

*https://github.com/harvardnlp/annotated-transformer

In [26]:
from torchtext.datasets import multi30k
#
multi30k.URL["train"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/training.tar.gz"
multi30k.URL["valid"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/validation.tar.gz"
multi30k.URL["test"] = "https://raw.githubusercontent.com/neychev/small_DL_repo/master/datasets/Multi30k/mmt16_task1_test.tar.gz"
###
multi30k.MD5["train"] = "20140d013d05dd9a72dfde46478663ba05737ce983f478f960c1123c6671be5e"
multi30k.MD5["valid"] = "a7aa20e9ebd5ba5adce7909498b94410996040857154dab029851af3a866da8c"
multi30k.MD5["test"] = "6d1ca1dba99e2c5dd54cae1226ff11c2551e6ce63527ebb072a1f70f72a5cd36"

In [27]:
from utilsTransformer import *

### Load BERT tokenizer, BERT vocabulary and trained BERT model

In [28]:
from transformers import  AutoTokenizer
tknzrBert = AutoTokenizer.from_pretrained("bert-base-multilingual-cased")
vocabBertDe, vocabBertEng = torch.load("vocab_%s.pt" %"Bert")
print(f"Length of Bert German Vocabulary = {len(vocabBertDe)}")
print(f"Length of Bert English Vocabulary = {len(vocabBertEng)}")
modelBert = make_model(len(vocabBertDe), len(vocabBertEng), N=6)
model_path = "multi30k_model_%s_final.pt" %"bert"
modelBert.load_state_dict(torch.load(model_path, map_location= device))

Length of Bert German Vocabulary = 8804
Length of Bert English Vocabulary = 8076


<All keys matched successfully>

### Load GPT tokenizer, GPT vocabulary, and GPT model

In [29]:
from transformers import  AutoTokenizer
tknzrGpt = AutoTokenizer.from_pretrained("openai-community/openai-gpt")
vocabGptDe, vocabGptEng = torch.load("vocab_%s.pt" %"gpt")
print(f"Length of GPT German Vocabulary = {len(vocabGptDe)}")
print(f"Length of GPT English Vocabulary = {len(vocabGptEng)}")
modelGpt = make_model(len(vocabGptDe), len(vocabGptEng), N=6)
model_path = "multi30k_model_%s_final.pt" %"gpt"
modelGpt.load_state_dict(torch.load(model_path, map_location= device))

Length of GPT German Vocabulary = 4218
Length of GPT English Vocabulary = 9570


<All keys matched successfully>

In [30]:
def check_outputs(
    valid_dataloader,
    modelBert,
    vocabBertDe,
    vocabBertEng,#
    modelGpt,
    vocabGptDe,
    vocabGptEng,
    n_examples=15,
    pad_idx=2,
    eos_string="</s>"
):
    for idx in range(n_examples):
        print("\nExample %d ========\n" % idx)
        b = next(iter(valid_dataloader))
        ###
        # BERT MODEL
        ###
        rb = Batch(b[0], b[1], pad_idx) # change the collate function for valid_dataloader to return b[0],b[1],b[2],b[3]
                                        # b[2],b[3] are the source and target tok2Id for HF model
        src_tokens = [
            vocabBertDe.get_itos()[x] for x in rb.src[0] if x != pad_idx
        ]
        tgt_tokens = [
            vocabBertEng.get_itos()[x] for x in rb.tgt[0] if x != pad_idx
        ]
        print(
            "Source Bert Text (Input)        : "
            + " ".join(src_tokens).replace("\n", "")
        )
        print(
            "Target Bert Text (Ground Truth) : "
            + " ".join(tgt_tokens).replace("\n", "")
        )
        model_out = greedy_decode(modelBert, rb.src, rb.src_mask, 72, 0)[0]
        model_txt = (
            " ".join(
                [vocabBertEng.get_itos()[x] for x in model_out if x != pad_idx]
            ).split(eos_string, 1)[0]
            + eos_string
        )
        print("Model Bert Output: " + model_txt.replace("\n", ""))
        print("========\n")
        ###
        # GPT MODEL
        ###
        rb = Batch(b[2], b[3], pad_idx) 
        src_tokens = [
            vocabGptDe.get_itos()[x] for x in rb.src[0] if x != pad_idx
        ]
        tgt_tokens = [
            vocabGptEng.get_itos()[x] for x in rb.tgt[0] if x != pad_idx
        ]
        print(
            "Source GPT Text (Input)        : "
            + " ".join(src_tokens).replace("\n", "")
        )
        print(
            "Target GPT Text (Ground Truth) : "
            + " ".join(tgt_tokens).replace("\n", "")
        )
        model_out = greedy_decode(modelGpt, rb.src, rb.src_mask, 72, 0)[0]
        model_txt = (
            " ".join(
                [vocabGptEng.get_itos()[x] for x in model_out if x != pad_idx]
            ).split(eos_string, 1)[0]
            + eos_string
        )
        print("Model Gpt Output: " + model_txt.replace("\n", ""))
        print("========\n")

def run_model_example(n_examples=25):
    global vocab_src, vocab_tgt, spacy_de, spacy_en

    print("Preparing Data ...")
    valid_dataloader = create_dataloaders(
        device,
        vocabBertDe,
        vocabBertEng,
        vocabGptDe,
        vocabGptEng,
        batch_size=1,
        max_padding = 72,
        is_distributed=False,
        tknzrBert=tknzrBert,
        tknzrGpt=tknzrGpt
    )
    #
    print("Comparing Model Outputs:")
    check_outputs(
        valid_dataloader, modelBert, vocabBertDe, vocabBertEng, modelGpt, vocabGptDe, vocabGptEng, 
        n_examples=n_examples
    )

run_model_example()

Preparing Data ...
Comparing Model Outputs:


Source Bert Text (Input)        : <s> [CLS] Ein Soldat bli ##ckt durch ein Fe ##rn ##glas in die Berg ##lands ##chaft . [SEP] </s>
Target Bert Text (Ground Truth) : <s> [CLS] A soldier is looking at bin ##oc ##ular ##s into the mountain ##ous landscape . [SEP] </s>
Model Bert Output: <s> [CLS] A soldier looks through a bin ##oc ##ular ##s at the mountain ##ous area . [SEP] </s>

Source GPT Text (Input)        : <s> e in sol dat bli ck t dur ch e in fer n glas in die ber glan d scha ft . </s>
Target GPT Text (Ground Truth) : <s> a soldier is looking at binoculars into the mountainous landscape . </s>
Model Gpt Output: <s> a soldier is looking through a binoculars in the mountain . </s>



Source Bert Text (Input)        : <s> [CLS] Eine Frau mit gel ##bem Helm benutzt eine Sei ##l ##rut ##sche . [SEP] </s>
Target Bert Text (Ground Truth) : <s> [CLS] A woman wearing a yellow hel ##met is using a zi ##plin ##e . [SEP] </s>
Model Bert Output: <