In [12]:
import lightning as L

import itertools
import re
from torch.utils.data import Dataset
from pathlib import Path
import pandas as pd
import json

from transformers import AutoTokenizer, RobertaModel, BertModel, PreTrainedTokenizer
from tokenizers import Encoding

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchmetrics import Metric, F1Score, Precision, Recall, Accuracy
from torch import Tensor
from torch.nn import ModuleDict
from torch.utils.data import DataLoader
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint
import os

# 
def flatten_list(l):
    if not isinstance(l, list):
        return [l]
    return [item for sublist in l for item in flatten_list(sublist)]


def align_labels_to_text(text_encoding: Encoding, labels: list[dict], tag2label: dict):
    num_labels = len(tag2label.keys())
    text_labels = torch.zeros((text_encoding.input_ids.shape[1], num_labels))
    for label in labels:
        tag, start_idx, end_idx = label["tag"], int(label["start_span"]), int(label["end_span"])
        start_token_idx = text_encoding.char_to_token(start_idx)
        end_token_idx = text_encoding.char_to_token(end_idx - 1)
        text_labels[start_token_idx:end_token_idx, tag2label[tag]] = 1
    text_labels[~text_labels[:, 1:].any(dim=1)] = 1  # Adding null class if no other label is present
    return text_labels


def split_text(text: str, tokenizer: PreTrainedTokenizer, max_seq_len: int):
    paragraphs = re.split("(\n\n)", text)
    paragraphs = ["".join(paragraphs[i : i + 2]) for i in range(0, len(paragraphs), 2)]
    for p_idx in range(len(paragraphs)):
        ids = tokenizer.encode(paragraphs[p_idx], add_special_tokens=True)
        if len(ids) > max_seq_len:
            lines = re.split(("(\n)"), paragraphs[p_idx])
            lines = ["".join(lines[i : i + 2]) for i in range(0, len(lines), 2)]
            for l_idx in range(len(lines)):
                ids = tokenizer.encode(lines[l_idx], add_special_tokens=True)
                if len(ids) > max_seq_len:
                    sentences = re.split("([\.!\?]\s+)", lines[l_idx])
                    sentences = ["".join(sentences[i : i + 2]) for i in range(0, len(sentences), 2)]
                    for s_idx in range(len(sentences)):
                        ids = tokenizer.encode(sentences[s_idx], add_special_tokens=True)
                        if len(ids) > max_seq_len:
                            words = re.split("(\s+)", sentences[s_idx])
                            words = ["".join(words[i : i + 2]) for i in range(0, len(words), 2)]
                            sentences[s_idx] = words
                    lines[l_idx] = sentences
            paragraphs[p_idx] = lines
    splits = flatten_list(paragraphs)
    return splits


def get_tokens_indices(char_to_token_list: list[int], start_idx: int, end_idx: int):
    token_idx_list = [char_to_token_list[i] for i in range(start_idx, end_idx) if char_to_token_list[i] is not None]
    token_idx_list = [k for k, _ in itertools.groupby(token_idx_list)]
    return token_idx_list


