In [None]:
!pip3 install transformers==4.33.2
!pip3 install optimum==1.13.2
!pip3 install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/

In [ ]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BertForMaskedLM, BertTokenizer
import pickle
import re
import os
import itertools
import json
import numpy as np
import pandas as pd
import torch
from tqdm.notebook import tqdm

tqdm.pandas()

### Models
First, let us load a LLAMA model:

In [ ]:
llama_name_or_path = 'TheBloke/Llama-2-13B-German-Assistant-v4-GPTQ'
llama = AutoModelForCausalLM.from_pretrained(llama_name_or_path,
                                             device_map='auto',
                                             trust_remote_code=False,
                                             revision="gptq-4bit-32g-actorder_True")
llama_tokenizer = AutoTokenizer.from_pretrained(llama_name_or_path)

Then, we also need a BERT model:

In [ ]:
bert_base = BertForMaskedLM.from_pretrained('dbmdz/bert-base-german-cased', return_dict=True)
bert_base_tokenizer = BertTokenizer.from_pretrained('dbmdz/bert-base-german-cased')
bert_base.bert.embeddings.word_embeddings

### Data
Now, we load some data we need. First, we need some 

In [ ]:
with open("../../data/pseudowords/CoMaPP_all_bert.json") as json_file:
    data = json.load(json_file)
    
data = [{"example": d["target1"], "query": (" ".join(d["query"].split()[:d["query_idx"]]) + " " + d["label"] + " " + " ".join(d["query"].split()[d["query_idx"]+1:])).strip(), "pseudoword": d["label"]} for d in data]
df = pd.DataFrame.from_dict(data).drop_duplicates(ignore_index=True)
df

In [ ]:
with open("definitions.pickle", "rb") as file:
    definitions = pickle.load(file)
with open("sentences.pickle", "rb") as file:
    sentences = pickle.load(file)

In [ ]:
def generate_examples(definition: str, sentence: str, temperature=0.75, top_p=0.95, top_k=1, max_new_tokens=1024):
    if len(sentence) > 0:
        prompt = lambda definition, sentence: f'''### User: Du bist kreativ und gewissenhaft. Hier ist eine Definition: {definition} Bilde neue Sätze gemäß dieser Definition. Gib die Sätze in einer Python-Liste aus. Gib sonst nichts aus.
        ### Example: {sentence}
        ### Assistant:["'''
    else:
        prompt = lambda definition, sentence: f'''### User: Du bist kreativ und gewissenhaft. Hier ist eine Definition: {definition} Bilde neue Sätze gemäß dieser Definition. Gib die Sätze in einer Python-Liste aus. Gib sonst nichts aus.
        ### Assistant:["'''

    prompt_length = len(prompt(definition, sentence))
    output = []
    first_output = ""
    i = 100  # only try 100 times
    while (
        (not any([c.isalpha() for c in first_output]))
        or any([x in first_output for x in {"Konstruktion", "Satz", "Überschrift", "_", ":", "XY", "XP", "X ", "Y ", "X.", "Y."}])
        or re.search(r".*\].*\[.*", first_output)
    ):
        input_ids = llama_tokenizer(prompt(definition, sentence), return_tensors='pt').input_ids.cuda()
        output = llama.generate(inputs=input_ids, temperature=temperature,
                                do_sample=True, top_p=top_p, top_k=top_k,
                                max_new_tokens=max_new_tokens)
        output = llama_tokenizer.decode(output[0])[prompt_length:].strip()
        output = re.findall('\[.*\]', output)
        if len(output) > 0:
            first_output = output[0]
        i -= 1
        print(i, end=" ")
        if i == 0:
            first_output = "[]"
            break
        #print(f"\t{output}")
    #print(f"\n{output}")
    print()
    return first_output

In [ ]:
def generate_and_check_examples(definition, sentence, temperature=0.75, max_new_tokens=1000, top_k=100, top_p=0.99):
    

In [ ]:
for shot in range(1, 2):
    examples = {}
    if os.path.exists(f"examples_{shot}_shot_plus_bert.pickle"):
        with open(f"examples_{shot}_shot_plus_bert.pickle", "rb") as file:
            examples = pickle.load(file)

    for k in tqdm(definitions.keys()):
        if k in examples.keys():
            continue  # have already generated examples for this construction

        definition = definitions[k]
        try:
            sentence = str(list(sentences[int(k)])[0:shot])  # get some sentences
        except KeyError:
            print(("[]", "[]", "This seems wrong..."))
            examples[k] = ("[]", "[]")
            continue

        example = generate_and_check_examples(
            definition, sentence, temperature=0.75,
            max_new_tokens=1000, top_k=100, top_p=0.99
        )
        
        try:
            print((sentence, example))
            examples[k] = (sentence, example)
        except:
            print((sentence, "[]"))
            examples[k] = (sentence, "[]")

        with open(f"examples_{shot}_shot.pickle", "wb") as file:
            pickle.dump(examples, file)