<a href="https://colab.research.google.com/github/Lukas-Forst/Blogposts/blob/main/NER_Eval.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# -*- coding: utf-8 -*-

!pip install seqeval
!pip install transformers
!pip install datasets
!pip install wandb

import torch
import wandb
wandb.init(project='ner_conll', settings=wandb.Settings(start_method="thread"))

from transformers import AutoTokenizer, AutoModelForTokenClassification, Trainer, TrainingArguments
from datasets import load_dataset, load_metric, Dataset, DatasetDict
import numpy as np
from seqeval.metrics import f1_score, precision_score, recall_score, classification_report

print("CUDA available:", torch.cuda.is_available())
torch.device('cuda' if torch.cuda.is_available() else 'cpu')



def read_conll_file(file_path):
    with open(file_path, "r") as f:
        content = f.read().strip()
        sentences = content.split("\n\n")
        data = []
        for sentence in sentences:
            tokens = sentence.split("\n")
            token_data = []
            for token in tokens:
                token_data.append(token.split())
            data.append(token_data)
    return data


train_data = read_conll_file("/content/train.txt")
validation_data = read_conll_file("/content/valid.txt")
test_data = read_conll_file("/content/test.txt")


def convert_to_dataset(data, label_map):
    formatted_data = {"tokens": [], "ner_tags": []}
    for sentence in data:
        tokens = [token_data[0] for token_data in sentence]
        ner_tags = [label_map[token_data[3]] for token_data in sentence]
        formatted_data["tokens"].append(tokens)
        formatted_data["ner_tags"].append(ner_tags)
    return Dataset.from_dict(formatted_data)


label_list = sorted(list(set([token_data[3] for sentence in train_data for token_data in sentence])))
label_map = {label: i for i, label in enumerate(label_list)}


train_dataset = convert_to_dataset(train_data, label_map)
validation_dataset = convert_to_dataset(validation_data, label_map)
test_dataset = convert_to_dataset(test_data, label_map)


datasets = DatasetDict({
    "train": train_dataset,
    "validation": validation_dataset,
    "test": test_dataset,
})

model_name = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name, num_labels=len(label_list))

def compute_metrics(eval_prediction):
    predictions, labels = eval_prediction
    predictions = np.argmax(predictions, axis=2)


    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]


    return {
        "precision": precision_score(true_labels, true_predictions),
        "recall": recall_score(true_labels, true_predictions),
        "f1": f1_score(true_labels, true_predictions),
        "classification_report": classification_report(true_labels, true_predictions),
    }


def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(
        examples["tokens"], truncation=True, is_split_into_words=True, padding=True
    )
    labels = []
    for i, label in enumerate(examples["ner_tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:
                label_ids.append(label[word_idx])
            else:
                label_ids.append(-100)
            previous_word_idx = word_idx
        labels.append(label_ids)
    tokenized_inputs["labels"] = labels
    return tokenized_inputs

#!pip install accelerate -U

#!pip install accelerate

import accelerate

tokenized_datasets = datasets.map(tokenize_and_align_labels, batched=True)


training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="steps",
    eval_steps=100,
    save_steps=100,
    num_train_epochs=10,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    logging_steps=20,
    learning_rate=5e-5,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
)

def data_collator(data):
    input_ids = [torch.tensor(item["input_ids"]) for item in data]
    attention_mask = [torch.tensor(item["attention_mask"]) for item in data]
    labels = [torch.tensor(item["labels"]) for item in data]


    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=tokenizer.pad_token_id)
    attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)


    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }



# Updating W&B config
wandb.config.update({
    'num_train_epochs': 5,
    'learning_rate': 5e-5,
    'weight_decay': 0.01,
    'per_device_train_batch_size': 8,
    'per_device_eval_batch_size': 8,
    'logging_dir': './logs',
    'logging_steps': 10,
    'eval_steps': 10,
    'save_steps': 50,
    'evaluation_strategy': 'steps'
})

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["validation"],
    data_collator=data_collator,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)




trainer.train()

wandb.finish()

sentence = "Apple Inc. announced that Tim Cook will be attending the event in San Francisco."


tokenized_input = tokenizer(sentence, return_tensors="pt").to(model.device)


outputs = model(**tokenized_input)


predicted_labels = outputs.logits.argmax(-1)[0]


named_entities = [tokenizer.decode([token]) for token, label in zip(tokenized_input["input_ids"][0], predicted_labels) if label != 0 and label != label_map['O']]


print("Named Entities - Example 1:", named_entities)