# Install libraries

In [None]:
%pip install -r ../requirements.txt

In [4]:
# library
import os
#from pprint import pprint as pp

import torch
from torch import nn
from transformers import LongformerTokenizer, AutoTokenizer
from torch.utils.data import Dataset, DataLoader

from torch.optim import AdamW
from torch.nn import CrossEntropyLoss

from collections import defaultdict

  from .autonotebook import tqdm as notebook_tqdm


### load data

In [5]:
first_dataset_doc_path = "./dataset/First_Phase_Release(Correction)/First_Phase_Text_Dataset/"
second_dataset_doc_path = "./dataset/Second_Phase_Dataset/Second_Phase_Text_Dataset/"
label_path = ["./dataset/First_Phase_Release(Correction)/answer.txt", "./dataset/Second_Phase_Dataset/answer.txt"]
val_dataset_doc_parh = "./dataset/validation_dataset/Validation_Release/"
val_label_path = "./dataset/validation_dataset/answer.txt"

first_dataset_path = [first_dataset_doc_path + file_path for file_path in os.listdir(first_dataset_doc_path)]
second_dataset_path = [second_dataset_doc_path + file_path for file_path in os.listdir(second_dataset_doc_path)]
train_path = first_dataset_path + second_dataset_path
val_path = [val_dataset_doc_parh + file_path for file_path in os.listdir(val_dataset_doc_parh)]

#check number of data-path
print(len(first_dataset_path)) #1120
print(len(second_dataset_path)) #614
print()
print(len(train_path)) #1734
print(len(val_path)) #560

1121
615

1736
561


In [6]:
def create_label_dict(label_path):
    """
    Read labeled data from a file and create dictionaries for training and validation datasets.

    uses the `create_label_dict` function to read label files, addressing the potential UTF-8 BOM issue (U+FEFF) by using the `utf-8-sig` encoding.

    # Function: create_label_dict

    ## Description
    This function reads a label file and creates a dictionary containing labeled data. It removes the UTF-8 BOM if present using the `utf-8-sig` encoding.

    ## Parameters
    - `label_path`: The path to the label file that needs to be processed.

    ## Returns
    - `label_dict`: A dictionary containing labeled data with unique IDs as keys and corresponding label information as values.
    """
    label_dict = {}  # y
    with open(label_path, "r", encoding="utf-8-sig") as f:
        file_text = f.read().strip()  

    # (id, label, start, end, query) or (id, label, start, end, query, time_org, timefix)
    for line in file_text.split("\n"):
        sample = line.split("\t")  
        sample[2], sample[3] = int(sample[2]), int(sample[3])

        if sample[0] not in label_dict:
            label_dict[sample[0]] = [sample[1:]]
        else:
            label_dict[sample[0]].append(sample[1:])

    return label_dict

train_label_dict = create_label_dict(label_path[0])
second_dataset_label_dict = create_label_dict(label_path[1])
train_label_dict.update(second_dataset_label_dict)
val_label_dict = create_label_dict(val_label_path)

In [7]:
def load_medical_records(paths):
    """
    Function to load medical records from text files and create dictionaries for training and validation datasets.

    Description :
    This function takes a list of file paths, reads the corresponding text files, and creates a dictionary containing medical records.
    Each file is identified by its unique ID, extracted from the file path.

    Parameters :
    - `paths`: A list of file paths containing medical records.

    Returns :
    - `medical_record_dict`: A dictionary containing medical records, where file IDs are used as keys, and the corresponding text content is the value.
    """
    medical_record_dict = {}
    for data_path in paths:

        if os.path.isfile(data_path):
            file_id = data_path.split("/")[-1].split(".txt")[0]
            with open(data_path, "r", encoding="utf-8") as f:
                file_text = f.read()
                medical_record_dict[file_id] = file_text
    return medical_record_dict

train_medical_record_dict = load_medical_records(train_path)
val_medical_record_dict = load_medical_records(val_path)

In [8]:
#chect the number of data
print(len(list(train_medical_record_dict.keys()))) #1734
print(len(list(train_label_dict.keys()))) #1734
print(len(list(val_medical_record_dict.keys()))) #560
print(len(list(val_label_dict.keys()))) #560

1734
1734
560
560


### clean data