def merge_splits_into_chunks(
    text: str,
    splits: list[str],
    tokenizer: PreTrainedTokenizer,
    max_seq_len: int,
    labels: list[dict],
    tag2label: dict,
):
    encoding = tokenizer(text, add_special_tokens=False, return_tensors="pt")
    char_to_token_list = [encoding.char_to_token(i) for i in range(len(text))]
    text_ids = encoding.input_ids[0]
    text_label_ids = align_labels_to_text(encoding, labels, tag2label)
    num_labels = len(tag2label.keys())
    assert len(text_ids) == len(text_label_ids)

    # Merge splits into chunks without exceeding max_seq_len
    start_chunk_idx, end_chunk_idx = 0, 0
    chunks = {"text": [], "input_ids": [], "label_ids": []}
    for i in range(len(splits) + 1):
        # TODO: optimize this
        if i < len(splits):
            # Compute the current chunk length after adding the next tokenized split
            sentence = splits[i]
            token_idx_list = get_tokens_indices(char_to_token_list, start_chunk_idx, end_chunk_idx + len(sentence))
            chunk_ids = text_ids[token_idx_list]
        if i == len(splits) or len(chunk_ids) > max_seq_len - 2:  # account for [CLS] and [SEP] token
            # add previous splits as a chunk if current chunk exceeds max_seq_len - 2 or if the splits are finished
            token_idx_list = get_tokens_indices(char_to_token_list, start_chunk_idx, end_chunk_idx)
            chunk_ids = torch.cat(
                [
                    torch.LongTensor([tokenizer.cls_token_id]),
                    text_ids[token_idx_list],
                    torch.LongTensor([tokenizer.sep_token_id]),
                ]
            )
            chunk_labels_ids = torch.cat(
                [
                    torch.LongTensor([[-100] * num_labels]),
                    text_label_ids[token_idx_list],
                    torch.LongTensor([[-100] * num_labels]),
                ],
            )
            chunks["text"].append(text[start_chunk_idx:end_chunk_idx])
            chunks["input_ids"].append(chunk_ids)
            chunks["label_ids"].append(chunk_labels_ids)
            start_chunk_idx = end_chunk_idx
        end_chunk_idx += len(sentence)
    return chunks

  sentences = re.split("([\.!\?]\s+)", lines[l_idx])
  words = re.split("(\s+)", sentences[s_idx])


In [14]:
class CardioCCC(Dataset):
    LABEL_FOLDERS = ["dis", "med", "symp", "proc"]

    def __init__(self, root_path: str, split: str, lang: str = "it", encoding: str = 'latin-1'):
        self.root_path = Path(root_path)
        self.split_file_names = json.load((self.root_path / "splits.json").open())[lang][split]["symp"]
        self.lang = lang
        batches = ["b1", "b2"] if lang != "ro" else ["b1"]
        self.annotations = []
        for batch in batches:
            lang_path = self.root_path / batch / "1_validated_without_sugs" / lang
            raw_annotations = []
            for label_folder in self.LABEL_FOLDERS:
                ann_path = lang_path / label_folder / "tsv"
                raw_annotations.append(pd.read_csv(next(ann_path.glob("*.tsv")), sep="\t", na_filter=False))
            raw_annotations = pd.concat(raw_annotations, axis=0)

            for group in raw_annotations.groupby("name"):
                if group[0] not in self.split_file_names:
                    continue
                file_name = group[0] + ".txt"
                text = (lang_path / "dis/txt" / file_name).read_text(encoding=encoding)
                labels = group[1].loc[:, ["tag", "start_span", "end_span", "text"]].to_dict(orient="records")
                self.annotations.append({"text": text, "labels": labels})

    def __len__(self):
        return len(self.annotations)

    def __getitem__(self, idx):
        return self.annotations[idx]


class ChunkedCardioCCC(Dataset):
    TAG2LABEL = {"0": 0, "DISEASE": 1, "MEDICATION": 2, "PROCEDURE": 3, "SYMPTOM": 4}
    LABEL2TAG = {v: k for k, v in TAG2LABEL.items()}

    def __init__(self, dataset: CardioCCC, tokenizer: PreTrainedTokenizer, language: str, iter_by_chunk: bool = False, model_max_len: int = 512):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.language = language
        self.chunked_data = []
        self.iter_by_chunk = iter_by_chunk
        for i, item in enumerate(dataset):
            text, labels = item["text"], item["labels"]
            splits = split_text(text, tokenizer, model_max_len)
            chunks = merge_splits_into_chunks(text, splits, tokenizer, model_max_len, labels, self.TAG2LABEL)
            if iter_by_chunk:
                for i in range(len(chunks["text"])):
                    self.chunked_data.append(
                        {
                            "text": chunks["text"][i],
                            "input_ids": chunks["input_ids"][i],
                            "label_ids": chunks["label_ids"][i],
                        }
                    )
            else:
                self.chunked_data.append(chunks)

    def __len__(self):
        return len(self.chunked_data)

    def __getitem__(self, idx):
        return self.chunked_data[idx]

