In [ ]:
!pip3 install transformers==4.33.2
!pip3 install optimum==1.13.2
!pip3 install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/

!nvidia-smi

In [None]:
import itertools
import pickle
import random
import re

import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer

from tqdm.notebook import tqdm
tqdm.pandas()

In [None]:
model_name_or_path = 'TheBloke/Llama-2-13B-German-Assistant-v4-GPTQ'
model = AutoModelForCausalLM.from_pretrained(
    model_name_or_path,
    device_map="cuda:0",
    trust_remote_code=False,
    revision="gptq-4bit-32g-actorder_True"
)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)

In [None]:
with open("../../out/definitions.pickle", "rb") as file:
    definitions = pickle.load(file)
with open("../../out/sentences.pickle", "rb") as file:
    sentences = pickle.load(file)

In [None]:
def find_examples(definition, examples, temperature=0.1, top_p=0.95, top_k=1, true_examples=0):
    max_new_tokens = 5 * true_examples
    options = ""
    for num, example in enumerate(examples):
        options += f"{num}: \"{example}\", "
    options = "{" + options[:-2] + "}"

    tries = 5  # only try 5 times
    output = set()

    if true_examples > 0:
        while len(output) != true_examples:
            if true_examples == 1:
                prompt = lambda d, e: f'''### User: Du bist genau und gewissenhaft und gibst stets die korrekte Antwort. Hier sind mögliche Sätze: {e} Hier ist eine Definition: {d} Nenne den Index des Satzes, der zur Definition passt.
                ### Assistant: '''
            else:
                prompt = lambda d, e: f'''### User: Du bist genau und gewissenhaft und gibst stets die korrekte Antwort. Hier sind mögliche Sätze: {e} Hier ist eine Definition: {d} Nenne die Indizes der Sätze, die zur Definition passen, in einer Liste. Es sind genau {true_examples} Beispiele richtig.
                ### Assistant: ['''
            prompt_length = len(prompt(definition, options))

            input_ids = tokenizer([prompt(definition, options)]*tries, return_tensors='pt').input_ids.to("cuda:0")
            pred = model.generate(
                inputs=input_ids, temperature=temperature, do_sample=True, top_p=top_p, top_k=top_k,
                max_new_tokens=len(input_ids[0]) + max_new_tokens
            )
            # print(tokenizer.batch_decode(pred))
            pred = [re.findall('\d', p[prompt_length:].strip()) for p in tokenizer.batch_decode(pred)]
            output = {int(item) for row in pred for item in row}

    else:
        prompt = lambda d, e: f'''### User: Du bist genau und gewissenhaft und gibst stets die korrekte Antwort. Hier sind mögliche Sätze: {e} Hier ist eine Definition: {d} Nenne die Indizes der Sätze, die zur Definition passen, in einer Liste. Falls kein Satz richtig ist, gib [] aus.
        ### Assistant: ['''
        prompt_length = len(prompt(definition, options))

        input_ids = tokenizer([prompt(definition, options)]*tries, return_tensors='pt').input_ids.to("cuda:0")
        pred = model.generate(inputs=input_ids, temperature=temperature,
                                do_sample=True, top_p=top_p, top_k=top_k,
                                max_new_tokens=len(input_ids[0]) + max_new_tokens)
        # print(tokenizer.batch_decode(pred))
        pred = [re.findall('\d', p[prompt_length:].strip()) for p in tokenizer.batch_decode(pred)]
        output |= {int(item) for row in pred for item in row}
    print(output, end=" ")
    res = [examples[i] for i in output if 0 <= i < len(examples)]
    input_ids = input_ids.to("cpu")
    return res

In [None]:
def get_metrics(positive_predicted, negative_predicted, true_sentences, false_sentences, key, definition, examples):
    true_positives = [pr for pr in positive_predicted if pr in true_sentences]
    false_positives = [pr for pr in positive_predicted if pr in false_sentences]
    false_negatives = [pr for pr in negative_predicted if pr in true_sentences]
    true_negatives = [pr for pr in negative_predicted if pr in false_sentences]
    
    if len(true_positives) + len(false_positives) > 0:
        precision = len(true_positives) / (len(true_positives) + len(false_positives))
    else:
        precision = 1.0  # nothing found, so all things found are correct
    
    if len(true_positives) > 0:
        recall = len(true_positives) / (len(true_positives) + len(false_negatives))
    else:
        recall = 1.0  # all found
        
    return pd.Series({
        "constr": key, 
        "definition": definition, 
        "examples": examples, 
        "positive_predicted": positive_predicted,
        "negative_predicted": negative_predicted,
        "true_positives": true_positives,
        "false_positives": false_positives,
        "false_negatives": false_negatives,
        "true_negatives": true_negatives,
        "precision": precision,
        "recall": recall,
        "f1": (2 * precision * recall) / (precision + recall),
        "accuracy": (len(true_positives) + len(true_negatives)) / (len(true_sentences) + len(false_sentences))
    })

In [None]:
random.seed(15)
attempts = 3  # will be done in batches of 5 => 15 attempts!
for num_true in range(0, 6):
    for num_false in range(0, 6):
        print()
        if not num_true and not num_false:
            continue  # skip (0, 0)
        result = []
        for key, definition in tqdm(definitions.items()):
            others = list(itertools.chain.from_iterable([sentence_list for constr, sentence_list in sentences.items() if int(constr) != int(key)]))  # flatten all other sentences which are not part of the current construction
            
            for attempt in range(attempts):
                try:
                    # pick the true elements of the current construction
                    true_sentences = {random.choice(list(sentences[int(key)])) for t in range(num_true)}
                except KeyError:
                    result.append(pd.Series({"constr": key, "definition": definition}))
                    continue
    
                # pick random false positives from the other sentences
                false_sentences = set(random.choices(others, k=num_false))
                examples = list(false_sentences | true_sentences)
                
                positive_predicted = find_examples(definition[:definition.index(".")+1], examples)
                negative_predicted = [ex for ex in examples if ex not in positive_predicted]
                
                result.append(get_metrics(positive_predicted, negative_predicted, true_sentences, false_sentences, key, definition, examples))
                
        result = pd.DataFrame(result)
        result.to_csv(f"../../out/llama/result_{num_true}t_vs_{num_false}f_{attempts}attempts_llama.tsv", sep="\t", decimal=",", header=True, index=False)

In [ ]:
random.seed(15)
attempts = 15
for num_true in range(1, 6):
    for num_false in range(1, 6):
        print()
        result = []
        for key, definition in tqdm(definitions.items()):
            others = list(itertools.chain.from_iterable([sentence_list for constr, sentence_list in sentences.items() if int(constr) != int(key)]))  # flatten all other sentences which are not part of the current construction
            
            for attempt in range(attempts):
                try:
                    # pick the true elements of the current construction
                    true_sentences = {random.choice(list(sentences[int(key)])) for t in range(num_true)}
                except KeyError:
                    result.append(pd.Series({"constr": key, "definition": definition}))
                    continue
    
                # pick random false positives from the other sentences
                false_sentences = set(random.choices(others, k=num_false))
                examples = list(false_sentences | true_sentences)
                
                positive_predicted = find_examples(definition[:definition.index(".")+1], examples, num_true)
                negative_predicted = [ex for ex in examples if ex not in positive_predicted]
                
                result.append(get_metrics(positive_predicted, negative_predicted, true_sentences, false_sentences, key, definition, examples))
                
        result = pd.DataFrame(result)
        result.to_csv(f"../../out/llama/result_{num_true}t_vs_{num_false}f_{attempts}attempts_llama_2.tsv", sep="\t", decimal=",", header=True, index=False)