In [1]:
"""
pip install llama-index
pip install llama-index-llms-huggingface
pip install "transformers[torch]" "huggingface_hub
"""

import os
import re
from random import random
from typing import List, Tuple
import torch
from llama_index.llms.huggingface import HuggingFaceInferenceAPI  # 0.10.25
from transformers import AutoTokenizer, BertTokenizer, BertForTokenClassification
from sklearn.decomposition import PCA

  from .autonotebook import tqdm as notebook_tqdm


# Generate Training Data with Semi-Supervised Labels from an LLM

In [2]:
# connect to LLM
llm_checkpoint = "mistralai/Mixtral-8x7B-Instruct-v0.1"
remotely_run = HuggingFaceInferenceAPI(model_name=llm_checkpoint, token="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")

In [3]:
# LLM-generated text blocks
nachnamen = list(set([
    "Müller", "Schmidt", "Schneider", "Fischer", "Weber Meier", "Meyer", "Wagner", "Becker",
    "Schulz", "Berg Hoffmann", "Schäfer", "Koch", "Bauer", "Richter", "Klein", "Wolf", 
    "Schröder", "Neumann-Schmid", "Schwarz", "Zimmermann", "Braun", "Krüger", "Hofmann", "Hartmann", 
    "Lange", "Schmitt", "Werner", "Schmitz", "Krause", "Meier", "Lehmann", "Schmid", "Schulze", 
    "Maier", "Köhler", "Herrmann", "Öztürk", "König", "Walter", "Mayer", "Huber", "Kaiser", "Fuchs", 
    "Peters", "Lang-Maler", "Scholz", "Möller", "Weiß", "Jung", "Hahn", "Schubert", "Vogel", "Friedrich", 
    "Keller", "Günther", "Frank", "Berger", "Winkler", "Roth", "Beck", "Lorenz Fischer", "Baumann", 
    "Franke", "Albrecht", "Schuster", "Simon", "Ludwig", "Böhm", "Winter", "Kraus", "Martin", 
    "Schumacher", "Krämer", "Vogt", "Stein", "Jäger", "Otto", "Sommer", "Groß", "Seidel", 
    "Heinrich", "Brandt", "Haas", "Schreiber", "Graf", "Schlegel", "Dietrich", "Ziegler", "Kuhn", 
    "Kühn", "Pohl", "Engel", "Horn", "Busch", "Bergmann", "Thomas", "Voigt", "Sauer", "Arnold", 
    "Wolff", "Pfeiffer", "Wolf-Richter", "Yılmaz", "Kaya", "Demir Schmidt", "Şahin", "Çelik", "Yıldız", "Yıldırım", "Öztürk", "Aydın", "Özdemir",
    "Arslan", "Doğan", "Kılıç", "Aslan", "Çetin", "Karadağ", "Koç", "Kurt", "Özkan", "Akar",
    "Acar", "Tekin", "Kara", "Ekici", "Kaplan", "Şimşek", "Avcı", "Güler", "Korkmaz", "Sarı",
    "Balcı", "Selvi", "Göçer", "Polat", "Demirci", "Duman", "Tuna", "Taş", "Keskin", "Güneş",
    "Can", "Aydemir", "Ata", "Özer-Maler", "Çiftçi", "Bayraktar", "Erdoğan", "Bozkurt", "Kan", "Dağ", 
    "Nowak", "Kowalski", "Wiśniewski", "Dąbrowski", "Lewandowski", "Wójcik", "Kamiński", "Kowalczyk", "Zieliński", "Szymański",
    "Woźniak", "Kozłowski", "Jankowski", "Wojciechowski", "Kwiatkowski", "Kaczmarek", "Mazur", "Kubiak", "Król", "Pawłowski"]))

