In [1]:
from ast import literal_eval
import functools
from typing import Any
import os

import torch
import tensorflow as tf
import lightning as pl
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import TensorBoardLogger
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
import transformers
from transformers import AutoTokenizer, BertForTokenClassification
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt

torch.set_float32_matmul_precision('medium')

2024-06-22 01:07:48.883330: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-22 01:07:48.908196: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm


In [2]:
data_path = "../ner.csv"

In [3]:
df = pd.read_csv(data_path)

In [4]:
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
all_tags = sorted(functools.reduce(lambda x, y: x.union(y), df['Tag'].map(literal_eval).map(set).to_list()))

def align_labels(tags:list[str], tokenized_inputs: transformers.BatchEncoding, batch_index:int):
    word_ids = tokenized_inputs.word_ids(batch_index=batch_index)
    label_ids = []
    for word_idx in word_ids:
        if word_idx is None:
            label_ids.append(-100)
        else:
            tag = tags[word_idx]
            label_ids.append(all_tags.index(tag))
    return label_ids

def tokenize(df:pd.DataFrame):
    sentences = df['Sentence'].map(lambda x: x.split(" ")).to_list()
    tags = df['Tag'].map(literal_eval).to_list()
    tokenized_inputs = tokenizer(
        sentences, truncation=True, is_split_into_words=True
    )
    aligned_labels = [
        align_labels(sentence_tags, tokenized_inputs, i)
        for i, sentence_tags in enumerate(tags)
    ]
    
    tokenized_inputs["labels"] = aligned_labels
    return tokenized_inputs

tokenized = tokenize(df)

In [5]:
class NERDataset(Dataset):
    def __init__(self, tokenized_data: transformers.BatchEncoding, indexes: list[int]) -> None:
        super().__init__()
        self.tokenized_data = tokenized_data
        self.indexes = indexes

    def __getitem__(self, index):
        ids = self.tokenized_data[index].ids
        attention_mask = self.tokenized_data[index].attention_mask
        targets = self.tokenized_data["labels"][index]
        return {
            "input_ids": torch.tensor(ids, dtype=torch.long),
            "attention_mask": torch.tensor(attention_mask, dtype=torch.long),
            "labels": torch.tensor(targets, dtype=torch.long),
        }

    def __len__(self):
        return len(self.indexes)


def collate(samples):
    padding_id = tokenizer.get_vocab()["[PAD]"]
    input_ids = torch.nn.utils.rnn.pad_sequence(
        [sample["input_ids"] for sample in samples], batch_first=True, padding_value=padding_id
    )
    attention_mask = torch.nn.utils.rnn.pad_sequence(
        [sample["attention_mask"] for sample in samples], batch_first=True, padding_value=0
    )
    labels = torch.nn.utils.rnn.pad_sequence(
        [sample["labels"] for sample in samples], batch_first=True, padding_value=-100
    )
    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "labels": labels,
    }

In [6]:
import random
indexes = list(range(len(df)))
random.Random(1337).shuffle(indexes)

train_indexes = indexes[:38367]
val_indexes = indexes[38367:]

# train_indexes = indexes[:100]
# val_indexes = indexes[-100:]

train_dataset = NERDataset(tokenized, train_indexes)
val_dataset = NERDataset(tokenized, val_indexes)

