In [113]:
import codecs
from collections import defaultdict
import os

import torch
import transformers

In [102]:
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

mbart_model = "facebook/mbart-large-50-many-to-many-mmt"

model = MBartForConditionalGeneration.from_pretrained(mbart_model)
tokenizer = MBart50TokenizerFast.from_pretrained(mbart_model, src_lang="de_DE", tgt_lang="en_XX")

In [75]:
# Read in data
DATA_DIR = "./data/wmt19/ende"

In [76]:
references = defaultdict(list)

for ref_path in os.listdir(DATA_DIR):
    with codecs.open(os.path.join(DATA_DIR, ref_path), "rb", "utf-8") as ref_file:
        for i, line in enumerate(ref_file.readlines()):
            references[i].append(line.strip())

In [77]:
references

defaultdict(list,
            {0: ['Abgeordnete walisischen Ursprungs machen sich Sorgen, „wie Idioten auszusehen“',
              'Walisische Abgeordnete befürchten als ,Idioten’ dazustehen.',
              'Walisische Abgeordnete befürchten als ,Idioten’ dazustehen.',
              'Abgeordnete aus Wales hatten Angst, das Bild von „Dummköpfen“ abzugeben.',
              'Walisische Abgeordnete befürchten als ,Idioten’ dazustehen.',
              'Abgeordnete walisischen Ursprungs machen sich Sorgen, „wie Idioten auszusehen“'],
             1: ['Der Vorschlag, den Namen der Versammlung in MWP (Mitglied des walisischen Parlaments) zu ändern, rief bei einigen ihrer Angehörigen Fassungslosigkeit hervor.',
              'Bei einigen AMs herrscht Fassungslosigkeit über einen Vorschlag, ihr Titel solle in MWPs (Members of the Welsh Parliament) geändert werden.',
              'Bei einigen AMs herrscht Fassungslosigkeit über einen Vorschlag, ihr Titel solle in MWPs (Members of the Welsh Parl

In [196]:
def extract_latents(model, tokenizer, sentences):
    # Tokenize inputs and feed to model
    generation_config = transformers.GenerationConfig.from_pretrained(mbart_model)
    generation_config.decoder_hidden_states = True
    generation_config.return_dict_in_generate = True
    generation_config.num_beams = 1  # TODO: Debug
    print(generation_config)
    
    tokenized_sequences = tokenizer(sentences, return_tensors="pt", padding=True)
    in_shape = tokenized_sequences["input_ids"].shape[0]
    
    outputs = model.generate(
        input_ids=tokenized_sequences["input_ids"], 
        attention_mask=tokenized_sequences["attention_mask"],
        output_hidden_states=True,
        forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"],
        generation_config=generation_config,
        return_dict_in_generate=True,
    )
    print("Generated shape", outputs.sequences.shape)

    decoded = tokenizer.batch_decode(outputs.sequences)
    print(decoded)
    
    # outputs.decoder_hidden_states has the following structure:
    # One tuple entry per time step
    # Inside that: 13 entries for every decoder layer
    # Inside the uppermost decoder layer: Shape [30, 1, 1024]
    # 30: batch size (?), 1: sequence length, 1024: Embedding size
    
    
    print(outputs.keys())
    
    print(len(outputs.decoder_hidden_states)) # One tuple entry per time step
    print(outputs.decoder_hidden_states[0][0].shape)

In [197]:
extract_latents(model, tokenizer, references[0])

GenerationConfig {
  "_from_model_config": true,
  "bos_token_id": 0,
  "decoder_hidden_states": true,
  "decoder_start_token_id": 2,
  "early_stopping": true,
  "eos_token_id": 2,
  "forced_eos_token_id": 2,
  "max_length": 200,
  "pad_token_id": 1,
  "return_dict_in_generate": true,
  "transformers_version": "4.26.1"
}

Generated shape torch.Size([6, 26])
['</s>en_XX Welsh-born MPs worry about "looking like idiots"</s><pad><pad><pad><pad><pad><pad><pad><pad>', '</s>en_XX Welsh MPs fear that they are ‘idiots’.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>', '</s>en_XX Welsh MPs fear that they are ‘idiots’.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>', '</s>en_XX Members of the House of Representatives from Wales were afraid to show the image of "dummy heads".</s>', '</s>en_XX Welsh MPs fear that they are ‘idiots’.</s><pad><pad><pad><pad><pad><pad><pad><pad><pad><pad>', '</s>en_XX Welsh-born MPs worry about "looking like idiots"</s><pad><pad><pad><pad><pad><pad><pad><