vornamen = list(set(["Leon", "Mustafa", "Lukas", "Finn", "Noah", "Paul", "Jonas", "Luis", "Elias", "Felix", "Luca",
    "Max", "Henry", "Julian", "Niklas", "Tim", "Alexander", "Philipp", "David", "Maximilian", "Liam",
    "Oskar", "Moritz", "Fabian", "Simon", "Erik", "Jakob-Mark", "Vincent", "Benjamin", "Matteo", "Anton",
    "Emil", "Carl", "Jonathan", "Theo", "Samuel", "Linus", "Mats", "Jan", "Nico", "Leonard",
    "Hannes", "Florian", "Ben", "Adam", "Raphael", "Tobias", "Sebastian", "Martin", "Johannes", "Fabio",
    "Lennard", "Michael", "Jona", "Joshua", "Marcel", "Tom", "Valentin", "Lennart", "Levin", "Maxim",
    "Kilian", "Konstantin", "Robin", "Lars", "Emilian", "Arne", "Matthias", "Milan", "Mohammed", "Kai",
    "Nick", "Ole Joost", "Julius", "Benedikt", "Marvin", "Leopold", "Nils", "Daniel", "Franz", "Manuel",
    "Noel", "Pascal", "Mika", "Adrian", "Oliver", "Stefan", "Lorenz", "Valentino", "Magnus", "Jan-Phillip",
    "Constantin", "Artur", "Albert", "Frederik", "Hugo", "Timo", "Jasper", "Aron", "Joel", "Christian",
    "Anna-Lena", "Maria", "Emma", "Sofia", "Mia", "Hannah", "Lena", "Sarah", "Lea", "Laura",
    "Katharina", "Lisa", "Julia", "Sophie", "Isabella", "Charlotte", "Lara", "Marie", "Clara", "Lina",
    "Luisa", "Johanna", "Paula", "Emilia", "Antonia", "Theresa", "Luise", "Helena", "Elisabeth", "Nina",
    "Magdalena", "Melanie", "Anja", "Christina", "Sandra", "Annika", "Silke", "Katja", "Veronika", "Monika",
    "Birgit", "Sabine", "Petra", "Jana Susanne", "Simone", "Annette", "Stefanie", "Nicole", "Barbara", "Sonja",
    "Carina", "Yvonne", "Daniela", "Eva", "Heike", "Tanja", "Ingrid", "Franziska", "Renate", "Irina",
    "Gisela", "Martina", "Andrea", "Ursula", "Ines", "Beate", "Gabriele", "Cornelia", "Diana", "Brigitte",
    "Elena", "Valentina", "Alicia", "Maja", "Anastasia", "Karina", "Doris", "Judith", "Frieda", "Irma",
    "Hilde", "Erika", "Margarete", "Elfriede", "Gertrud", "Edith", "Ruth", "Ilse", "Hedwig", "Lieselotte",
    "Klara", "Olga", "Rita", "Waltraud", "Inge", "Herta", "Martha", "Else", "Ute", "Helga Marie", "Ahmet", 
    "Mehmet", "Mustafa", "Ayşe", "Fatma Binnaz", "Yusuf", "Zeynep",
    "Elif", "Ömer", "Emir", "Hüseyin", "Hasan", "Ali", "Ibrahim",
    "Sümeyye", "Hatice", "Ece", "Kerem", "Büşra", "Taha", "Rümeysa",
    "Furkan", "Selin", "Cem", "Esra", "Berk", "Derya", "Merve", "Cansu", "Deniz", "Aleksandra", "Aneta", "Bartosz", 
    "Czesław", "Daria","Emilia", "Filip", "Grażyna", "Henryk", "Iwona",
    "Jakub", "Katarzyna", "Łukasz", "Małgorzata", "Natalia",
    "Oskar", "Paweł", "Róża", "Szymon", "Tomasz"]))

roles = ["Verkehrspolizist", "Rechtsanwalt", "Beschuldigter", "Beschädigter", "Zeuge", "Reporter"]

categories = [
    "Überhöhte Geschwindigkeit",
    "Alkohol am Steuer",
    "Unaufmerksamkeit des Fahrers",
    "Nichtbeachtung von Verkehrszeichen",
    "Unangepasste Geschwindigkeit bei schlechtem Wetter",
    "Fahren unter Drogeneinfluss",
    "Ablenkung durch Handynutzung",
    "Ermüdung des Fahrers",
    "Technisches Versagen des Fahrzeugs",
    "Fehler beim Abbiegen",
    "Missachtung der Vorfahrt",
    "Fehlerhaftes Überholen",
    "Mangelnder Sicherheitsabstand",
    "Falschfahren auf der Autobahn",
    "Schlechte Straßenverhältnisse",
    "Nichtanpassung der Geschwindigkeit an Verkehrsdichte",
    "Fahren ohne gültige Fahrerlaubnis",
    "Unsicheres Wechseln der Fahrstreifen",
    "Defekte Verkehrssignale",
    "Unzureichende Fahrzeugbeleuchtung",
    "Fehler beim Rückwärtsfahren",
    "Ladungssicherungsmängel",
    "Fehlverhalten von Fußgängern",
    "Aggressives Fahren",
    "Missachtung der Fußgängerüberwege",
    "Unfälle durch Wildwechsel",
    "Unzureichende Kennzeichnung von Baustellen",
    "Mängel an der Bremsanlage",
    "Reifenpannen",
    "Fahren mit überladenen Fahrzeugen",
    "Nichtbeachtung von Stoppschildern",
    "Fehler beim Einfahren in einen Kreisverkehr",
    "Unaufmerksamkeit beim Ausparken",
    "Verwendung von Alkohol oder Drogen durch Fußgänger",
    "Kollisionen beim Spurwechsel",
    "Unfälle in Kreuzungsgebieten",
    "Zu dichtes Auffahren",
    "Unerlaubtes Wenden auf Straßen",
    "Schlechte Sichtverhältnisse",
    "Fehlerhafte Verkehrsplanung",
    "Nicht beachten von Ampelsignalen",
    "Verkehrsunfälle durch Tierkollision",
    "Fahren auf der falschen Fahrbahnseite",
    "Unfälle verursacht durch gesundheitliche Probleme",
    "Verkehrsunsicherer Zustand des Fahrzeugs",
    "Starkes Beschleunigen",
    "Unfälle in Baustellenbereichen",
    "Fahren ohne Sicherheitsgurt",
    "Unfälle durch Fahrerflucht",
    "Irrtum des Fahrers bezüglich der Verkehrsregeln"]