In [15]:
def collate_fn_chunked_bert(batch: list[dict]):
    input_ids = [chunk["input_ids"] for chunk in batch]
    labels = [chunk["label_ids"] for chunk in batch]
    attention_mask = [torch.ones_like(ids) for ids in input_ids]
    input_ids = torch.nn.utils.rnn.pad_sequence(input_ids, batch_first=True, padding_value=0)
    attention_mask = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
    labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=-100)
    return {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask}


class NEREval(Metric):
    def __init__(self, num_labels: int):
        super().__init__()
        self.num_labels = num_labels
        self.add_state("preds", default=[], dist_reduce_fx="cat")
        self.add_state("labels", default=[], dist_reduce_fx="cat")
        metric_classes_dict = {"f1": F1Score, "precision": Precision, "recall": Recall, "accuracy": Accuracy}
        self.classification_metrics = ModuleDict(
            {
                k
                + (f"_{avg}" if avg != "none" else ""): v(task="multilabel", num_labels=num_labels, average=avg, zero_division=1)
                for k, v in metric_classes_dict.items()
                for avg in ["none", "micro", "macro"]
            }
        )

    def update(self, preds: Tensor, labels: Tensor) -> None:
        self.preds.append(preds)
        self.labels.append(labels)

    def compute(self):
        preds, labels = self.preds, self.labels
        if isinstance(preds, list):
            preds, labels = torch.cat(self.preds), torch.cat(self.labels)

        results = {}
        for metric_name, metric in self.classification_metrics.items():
            results[metric_name] = metric(preds, labels)
            metric.reset()
        return results


