In [None]:
import itertools
import json
import pickle
import re

import numpy as np
import pandas as pd
import torch
from transformers import BertForMaskedLM, BertTokenizer

from tqdm.notebook import tqdm
tqdm.pandas()

In [None]:
with open("../../data/pseudowords/CoMaPP_all_bert.json") as json_file:
    data = json.load(json_file)
    
data = [{"example": d["target1"], "query": (" ".join(d["query"].split()[:d["query_idx"]]) + " " + d["label"] + " " + " ".join(d["query"].split()[d["query_idx"]+1:])).strip(), "pseudoword": d["label"]} for d in data]
df = pd.DataFrame.from_dict(data).drop_duplicates(ignore_index=True)
df

In [None]:
df['index'] = df['pseudoword'].str.extract('(\d+)').astype(int)
df.set_index('index', inplace=True)

df

In [None]:
df.reset_index(inplace=True)
df.rename(columns={'index': 'construction'}, inplace=True)

result_df = df.groupby(['construction', 'pseudoword']).agg({'example': list, 'query': list})

result_df

In [None]:
with open("../../out/definitions.pickle", "rb") as definitions_file:
    definitions = pd.DataFrame.from_dict(pickle.load(definitions_file), orient="index", columns=["definition"])
    
definitions

In [None]:
examples = pd.merge(result_df, definitions, how="inner", left_on="construction", right_index=True)
examples

### Generation of new sentences:

In [None]:
pseudowords = []
for i in range(15):
    pseudowords.append(np.load(f"../../data/pseudowords/bsbbert/pseudowords_comapp_bsbbert_{i*37}_{i*37+37}.npy"))
pseudowords = np.concatenate(pseudowords)
pseudowords

In [None]:
csv_data = []
for i in range(1, 16):
    csv_data.append(pd.read_csv(f"../../data/pseudowords/bsbbert/order_bsbbert_{i}.csv", sep=";", index_col=0, header=None, quotechar="|", names=["order", "label"]))
csv_data = pd.concat(csv_data)
csv_data

In [None]:
bert_tokens = [d[0] for d in csv_data.values]

bert_tokens, len(bert_tokens)

Load the vanilla bert-german model:

In [None]:
model = BertForMaskedLM.from_pretrained('dbmdz/bert-base-german-cased', return_dict=True)
tokenizer = BertTokenizer.from_pretrained('dbmdz/bert-base-german-cased')
model.bert.embeddings.word_embeddings

Add to existing embeddings:

In [None]:
combined_embeddings = torch.cat((model.bert.embeddings.word_embeddings.weight, torch.tensor(pseudowords)), dim=0)
model.bert.embeddings.word_embeddings = torch.nn.Embedding.from_pretrained(combined_embeddings)
model.bert.embeddings.word_embeddings

Add to existing tokens:

In [None]:
tokenizer.add_tokens(bert_tokens)
model.resize_token_embeddings(len(tokenizer))

In [None]:
model.to("cuda:0")

Complete the masks:

In [None]:
def complete_masks(row):
    try:
        output_texts = []
        scores = []
        for query, example in list(zip(row["query"], row["example"])):
            tokenized_query = ["[CLS]"] + tokenizer.tokenize(query) + ["[SEP]"]  # adding start and end of sequence
            masked_index = tokenized_query.index("[MASK]")
            input_ids = tokenizer.convert_tokens_to_ids(tokenized_query)
            input_ids = torch.tensor([input_ids], device="cuda:0")
            
            # Predict the most probable word that is not part of the new embeddings:
            with torch.no_grad():
                outputs = model(input_ids)
                predictions = outputs.logits
            predicted_token_probs = predictions[0, masked_index]
            vocab_size = len(tokenizer)
            wanted_vocab_size = vocab_size - len(tokenizer.get_added_vocab())  # 27000 - 30000: unused tokens; 30000+: new tokens
            
            # Find the top 5 predicted tokens with IDs lower than 28997
            found = 0
            for i in range(vocab_size):
                if found:#  >= 5:
                    break
                token_id = torch.argsort(predicted_token_probs, descending=True)[i].item()
                if token_id < wanted_vocab_size:
                    predicted_token = tokenizer.convert_ids_to_tokens([token_id])[0]
                    if "unused_" in predicted_token:  # unused_token, unused_punctuation
                        continue
                    found += 1
                    output_text = tokenized_query[:masked_index] + [predicted_token] + tokenized_query[masked_index+1:]
                    score = predicted_token_probs[token_id].item()
                    #print(row["pseudoword"], found, " ".join(output_text), score)
                    output_texts.append(output_text)
                    scores.append(score)
        
        return pd.Series({'construction': row['construction'], 'pseudoword': row['pseudoword'], 'example': row['example'], 'generated': output_texts, 'score': [float(score) for score in scores], 'definition': row['definition']})
    except Exception as e:
        print(".", end="")
        return pd.Series({'construction': row['construction'], 'pseudoword': row['pseudoword'], 'example': row['example'], 'generated': [str(e)], 'score': [-1.0], 'definition': row['definition']})

examples_reset = examples.reset_index()
pseudoword_output_scores = examples_reset.progress_apply(complete_masks, axis=1)
pseudoword_output_scores

In [None]:
examples = pseudoword_output_scores[["pseudoword", "generated", "score"]]

examples

In [None]:
examples.to_csv(f"../../out/comapp/data_bsbbert.tsv", sep="\t", decimal=",")
examples.to_excel(f"../../out/comapp/data_bsbbert.xlsx")

In [None]:
pseudoword_output_scores.to_csv("../../out/comapp/data_bsbbert_complete.tsv", sep="\t", decimal=",")