In [ ]:
!pip3 install transformers==4.33.2
!pip3 install optimum==1.13.2
!pip3 install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/

!nvidia-smi

In [1]:
import pickle
import random
import itertools
import re

from tqdm.notebook import tqdm, trange
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import pandas as pd

In [ ]:
model_name_or_path = 'TheBloke/Llama-2-13B-German-Assistant-v4-GPTQ'  # "TheBloke/Llama-2-13B-chat-GPTQ"
model = AutoModelForCausalLM.from_pretrained(model_name_or_path,
                                             device_map='auto',
                                             trust_remote_code=False,
                                             revision="gptq-4bit-32g-actorder_True")
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

In [3]:
with open("../../out/definitions.pickle", "rb") as file:
    definitions = pickle.load(file)
with open("../../out/sentences.pickle", "rb") as file:
    sentences = pickle.load(file)

In [ ]:
def find_examples(definition: str, options: dict, temperature=0.1, top_p=0.95, top_k=1, max_new_tokens=5):
    examples = ""
    for num, option in options.items():
        examples += f"{num}. {option} "
    prompt = lambda definition, examples: f'''### User: Du bist genau und gewissenhaft und gibst stets die korrekte Antwort. Hier ist eine Definition: {definition} Hier sind nummerierte Beispiele: {examples} Nenne die Nummer des Beispiels, das zur Definition passt.
        ### Assistant: '''
    prompt_length = len(prompt(definition, examples))
    i = 100  # only try 100 times
    output = -1
    while (output >= 0) and (i > 0):
        i -= 1
        input_ids = tokenizer(prompt(definition, examples), return_tensors='pt').input_ids.cuda()
        output = model.generate(inputs=input_ids, temperature=temperature,
                                do_sample=True, top_p=top_p, top_k=top_k,
                                max_new_tokens=max_new_tokens)
        output = tokenizer.decode(output[0])[prompt_length:].strip()
        output = re.findall('\d+', output)
        if len(output) > 0:
            output = int(output)
        print(i, end=" ")
        if i == 0:
            return -1
    print()
    return output

In [12]:
for false_positives in range(2, 6):
    result = []
    for key, definition in tqdm(definitions.items()):
        sentence = random.choice(list(sentences[int(key)]))
        others = itertools.chain.from_iterable([sentence_list for constr, sentence_list in sentences.items() if int(constr) != int(key)])
        others = random.choices(list(others), k=false_positives)
        query = dict(enumerate(set(others) | {sentence}))
        prediction = find_examples(definition, query)
        print(query[prediction] == sentence, sentence, prediction)
        result.append(pd.Series({"constr": key, "definition": definition, "example": sentence, "prediction": prediction, "correct": query[prediction] == sentence}))
    result = pd.DataFrame(result)
    result.to_csv(f"result_1_in_{false_positives}.tsv", sep="\t")

  0%|          | 0/211 [00:00<?, ?it/s]

KeyboardInterrupt: 