In [7]:
class MetricsCalculator:
    def __init__(self) -> None:
        self.targets = []
        self.predictions = []

    def update(self, logits: torch.Tensor, targets: torch.Tensor) -> None:
        prediction = logits.argmax(-1).view(-1)
        targets = targets.view(-1)

        prediction = prediction[targets != -100].cpu().numpy()
        targets = targets[targets != -100].cpu().numpy()

        self.predictions.append(prediction)
        self.targets.append(targets)

    def get_confusion_matrix(self) -> np.ndarray:
        y_true = np.concatenate(self.targets)
        y_pred = np.concatenate(self.predictions)

        fig = plt.figure()
        ax = fig.add_subplot()
        ConfusionMatrixDisplay.from_predictions(
            y_true,
            y_pred,
            display_labels=all_tags,
            labels=list(range(len(all_tags))),
            ax=ax,
            xticks_rotation="vertical",
            normalize="true",
            values_format=".2f",
            text_kw={"fontsize": "xx-small"}
        )
        fig.canvas.draw()
        data = np.array(fig.canvas.renderer.buffer_rgba())
        plt.close(fig)

        return data

    def get_metrics(self) -> dict[str, float]:
        y_true = np.concatenate(self.targets)
        y_pred = np.concatenate(self.predictions)

        accuracy = accuracy_score(y_true, y_pred)
        f1 = f1_score(y_true, y_pred, average="macro")

        return {
            "accuracy": accuracy,
            "f1": f1,
        }

    def clear(self):
        self.targets.clear()
        self.predictions.clear()


class NERModule(pl.LightningModule):
    def __init__(self) -> None:
        super().__init__()
        self.model = BertForTokenClassification.from_pretrained(
            "bert-base-uncased", torch_dtype=torch.float32, attn_implementation="sdpa", num_labels=len(all_tags)
        )
        self.train_metrics_calculator = MetricsCalculator()
        self.val_metrics_calculator = MetricsCalculator()

        self.losses = []
        self.val_losses = []

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(params=self.model.parameters(), lr=1e-05)
        return optimizer

    def train_dataloader(self) -> Any:
        return DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0, collate_fn=collate)

    def val_dataloader(self) -> Any:
        return DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0, collate_fn=collate)

    def forward(self, x):
        return self.model(**x)

    def training_step(self, x):
        output = self(x)
        loss = output.loss
        logits = output.logits

        self.train_metrics_calculator.update(logits, x["labels"])
        self.losses.append(loss.item())

        return loss

    def validation_step(self, x):
        output = self(x)
        loss = output.loss
        logits = output.logits

        self.val_metrics_calculator.update(logits, x["labels"])
        self.val_losses.append(loss.item())

        return loss

    def on_train_epoch_end(self) -> None:
        metrics = self.train_metrics_calculator.get_metrics()
        metrics["loss"] = np.mean(self.losses)
        self.log_dict(metrics, on_epoch=True)
        self.losses.clear()
        self.train_metrics_calculator.clear()

    def on_validation_epoch_end(self) -> None:
        metrics = self.val_metrics_calculator.get_metrics()
        metrics["loss"] = np.mean(self.val_losses)
        metrics = {"val_" + k: v for k, v in metrics.items()}

        confusion_mat = self.val_metrics_calculator.get_confusion_matrix()

        self.log_dict(metrics, on_epoch=True)
        self.val_losses.clear()
        self.val_metrics_calculator.clear()
        self.logger.experiment.add_image(
            "confusion matrix", confusion_mat.astype(np.float32) / 255, dataformats="HWC", global_step=self.global_step
        )

In [8]:
model = NERModule()
logger = TensorBoardLogger("tb_logs")
checkpoint_callback = ModelCheckpoint(dirpath="checkpoints/", save_top_k=2, monitor="val_f1", mode="max")
trainer = pl.Trainer(max_epochs=20, precision="16-mixed", accelerator='gpu', logger=logger, callbacks=[checkpoint_callback])
trainer.fit(model)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/pokropow/programming/Named Entity Recognition/venv/lib/python3.10/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:652: Checkpoint directory /home/pokropow/programming/Named Entity Recognition/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                       | Params | Mode
------------------------------------------------------------
0 | model | BertForTokenClassification | 108 M  | eval
------------------------------------------------------------
108 M     Trainable params

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/pokropow/programming/Named Entity Recognition/venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


                                                                           

/home/pokropow/programming/Named Entity Recognition/venv/lib/python3.10/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Epoch 19: 100%|██████████| 1199/1199 [00:47<00:00, 25.27it/s, v_num=8]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 1199/1199 [00:47<00:00, 25.11it/s, v_num=8]