In [9]:
def check_labels(text, labels, record_id, tag=False):
    """
    Check if the extracted labels from the text match the expected labels.

    Parameters:
    - text (str): The full text of the medical record.
    - labels (list): A list of labels containing (id, start, end, expected_text).
    - record_id (str): The identifier of the medical record.
    - tag (bool): If True, print correct extractions as well.

    Returns:
    None
    """
    for i, label in enumerate(labels):  
        extracted_text = text[label[1]:label[2]]
        if extracted_text != label[3]:
            print(f"Error in ID {record_id}, Line {i}: {label[0]}, position: {label[1]}-{label[2]}, "
                  f"label: '{label[3]}', extracted: '{extracted_text}'")
        elif tag:
            print(f"Correct in ID {record_id}, Line {i}: {label[0]}, position: {label[1]}-{label[2]}, extracted: '{extracted_text}'")

def check_all_labels(medical_records, label_dict, tag=False):
    """
    Check labels for all medical records against the provided label dictionary.

    Parameters:
    - medical_records (dict): A dictionary with medical record IDs as keys and corresponding text as values.
    - label_dict (dict): A dictionary with medical record IDs as keys and lists of labels as values.
    - tag (bool): If True, print correct extractions as well.

    Returns:
    None
    """
    for record_id, text in medical_records.items():
        if record_id in label_dict:
            labels = label_dict[record_id]
            check_labels(text, labels, record_id, tag)
        else:
            print(f"ID: {record_id} has no label")

In [10]:
# check training data
check_all_labels(train_medical_record_dict, train_label_dict)   

Error in ID 1139, Line 16: HOSPITAL, position: 2702-2722, label: 'PLANTAGENET HOSPITAL', extracted: 'PLANTAGENE3/9 JENNIE'
Error in ID 1481, Line 21: DEPARTMENT, position: 2390-2403, label: 'SEALS Central', extracted: 'SEAKALBARRI H'


In [11]:
# check 1139, PLANTAGENET 3/9 JENNIE COX CLOSE Pathology ?
print(train_medical_record_dict['1139'][2702:2722])
print(train_label_dict['1139'][16])

# replace it
train_label_dict['1139'][16][3]=train_medical_record_dict['1139'][2702:2722]

PLANTAGENE3/9 JENNIE
['HOSPITAL', 2702, 2722, 'PLANTAGENET HOSPITAL']


In [12]:
# check 1481, there is no DEPARTMENT
print(train_medical_record_dict['1481'][2390:2403])
print(train_label_dict['1481'][21])

# remove it 
train_label_dict['1481'].pop(21)

SEAKALBARRI H
['DEPARTMENT', 2390, 2403, 'SEALS Central']


['DEPARTMENT', 2390, 2403, 'SEALS Central']

In [13]:
# check val data
check_all_labels(val_medical_record_dict, val_label_dict) 

Error in ID file21297, Line 20: ORGANIZATION, position: 6045-6064, label: 'KB Home Los Angeles', extracted: 'KB Home	Los Angeles'


In [14]:
# check file21297, index 6047 is '\t'
val_medical_record_dict['file21297'][6045:6064]

# replace it
val_medical_record_dict['file21297'] = val_medical_record_dict['file21297'][:6047] + ' ' + val_medical_record_dict['file21297'][6048:]

### create label type table

In [15]:
#add special token [other] in label list
labels_type = list(set( [label[0] for labels in train_label_dict.values() for label in labels] ))
labels_type = ["OTHER"] + labels_type 
labels_num = len(labels_type)
# print(labels_type)
# print("The number of labels:", labels_num)
labels_type_table = {label_name:id for id, label_name in enumerate(labels_type)}
print(labels_type_table)

{'OTHER': 0, 'ROOM': 1, 'STREET': 2, 'IDNUM': 3, 'TIME': 4, 'DATE': 5, 'MEDICALRECORD': 6, 'PHONE': 7, 'AGE': 8, 'URL': 9, 'ZIP': 10, 'COUNTRY': 11, 'DURATION': 12, 'HOSPITAL': 13, 'PATIENT': 14, 'ORGANIZATION': 15, 'DEPARTMENT': 16, 'STATE': 17, 'DOCTOR': 18, 'LOCATION-OTHER': 19, 'CITY': 20, 'SET': 21}


