In [1]:
from transformers import pipeline, AutoTokenizer
import pandas as pd
import json
from random import Random
from torch.utils.data import Dataset
from tqdm import tqdm
import transformers
import traceback
from Annotation import annotate
from copy import deepcopy

In [6]:
transformers.logging.set_verbosity_error()

ITEMS_PER_CONDITION = 10

class PromptDataset(Dataset):
    def __init__(self, prompts):
        self.prompts = prompts
    def __len__(self):
        return len(self.prompts)
    def __getitem__(self, idx):
        return self.prompts[idx]
        
def do_one_condition(model, row, forced_tokens):
    rows = []
    try:
        ...
        annotate(row)
    except StopIteration:
        print("Reached the end!")
    return rows

models = [
    # model name, batch_size, device, device Mapping
    ("stefan-it/german-gpt2-larger", 64, 0, None),]

with open("../items/names.json", encoding="utf-8") as nfile:
    namedict = json.load(nfile)
male_names = [name for name in namedict["male"]] * 2
female_names = [name for name in namedict["female"]] * 2
with open("../items/verbs_forced_reference.json", encoding="utf-8") as nfile:
    verbdict = json.load(nfile)
es_verbs = verbdict["es"]
se_verbs = verbdict["se"]

female_shuffled = female_names.copy()
Random(42).shuffle(female_shuffled)
male_shuffled = male_names.copy()
Random(84).shuffle(male_shuffled)
male_pairing = list(zip(male_names, female_shuffled, [False for name in male_names]))
female_pairing = list(zip(female_names, male_shuffled, [True for name in male_names]))

conditions = [
    (2,  es_verbs, female_pairing, "NP1"),
    (3,  es_verbs,   male_pairing, "NP1"),
    (6,  se_verbs, female_pairing, "NP1"),
    (7,  se_verbs,   male_pairing, "NP1"),
    (10, es_verbs, female_pairing, "NP2"),
    (11, es_verbs,   male_pairing, "NP2"),
    (14, se_verbs, female_pairing, "NP2"),
    (15, se_verbs,   male_pairing, "NP2")
]

items_per_condition = []
  
for condition, verbs, pairing, forced_reference in conditions:
    rows = []
    for verbdict in verbs:
        verb, filler, verbclass = verbdict["verb"], verbdict["filler"], verbdict["verbclass"]
        for np1, np2, female in pairing:
            prompt = f"{np1} {verb} {np2}{filler}, weil"
            nrow = {"condition": condition, "type": "Experiment", "prompt": prompt, "NP1": np1, "NP2": np2, 
                    "NP1gender": "f" if female else "m", "verb": verb, "verbclass": verbclass, "forced": forced_reference}
            rows.append(nrow)
    Random(168).shuffle(rows)
    items_per_condition.append(rows)

for model_name, batch_size, device, device_map in models:
       
    print(f"now loading: {model_name}")
    model = pipeline("text-generation", model = model_name, device = device, device_map = device_map)
    model.tokenizer.pad_token_id = model.model.config.eos_token_id
    model.tokenizer.padding_side = "left"
    
    male_tokens = list(map(model.tokenizer.encode, [" er", " dieser", " jener", " der"]))
    female_tokens = list(map(model.tokenizer.encode, [" sie", " diese", " jene", " die"]))
    
    data = []

    for index, condition in items_per_condition:
        items = deepcopy(condition)
        bar = tqdm(total = ITEMS_PER_CONDITION)
        item_iter = iter(items)
        rows = []
        while bar.n < ITEMS_PER_CONDITION:
            try:
                row = next(item_iter)
                if row["forced"] == "NP1":
                    tokenized_name = model.tokenizer.encode(f" {row['NP1']}")
                else:
                    tokenized_name = model.tokenizer.encode(f" {row['NP2']}")
                if (row["NP1gender"] == "m" and row["forced"] == "NP1") or (row["NP1gender"] == "f" and row["forced"] == "NP2"):
                    forced_tokens = male_tokens + tokenized_name
                else:
                    forced_tokens = female_tokens + tokenized_name
                continuation = model(row["prompt"], force_words_ids = [forced_tokens], remove_invalid_values=True, early_stopping = True, do_sample = False, num_beams = 10, max_new_tokens = 25)[0]["generated_text"]
                row["con"] = continuation[len(row["prompt"]) + 1:]
                res = annotate(row, False)
                if res["Koreferenz"] == row["forced"]:
                    bar.update(1)
                    row.update(res)
                    rows.append(row)
            except StopIteration:
                print(f"Run out of data in condition {items[0]['condition']}")
                break
        data += rows

    exp3 = pd.DataFrame(data, columns = ["condition", "type", "prompt", "cont", "NP1", "NP2", "NP1gender", "verb", "verbclass", "forced", "Koreferenz", "Anaphorische Form"])
    exp3.to_csv(f"../data/forced_coreference--{model_name.replace('/', '--')}.csv", sep=";", index=False)
    
    del model
    del exp3
    del rows
    del items
    

KeyError: 'verblass'