In [4]:
# use generated text blocks to create versatile generated texts with semi supervised labels & save texts on disk
for i in range(50):
    try:
        # select role, category and name
        role = roles[int(random() * len(roles))]
        category = categories[int(random() * len(categories))]
        vorname = vornamen[int(random() * len(vornamen))]
        nachname = nachnamen[int(random() * len(nachnamen))]
        # name variations: Herr/Frau <last_name> or <forename> <last_name> or <forename>
        rand = random()
        if rand < 0.1:
            name = "Herr " + nachname
        elif rand < 0.2:
            name = "Frau " + nachname
        elif rand < 0.3:
            name = vorname
        else:
            name = vorname + " " + nachname
        if os.path.exists(f"generated_texts/company {name}.txt"):
            continue
        print(i, name)
        # use LLM to generate text
        prompt = f"Du bist {role}. Schreibe einen Satz, welcher einen Kfz-Unfall der Kategorie {category} beschreibt und den Namen '''{name}''' enthält! Gehe nicht auf zeitliche Details ein. Beginne den Satz mit dem Unfallhergang! Schreibe den Satz in <satz><\satz> XML-Tags!"
        text = remotely_run.complete(prompt).text
        print(text)
        # save text
        with open(f"generated_texts/{name}.txt", "w") as f:
            f.write(text)
    except:
        print("rate limit reached")
        break  # time.sleep(10 * 60)

0 Julius Bayraktar


<satz>Julius Bayraktar bemerkte die Verkehrsdichte zu spät und fuhr mit seinem Fahrzeug ungebremst auf das vor ihm stehende Auto auf.</satz>
1 Nico Şahin


<satz>Nico Şahin, der unaufmerksam war, verursachte einen Kfz-Unfall.</satz>
rate limit reached


# Post-Process LLM-Generated Texts

In [5]:
def create_dataset(texts: List[str], to_label: List[str], label_as: List[int], tokenizer: BertTokenizer) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    creates the tensor data set, padding to maximum length.
    out: token_ids, token_labels, attention_mask
    """
    max_len = 0
    text_ids_dataset = []
    labels_dataset = []
    attention_mask_dataset = []
    # tokenize each text and assign token labels
    for i in range(len(texts)):
        text_ids, labels = ids_labels(text=texts[i], to_label=to_label[i], label=label_as[i], tokenizer=tokenizer)
        text_ids_dataset.append(text_ids)
        labels_dataset.append(labels)
        attention_mask_dataset.append([1] * len(labels))
        if len(labels) > max_len:
            max_len = len(labels)
    # pad according to longest text
    for i in range(len(texts)):
        if len(text_ids_dataset[i]) < max_len:
            text_ids_dataset[i] = text_ids_dataset[i] + [tokenizer.pad_token_id] * (max_len - len(text_ids_dataset[i]))
            labels_dataset[i] = labels_dataset[i] + [0] * (max_len - len(labels_dataset[i]))
            attention_mask_dataset[i] = attention_mask_dataset[i] + [0] * (max_len - len(attention_mask_dataset[i]))
    text_ids_dataset = torch.tensor(text_ids_dataset)
    labels_dataset = torch.tensor(labels_dataset)
    attention_mask_dataset = torch.tensor(attention_mask_dataset)
    return text_ids_dataset, labels_dataset, attention_mask_dataset

def ids_labels(text: str, to_label: str, label: int, tokenizer) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    finds <to_label> in <text> and labels respective tokens with <label>, all other with 0.
    TODO: allow labeling multiple sequences per text; allow labeling by position spans, not by to_label-string
    """
    # fix labels according to grammar rules
    if (" " + to_label + "s" + " ") in text:
        to_label = to_label + "s"
    elif "Herr " in to_label:
        if to_label in text:
            pass
        else:
            if "Herrn " + to_label[5:] in text:
                to_label = "Herrn " + to_label[5:]
    tokens_label = tokenizer.tokenize(to_label)
    tokens_text = tokenizer.tokenize(text)
    labels = []
    last_found_pos = - (len(tokens_text) + 1)  # initiate with impossible value
    for i in range(len(tokens_text)):
        if i < (last_found_pos + len(tokens_label)):
            continue
        matches = True
        for j in range(len(tokens_label)):
            if tokens_label[j] != tokens_text[i + j]:
                matches = False
                break
        if matches:
            labels.extend([label] * len(tokens_label))
            last_found_pos = i
        else:
            labels.append(0)
    return tokenizer(text)["input_ids"], [0] + labels + [0]