In [16]:
# fix it
labels_type_table={'OTHER': 0, 'PATIENT': 1, 'DOCTOR': 2, 'CITY': 3, 'ROOM': 4, 'STREET': 5, 'MEDICALRECORD': 6, 'DEPARTMENT': 7, 'LOCATION-OTHER': 8, 'COUNTRY': 9, 'IDNUM': 10, 'STATE': 11, 'AGE': 12, 'SET': 13, 'HOSPITAL': 14, 'DATE': 15, 'ZIP': 16, 'URL': 17, 'DURATION': 18, 'ORGANIZATION': 19, 'TIME': 20, 'PHONE': 21}
print(labels_type_table)

{'OTHER': 0, 'PATIENT': 1, 'DOCTOR': 2, 'CITY': 3, 'ROOM': 4, 'STREET': 5, 'MEDICALRECORD': 6, 'DEPARTMENT': 7, 'LOCATION-OTHER': 8, 'COUNTRY': 9, 'IDNUM': 10, 'STATE': 11, 'AGE': 12, 'SET': 13, 'HOSPITAL': 14, 'DATE': 15, 'ZIP': 16, 'URL': 17, 'DURATION': 18, 'ORGANIZATION': 19, 'TIME': 20, 'PHONE': 21}


In [17]:
#check the label_type is enough for validation
val_labels_type = list(set( [label[0] for labels in val_label_dict.values() for label in labels] ))
for val_label_type in val_labels_type:
    if val_label_type not in labels_type:
        print("Special label in validation:", val_label_type)

### Load pre-trained model

In [18]:
model_name = "allenai/longformer-base-4096"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

In [20]:
from transformers import LongformerModel
from torchcrf import CRF

