In [77]:
import itertools
import os
import csv
import json

import torch
import numpy as np
from transformers import AutoTokenizer, MBart50Tokenizer, MBartForConditionalGeneration, Text2TextGenerationPipeline
from tqdm.notebook import tqdm, trange

In [78]:
pseudowords = np.load("../../out/pseudowords_comapp.npy")
pseudowords

array([[-3.66079696e-02, -1.11742102e-05,  5.81527799e-02, ...,
         6.40575811e-02,  2.65836474e-02, -1.56662956e-01],
       [-5.31983189e-02, -1.80779640e-02, -6.33598045e-02, ...,
         2.10250225e-02, -6.49476610e-03,  5.84359504e-02],
       [-9.96073708e-02,  3.98329496e-02, -1.11172702e-02, ...,
         1.56649694e-01,  2.40581349e-01, -9.96441580e-03],
       ...,
       [-1.07323132e-01,  9.40747932e-02,  8.12232345e-02, ...,
        -1.49555236e-01,  8.22361112e-02, -5.65737225e-02],
       [-4.59361728e-03,  1.85684375e-02,  2.77958419e-02, ...,
         2.84781586e-02,  2.85724066e-02, -5.74652366e-02],
       [-1.82020497e-02, -2.03991937e-03, -3.41283977e-02, ...,
        -2.80237161e-02, -2.23183036e-02,  5.36667835e-03]])

In [79]:
with open("../../data/pseudowords/CoMaPP_all.json") as json_file:
    data = json.load(json_file)
data