def discard_zero_rows(token_ids: torch.Tensor, token_labels: torch.Tensor, attention_mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    discard rows where all tokens are labeled as 0.
    Note: it might be beneficial to keep few all 0 sequences in training data to demonstrate, 
    that some sentences don't contain any label
    """
    mask = torch.max(token_labels, dim=1).values > 0
    return token_ids[mask], token_labels[mask], attention_mask[mask]

In [6]:
# discard faulty texts
max_len = 250  # too long text -> sth likely went wrong
texts = []
label_as = []
to_label = []
for file in os.listdir("generated_texts"):
    with open(f"generated_texts/{file}", "r") as f:
        generated = f.read()
        generated_between_tags = re.findall(pattern=r"(?<=<satz>).+(?=</satz>|<\\satz>)", string=generated, flags=re.DOTALL)
        if len(generated_between_tags) == 1:  # otherwise something went wrong in generation
            text = generated_between_tags[0]
            text = re.sub(pattern="'", repl="", string=text)  # often indicate names in the generated texts
            text = re.sub(pattern="\s+", repl=" ", string=text)  # multiple whitespaces
            text = text.strip()  # remove leading and trailing whitespaces
            if len(text) <= max_len:  # discard too long texts
                texts.append(text)
                to_label.append(file[:-4])
                label_as.append(1)
# create tensor dataset
tokenizer = AutoTokenizer.from_pretrained("bert-base-german-cased")
token_ids, token_labels, attention_mask = create_dataset(texts=texts, label_as=label_as, to_label=to_label, tokenizer=tokenizer)
print(token_labels.shape)
token_ids, token_labels, attention_mask = discard_zero_rows(token_ids=token_ids, token_labels=token_labels, attention_mask=attention_mask)  # discard further errors
print(token_labels.shape)

torch.Size([860, 76])
torch.Size([847, 76])


# Robust Outlier Sampling for Train-Eval Split

In [7]:
def distance_rank(vector: torch.Tensor, vector_set: torch.Tensor):
    distances = torch.norm(vector - vector_set, dim=1)
    return  torch.argsort(distances, descending=True)

def distance_rank_fused(vectors, vector_set: torch.Tensor):
    ranks = [distance_rank(v, vector_set) for v in vectors]
    ranking_fused = []
    for i in range(len(vector_set)):
        fused_rank = 0
        for rank in ranks:        
            rank: torch.Tensor = rank
            fused_rank += torch.where(rank == i)[0]
        ranking_fused.append(float(fused_rank / len(ranks)))
    return torch.tensor(ranking_fused)

def edgecase_sampling(embeddings: torch.Tensor, n: int) -> List[int]:
    mean = torch.mean(embeddings, dim=0)
    sample = [mean]
    sample_ids = []
    # iteratively select data point that is on average most distant from all others
    for i in range(n):
        dist_rank = distance_rank_fused(vectors=sample, vector_set=embeddings)
        id_most_distant = torch.argmin(dist_rank)
        sample.append(embeddings[id_most_distant])
        embeddings = torch.cat((embeddings[:id_most_distant], embeddings[(id_most_distant + 1):]))
        sample_ids.append(id_most_distant)
    return sample_ids

In [11]:
eval_size = 64

# create bert embeddings for all data points
model = BertForTokenClassification.from_pretrained("bert-base-german-cased")
embeddings = []
with torch.no_grad():
    for i in torch.arange(start=0, end=len(token_ids), step=32):
        embeddings.append(torch.mean(model.bert(input_ids=token_ids[i:i+32], attention_mask=attention_mask[i:i+32]).last_hidden_state, dim=1))
embeddings = torch.cat(embeddings, dim=0)
# robust outlier sampling
indices_eval = edgecase_sampling(embeddings=torch.tensor(embeddings), n=eval_size)

# split:
# eval data
token_ids_eval = token_ids[indices_eval]
token_labels_eval = token_labels[indices_eval]
attention_mask_eval = attention_mask[indices_eval]
# train data
token_ids_train = token_ids[[i for i in range(len(token_ids)) if i not in indices_eval]]
token_labels_train = token_labels[[i for i in range(len(token_ids)) if i not in indices_eval]]
attention_mask_train = attention_mask[[i for i in range(len(token_ids)) if i not in indices_eval]]

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-german-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  indices_eval = edgecase_sampling(embeddings=torch.tensor(embeddings), n=eval_size)


# Train and Evaluate

In [12]:
def acc(pred: torch.Tensor, true: torch.Tensor) -> float:
    return float((pred == true).sum() / len(true.flatten()))

def recall(pred: torch.Tensor, true: torch.Tensor, on: int = 1) -> float:
    pred = pred.flatten()
    true = true.flatten()
    correct = 0
    for i in range(len(pred)):
        if pred[i] == true[i] == on:
            correct += 1
    return correct / len(true[true == on])

def random_batch(token_ids: torch.Tensor, token_labels: torch.Tensor, attention_mask: torch.Tensor, batch_size: int):
    indices = torch.randint(low=0, high=len(token_ids), size=(batch_size,))
    return token_ids[indices], token_labels[indices], attention_mask[indices], indices

In [13]:
# setting
model = BertForTokenClassification.from_pretrained("bert-base-german-cased")
optimizer = torch.optim.Adam(params=model.parameters(), lr=0.00005)
batch_size = 32

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-german-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [14]:
# train-eval loop
for i in range(20):
    # train
    model.train()
    token_ids_batch, token_labels_batch, attention_mask_batch, _ = random_batch(
        token_ids=token_ids_train, 
        token_labels=token_labels_train,
        attention_mask=attention_mask_train,
        batch_size=batch_size)
    out = model.forward(input_ids=token_ids_batch, labels=token_labels_batch, attention_mask=attention_mask_batch)
    loss: torch.Tensor = out["loss"]
    logits: torch.Tensor = out["logits"]
    pred = torch.argmax(logits, dim=2)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    # eval
    with torch.no_grad():
        model.eval()
        # on evaluation data
        out = model.forward(input_ids=token_ids_eval, labels=token_labels_eval, attention_mask=attention_mask_eval)
        logits: torch.Tensor = out["logits"]
        pred = torch.argmax(logits, dim=2)
        acc_eval = acc(pred=pred, true=token_labels_eval)
        rec_eval = recall(true=token_labels_eval, pred=pred) 
        print(i, "accuracy eval:", acc_eval, "recall eval:", rec_eval, "\n")
        if rec_eval > 0.99:
            break
# save to disk
torch.save(model.state_dict(), "model_state_dict.pth")

0 accuracy eval: 0.9636101722717285 recall eval: 0.0 

1 accuracy eval: 0.9636101722717285 recall eval: 0.0 

2 accuracy eval: 0.9636101722717285 recall eval: 0.0 

3 accuracy eval: 0.9636101722717285 recall eval: 0.0 

4 accuracy eval: 0.9636101722717285 recall eval: 0.0 

5 accuracy eval: 0.9732730388641357 recall eval: 0.2655367231638418 

6 accuracy eval: 0.9917762875556946 recall eval: 0.7740112994350282 

7 accuracy eval: 0.9991776347160339 recall eval: 0.9774011299435028 

8 accuracy eval: 0.9989720582962036 recall eval: 0.9830508474576272 

9 accuracy eval: 0.9983552694320679 recall eval: 0.9943502824858758 



# Inference on OOD Data

In [23]:
tokenizer = BertTokenizer.from_pretrained("bert-base-german-cased")
model = BertForTokenClassification.from_pretrained("bert-base-german-cased")
model.load_state_dict(torch.load("model_state_dict.pth"))
text = "Henriette arbeitet heute zusammen mit Krankenschwestern. In der Uni hat Emilio eine neue Methode gelernt."
with torch.no_grad():
    ids, labels, mask = create_dataset(texts=[text], to_label=["Henriette"], label_as=[1], tokenizer=tokenizer)
    pred = torch.argmax(model(input_ids=ids, attention_mask=mask)["logits"], dim=2)
    tokens = [0] + tokenizer.tokenize(text)
    print([t for i, t in enumerate(tokens) if pred[0][i] == 1])

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-german-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


['Henri', '##ette', 'Emil', '##io']
