In [None]:
import json
import pickle 
import pandas as pd
import torch
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast

from tqdm.notebook import tqdm
tqdm.pandas()

In [None]:
with open("../../data/pseudowords/CoMaPP_all.json") as json_file:
    data = json.load(json_file)
    
# d+1 since we are not replacing the ke-lex this time!
data = [{"example": d["target1"], "cue": " ".join(d["target1"].split()[:d["query_idx"]+1]), "pseudoword": d["label"]} for d in data if d["target1"].split()[d["query_idx"]] in d["label"]]
df = pd.DataFrame.from_dict(data).drop_duplicates(ignore_index=True)
df

In [None]:
df['index'] = df['pseudoword'].str.extract('(\d+)').astype(int)
df.set_index('index', inplace=True)

df

In [None]:
contextleft = pd.read_pickle("../../data/pseudowords/contextleft_text.pickle")

def update_cue(row):
    output = row[['example', 'cue']]
    if len(row['cue'].split()) == 1:  # if the string in cue is empty (except for the kelex)
        # match the index of row with contextleft['construction_id'] and match contextleft['text'] with row['example'] and create matching_entry
        matching_entry = contextleft.loc[(contextleft['construction_id'] == row.name) & (contextleft['text'] == row['example']), 'contextleft'].tolist()
        if len(matching_entry) > 0:
            # Add the left context to the example and to the cue:
            output = [matching_entry[0] + " " + row['example'], matching_entry[0] + " " + row['cue']]
    return output

# Add the left context if there is no cue up until the pseudoword.
df[["example", "cue"]] = df.apply(update_cue, axis=1)
df

In [None]:
df.reset_index(inplace=True)
df.rename(columns={'index': 'construction'}, inplace=True)

result_df = df.groupby(['construction', 'pseudoword']).agg({'example': list, 'cue': list})

result_df

In [None]:
with open("../../out/definitions.pickle", "rb") as definitions_file:
    definitions = pd.DataFrame.from_dict(pickle.load(definitions_file), orient="index", columns=["definition"])
    
definitions

In [None]:
examples = pd.merge(result_df, definitions, how="inner", left_on="construction", right_index=True)
examples

### Generation of new sentences:

Load the vanilla mbart model:

In [None]:
model = MBartForConditionalGeneration.from_pretrained(
    "facebook/mbart-large-50", return_dict=True
) 
tokenizer = MBart50TokenizerFast.from_pretrained(
    "facebook/mbart-large-50", src_lang="de_DE", tgt_lang="de_DE"
)
model.to("cuda:1")
model.model.encoder.embed_tokens

Complete the cues:

In [None]:
def complete_cues(row):
    try:
        output_texts = []
        scores = []
        for cue, example in zip(row["cue"], row["example"]):
            input_text = "</s> " + cue + " <mask> " + " ".join(example.split()[-2:]) + " </s> de_DE"
            
            target_length = int(1.5 * len(example))  # allow double the length of the original sentence
            
            input_ids = tokenizer([input_text], add_special_tokens=False, return_tensors="pt")["input_ids"].to("cuda:1")
            outputs = model.generate(input_ids, max_length=target_length, num_return_sequences=1, num_beams=20, output_scores=True, return_dict_in_generate=True)
            output_text = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True, clean_up_tokenization_spaces=True)
            score = torch.exp(outputs.sequences_scores)
            output_texts.append(" ".join(output_text[0].split()))  # remove unneccesary spaces
            
            scores.append(score)
        #print(row["pseudoword"], str(output_texts), str([float(score) for score in scores]))
        print(".", end="")
        return pd.Series({"construction": row["construction"], "pseudoword": row["pseudoword"].iloc[0], "orig_example": row["example"], "generated": str(output_texts), "scores": str([float(score) for score in scores])})
    except Exception as e:
        print(":", end="")
        return pd.Series({"construction": row["construction"], "pseudoword": row["pseudoword"].iloc[0], "orig_example": row["example"], "generated": str(e), "scores": "[-1.0]"})

examples_reset = examples.reset_index()
pseudoword_output_scores = examples_reset[["construction", "pseudoword", "example", "cue", "pseudoword"]].progress_apply(complete_cues, axis=1)
pseudoword_output_scores

In [None]:
examples = pseudoword_output_scores[["pseudoword", "generated", "scores"]]

examples

In [None]:
examples.to_csv(f"../../out/comapp/mbart/data_mbart_vanilla.tsv", sep="\t", decimal=",")
#examples.to_excel(f"../../out/comapp/mbart/data_mbart_vanilla.xlsx")

In [None]:
pseudoword_output_scores.to_csv("../../out/comapp/mbart/data_mbart_vanilla_complete.tsv", sep="\t", decimal=",")