[{'label': 'geschweige10',
  'target1': 'Und dann ist da noch das generelle Problem mit Hamas , dass nicht jeder Sprecher und Führer , der redet , auch etwas zu sagen , geschweige denn das letzte Wort hat .',
  'target1_idx': 26,
  'query': 'Und dann ist da noch das generelle Problem mit Hamas , <mask> nicht jeder Sprecher und Führer , der redet , auch etwas zu sagen , geschweige denn das letzte Wort hat .',
  'query_idx': 26},
 {'label': 'denn10',
  'target1': 'Und dann ist da noch das generelle Problem mit Hamas , dass nicht jeder Sprecher und Führer , der redet , auch etwas zu sagen , geschweige denn das letzte Wort hat .',
  'target1_idx': 27,
  'query': 'Und dann ist da noch das generelle Problem mit Hamas , <mask> nicht jeder Sprecher und Führer , der redet , auch etwas zu sagen , geschweige denn das letzte Wort hat .',
  'query_idx': 27},
 {'label': 'geschweige10',
  'target1': 'Und dann ist da noch das generelle Problem mit Hamas , dass nicht jeder Sprecher und Führer , der red

In [80]:
data.sort(key=lambda x: x["label"])  # Grouping doesn't work without sorting first!
data = [list(group) for _, group in itertools.groupby(data, key=lambda x: x["label"])]
mbart_tokens = [d[0]['label'] for d in data]  # TODO Check whether the labels are in correct order!

mbart_tokens, len(mbart_tokens)

(['""Was13',
  '"647',
  '"Wir-äh-spielen-äh-in-der-äh-Champions-League647',
  '(1597',
  '(1600',
  '(1602',
  '(1637',
  '(1641',
  '(1643',
  '(379',
  '(579',
  '(581',
  '(584',
  '(590',
  '(592',
  '(600',
  '(886',
  '(892',
  '(900',
  '(905',
  '(907',
  '(909',
  '(917',
  '(919',
  '(921',
  ')1597',
  ')1600',
  ')1602',
  ')1637',
  ')1641',
  ')1643',
  ')1792',
  ')379',
  ')579',
  ')581',
  ')584',
  ')590',
  ')592',
  ')600',
  ')886',
  ')892',
  ')900',
  ')905',
  ')907',
  ')909',
  ')917',
  ')919',
  ')921',
  ')«579',
  ',1459',
  ',973',
  '-128',
  '-651',
  '-654',
  '-875',
  '-973',
  ':595',
  ':875',
  ':973',
  'Abstand683',
  'Allein20',
  'Aller1630',
  'Als1315',
  'Als133',
  'Als1770',
  'Am488',
  'Am492',
  'Am500',
  'Amerika605',
  'Anstatt320',
  'Art129',
  'Arzt1509',
  'Augenblick1301',
  'Ausmaß1777',
  'BRUTAL1503',
  'Besser1762',
  'Bis559',
  'Brutal1503',
  'Buche1346',
  'Das1313',
  'Das1461',
  'Dass21',
  'Dasselbe104',
  'Der13

Load the vanilla mBART-50 model:

In [81]:
model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50", return_dict=True)
tokenizer = MBart50Tokenizer.from_pretrained("facebook/mbart-large-50", src_lang="de_DE", tgt_lang="de_DE")
model.model.encoder.embed_tokens

Embedding(250054, 1024, padding_idx=1)

Add to existing embeddings:

In [82]:
combined_embeddings = torch.cat((model.model.encoder.embed_tokens.weight, torch.tensor(pseudowords)), dim=0)
model.model.encoder.embed_tokens = torch.nn.Embedding.from_pretrained(combined_embeddings)
model.model.encoder.embed_tokens

Embedding(250327, 1024)

Add to existing tokens:

In [83]:
tokenizer.add_tokens(mbart_tokens[:len(pseudowords)])
model.resize_token_embeddings(len(tokenizer))

You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding dimension will be 250327. This might induce some performance reduction as *Tensor Cores* will not be available. For more details about this, or help on choosing the correct value for resizing, refer to this guide: https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc


Embedding(250327, 1024)

In [118]:
test_text = "de_DE <s> Auf ein Internat wollte sie unter keinen Umständen! Geschweige10 denn <mask>!</s>"

tokenized_text = tokenizer.tokenize(test_text)
masked_index = tokenized_text.index("<mask>")
tokenized_text

['de_DE',
 '<s>',
 '▁Auf',
 '▁ein',
 '▁Interna',
 't',
 '▁wollte',
 '▁sie',
 '▁unter',
 '▁keinen',
 '▁Um',
 'stände',
 'n',
 '!',
 'Geschweige10',
 '▁denn',
 '<mask>',
 '▁!',
 '</s>']

Convert the tokens to indices:

In [119]:
input_ids = tokenizer.convert_tokens_to_ids(tokenized_text)
input_ids = torch.tensor([input_ids])
input_ids

tensor([[250003,      0,   7601,    599, 106745,     18,  34485,   1329,   3993,
          38951,   2793,  61211,     19,     38, 250149,  13808, 250053,    711,
              2]])

Predict the token:

In [120]:
with torch.no_grad():
    outputs = model(input_ids)
    predictions = outputs.logits

predicted_token_id = torch.argmax(predictions[0, masked_index]).item()
predicted_token = tokenizer.convert_ids_to_tokens([predicted_token_id])[0]

predicted_token

'▁sie'

Predict the top 100 tokens:

In [121]:
top_k = 100
predicted_token_ids = torch.topk(predictions[0, masked_index], top_k).indices
predicted_token_probs = torch.topk(predictions[0, masked_index], top_k).values

# Convert the predicted token IDs back to tokens
predicted_tokens = tokenizer.convert_ids_to_tokens(predicted_token_ids)

# Print the top 5 predictions and their probabilities
for token, prob in zip(predicted_tokens, predicted_token_probs):
    print(token, prob)

▁sie tensor(52.3367)
▁das tensor(51.5574)
▁es tensor(51.0821)
▁auch tensor(50.8859)
▁die tensor(50.7422)
▁nicht tensor(50.5111)
▁ich tensor(50.4021)
▁auf tensor(50.2961)
▁was tensor(50.1377)
▁nur tensor(50.0148)
▁ihr tensor(49.8954)
▁ein tensor(49.7167)
▁in tensor(49.5490)
▁Sie tensor(49.5281)
▁er tensor(49.5217)
▁so tensor(49.4793)
▁wir tensor(49.4491)
▁ihre tensor(49.2473)
▁schon tensor(49.1941)
▁eine tensor(49.1412)
▁mal tensor(49.0997)
▁ tensor(49.0247)
▁keine tensor(48.9854)
▁der tensor(48.9668)
▁nichts tensor(48.8838)
▁zu tensor(48.8826)
▁wie tensor(48.8659)
▁kein tensor(48.8560)
▁man tensor(48.8392)
▁alles tensor(48.8315)
▁du tensor(48.8011)
▁für tensor(48.7561)
▁hier tensor(48.7476)
▁da tensor(48.7195)
▁sonst tensor(48.6750)
▁doch tensor(48.6585)
▁nun tensor(48.6465)
▁nie tensor(48.6216)
▁mit tensor(48.6142)
▁jetzt tensor(48.6032)
▁immer tensor(48.6012)
▁diese tensor(48.5855)
▁... tensor(48.5716)
▁wohl tensor(48.5687)
▁gar tensor(48.5074)
▁ja tensor(48.4864)
▁den tensor(48.4008

Predict the most probable word that is not part of the new embeddings:

In [122]:
predicted_token_ids = torch.argmax(predictions[0, masked_index])
vocab_size = len(tokenizer)
# Find the highest predicted token with an ID lower than 28997
for i in range(vocab_size):
    if predicted_token_ids < 28996:  # if no [PAD]: <= 28996
        break
    predicted_token_ids = torch.argsort(predictions[0, masked_index], descending=True)[i]

# Convert the predicted token ID back to a token
predicted_token = tokenizer.convert_ids_to_tokens([predicted_token_ids])[0]

print(predicted_token)

▁sie


Predict the top 5 words that are not part of the new embeddings:

In [123]:
# Get the predicted token IDs and their probabilities
predicted_token_probs = predictions[0, masked_index]
vocab_size = len(tokenizer)
# Create a list to store the top 5 predictions and their probabilities
top_5_predictions = []

# Find the top 5 predicted tokens with IDs lower than 28997
for i in range(vocab_size):
    if len(top_5_predictions) >= 5 or i >= vocab_size:
        break
    token_id = torch.argsort(predicted_token_probs, descending=True)[i].item()
    if token_id < 28996:  # if no [PAD]: <= 28996
        predicted_token = tokenizer.convert_ids_to_tokens([token_id])[0]
        top_5_predictions.append((predicted_token, predicted_token_probs[token_id].item()))

# Print the top 5 predictions and their probabilities
for token, prob in top_5_predictions:
    print(token, prob)

▁sie 52.33668899536133
▁das 51.55738830566406
▁es 51.08205795288086
▁auch 50.88592529296875
▁die 50.74215316772461


In [124]:
# TODO Generate instead of predicting!

In [127]:
test_text = "de_DE <s> Geschweige10 denn <mask>.</s>"

outputs = tokenizer(test_text, return_tensors="pt")
outputs = model.generate(outputs["input_ids"], max_length=100, num_return_sequences=5,
                         num_beams=20, output_scores=True, return_dict_in_generate=True)
output_strings = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)
output_probs = torch.exp(outputs.sequences_scores)

print([f'output: {output}, score: {score}'
       for output, score in zip(output_strings, output_probs)])

['output: Was ist denn das?.., score: 0.4156593382358551', 'output: Was ist denn das..., score: 0.39563310146331787', 'output: Was ist denn nun..., score: 0.3922627568244934', 'output: Was ist denn eigentlich.., score: 0.38697561621665955', 'output: ist es denn nicht..., score: 0.3854128420352936']
