In [None]:
!pip3 install transformers==4.33.2
!pip3 install optimum==1.13.2
!pip3 install auto-gptq --extra-index-url https://huggingface.github.io/autogptq-index/whl/cu118/

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BertForMaskedLM, BertTokenizer
import pickle
import re
import os
import csv
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from tqdm.notebook import tqdm

tqdm.pandas()

### Models
First, let us load a LLAMA model:

In [None]:
llama_name_or_path = 'TheBloke/Llama-2-13B-German-Assistant-v4-GPTQ'
llama = AutoModelForCausalLM.from_pretrained(llama_name_or_path,
                                             device_map="cuda:0",
                                             trust_remote_code=False,
                                             revision="gptq-4bit-32g-actorder_True")
llama_tokenizer = AutoTokenizer.from_pretrained(llama_name_or_path)

Then, we also need a BERT model:

In [None]:
bert_base = BertForMaskedLM.from_pretrained('dbmdz/bert-base-german-cased', return_dict=True)
bert_base_tokenizer = BertTokenizer.from_pretrained('dbmdz/bert-base-german-cased')
bert_base.bert.embeddings.word_embeddings

### Data
Now, we load some data we need. First, we need some definitions and their example sentences.

In [None]:
with open("../../out/definitions.pickle", "rb") as file:
    definitions = pickle.load(file)
with open("../../out/sentences.pickle", "rb") as file:
    sentences = pickle.load(file)

In [None]:
definitions

Next up, we load the prepared BERT pseudoword embeddings.

In [None]:
pseudowords = []
for i in range(15):
    pseudowords.append(np.load(f"../../data/pseudowords/bsbbert/pseudowords_comapp_bsbbert_{i*37}_{i*37+37}.npy"))
pseudowords = np.concatenate(pseudowords)
pseudowords

In [None]:
csv_data = []
for i in range(1, 16):
    csv_data.append(pd.read_csv(f"../../data/pseudowords/bsbbert/order_bsbbert_{i}.csv", sep=";", index_col=0, header=None, quotechar="|", names=["order", "label"]))
csv_data = pd.concat(csv_data)
csv_data

Also, we define a lookup table to map from construction ids to the pseudowords more quickly.

In [None]:
with open("../../data/pseudowords/annotations.csv", "r") as csv_file:
    data = [row for row in csv.DictReader(csv_file)]
    
kelex_dict = {}
for example in data:
    kees = set()
    for kee in eval(example["kees"]):
        kees |= set(kee.split())
    kelex_dict[int(example["construction_id"])] = kees

kelex_dict

These pseudowords are now added to BERT.

In [None]:
bert_tokens = [d[0] for d in csv_data.values]

bert_tokens, len(bert_tokens)

In [None]:
combined_embeddings = torch.cat((bert_base.bert.embeddings.word_embeddings.weight, torch.tensor(pseudowords)), dim=0)
bert_base.bert.embeddings.word_embeddings = torch.nn.Embedding.from_pretrained(combined_embeddings)
bert_base.bert.embeddings.word_embeddings

In [None]:
bert_base_tokenizer.add_tokens(bert_tokens)
bert_base.resize_token_embeddings(len(bert_base_tokenizer))

Finally, we move both models to a GPU. If there is only one GPU, only the LLAMA model is moved there.

In [None]:
llama.to("cuda:0")
bert_device = "cuda:1" if torch.cuda.device_count() >= 2 else "cpu"
bert_base.to(bert_device)

# Generation
Finally, we can start generating new examples. First we define a function which lets LLAMA propose a sentence. 

In [None]:
def generate_examples(definition: str, sentence: str, temperature=0.75, top_p=0.95, top_k=1, max_new_tokens=1024):
    if len(sentence) > 2:
        prompt = lambda definition, sentence: f'''### User: Du bist kreativ und gewissenhaft. Hier ist eine Definition: {definition} Bilde neue Sätze gemäß dieser Definition. Gib die Sätze in einer Python-Liste aus. Gib sonst nichts aus.
        ### Example: {sentence}
        ### Assistant:["'''
    else:
        prompt = lambda definition, sentence: f'''### User: Du bist kreativ und gewissenhaft. Hier ist eine Definition: {definition} Bilde neue Sätze gemäß dieser Definition. Gib die Sätze in einer Python-Liste aus. Gib sonst nichts aus.
        ### Assistant:["'''

    prompt_length = len(prompt(definition, sentence))
    output = []
    first_output = ""
    i = 100  # only try 100 times
    while (
        (not any([c.isalpha() for c in first_output]))
        or any([x in first_output for x in {"Konstruktion", "Satz", "Überschrift", "_", ":", "XY", "XP", "X ", "Y ", "X.", "Y."}])
        or re.search(r".*\].*\[.*", first_output)
    ):
        input_ids = llama_tokenizer(prompt(definition, sentence), return_tensors='pt').input_ids.cuda()
        output = llama.generate(inputs=input_ids, temperature=temperature,
                                do_sample=True, top_p=top_p, top_k=top_k,
                                max_new_tokens=max_new_tokens)
        output = llama_tokenizer.decode(output[0])[prompt_length:].strip()
        output = re.findall('\[.*\]', output)
        if len(output) > 0:
            first_output = output[0]
        i -= 1
        print(i, end=" ")
        if i == 0:
            first_output = "[]"
            break
        #print(f"\t{output}")
    #print(f"\n{output}")
    print()
    return first_output

To validate the output from LLAMA, we check the produced sentence for the existence of the necessary KE-lex. If it is not available, we try again. If it is available, we check whether the pseudoword or the general embedding is closer to the representation given in the sentence. We accept a sentence only if the pseudoword is closer.

Load the contextual embeddings first (are prepared in `compare_embeddings.ipynb`).