class MyLongformerModel(nn.Module):
    """
    Custom PyTorch model using the Longformer architecture with Conditional Random Fields (CRF) for sequence labeling.

    Parameters:
    - num_labels (int): The number of unique labels/classes for sequence labeling.

    Attributes:
    - longformer (LongformerModel): Pre-trained Longformer model.
    - dropout (nn.Dropout): Dropout layer to prevent overfitting.
    - classifier (nn.Linear): Linear layer for classification.
    - crf (CRF): Conditional Random Field layer for sequence labeling.

    Methods:
    - forward(input_ids, attention_mask, labels=None): Forward pass of the model.

    Example Usage:
    ```python
    model = MyLongformerModel(num_labels=22)
    ```

    """

    def __init__(self, num_labels):
        """
        Initializes the MyLongformerModel.

        Parameters:
        - num_labels (int): The number of unique labels/classes for sequence labeling.
        """
        super(MyLongformerModel, self).__init__()

        # Pre-trained Longformer model
        self.longformer = LongformerModel.from_pretrained('allenai/longformer-base-4096')

        # Dropout layer to prevent overfitting
        self.dropout = nn.Dropout(p=0.1)

        # Linear layer for classification
        self.classifier = nn.Linear(768, num_labels)

        # Conditional Random Field layer for sequence labeling
        self.crf = CRF(num_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, labels=None):
        """
        Forward pass of the MyLongformerModel.

        Parameters:
        - input_ids (torch.Tensor): Input tensor containing token IDs.
        - attention_mask (torch.Tensor): Attention mask tensor indicating which tokens should be attended to.
        - labels (torch.Tensor): Ground truth labels for sequence labeling. If None, decoding is performed.

        Returns:
        - loss (torch.Tensor) if labels are provided, else decoded sequence labels (torch.Tensor).
        """
        outputs = self.longformer(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        sequence_output = self.dropout(outputs.last_hidden_state)
        logits = self.classifier(sequence_output)

        if labels is not None:
            loss = -self.crf(logits, labels, mask=attention_mask.byte())
            return loss
        else:
            return self.crf.decode(logits, mask=attention_mask.byte())

# Usage
model = MyLongformerModel(num_labels=22)

In [21]:
BACH_SIZE = 4
#TRAIN_RATIO = 0.9
LEARNING_RATE = 1e-4
EPOCH = 10

### dataloader

In [22]:
import torch
from torch.utils.data import Dataset, DataLoader

class PrivacyProtectionDataset(Dataset):
    """
    Custom PyTorch Dataset class for privacy protection tasks with medical records.

    Parameters:
    - medical_record_dict (dict): A dictionary containing medical record IDs as keys and text content as values.
    - medical_record_labels (dict): A dictionary containing medical record IDs as keys and lists of labels as values.
    - tokenizer: Tokenizer for encoding the text.
    - labels_type_table (dict): A dictionary mapping label types to numerical IDs.
    - mode (str): Mode of the dataset, e.g., 'train', 'test', or 'val'.

    Attributes:
    - max_length (int): Maximum length for text chunks.
    - labels_type_table (dict): A dictionary mapping label types to numerical IDs.
    - tokenizer: Tokenizer for encoding the text.
    - data (list): A list containing tuples of text chunks, corresponding labels, and record IDs.

    Methods:
    - split_and_add_data(text, labels, id): Splits text into chunks and adds data to the dataset.
    - __getitem__(index): Returns a tuple containing text chunk, labels, and record ID for a given index.
    - __len__(): Returns the total number of items in the dataset.
    - find_token_ids(label_start, label_end, offset_mapping): Finds token IDs corresponding to label positions after tokenization.
    - encode_labels_position(batch_labels, offset_mapping): Encodes the positions of labels in tokenized text.
    - create_labels_tensor(batch_shape, batch_labels_position_encoded): Creates a tensor representing labels for the batch.
    - collate_fn(batch_items): Collates a batch of items during data loading.

    Example Usage:
    ```python
    dataset = PrivacyProtectionDataset(medical_record_dict, medical_record_labels, tokenizer, labels_type_table, mode='train')
    dataloader = DataLoader(dataset, batch_size=32, collate_fn=dataset.collate_fn)
    ```

    """

    def __init__(self, medical_record_dict: dict, medical_record_labels: dict, tokenizer, labels_type_table: dict, mode: str):
        """
        Initializes the PrivacyProtectionDataset.

        Parameters:
        - medical_record_dict (dict): A dictionary containing medical record IDs as keys and text content as values.
        - medical_record_labels (dict): A dictionary containing medical record IDs as keys and lists of labels as values.
        - tokenizer: Tokenizer for encoding the text.
        - labels_type_table (dict): A dictionary mapping label types to numerical IDs.
        - mode (str): Mode of the dataset, e.g., 'train', 'test', or 'val'.
        """
        self.max_length = 4096
        self.labels_type_table = labels_type_table
        self.tokenizer = tokenizer
        self.data = []

        for id, text in medical_record_dict.items():
            labels = medical_record_labels.get(id, [])
            self.split_and_add_data(text, labels, id)

    def split_and_add_data(self, text, labels, id):
        """
        Splits the text into chunks and adds data to the dataset.

        Parameters:
        - text (str): The full text of the medical record.
        - labels (list): A list of labels containing (id, start, end, expected_text).
        - id (str): The identifier of the medical record.

        Returns:
        None
        """
        # Split text into chunks of max_length
        for i in range(0, len(text), self.max_length):
            text_chunk = text[i:i + self.max_length]
            # Adjust labels for this chunk
            chunk_labels = [label for label in labels if label[1] >= i and label[2] <= i + self.max_length]
            chunk_labels = [[label[0], label[1] - i, label[2] - i] for label in chunk_labels]
            self.data.append((text_chunk, chunk_labels, id))

    def __getitem__(self, index):
        """
        Returns a tuple containing text chunk, labels, and record ID for a given index.

        Parameters:
        - index (int): Index of the item in the dataset.

        Returns:
        - tuple: A tuple containing text chunk, labels, and record ID.
        """
        text_chunk, chunk_labels, id = self.data[index]
        return text_chunk, chunk_labels, id

    def __len__(self):
        """
        Returns the total number of items in the dataset.

        Returns:
        - int: The total number of items in the dataset.
        """
        return len(self.data)

    def find_token_ids(self, label_start, label_end, offset_mapping):
        """
        Finds token IDs corresponding to label positions after tokenization.

        Parameters:
        - label_start (int): Start position of the label.
        - label_end (int): End position of the label.
        - offset_mapping (list): List of token offset mappings.

        Returns:
        - tuple: A tuple containing the start and end token IDs.
        """
        encode_start = float("inf")  # max
        encode_end = 0
        for token_id, token_range in enumerate(offset_mapping):
            token_start, token_end = token_range

            # if token range one side out of label range, still take the token
            if token_start == 0 and token_end == 0:  # special token
                continue

            if label_start < token_end and label_end > token_start:
                if token_id < encode_start:
                    encode_start = token_id
                encode_end = token_id + 1

        return encode_start, encode_end

    def encode_labels_position(self, batch_labels: list, offset_mapping: list):
        """
        Encodes the positions of labels in tokenized text.

        Parameters:
        - batch_labels (list): List of labels for a batch.
        - offset_mapping (list): List of token offset mappings for the batch.

        Returns:
        - list: List of encoded label positions for the batch.
        """
        batch_encoding_labels = []
        for sample_labels, sample_offsets in zip(batch_labels, offset_mapping):
            encoding_labels = []
            for label in sample_labels:
                encoding_start, encoding_end = self.find_token_ids(label[1], label[2], sample_offsets)
                encoding_labels.append([label[0], encoding_start, encoding_end])
            batch_encoding_labels.append(encoding_labels)
        return batch_encoding_labels

    def create_labels_tensor(self, batch_shape: list, batch_labels_position_encoded: list):
        """
        Creates a tensor representing labels for the batch.

        Parameters:
        - batch_shape (list): Shape of the tensor to be created.
        - batch_labels_position_encoded (list): List of encoded label positions for the batch.

        Returns:
        - torch.Tensor: Tensor representing labels for the batch.
        """
        if batch_shape[-1] > self.max_length:
            batch_shape[-1] = self.max_length
        labels_tensor = torch.zeros(batch_shape)

        for sample_id in range(batch_shape[0]):
            for label in batch_labels_position_encoded[sample_id]:
                label_id = self.labels_type_table[label[0]]
                start = label[1]
                end = label[2]

                if start >= self.max_length:
                    continue
                elif end >= self.max_length:
                    end = self.max_length

                labels_tensor[sample_id][start:end] = label_id

        return labels_tensor

    def collate_fn(self, batch_items: list):
        """
        Collates a batch of items during data loading.

        Parameters:
        - batch_items (list): List of items in the batch.

        Returns:
        - tuple: A tuple containing tokenized encodings, labels tensor, and original labels.
        """
        # the calculation process in dataloader iteration
        batch_medical_record = [sample[0] for sample in batch_items]
        batch_labels = [sample[1] for sample in batch_items]
        batch_id_list = [sample[2] for sample in batch_items]

        encodings = self.tokenizer(batch_medical_record, padding=True, truncation=True, return_tensors="pt",
                                   return_offsets_mapping=True)  # truncation=True

        batch_labels_position_encoded = self.encode_labels_position(batch_labels, encodings["offset_mapping"])
        batch_labels_tensor = self.create_labels_tensor(encodings["input_ids"].shape, batch_labels_position_encoded)

        return encodings, batch_labels_tensor, batch_labels

In [23]:
train_id_list = list(train_medical_record_dict.keys())
train_medical_record = {sample_id: train_medical_record_dict[sample_id] for sample_id in train_id_list}
train_labels = {sample_id: train_label_dict[sample_id] for sample_id in train_id_list}

val_id_list = list(val_medical_record_dict.keys())
val_medical_record = {sample_id: val_medical_record_dict[sample_id] for sample_id in val_id_list}
val_labels = {sample_id: val_label_dict[sample_id] for sample_id in val_id_list}

train_dataset = Privacy_protection_dataset(train_medical_record, train_labels, tokenizer, labels_type_table, "train")
val_dataset = Privacy_protection_dataset(val_medical_record, val_labels, tokenizer, labels_type_table, "validation")


train_dataloader = DataLoader( train_dataset, batch_size = BACH_SIZE, shuffle = True, collate_fn = train_dataset.collate_fn)
val_dataloader = DataLoader( val_dataset, batch_size = BACH_SIZE, shuffle = False, collate_fn = val_dataset.collate_fn)

### Training

In [26]:

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = model.to(device) # Put model on device
optim = AdamW(model.parameters(), lr = LEARNING_RATE)
#if use CRF
#loss_fct = CrossEntropyLoss()

In [27]:
from collections import defaultdict

def decode_model_result(model_predict_list, offsets_mapping, labels_type_table):
    """
    Decode the model predictions into labeled entities.

    Parameters:
    - model_predict_list (list): List of predicted label IDs.
    - offsets_mapping (list): List of offset mappings obtained from tokenization.
    - labels_type_table (dict): Dictionary mapping label types to their corresponding numerical identifiers.

    Returns:
    - list: List of decoded labeled entities.
    """
    id_to_label = {id: label for label, id in labels_type_table.items()}
    predict_y = []
    pre_label_id = 0
    start = 0

    for position_id, label_id in enumerate(model_predict_list):
        if label_id != 0:
            if pre_label_id != label_id:
                start = int(offsets_mapping[position_id][0])
            end = int(offsets_mapping[position_id][1])

        if pre_label_id != label_id and pre_label_id != 0:
            predict_y.append([id_to_label[pre_label_id], start, end])
        pre_label_id = label_id

    if pre_label_id != 0:
        predict_y.append([id_to_label[pre_label_id], start, end])

    return predict_y

def calculate_batch_score(batch_labels, model_predict_sequences, offset_mappings, labels_type_table):
    """
    Calculate TP, FP, and FN for each label in a batch.

    Parameters:
    - batch_labels (list): List of ground truth labels for a batch.
    - model_predict_sequences (list): List of predicted label sequences for a batch.
    - offset_mappings (list): List of offset mappings obtained from tokenization.
    - labels_type_table (dict): Dictionary mapping label types to their corresponding numerical identifiers.

    Returns:
    - dict: A nested dictionary containing TP, FP, and FN scores for each label.
    """
    score_table = defaultdict(lambda: {"TP": 0, "FP": 0, "FN": 0})
    id_to_label = {id: label for label, id in labels_type_table.items()}
    batch_size = len(model_predict_sequences)

    for batch_id in range(batch_size):
        sample_prediction = decode_model_result(model_predict_sequences[batch_id], offset_mappings[batch_id], labels_type_table)
        sample_ground_truth = batch_labels[batch_id]

        # convert ground truth and predictions to sets for comparison
        sample_ground_truth = set([tuple(token) for token in sample_ground_truth])
        sample_prediction = set([tuple(token) for token in sample_prediction])

        # calculate TP, FP, FN for each label
        for label_id in labels_type_table.values():
            label = id_to_label[label_id]
            gt_entities = {x for x in sample_ground_truth if x[0] == label}
            pred_entities = {x for x in sample_prediction if x[0] == label}

            score_table[label]["TP"] += len(gt_entities & pred_entities)
            score_table[label]["FP"] += len(pred_entities - gt_entities)
            score_table[label]["FN"] += len(gt_entities - pred_entities)

    return score_table

def calculate_macro_f1(score_table):
    """
    Calculate macro-averaged Precision, Recall, and F1 Score.

    Parameters:
    - score_table (dict): A nested dictionary containing TP, FP, and FN scores for each label.

    Returns:
    - tuple: Macro-averaged Precision, Recall, and F1 Score.
    """
    macro_precision = macro_recall = macro_f1 = 0
    num_labels = len(score_table)

    for label, scores in score_table.items():
        precision = scores["TP"] / (scores["TP"] + scores["FP"]) if scores["TP"] + scores["FP"] > 0 else 0
        recall = scores["TP"] / (scores["TP"] + scores["FN"]) if scores["TP"] + scores["FN"] > 0 else 0
        f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0

        macro_precision += precision
        macro_recall += recall
        macro_f1 += f1

    return macro_precision / num_labels, macro_recall / num_labels, macro_f1 / num_labels


In [28]:
"""
Train and validate a deep learning model for privacy protection using a custom dataset.

Parameters:
- EPOCH (int): Number of training epochs.
- model (torch.nn.Module): The deep learning model to be trained.
- train_dataloader (torch.utils.data.DataLoader): DataLoader for the training dataset.
- val_dataloader (torch.utils.data.DataLoader): DataLoader for the validation dataset.
- optim (torch.optim.Optimizer): The optimizer for updating model parameters during training.
- device (torch.device): Device on which the model and data reside (e.g., "cuda" or "cpu").
- labels_type_table (dict): A mapping of label names to corresponding numerical identifiers.
"""

train_step = 0
val_step = 0

for epoch in range(EPOCH):
    """
    Training loop for each epoch.

    Parameters:
    - epoch (int): Current epoch number.
    """

    model.train()
    total_train_loss = 0

    for batch_x, batch_y, batch_labels in train_dataloader:
        """
        Training iteration over the batches in the training dataset.

        Parameters:
        - batch_x (dict): Input features for the batch.
        - batch_y (torch.Tensor): Ground truth labels for the batch.
        - batch_labels (list): True label information for each token in the batch.

        Returns:
        - None
        """

        train_step += 1
        optim.zero_grad()
        input_ids = batch_x["input_ids"].to(device)
        attention_mask = batch_x["attention_mask"].to(device)
        labels = batch_y.long().to(device)

        loss = model(input_ids, attention_mask, labels)
        total_train_loss += loss.item()

        loss.backward()
        optim.step()

    avg_train_loss = total_train_loss / len(train_dataloader)
    print(f"Epoch {epoch} - Train Loss: {avg_train_loss:.6f}")

    model.eval()
    total_score_table = defaultdict(lambda: {"TP": 0, "FP": 0, "FN": 0})

    for batch_x, batch_y, batch_labels in val_dataloader:
        """
        Validation iteration over the batches in the validation dataset.

        Parameters:
        - batch_x (dict): Input features for the batch.
        - batch_y (torch.Tensor): Ground truth labels for the batch.
        - batch_labels (list): True label information for each token in the batch.

        Returns:
        - None
        """

        batch_x["input_ids"] = batch_x["input_ids"].to(device)
        batch_x["attention_mask"] = batch_x["attention_mask"].to(device)

        with torch.no_grad():
            model_predict_sequences = model(batch_x["input_ids"], batch_x["attention_mask"])
            batch_score_table = calculate_batch_score(
                batch_labels, model_predict_sequences, batch_x["offset_mapping"], labels_type_table
            )

            for label, scores in batch_score_table.items():
                for key in total_score_table[label]:
                    total_score_table[label][key] += scores[key]

    avg_precision, avg_recall, avg_macro_f1 = calculate_macro_f1(total_score_table)
    print(f"Epoch {epoch} - Macro Precision: {avg_precision:.6f}, Macro Recall: {avg_recall:.6f}, Macro F1 Score: {avg_macro_f1:.6f}")

    torch.save(model.state_dict(), "./model/" + "longformer" + "_" + str(epoch) + "_" + str(avg_macro_f1))


  score = torch.where(mask[i].unsqueeze(1), next_score, score)


Epoch 0 - Train Loss: 223.257924
Epoch 0 - Macro Precision: 0.576980, Macro Recall: 0.596096, Macro F1 Score: 0.584692
Epoch 1 - Train Loss: 21.471636
Epoch 1 - Macro Precision: 0.627141, Macro Recall: 0.639088, Macro F1 Score: 0.631008
Epoch 2 - Train Loss: 16.164490
Epoch 2 - Macro Precision: 0.661379, Macro Recall: 0.672359, Macro F1 Score: 0.664760
Epoch 3 - Train Loss: 15.499107
Epoch 3 - Macro Precision: 0.654168, Macro Recall: 0.710479, Macro F1 Score: 0.673221
Epoch 4 - Train Loss: 17.719697
Epoch 4 - Macro Precision: 0.633542, Macro Recall: 0.679608, Macro F1 Score: 0.652611
Epoch 5 - Train Loss: 11.341971
Epoch 5 - Macro Precision: 0.663300, Macro Recall: 0.691600, Macro F1 Score: 0.674654


KeyboardInterrupt: 

In [None]:
model.eval()
total_score_table = defaultdict(lambda: {"TP": 0, "FP": 0, "FN": 0})

for batch_x, batch_y, batch_labels in val_dataloader:
    batch_x["input_ids"] = batch_x["input_ids"].to(device)
    batch_x["attention_mask"] = batch_x["attention_mask"].to(device)

    with torch.no_grad():
        model_predict_sequences = model(batch_x["input_ids"], batch_x["attention_mask"])
        batch_score_table = calculate_batch_score(batch_labels, model_predict_sequences, batch_x["offset_mapping"], labels_type_table)
        for label, scores in batch_score_table.items():
            for key in total_score_table[label]:
                total_score_table[label][key] += scores[key]

avg_precision, avg_recall, avg_macro_f1 = calculate_macro_f1(total_score_table)
print(f"Epoch {epoch} - Macro Precision: {avg_precision:.6f}, Macro Recall: {avg_recall:.6f}, Macro F1 Score: {avg_macro_f1:.6f}")