In [None]:
import itertools
import pickle
import random
import re

import numpy as np
import pandas as pd
import torch
from transformers import BertForNextSentencePrediction, BertTokenizer

from tqdm.notebook import tqdm
tqdm.pandas()

In [None]:
pseudowords = []
for i in range(15):
    pseudowords.append(np.load(f"../../data/pseudowords/bert/pseudowords_comapp_bert_{i*37}_{i*37+37}.npy"))
pseudowords = np.concatenate(pseudowords)

csv_data = []
for i in range(1, 16):
    csv_data.append(pd.read_csv(f"../../data/pseudowords/bert/order_bert_{i}.csv", sep=";", index_col=0, header=None, quotechar="|", names=["order", "label"]))
csv_data = pd.concat(csv_data)

bert_tokens = [d[0] for d in csv_data.values]
bert_tokens

In [None]:
model = BertForNextSentencePrediction.from_pretrained("bert-base-german-cased", return_dict=True)
tokenizer = BertTokenizer.from_pretrained('bert-base-german-cased')

combined_embeddings = torch.cat((model.bert.embeddings.word_embeddings.weight, torch.tensor(pseudowords)), dim=0)
model.bert.embeddings.word_embeddings = torch.nn.Embedding.from_pretrained(combined_embeddings)
tokenizer.add_tokens(bert_tokens)
model.resize_token_embeddings(len(tokenizer))

In [None]:
with open("../../out/definitions.pickle", "rb") as file:
    definitions = pickle.load(file)
with open("../../out/sentences.pickle", "rb") as file:
    sentences = pickle.load(file)

In [None]:
def find_examples(definition, examples):
    predictions = {}
    for num, example in enumerate(examples):
        len_prompt = len(definition) + len(" Zum Beispiel: ") + len(example)
        if len_prompt > 512:
            prompt = definition[:512-len_prompt+len(definition)-1] + "… Zum Beispiel: "
            if len(prompt) > 512:
                prompt = prompt[:511] + "…"
        else:
            prompt = definition + " Zum Beispiel: "  # TODO Deutsch
        
        inputs = tokenizer(prompt, example, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)
        logits = outputs.logits
        predictions[num] = logits[0, 0]  # probability that the next sentence makes sense
    res = max(predictions, key=predictions.get)
    return examples[res]

In [None]:
random.seed(15)
attempts = 10
for false_positives in range(2, 8):
    result = []
    for key, definition in tqdm(definitions.items()):
        for attempt in range(attempts):
            try:
                sentence = random.choice(list(sentences[int(key)]))
            except KeyError:
                #print(None, None, None)
                result.append(pd.Series({"constr": key, "definition": definition, "example": None, "prediction": None, "correct": None}))
                continue
            others = itertools.chain.from_iterable([sentence_list for constr, sentence_list in sentences.items() if int(constr) != int(key)])
            others = random.choices(list(others), k=false_positives)
            examples = list(set(others) | {sentence})
            prediction = find_examples(definition, examples)
            #print(prediction == sentence, sentence, prediction)
            result.append(pd.Series({"constr": key, "definition": definition, "example": sentence, "prediction": prediction, "correct": prediction == sentence}))
    result = pd.DataFrame(result)
    result.to_csv(f"../../out/comapp/result_1_vs_{false_positives}_{attempts}attempts_bert.tsv", sep="\t")

In [None]:
kelex = csv_data.copy()
kelex['constr'] = csv_data['label'].str.extract('(\d+)').astype(int)
#kelex.set_index('constr', inplace=True)
kelex = kelex.groupby('constr')['label'].apply(set).to_dict()
kelex

In [32]:
random.seed(15)
attempts = 10
for false_positives in range(2, 8):
    result = []
    for key, definition in tqdm(definitions.items()):
        for attempt in range(attempts):
            try:
                sentence = random.choice(list(sentences[int(key)]))
            except KeyError:
                # print(None, None, None)
                result.append(pd.Series({"constr": key, "definition": definition, "example": None, "prediction": None, "correct": None}))
                continue
            sentence_kelex = []
            if kelex.get(key):
                for token in sentence.split():
                    new_token = token
                    # assert kelex.get(key) is not None
                    for pseudoword in kelex[key]:
                        if token == re.findall(r'\D+', pseudoword)[0]:
                            new_token = pseudoword
                            break
                    sentence_kelex.append(new_token)
                sentence_kelex = " ".join(sentence_kelex)
            else:
                continue  # skip constructions without kelex
                # sentence_kelex = sentence
            others = itertools.chain.from_iterable([sentence_list for constr, sentence_list in sentences.items() if int(constr) != int(key)])
            others = random.choices(list(others), k=false_positives)
            examples = list(set(others) | {sentence_kelex})
            prediction = find_examples(definition, examples)
            # print(prediction == sentence_kelex, sentence, prediction)
            result.append(pd.Series({"constr": key, "definition": definition, "example": sentence, "example_kelex": sentence_kelex, "prediction": prediction, "correct": prediction == sentence_kelex}))
    result = pd.DataFrame(result)
    result.to_csv(f"../../out/comapp/result_1_vs_{false_positives}_kelex_{attempts}attempts_bert.tsv", sep="\t")

  0%|          | 0/211 [00:00<?, ?it/s]

  0%|          | 0/211 [00:00<?, ?it/s]

  0%|          | 0/211 [00:00<?, ?it/s]

  0%|          | 0/211 [00:00<?, ?it/s]

  0%|          | 0/211 [00:00<?, ?it/s]

  0%|          | 0/211 [00:00<?, ?it/s]

In [40]:
random.seed(15)
attempts = 10
for false_positives in range(2, 8):
    result = []
    for key, definition in tqdm(definitions.items()):
        for attempt in range(attempts):
            try:
                sentence = random.choice(list(sentences[int(key)]))
            except KeyError:
                # print(None, None, None)
                result.append(pd.Series({"constr": key, "definition": definition, "example": None, "prediction": None, "correct": None}))
                continue
            sentence_kelex = []
            if kelex.get(key):
                for token in sentence.split():
                    new_token = token
                    # assert kelex.get(key) is not None
                    for pseudoword in kelex[key]:
                        if token == re.findall(r'\D+', pseudoword)[0]:
                            new_token = pseudoword
                            break
                    sentence_kelex.append(new_token)
                sentence_kelex = " ".join(sentence_kelex)
            else:
                # continue  # skip constructions without kelex
                sentence_kelex = sentence
            others = itertools.chain.from_iterable([sentence_list for constr, sentence_list in sentences.items() if int(constr) != int(key)])
            others = random.choices(list(others), k=false_positives)
            examples = list(set(others) | {sentence_kelex})
            prediction = find_examples(definition, examples)
            # print(prediction == sentence_kelex, sentence, prediction)
            result.append(pd.Series({"constr": key, "definition": definition, "example": sentence, "example_kelex": sentence_kelex, "prediction": prediction, "correct": prediction == sentence_kelex}))
    result = pd.DataFrame(result)
    result.to_csv(f"../../out/comapp/result_1_vs_{false_positives}_kelex_all_{attempts}attempts_bert.tsv", sep="\t")

  0%|          | 0/211 [00:00<?, ?it/s]

  0%|          | 0/211 [00:00<?, ?it/s]

  0%|          | 0/211 [00:00<?, ?it/s]

  0%|          | 0/211 [00:00<?, ?it/s]

  0%|          | 0/211 [00:00<?, ?it/s]

  0%|          | 0/211 [00:00<?, ?it/s]