class NERModule(L.LightningModule):
    def __init__(self, lm: nn.Module, lm_output_size: int, label2tag: int):
        super().__init__()
        self.lm = lm
        self.lm_output_size = lm_output_size
        self.label2tag = label2tag
        self.num_labels = len(label2tag.keys())
        self.classifier = nn.Linear(lm_output_size, self.num_labels)
        self.metric = NEREval(num_labels=self.num_labels)

    def exclude_padding_and_special_tokens(self, logits: torch.Tensor, labels: torch.Tensor):
        logits = logits.view(-1, self.num_labels)
        labels = labels.view(-1, self.num_labels)
        label_mask = labels[:, 0] != -100  # exclude padding and special tokens
        logits = logits[label_mask]
        labels = labels[label_mask]
        return logits, labels

    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        sequence_out = self.lm(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        logits = self.classifier(sequence_out)
        logits, labels = self.exclude_padding_and_special_tokens(logits, labels)
        loss = F.binary_cross_entropy_with_logits(logits, labels)

        self.log("train_loss", loss, on_epoch=True, sync_dist=True)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        sequence_out = self.lm(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        logits = self.classifier(sequence_out)
        logits, labels = self.exclude_padding_and_special_tokens(logits, labels)
        loss = F.binary_cross_entropy_with_logits(logits, labels)
        self.log("val_loss", loss, on_epoch=True, sync_dist=True)
        preds = logits.sigmoid()
        self.metric.update(preds, labels)

    def on_validation_epoch_end(self):
        results = self.metric.compute()
        for k, v in results.items():
            if "micro" not in k and "macro" not in k:
                for i in range(self.num_labels):
                    self.log(f"val_{k}_class_{self.label2tag[i]}", v[i], on_epoch=True, sync_dist=True)
            else:
                self.log(f"val_{k}", v, on_epoch=True, sync_dist=True)
        self.metric.reset()

    def test_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]
        sequence_out = self.lm(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        logits = self.classifier(sequence_out)
        logits, labels = self.exclude_padding_and_special_tokens(logits, labels)
        preds = logits.sigmoid()
        self.metric.update(preds, labels)

    def on_test_epoch_end(self):
        results = self.metric.compute()
        new_results = {}
        for k, v in results.items():
            if "micro" not in k and "macro" not in k:
                for i in range(self.num_labels):
                    new_results[f"test_{k}_class_{self.label2tag[i]}"] = v[i].item()
                    self.log(f"test_{k}_class_{self.label2tag[i]}", v[i], on_epoch=True, sync_dist=True)
            else:
                new_results[f"test_{k}"] = v.item()
                self.log(f"test_{k}", v, on_epoch=True, sync_dist=True)
        self.metric.reset()
        return new_results

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=2e-5)
        return optimizer

In [26]:
batch_size = 32
patience = 5
num_workers = 4
max_epochs = 30
num_labels = len(ChunkedCardioCCC.TAG2LABEL.keys())
root_path = "T://laupodteam/AIOS/Bram/notebooks/code_dev/CardioNER.nl/assets"
lang = "nl"
model_name = "CLTL/MedRoBERTa.nl"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = RobertaModel.from_pretrained(model_name, add_pooling_layer=False)
devices = [4] #[0]
use_cpu = True

max_len = model.config.max_position_embeddings

print(f"The maximum length: {max_len}")

train = CardioCCC(root_path, "train", lang)
val = CardioCCC(root_path, "validation", lang)
test = CardioCCC(root_path, "test", lang)
train = ChunkedCardioCCC(train, tokenizer, lang, iter_by_chunk=True, model_max_len=max_len)
val = ChunkedCardioCCC(val, tokenizer, lang, iter_by_chunk=True,  model_max_len=max_len)
test = ChunkedCardioCCC(test, tokenizer, lang, iter_by_chunk=True)
train_loader = DataLoader(train, batch_size=batch_size, collate_fn=collate_fn_chunked_bert, shuffle=True, num_workers=num_workers)
val_loader = DataLoader(val, batch_size=batch_size, collate_fn=collate_fn_chunked_bert, shuffle=False, num_workers=num_workers)
test_loader = DataLoader(test, batch_size=batch_size, collate_fn=collate_fn_chunked_bert, shuffle=False, num_workers=num_workers)

module = NERModule(lm=model, lm_output_size=model.config.hidden_size, label2tag=train.LABEL2TAG)
trainer = L.Trainer(max_epochs=1)

if torch.cuda.is_available():
    os.environ["TOKENIZERS_PARALLELISM"] = "false"
    torch.set_float32_matmul_precision("medium")

callbacks = [
    EarlyStopping(monitor="val_loss", mode="min", patience=patience),
    ModelCheckpoint(monitor="val_loss", mode="min"),
]
strategy = "ddp_find_unused_parameters_true" if len(devices) > 1 else "auto" 
strategy = 'ddp' if use_cpu else strategy # ddp_spawn if use_cpu and not in notebook

trainer = L.Trainer(
    callbacks=callbacks,
    devices=devices[0] if use_cpu else devices,
    max_epochs=max_epochs,
    strategy=strategy,
    precision="16-mixed" if isinstance(devices, list) or devices == "cuda" else "bf16",
)
trainer.fit(module, train_dataloaders=train_loader, val_dataloaders=val_loader)
trainer.test(model=module, dataloaders=test_loader)

The maximum length: 514


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
c:\Users\bes3\AppData\Local\pypoetry\Cache\virtualenvs\cardioner-nl-o7dqlUGo-py3.12\Lib\site-packages\lightning\pytorch\trainer\connectors\accelerator_connector.py:513: You passed `Trainer(accelerator='cpu', precision='16-mixed')` but AMP with fp16 is not supported on CPU. Using `precision='bf16-mixed'` instead.
Using bfloat16 Automatic Mixed Precision (AMP)


MisconfigurationException: `Trainer(strategy='ddp')` is not compatible with an interactive environment. Run your code as a script, or choose a notebook-compatible strategy: `Trainer(strategy='ddp_notebook')`. In case you are spawning processes yourself, make sure to include the Trainer creation inside the worker function.