In [None]:
with open("../../out/comapp/contextual_embeds_ex.pickle", "rb") as file:
    contextual_embeds_ex = pickle.load(file)
    
contextual_embeds_ex[10].keys()

In [None]:
def compare_distances(constr, sentence):
    global contextual_embeds_ex
    
    sentence_ids = bert_base_tokenizer(sentence, return_tensors='pt')['input_ids']
    # First, check whether the construction has ke-lex and if any ke-lex is in the sentence. Also, drop a sentence if it is way too long!
    if (constr not in contextual_embeds_ex.keys()) or (not any([kelex in sentence for kelex in contextual_embeds_ex[constr].keys()])) or (sentence_ids.size(-1) > 512):
        print(".", end="")
        return False
    with torch.no_grad():
        sentence_id_list = [sentence_ids]
        outputs_list = [bert_base(sentence_ids.to(bert_device), output_hidden_states=True)]
    
        pseudoword_fitting = []
        for kelex, embeds in contextual_embeds_ex[constr].items():
            if kelex not in sentence:
                continue
            bert_sims = []
            pseudoword_sims = []
            bert_euclideans = []
            pseudoword_euclideans = []
            bert_manhattans = []
            pseudoword_manhattans = []
            for cur_sentence_ids, outputs in zip(sentence_id_list, outputs_list):
                kelex_ids = [idx for idx, t in enumerate(cur_sentence_ids[0]) if t in bert_base_tokenizer(kelex, return_tensors='pt')['input_ids'][0][1:-1]]
                if len(kelex_ids) == 0:  # the KE-LEX is not in the current segment
                    continue
                sentence_contextual_embeds = outputs.hidden_states[12][0][kelex_ids]
                
                # Now let's compare BERT and pseudoword:
                bert_sims.append(torch.mean(F.cosine_similarity(embeds[0].to(bert_device), sentence_contextual_embeds, dim=-1)))
                pseudoword_sims.append(torch.mean(F.cosine_similarity(embeds[1].to(bert_device).expand_as(sentence_contextual_embeds), sentence_contextual_embeds, dim=-1)))
                bert_euclideans.append(torch.mean(torch.norm(embeds[0].to(bert_device)-sentence_contextual_embeds, p=2, dim=-1)))
                pseudoword_euclideans.append(torch.mean(torch.norm(embeds[1].to(bert_device).expand_as(sentence_contextual_embeds) - sentence_contextual_embeds, p=2, dim=-1)))
                bert_manhattans.append(torch.mean(torch.norm(embeds[0].to(bert_device)-sentence_contextual_embeds, p=1, dim=-1)))
                pseudoword_manhattans.append(torch.mean(torch.norm(embeds[1].to(bert_device).expand_as(sentence_contextual_embeds) - sentence_contextual_embeds, p=1, dim=-1)))
                
            bert_sim = torch.mean(torch.tensor(bert_sims))
            pseudoword_sim = torch.mean(torch.tensor(pseudoword_sims))
            bert_euclidean = torch.mean(torch.tensor(bert_euclideans))
            pseudoword_euclidean = torch.mean(torch.tensor(pseudoword_euclideans))
            bert_manhattan = torch.mean(torch.tensor(bert_manhattans))
            pseudoword_manhattan = torch.mean(torch.tensor(pseudoword_manhattans))
            
            pseudoword_fitting.append(any([pseudoword_sim >= bert_sim, pseudoword_euclidean <= bert_euclidean, pseudoword_manhattan <= bert_manhattan]))
        return any(pseudoword_fitting)  # return True if for the pseudoword at least one metric is better than for any of the standard embeddings in the examples

In [None]:
def generate_and_check_examples(constr, definition, sentence, temperature=0.75, max_new_tokens=1000, top_k=100, top_p=0.99, patience=10):
    fitting_example = False
    # Loop until the example is fitting the construction properly according to pseudowords:
    example_list_fitting = []
    while not fitting_example and patience:
        patience -= 1
        example = generate_examples(definition=definition, sentence=sentence, temperature=temperature, max_new_tokens=max_new_tokens, top_k=top_k, top_p=top_p)
        example_list_fitting = []
        try:
            example_list = eval(example)
            if example_list != []:
                for e in example_list:     
                    e_fit = compare_distances(constr, e)
                    fitting_example = fitting_example or e_fit  # "or": at least one pseudoword needs to fit
                    if e_fit:
                        example_list_fitting.append(e)
            print(fitting_example, example_list, example_list_fitting)
        except:
            pass
    return example_list_fitting

In [None]:
for shot in [1, 0]:
    examples = {}
    if os.path.exists(f"../../out/llama_bert/examples_{shot}_shot_plus_bert.pickle"):
        with open(f"../../out/llama_bert/examples_{shot}_shot_plus_bert.pickle", "rb") as file:
            examples = pickle.load(file)

    for k in tqdm(definitions.keys()):
        if k in examples.keys():
            continue  # have already generated examples for this construction

        definition = definitions[k]
        try:
            sentence = str(list(sentences[int(k)])[0:shot])  # get some sentences
        except KeyError:
            print(("[]", "[]", "This seems wrong..."))
            examples[k] = ("[]", "[]")
            continue

        example = generate_and_check_examples(
            constr=k, definition=definition, sentence=sentence, temperature=0.75,
            max_new_tokens=512, top_k=100, top_p=0.99
        )
        
        try:
            print((sentence, example))
            examples[k] = (sentence, example)
        except:
            print((sentence, "[]"))
            examples[k] = (sentence, "[]")
        
        print("=====")

        with open(f"../../out/llama_bert/examples_{shot}_shot.pickle", "wb") as file:
            pickle.dump(examples, file)