In [None]:
# %%
"""
Data Download and Conversion to CoNLL Format.

This script downloads the CoNLL-2003 dataset from Hugging Face,
converts it to CoNLL format, and saves it to the data directory.

The script creates three files:
    - data/conll2003/eng.train: Training data in CoNLL format.
    - data/conll2003/eng.testa: Validation data in CoNLL format.
    - data/conll2003/eng.testb: Test data in CoNLL format.
"""
# Import necessary libraries
import os  # For operating system related tasks like creating directories
from datasets import load_dataset  # For loading datasets from Hugging Face

# Create directories to store the downloaded and processed data
# If the directories already exist, it won't raise an error
os.makedirs("data/conll2003", exist_ok=True)

# Print a message to indicate the download process
print("Downloading CoNLL-2003 dataset from HuggingFace...")

# Download the CoNLL-2003 dataset from Hugging Face
dataset = load_dataset("conll2003")

# Print a message to confirm the download completion
print("Download complete!")

# Save the downloaded data in CoNLL format
print("Converting to CoNLL format...")

# Iterate through the different splits of the dataset: train, validation, and test
for split in ["train", "validation", "test"]:
    # Define the output file path for each split
    # The file names are eng.train, eng.testa, and eng.testb for train, validation, and test sets respectively
    output_file = f"data/conll2003/eng.{'train' if split == 'train' else 'testa' if split == 'validation' else 'testb'}"

    # Print a message indicating the current split being processed
    print(f"Processing {split} set -> {output_file}")

    # Open the output file in write mode with UTF-8 encoding
    with open(output_file, "w", encoding="utf-8") as f:
        # Iterate through each example in the current split
        for example in dataset[split]:
            # Iterate through each token and its corresponding NER tag
            for token, tag in zip(example["tokens"], example["ner_tags"]):
                # Convert numeric NER tags to their corresponding string representations
                tag_str = "O"  # Default tag is "O" (Outside)
                if tag > 0:  # If the tag is not "O"
                    # Define a mapping from numeric tags to string tags
                    # This mapping is specific to the CoNLL-2003 dataset
                    tag_map = {
                        1: "B-PER", 2: "I-PER",  # Person tags
                        3: "B-ORG", 4: "I-ORG",  # Organization tags
                        5: "B-LOC", 6: "I-LOC",  # Location tags
                        7: "B-MISC", 8: "I-MISC"  # Miscellaneous tags
                    }
                    # Get the string representation of the tag
                    tag_str = tag_map[tag]

                # Write the token and its tag to the output file
                f.write(f"{token} {tag_str}\n")

            # Add an empty line to separate sentences
            f.write("\n")

# Print a success message
print("Dataset downloaded and converted to CoNLL format successfully!")
print("Files created:")
print("  - data/conll2003/eng.train")
print("  - data/conll2003/eng.testa")
print("  - data/conll2003/eng.testb")

In [None]:
# %%
"""
Data Preprocessing and Exploration.

This script reads the CoNLL-2003 dataset, performs data exploration,
and preprocesses the data for different NER models.

The script creates preprocessed data in various formats:
    - SpaCy format: Data formatted for use with SpaCy NER models.
    - BIO format: Data in BIO (Beginning, Inside, Outside) tagging format.
    - JSON format: Data stored in JSON format for easy access and manipulation.
    - Transformer format: Data preprocessed for use with transformer-based NER models.
"""
# Import necessary libraries
import os  # For operating system related tasks like creating directories
import matplotlib.pyplot as plt  # For creating visualizations
import seaborn as sns  # For creating statistical visualizations
import nltk  # For natural language processing tasks
from nltk.corpus import stopwords  # For removing common words
from nltk.stem import WordNetLemmatizer  # For reducing words to their base form
import pickle  # For saving and loading Python objects
import json  # For working with JSON data
import sys  # For system-specific parameters and functions
import subprocess  # For running external commands
import spacy  # For advanced natural language processing
from collections import Counter  # For counting the frequency of items
from typing import List, Dict, Tuple, Optional  # For type hinting

# Download NLTK resources needed for preprocessing
nltk.download('punkt', quiet=True)  # For sentence tokenization
nltk.download('punkt_tab', quiet=True)  # For tokenization with tab separation
nltk.download('stopwords', quiet=True)  # For removing common words
nltk.download('wordnet', quiet=True)  # For lemmatization

# Function to check and install the SpaCy model if not already installed
def ensure_spacy_model(model_name="en_core_web_sm"):
    """
    Ensure that the required SpaCy model is installed.

    Args:
        model_name: The name of the SpaCy model to check and install.
            Defaults to "en_core_web_sm".
    """
    try:
        # Attempt to load the SpaCy model
        spacy.load(model_name)
        # If successful, print a message indicating it's already installed
        print(f"SpaCy model '{model_name}' is already installed.")
    except OSError:
        # If loading fails, print a message indicating it's not found
        print(f"SpaCy model '{model_name}' not found. Installing...")
        # Install the SpaCy model using subprocess
        subprocess.check_call([sys.executable, "-m", "spacy", "download", model_name])
        # Print a message indicating the installation is complete
        print(f"SpaCy model '{model_name}' has been installed.")

# Ensure the SpaCy model is installed
ensure_spacy_model()

# Load the SpaCy model
nlp = spacy.load("en_core_web_sm")

# Function to read a CoNLL formatted file and extract sentences with their tags
def read_conll_file(file_path: str) -> List[List[Tuple[str, str]]]:
    """
    Read a CoNLL-2003 formatted file and return sentences with their tags.

    Args:
        file_path: The path to the CoNLL formatted file.

    Returns:
        A list of sentences, where each sentence is a list of (word, tag) tuples.
    """
    sentences = []  # Initialize an empty list to store sentences
    current_sentence = []  # Initialize an empty list to store the current sentence

    # Open the file in read mode with UTF-8 encoding
    with open(file_path, 'r', encoding='utf-8') as f:
        # Iterate through each line in the file
        for line in f:
            line = line.strip()  # Remove leading/trailing whitespace

            # Skip empty lines, comments, and metadata lines
            if not line or line.startswith('-DOCSTART-') or line.startswith('//'):
                # If the current sentence is not empty, add it to the list of sentences
                if current_sentence:
                    sentences.append(current_sentence)
                    current_sentence = []  # Reset the current sentence
                continue  # Move to the next line

            # Split the line by whitespace to get the word and tag
            parts = line.split()
            # Ensure there's at least a word and a tag
            if len(parts) >= 2:
                word = parts[0]  # The word is the first element
                tag = parts[-1]  # The tag is the last element
                current_sentence.append((word, tag))  # Add the (word, tag) tuple to the current sentence

    # Add the last sentence if it's not empty
    if current_sentence:
        sentences.append(current_sentence)

    # Return the list of sentences
    return sentences

# Function to explore the dataset and extract statistics
def explore_dataset(sentences: List[List[Tuple[str, str]]]) -> Dict:
    """
    Explore the dataset and return statistics.

    Args:
        sentences: A list of sentences, where each sentence is a list of (word, tag) tuples.

    Returns:
        A dictionary containing various statistics about the dataset.
    """
    # Calculate the number of sentences
    num_sentences = len(sentences)

    # Calculate the total number of words
    total_words = sum(len(sentence) for sentence in sentences)

    # Calculate the number of unique words
    unique_words = set()  # Use a set to store unique words
    for sentence in sentences:
        for word, _ in sentence:
            unique_words.add(word.lower())  # Add words in lowercase to the set
    num_unique_words = len(unique_words)

    # Count entity types and their occurrences
    entity_counts = Counter()  # Use a Counter to count entity types
    entity_length_distribution = {}  # Store entity length distributions
    current_entity = None  # Track the current entity being processed
    current_entity_length = 0  # Track the length of the current entity

    # Iterate through sentences and tokens to count entities and their lengths
    for sentence in sentences:
        for _, tag in sentence:
            if tag.startswith('B-'):  # Check for the beginning of an entity
                # If we were tracking an entity, finalize it
                if current_entity:
                    entity_length_distribution.setdefault(current_entity, []).append(current_entity_length)

                # Start tracking a new entity
                current_entity = tag[2:]  # Extract the entity type (remove 'B-')
                current_entity_length = 1  # Initialize the entity length
                entity_counts[current_entity] += 1  # Increment the entity count

            elif tag.startswith('I-'):  # Check for continuation of an entity
                # Continue the current entity if it matches
                if current_entity == tag[2:]:
                    current_entity_length += 1  # Increment the entity length

            else:  # 'O' tag (Outside)
                # If we were tracking an entity, finalize it
                if current_entity:
                    entity_length_distribution.setdefault(current_entity, []).append(current_entity_length)
                    current_entity = None  # Reset the current entity
                    current_entity_length = 0  # Reset the entity length

    # Calculate average entity length
    avg_entity_length = {}
    for entity, lengths in entity_length_distribution.items():
        # Calculate average length if there are lengths for the entity
        avg_entity_length[entity] = sum(lengths) / len(lengths) if lengths else 0

    # Calculate sentence length statistics
    sentence_lengths = [len(sentence) for sentence in sentences]  # Get lengths of all sentences
    # Calculate average sentence length
    avg_sentence_length = sum(sentence_lengths) / len(sentence_lengths) if sentence_lengths else 0
    # Calculate maximum sentence length
    max_sentence_length = max(sentence_lengths) if sentence_lengths else 0

    # Return the collected statistics in a dictionary
    return {
        'num_sentences': num_sentences,
        'total_words': total_words,
        'num_unique_words': num_unique_words,
        'entity_counts': dict(entity_counts),
        'avg_entity_length': avg_entity_length,
        'avg_sentence_length': avg_sentence_length,
        'max_sentence_length': max_sentence_length
    }

# Function to visualize the distribution of entity types
def visualize_entity_distribution(stats: Dict, save_path: Optional[str] = None):
    """
    Visualize the distribution of entity types using a bar chart.

    Args:
        stats: A dictionary containing dataset statistics, including entity counts.
        save_path: The path to save the visualization (optional). If None,
            the plot will be displayed instead of saved.
    """
    # Create a figure and axes for the plot
    plt.figure(figsize=(12, 6))

    # Extract entity counts from the statistics dictionary
    entity_counts = stats['entity_counts']
    entities = list(entity_counts.keys())  # Get the entity types
    counts = list(entity_counts.values())  # Get the corresponding counts

    # Create a bar chart
    plt.bar(entities, counts)

    # Set the title and labels for the plot
    plt.title('Distribution of Entity Types')
    plt.xlabel('Entity Type')
    plt.ylabel('Count')

    # Rotate x-axis labels for better readability
    plt.xticks(rotation=45)

    # Save or display the plot based on the save_path argument
    if save_path:
        plt.savefig(save_path, bbox_inches='tight')  # Save the plot to a file
    else:
        plt.show()  # Display the plot

    # Close the plot to release resources
    plt.close()

# Function to preprocess text using NLTK
def preprocess_text(text: str, remove_stopwords: bool = True, lemmatize: bool = True) -> str:
    """
    Preprocess text using NLTK for lowercasing, stopword removal, and lemmatization.

    Args:
        text: The input text to preprocess.
        remove_stopwords: Whether to remove stopwords (default: True).
        lemmatize: Whether to lemmatize words (default: True).

    Returns:
        The preprocessed text.
    """
    # Lowercase the text
    text = text.lower()

    # Tokenize the text into words
    tokens = nltk.word_tokenize(text)

    # Remove stopwords if requested
    if remove_stopwords:
        stop_words = set(stopwords.words('english'))  # Get a set of English stopwords
        tokens = [token for token in tokens if token not in stop_words]  # Filter out stopwords

    # Lemmatize words if requested
    if lemmatize:
        lemmatizer = WordNetLemmatizer()  # Create a lemmatizer object
        tokens = [lemmatizer.lemmatize(token) for token in tokens]  # Lemmatize each token

    # Join the tokens back into a string
    return ' '.join(tokens)

# Function to preprocess text using SpaCy
def preprocess_spacy(text: str, remove_stopwords: bool = True, lemmatize: bool = True) -> str:
    """
    Preprocess text using SpaCy for better tokenization and lemmatization.

    Args:
        text: The input text to preprocess.
        remove_stopwords: Whether to remove stopwords (default: True).
        lemmatize: Whether to lemmatize words (default: True).

    Returns:
        The preprocessed text.
    """
    # Process the text using the loaded SpaCy model
    doc = nlp(text)

    tokens = []  # Initialize a list to store processed tokens
    # Iterate through each token in the SpaCy document
    for token in doc:
        # Skip stopwords if requested
        if remove_stopwords and token.is_stop:
            continue

        # Use lemma if requested, otherwise use the original token text
        processed_token = token.lemma_ if lemmatize else token.text
        # Lowercase the token
        processed_token = processed_token.lower()

        # Add the processed token to the list
        tokens.append(processed_token)

    # Join the tokens back into a string
    return ' '.join(tokens)

# Function to convert CoNLL data to SpaCy format
def convert_to_spacy_format(sentences: List[List[Tuple[str, str]]], output_file: str):
    """
    Convert CoNLL data to spaCy format and save it to a pickle file.

    Args:
        sentences: A list of sentences, where each sentence is a list of (word, tag) tuples.
        output_file: The path to the output pickle file.
    """
    training_data = []  # Initialize a list to store training data in SpaCy format

    # Iterate through each sentence in the CoNLL data
    for sentence in sentences:
        words = [word for word, _ in sentence]  # Extract words from the sentence
        tags = [tag for _, tag in sentence]  # Extract tags from the sentence

        text = ' '.join(words)  # Join words to form the sentence text
        entities = []  # Initialize a list to store entity information

        # Extract entity spans and their types
        i = 0
        while i < len(tags):
            if tags[i].startswith('B-'):  # Check for the beginning of an entity
                entity_type = tags[i][2:]  # Extract the entity type
                start = i  # Store the start index of the entity
                end = i + 1  # Initialize the end index of the entity

                # Find the end of the entity
                while end < len(tags) and tags[end].startswith('I-') and tags[end][2:] == entity_type:
                    end += 1  # Extend the end index if the entity continues

                # Calculate character spans for the entity
                char_start = len(' '.join(words[:start]))
                if start > 0:
                    char_start += 1  # Add 1 for the space before the entity
                char_end = char_start + len(' '.join(words[start:end]))

                # Add the entity information (start, end, type) to the list
                entities.append((char_start, char_end, entity_type))
                i = end  # Move the index to the next token after the entity
            else:
                i += 1  # Move to the next token

        # Add the sentence text and entity information to the training data
        training_data.append((text, {'entities': entities}))

    # Save the training data to a pickle file
    with open(output_file, 'wb') as f:
        pickle.dump(training_data, f)

    # Return the training data
    return training_data

# Function to convert data to BIO format
def convert_to_bio_format(sentences: List[List[Tuple[str, str]]], output_file: str):
    """
    Convert data to BIO (Beginning, Inside, Outside) format and save it to a file.

    Args:
        sentences: A list of sentences, where each sentence is a list of (word, tag) tuples.
        output_file: The path to the output file.
    """
    # Open the output file in write mode with UTF-8 encoding
    with open(output_file, 'w', encoding='utf-8') as f:
        # Iterate through each sentence in the data
        for sentence in sentences:
            # Iterate through each word and tag in the sentence
            for word, tag in sentence:
                # Write the word and tag to the file in BIO format
                f.write(f"{word} {tag}\n")
            # Add an empty line to separate sentences
            f.write("\n")

# Function to convert data to JSON format
def convert_to_json_format(sentences: List[List[Tuple[str, str]]], output_file: str):
    """
    Convert data to JSON format and save it to a file.

    Args:
        sentences: A list of sentences, where each sentence is a list of (word, tag) tuples.
        output_file: The path to the output JSON file.
    """
    data = []  # Initialize a list to store the data in JSON format

    # Iterate through each sentence in the data
    for sentence in sentences:
        words = [word for word, _ in sentence]  # Extract words from the sentence
        tags = [tag for _, tag in sentence]  # Extract tags from the sentence

        # Create a dictionary representing the sentence with words, tokens, and tags
        data.append({
            'text': ' '.join(words),  # Join words to form the sentence text
            'tokens': words,  # List of words in the sentence
            'tags': tags  # List of tags corresponding to the words
        })

    # Open the output file in write mode with UTF-8 encoding
    with open(output_file, 'w', encoding='utf-8') as f:
        # Write the data to the JSON file with indentation for readability
        json.dump(data, f, indent=2)

# Function to preprocess data for transformer models
def preprocess_data_for_transformers(sentences: List[List[Tuple[str, str]]]):
    """
    Preprocess data specifically for transformer models like BERT.

    Args:
        sentences: A list of sentences, where each sentence is a list of (word, tag) tuples.

    Returns:
        A tuple containing two lists:
            - tokenized_texts: A list of tokenized sentences (list of words).
            - tags_list: A list of corresponding tags for each sentence.
    """
    tokenized_texts = []  # Initialize a list to store tokenized texts
    tags_list = []  # Initialize a list to store corresponding tags

    # Iterate through each sentence in the data
    for sentence in sentences:
        words = [word for word, _ in sentence]  # Extract words from the sentence
        tags = [tag for _, tag in sentence]  # Extract tags from the sentence

        # Append the words and tags to their respective lists
        tokenized_texts.append(words)
        tags_list.append(tags)

    # Return the tokenized texts and tags
    return tokenized_texts, tags_list

# Function to create a mapping of entity labels to IDs
def create_entity_labels_mapping(sentences: List[List[Tuple[str, str]]]):
    """
    Create a mapping of entity labels to numerical IDs.

    Args:
        sentences: A list of sentences, where each sentence is a list of (word, tag) tuples.

    Returns:
        A tuple containing two dictionaries:
            - tag_to_id: A dictionary mapping entity labels (tags) to numerical IDs.
            - id_to_tag: A dictionary mapping numerical IDs back to entity labels (tags).
    """
    unique_tags = set()  # Use a set to store unique entity labels

    # Iterate through sentences and tokens to collect unique tags
    for sentence in sentences:
        for _, tag in sentence:
            unique_tags.add(tag)

    # Sort the unique tags and create mappings
    # tag_to_id maps each tag to a unique numerical ID
    tag_to_id = {tag: i for i, tag in enumerate(sorted(list(unique_tags)))}
    # id_to_tag maps each numerical ID back to its corresponding tag
    id_to_tag = {i: tag for tag, i in tag_to_id.items()}

    # Return the tag-to-ID and ID-to-tag mappings
    return tag_to_id, id_to_tag

# Main function to execute the preprocessing steps
def main():
    """
    Main function to demonstrate data preprocessing functions.
    """
    # Define the data directory
    data_dir = "/content/

In [None]:
# %%
"""
Training Script for BERT-based NER Model.

This script trains a BERT-based Named Entity Recognition (NER) model
using the preprocessed CoNLL-2003 dataset. It leverages the Hugging Face
Transformers library for BERT and PyTorch for model training.

The script follows these key steps:
1. Data Loading: Loads preprocessed data and entity tag mappings.
2. Data Encoding: Encodes the dataset using the BERT tokenizer,
   handling subword tokenization and aligning tags.
3. Data Loading and Batching: Creates PyTorch DataLoaders for
   efficient training and validation data handling.
4. Model Initialization: Initializes the BERT model for token
   classification with the appropriate number of entity labels.
5. Training Setup: Defines the optimizer, learning rate scheduler,
   and loss function for model training.
6. Training Loop: Trains the model for a specified number of epochs,
   iterating over batches of data and updating model parameters.
7. Validation: Evaluates the model's performance on the validation
   set after each epoch, calculating metrics like precision, recall,
   F1-score, and accuracy.
8. Model Saving: Saves the trained model and tokenizer for later use.
"""

# Import necessary libraries and modules
import os  # For operating system interactions (e.g., creating directories)
import json  # For working with JSON data (e.g., loading tag mappings)
import pickle  # For object serialization (e.g., loading preprocessed data)
import numpy as np  # For numerical operations (e.g., array manipulation)
import torch  # For deep learning operations (e.g., tensors, model training)
from torch.utils.data import DataLoader, TensorDataset  # For data loading and batching
from torch.optim import AdamW  # For optimization (AdamW optimizer)
from torch.nn import CrossEntropyLoss  # For calculating loss (cross-entropy)
from sklearn.metrics import precision_recall_fscore_support, accuracy_score  # For evaluation metrics
from transformers import (  # For using pre-trained transformer models
    BertTokenizer,  # For tokenizing text with BERT
    BertForTokenClassification,  # For token classification tasks
    get_linear_schedule_with_warmup  # For learning rate scheduling
)
from typing import Dict, List, Tuple  # For type hinting (improved code readability)
from tqdm import tqdm  # For progress bars (visualizing training progress)
import matplotlib.pyplot as plt  # For creating visualizations (e.g., training curves)
import seaborn as sns  # For creating statistical visualizations (e.g., data distributions)

# Function to load preprocessed data and entity tag mappings
def load_preprocessed_data(data_dir: str) -> Tuple[Dict, Dict]:
    """
    Loads preprocessed data for transformer models and entity tag mappings.

    Args:
        data_dir (str): The directory containing the preprocessed data.

    Returns:
        Tuple[Dict, Dict]: A tuple containing the transformer data and tag mappings.
    """
    # Construct the path to the transformer data directory
    transformer_dir = os.path.join(data_dir, "processed", "transformer")

    # Load the transformer data from the pickle file
    with open(os.path.join(transformer_dir, "transformer_data.pickle"), "rb") as f:
        transformer_data = pickle.load(f)

    # Load the tag mappings from the JSON file
    with open(os.path.join(transformer_dir, "tag_mappings.json"), "r") as f:
        tag_mappings = json.load(f)

    # Return the transformer data and tag mappings
    return transformer_data, tag_mappings

# Function to encode the dataset using the BERT tokenizer
def encode_dataset(texts: List[List[str]], tags: List[List[str]],
                  tokenizer, tag_to_id: Dict, max_length: int = 128) -> Tuple:
    """
    Encodes the dataset using the BERT tokenizer, handling subword tokenization.

    Args:
        texts (List[List[str]]): List of tokenized sentences.
        tags (List[List[str]]): List of corresponding tags for each sentence.
        tokenizer: BERT tokenizer instance.
        tag_to_id (Dict): Mapping of entity labels to IDs.
        max_length (int, optional): Maximum sequence length for padding/truncation.
                                    Defaults to 128.

    Returns:
        Tuple: A tuple containing encoded input IDs, attention masks, and tag IDs.
    """
    # Initialize lists to store encoded data
    input_ids = []  # Store input IDs for BERT
    attention_masks = []  # Store attention masks for BERT
    tag_ids = []  # Store encoded tag IDs

    # Iterate through each sentence and its corresponding tags
    for sentence, sentence_tags in zip(texts, tags):
        # Tokenize the sentence using the BERT tokenizer
        encoded = tokenizer(
            sentence,
            is_split_into_words=True,  # Input is already tokenized
            add_special_tokens=True,  # Add [CLS] and [SEP] tokens
            max_length=max_length,  # Pad/truncate to max_length
            padding="max_length",  # Pad to max_length
            truncation=True,  # Truncate if longer than max_length
            return_attention_mask=True  # Return attention mask
        )

        # Get tokens, input IDs, and attention mask
        tokens = encoded.tokens()  # Get tokens from tokenizer output
        ids = encoded["input_ids"]  # Get input IDs
        mask = encoded["attention_mask"]  # Get attention mask

        # Align tags with subword tokens
        aligned_tags = []  # Store aligned tags
        current_word_idx = 0  # Track current word index in original sentence

        # Iterate over subword tokens
        for token_idx, token in enumerate(tokens):
            # Skip special tokens ([CLS], [SEP], [PAD])
            if token in ["[CLS]", "[SEP]", "[PAD]"]:
                aligned_tags.append("O")  # Assign "O" tag to special tokens
                continue

            # Check if token is a subword (starts with '##')
            if token.startswith("##"):
                # If subword, assign the same tag as the previous word
                aligned_tags.append(aligned_tags[-1])
            else:
                # If not a subword, assign the corresponding tag from the original sentence
                aligned_tags.append(sentence_tags[current_word_idx])
                # Move to the next word in the original sentence
                current_word_idx += 1

        # Convert tags to numerical IDs using the tag_to_id mapping
        aligned_tag_ids = [tag_to_id[tag] for tag in aligned_tags]

        # Append encoded data to lists
        input_ids.append(ids)  # Append input IDs to the list
        attention_masks.append(mask)  # Append attention mask to the list
        tag_ids.append(aligned_tag_ids)  # Append aligned tag IDs to the list

    # Convert lists to PyTorch tensors
    input_ids = torch.tensor(input_ids)  # Convert input IDs to tensor
    attention_masks = torch.tensor(attention_masks)  # Convert attention masks to tensor
    tag_ids = torch.tensor(tag_ids)  # Convert tag IDs to tensor

    # Return encoded data as tensors
    return input_ids, attention_masks, tag_ids  # Return encoded data

# Function to calculate metrics
def calculate_metrics(true_tags: List[List[str]], pred_tags: List[List[str]]) -> Dict:
    """
    Calculates precision, recall, F1-score, and accuracy.

    Args:
        true_tags (List[List[str]]): List of true tags for each sentence.
        pred_tags (List[List[str]]): List of predicted tags for each sentence.

    Returns:
        Dict: A dictionary containing precision, recall, F1-score, and accuracy.
    """
    # Flatten the lists for metric calculation
    flat_true_tags = [tag for sublist in true_tags for tag in sublist]  # Flatten true tags
    flat_pred_tags = [tag for sublist in pred_tags for tag in sublist]  # Flatten predicted tags

    # Calculate precision, recall, and F1-score using sklearn's metrics
    precision, recall, f1, _ = precision_recall_fscore_support(
        flat_true_tags, flat_pred_tags, average="weighted", zero_division=0  # Handle zero division
    )

    # Calculate accuracy
    accuracy = accuracy_score(flat_true_tags, flat_pred_tags)  # Calculate accuracy

    # Return the calculated metrics as a dictionary
    return {
        "precision": precision,  # Precision score
        "recall": recall,  # Recall score
        "f1": f1,  # F1 score
        "accuracy": accuracy  # Accuracy score
    }

# Function to train the BERT model
def train_model(data_dir: str, model_name: str = "bert-base-uncased",
              epochs: int = 3, batch_size: int = 32, learning_rate: float = 2e-5,
              max_length: int = 128) -> None:
    """
    Trains a BERT-based NER model.

    Args:
        data_dir (str): Directory containing the preprocessed data.
        model_name (str, optional): Name of the pre-trained BERT model to use.
                                     Defaults to "bert-base-uncased".
        epochs (int, optional): Number of training epochs. Defaults to 3.
        batch_size (int, optional): Batch size for training. Defaults to 32.
        learning_rate (float, optional): Learning rate for the optimizer.
                                         Defaults to 2e-5.
        max_length (int, optional): Maximum sequence length for padding/truncation.
                                    Defaults to 128.
    """

    # Load preprocessed data and tag mappings
    transformer_data, tag_mappings = load_preprocessed_data(data_dir)  # Load data and mappings
    tag_to_id = tag_mappings["tag_to_id"]  # Get tag-to-ID mapping
    id_to_tag = tag_mappings["id_to_tag"]  # Get ID-to-tag mapping

    # Initialize BERT tokenizer and model
    tokenizer = BertTokenizer.from_pretrained(model_name)  # Initialize tokenizer
    model = BertForTokenClassification.from_pretrained(
        model_name, num_labels=len(tag_to_id)  # Initialize model with num_labels
    )

    # Encode datasets
    train_inputs, train_masks, train_tags = encode_dataset(
        transformer_data["train"]["texts"],  # Training texts
        transformer_data["train"]["tags"],  # Training tags
        tokenizer, tag_to_id, max_length  # Tokenizer, tag mapping, max length
    )
    dev_inputs, dev_masks, dev_tags = encode_dataset(
        transformer_data["dev"]["texts"],  # Validation texts
        transformer_data["dev"]["tags"],  # Validation tags
        tokenizer, tag_to_id, max_length  # Tokenizer, tag mapping, max length
    )

    # Create data loaders
    train_dataset = TensorDataset(train_inputs, train_masks, train_tags)  # Create training dataset
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)  # Create training loader
    dev_dataset = TensorDataset(dev_inputs, dev_masks, dev_tags)  # Create validation dataset
    dev_loader = DataLoader(dev_dataset, batch_size=batch_size, shuffle=False)  # Create validation loader

    # Define optimizer, scheduler, and loss function
    optimizer = AdamW(model.parameters(), lr=learning_rate)  # Initialize AdamW optimizer
    total_steps = len(train_loader) * epochs  # Calculate total training steps
    scheduler = get_linear_schedule_with_warmup(
        optimizer, num_warmup_steps=0, num_training_steps=total_steps  # Initialize scheduler
    )
    loss_fn = CrossEntropyLoss()  # Initialize cross-entropy loss function

    # Training loop
    model.to("cuda" if torch.cuda.is_available() else "cpu")  # Move model to device (GPU if available)
    device = model.device  # Get device of the model

    # Iterate over epochs
    for epoch in range(epochs):
        model.train()  # Set model to training mode
        total_loss = 0  # Initialize total loss for the epoch

        # Iterate over batches in the training data loader using tqdm for progress bar
        for batch in tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}"):
            # Unpack batch data and move to device
            inputs, masks, tags = batch  # Unpack batch data
            inputs = inputs.to(device)  # Move inputs to device
            masks = masks.to(device)  # Move masks to device
            tags = tags.to(device)  # Move tags to device

            # Zero out gradients from previous step
            optimizer.zero_grad()  # Reset gradients

            # Perform forward pass
            outputs = model(inputs, attention_mask=masks)  # Get model outputs

            # Calculate loss
            loss = loss_fn(outputs.logits.view(-1, outputs.logits.shape[-1]), tags.view(-1))  # Calculate loss

            # Perform backward pass and update model parameters
            loss.backward()  # Calculate gradients
            optimizer.step()  # Update model parameters
            scheduler.step()  # Update learning rate

            # Accumulate total loss for the epoch
            total_loss += loss.item()  # Add batch loss to total loss

        # Calculate average loss for the epoch
        avg_loss = total_loss / len(train_loader)  # Calculate average loss
        print(f"Epoch {epoch + 1}/{epochs}, Average Loss: {avg_loss:.4f}")  # Print epoch loss

        # Evaluation on validation set
        model.eval()  # Set model to evaluation mode
        all_pred_tags = []  # Store all predicted tags
        all_true_tags = []  # Store all true tags

        # Disable gradient calculation during evaluation
        with torch.no_grad():
            # Iterate over batches in the validation data loader
            for batch in dev_loader:
                # Unpack batch data and move to device
                inputs, masks, tags = batch  # Unpack batch data
                inputs = inputs.to(device)  # Move inputs to device
                masks = masks.to(device)  # Move masks to device
                tags = tags.to(device)  # Move tags to device

                # Perform forward pass
                outputs = model(inputs, attention_mask=masks)  # Get model outputs

                # Get predicted tags
                predictions = torch.argmax(outputs.logits, dim=2)  # Get predicted tag IDs

                # Convert predicted and true tags to original format
                for i in range(inputs.shape[0]):
                    # Extract predicted tags for the current sentence
                    pred_tags = [id_to_tag[str(tag_id.item())]  # Convert tag ID to tag
                                  for tag_id in predictions[i][masks[i] == 1]  # Iterate over valid tokens
                                  if tag_id != tag_to_id['[PAD]']]  # Exclude padding tokens
                    # Extract true tags for the current sentence
                    true_tags = [id_to_tag[str(tag_id.item())]  # Convert tag ID to tag
                                 for tag_id in tags[i][masks[i] == 1]  # Iterate over valid tokens
                                 if tag_id != tag_to_id['[PAD]']]  # Exclude padding tokens

                    # Append predicted and true tags to the lists
                    all_pred_tags.append(pred_tags)  # Append predicted tags
                    all_true_tags.append(true_tags)  # Append true tags

        # Calculate and print evaluation metrics
        metrics = calculate_metrics(all_true_tags, all_pred_tags)  # Calculate metrics
        print(f"Validation Metrics: {metrics}")  # Print validation metrics

    # Save the trained model and tokenizer
    os.makedirs("models", exist_ok=True)  # Create 'models' directory if it doesn't exist
    model.save_pretrained("models/bert_ner")  # Save the trained model
    tokenizer.save_pretrained("models/bert_ner")  # Save the tokenizer
    print("Model and tokenizer saved to 'models/bert_ner'")  # Print save location

# Main execution block
if __name__ == "__main__":
    # Set the data directory
    data_dir = "data"  # Define data directory

    # Train the model
    train_model(data_dir)  # Call the train_model function to start training

In [None]:
!pip install seqeval

Collecting seqeval
  Downloading seqeval-1.2.2.tar.gz (43 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/43.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m1.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: seqeval
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
  Created wheel for seqeval: filename=seqeval-1.2.2-py3-none-any.whl size=16161 sha256=4898eadf8d4badda107c0c8d20f043be6d4ef0664b1c3be10e48aae8615e4abb
  Stored in directory: /root/.cache/pip/wheels/bc/92/f0/243288f899c2eacdfa8c5f9aede4c71a9bad0ee26a01dc5ead
Successfully built seqeval
Installing collected packages: seqeval
Successfully installed seqeval-1.2.2


In [None]:
import os
import json
import pickle
import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset
from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from sklearn.metrics import precision_recall_fscore_support, accuracy_score
from transformers import (
    BertTokenizer,
    BertForTokenClassification,
    get_linear_schedule_with_warmup
)
from typing import Dict, List, Tuple
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns

def load_preprocessed_data(data_dir: str) -> Tuple[Dict, Dict]:
    """
    Load preprocessed data for transformer models and tag mappings.

    Args:
        data_dir: Directory containing the preprocessed data

    Returns:
        Tuple of transformer data and tag mappings
    """
    transformer_dir = os.path.join(data_dir, "processed", "transformer")

    # Load transformer data
    with open(os.path.join(transformer_dir, "transformer_data.pickle"), "rb") as f:
        transformer_data = pickle.load(f)

    # Load tag mappings
    with open(os.path.join(transformer_dir, "tag_mappings.json"), "r") as f:
        tag_mappings = json.load(f)

    return transformer_data, tag_mappings

def encode_dataset(texts: List[List[str]], tags: List[List[str]],
                  tokenizer, tag_to_id: Dict, max_length: int = 128) -> Tuple:
    """
    Encode dataset using BERT tokenizer, handling subword tokenization.

    Args:
        texts: List of tokenized texts
        tags: List of tags for each token
        tokenizer: BERT tokenizer
        tag_to_id: Mapping from tags to IDs
        max_length: Maximum sequence length

    Returns:
        Tuple of encoded inputs, attention masks, and labels
    """
    input_ids = []
    attention_masks = []
    labels = []

    pad_token_id = tokenizer.pad_token_id
    pad_token_label_id = tag_to_id.get("O", -100)  # Use 'O' tag ID or -100 (ignored in loss)

    for sentence_tokens, sentence_tags in tqdm(zip(texts, tags), total=len(texts), desc="Encoding dataset"):
        # Tokenize each word and align tags
        bert_tokens = []
        bert_labels = []

        for word, tag in zip(sentence_tokens, sentence_tags):
            # Tokenize the word and count resulting tokens
            word_tokens = tokenizer.tokenize(word)

            # Add the tokenized word to the output
            bert_tokens.extend(word_tokens)

            # Add the label for the first token
            bert_labels.append(tag_to_id[tag])

            # Add padding label for remaining subword tokens
            bert_labels.extend([pad_token_label_id] * (len(word_tokens) - 1))

        # Truncate sequences if they're longer than max_length
        if len(bert_tokens) > max_length - 2:  # Account for [CLS] and [SEP]
            bert_tokens = bert_tokens[:max_length - 2]
            bert_labels = bert_labels[:max_length - 2]

        # Add [CLS] and [SEP] tokens and corresponding labels
        bert_tokens = [tokenizer.cls_token] + bert_tokens + [tokenizer.sep_token]
        bert_labels = [pad_token_label_id] + bert_labels + [pad_token_label_id]

        # Convert tokens to IDs
        token_ids = tokenizer.convert_tokens_to_ids(bert_tokens)

        # Calculate attention mask
        attention_mask = [1] * len(token_ids)

        # Pad sequences to max_length
        padding_length = max_length - len(token_ids)

        token_ids += [pad_token_id] * padding_length
        attention_mask += [0] * padding_length
        bert_labels += [pad_token_label_id] * padding_length

        input_ids.append(token_ids)
        attention_masks.append(attention_mask)
        labels.append(bert_labels)

    return torch.tensor(input_ids), torch.tensor(attention_masks), torch.tensor(labels)

def create_data_loaders(train_inputs, train_masks, train_labels,
                       val_inputs=None, val_masks=None, val_labels=None,
                       test_inputs=None, test_masks=None, test_labels=None,
                       batch_size=32, num_workers=4) -> Dict:
    """
    Create DataLoaders for training, validation, and testing.

    Args:
        *inputs, masks, labels: Tensor inputs for each dataset split
        batch_size: Batch size for DataLoaders
        num_workers: Number of worker processes for data loading

    Returns:
        Dictionary of DataLoaders
    """
    train_data = TensorDataset(train_inputs, train_masks, train_labels)
    train_loader = DataLoader(
        train_data,
        batch_size=batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=num_workers
    )

    loaders = {"train": train_loader}

    if val_inputs is not None:
        val_data = TensorDataset(val_inputs, val_masks, val_labels)
        val_loader = DataLoader(
            val_data,
            batch_size=batch_size,
            pin_memory=True,
            num_workers=num_workers
        )
        loaders["validation"] = val_loader

    if test_inputs is not None:
        test_data = TensorDataset(test_inputs, test_masks, test_labels)
        test_loader = DataLoader(
            test_data,
            batch_size=batch_size,
            pin_memory=True,
            num_workers=num_workers
        )
        loaders["test"] = test_loader

    return loaders

def train_model(model, data_loaders, optimizer, scheduler, device,
               num_epochs=3, evaluation_steps=100, id_to_tag=None,
               gradient_accumulation_steps=1):
    """
    Train the NER model and evaluate periodically.

    Args:
        model: BERT model for token classification
        data_loaders: Dictionary of DataLoaders
        optimizer: Optimizer for training
        scheduler: Learning rate scheduler
        device: Device to use for training
        num_epochs: Number of training epochs
        evaluation_steps: How often to evaluate on validation set
        id_to_tag: Mapping from IDs to tags for metrics calculation
        gradient_accumulation_steps: Number of steps to accumulate gradients

    Returns:
        Trained model and training history
    """
    if torch.cuda.is_available():
        # Enable cuDNN benchmarking for faster convolutions
        torch.backends.cudnn.benchmark = True

        # Print GPU info
        print(f"Using GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

    model.to(device)
    model.train()

    # Use mixed precision training with torch.cuda.amp
    scaler = torch.cuda.amp.GradScaler() if torch.cuda.is_available() else None

    loss_fn = CrossEntropyLoss(ignore_index=-100)

    # Initialize training history
    history = {
        "train_loss": [],
        "val_loss": [],
        "val_f1": [],
        "val_accuracy": []
    }

    # Training loop
    global_step = 0
    for epoch in range(num_epochs):
        total_train_loss = 0
        progress_bar = tqdm(data_loaders["train"], desc=f"Epoch {epoch+1}/{num_epochs}")

        for step, batch in enumerate(progress_bar):
            # Extract batch and move to device
            batch_inputs, batch_masks, batch_labels = [b.to(device) for b in batch]

            # Mixed precision training
            if scaler:
                with torch.cuda.amp.autocast():
                    # Forward pass
                    outputs = model(
                        input_ids=batch_inputs,
                        attention_mask=batch_masks,
                        labels=batch_labels
                    )

                    loss = outputs.loss / gradient_accumulation_steps

                # Backward pass with gradient scaling
                scaler.scale(loss).backward()

                # Gradient accumulation
                if (step + 1) % gradient_accumulation_steps == 0:
                    # Unscale gradients for clipping
                    scaler.unscale_(optimizer)

                    # Clip gradients to avoid explosion
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                    # Update parameters with gradient scaling
                    scaler.step(optimizer)
                    scaler.update()

                    # Update learning rate
                    scheduler.step()

                    # Clear gradients
                    model.zero_grad()

                    # Update global step
                    global_step += 1
            else:
                # Standard training without mixed precision
                # Forward pass
                outputs = model(
                    input_ids=batch_inputs,
                    attention_mask=batch_masks,
                    labels=batch_labels
                )

                loss = outputs.loss / gradient_accumulation_steps

                # Backward pass
                loss.backward()

                # Gradient accumulation
                if (step + 1) % gradient_accumulation_steps == 0:
                    # Clip gradients to avoid explosion
                    torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)

                    # Update parameters
                    optimizer.step()
                    scheduler.step()
                    model.zero_grad()

                    # Update global step
                    global_step += 1

            # Track loss (scaled back to original)
            total_train_loss += loss.item() * gradient_accumulation_steps
            progress_bar.set_postfix({"loss": loss.item() * gradient_accumulation_steps})

            # Evaluate periodically
            if global_step > 0 and global_step % evaluation_steps == 0 and "validation" in data_loaders:
                # Evaluate on validation set
                val_metrics = evaluate_model(model, data_loaders["validation"], device, id_to_tag)

                # Record metrics
                history["val_loss"].append(val_metrics["loss"])
                history["val_f1"].append(val_metrics["f1"])
                history["val_accuracy"].append(val_metrics["accuracy"])

                # Print progress
                print(f"\nStep {global_step}: Validation Loss: {val_metrics['loss']:.4f}, "
                      f"F1: {val_metrics['f1']:.4f}, Accuracy: {val_metrics['accuracy']:.4f}")

                # Back to training mode
                model.train()

        # Epoch-level statistics
        avg_train_loss = total_train_loss / len(data_loaders["train"])
        history["train_loss"].append(avg_train_loss)

        print(f"Epoch {epoch+1}/{num_epochs} - Average training loss: {avg_train_loss:.4f}")

        # Validate at the end of each epoch
        if "validation" in data_loaders:
            val_metrics = evaluate_model(model, data_loaders["validation"], device, id_to_tag)
            print(f"Validation - Loss: {val_metrics['loss']:.4f}, "
                  f"F1: {val_metrics['f1']:.4f}, Accuracy: {val_metrics['accuracy']:.4f}")

    return model, history

def evaluate_model(model, data_loader, device, id_to_tag=None):
    """
    Evaluate model on a dataset.

    Args:
        model: BERT model for token classification
        data_loader: DataLoader for evaluation
        device: Device to use for evaluation
        id_to_tag: Mapping from IDs to tags for metrics calculation

    Returns:
        Dictionary of evaluation metrics
    """
    model.eval()

    total_loss = 0
    loss_fn = CrossEntropyLoss(ignore_index=-100)

    true_labels = []
    predicted_labels = []

    with torch.no_grad():
        for batch in tqdm(data_loader, desc="Evaluating"):
            batch_inputs, batch_masks, batch_labels = [b.to(device) for b in batch]

            # Use mixed precision for evaluation too if available
            if torch.cuda.is_available():
                with torch.cuda.amp.autocast():
                    outputs = model(
                        input_ids=batch_inputs,
                        attention_mask=batch_masks
                    )

                    # Calculate loss
                    active_loss = batch_masks.view(-1) == 1
                    active_logits = outputs.logits.view(-1, model.config.num_labels)
                    active_labels = torch.where(
                        active_loss,
                        batch_labels.view(-1),
                        torch.tensor(-100).type_as(batch_labels)
                    )

                    loss = loss_fn(active_logits, active_labels)
            else:
                outputs = model(
                    input_ids=batch_inputs,
                    attention_mask=batch_masks
                )

                # Calculate loss
                active_loss = batch_masks.view(-1) == 1
                active_logits = outputs.logits.view(-1, model.config.num_labels)
                active_labels = torch.where(
                    active_loss,
                    batch_labels.view(-1),
                    torch.tensor(-100).type_as(batch_labels)
                )

                loss = loss_fn(active_logits, active_labels)

            total_loss += loss.item()

            # Get predictions
            logits = outputs.logits
            batch_preds = torch.argmax(logits, dim=2)

            # Convert to CPU and numpy for metric calculation
            labels = batch_labels.detach().cpu().numpy()
            preds = batch_preds.detach().cpu().numpy()
            mask = batch_masks.detach().cpu().numpy()

            # Collect only non-padding and non-special tokens (where label != -100)
            for i in range(labels.shape[0]):
                for j in range(labels.shape[1]):
                    if labels[i, j] != -100 and mask[i, j] == 1:
                        true_labels.append(labels[i, j])
                        predicted_labels.append(preds[i, j])

    # Calculate metrics
    metrics = {
        "loss": total_loss / len(data_loader)
    }

    # Add more detailed metrics if id_to_tag is provided
    if len(true_labels) > 0:
        # Convert IDs back to string labels for better interpretability
        if id_to_tag:
            id_to_tag = {int(k): v for k, v in id_to_tag.items()}
            true_tags = [id_to_tag.get(label, "O") for label in true_labels]
            pred_tags = [id_to_tag.get(label, "O") for label in predicted_labels]

            # Filter out "O" tag for entity-level metrics
            entity_true = [label for label in true_tags if label != "O"]
            entity_pred = [pred_tags[i] for i, label in enumerate(true_tags) if label != "O"]

            # Entity-level metrics
            entity_precision, entity_recall, entity_f1, _ = precision_recall_fscore_support(
                entity_true, entity_pred, average='micro', zero_division=0
            )

            metrics["entity_precision"] = entity_precision
            metrics["entity_recall"] = entity_recall
            metrics["entity_f1"] = entity_f1

        # Token-level metrics (all tokens including "O")
        precision, recall, f1, _ = precision_recall_fscore_support(
            true_labels, predicted_labels, average='micro', zero_division=0
        )
        accuracy = accuracy_score(true_labels, predicted_labels)

        metrics.update({
            "precision": precision,
            "recall": recall,
            "f1": f1,
            "accuracy": accuracy
        })

    return metrics

def visualize_training_history(history, save_path=None):
    """
    Visualize training history.

    Args:
        history: Dictionary containing training history
        save_path: Path to save the plot (optional)
    """
    plt.figure(figsize=(15, 5))

    # Plot loss
    plt.subplot(1, 2, 1)
    plt.plot(history["train_loss"], label="Training Loss")
    if "val_loss" in history and history["val_loss"]:
        plt.plot(history["val_loss"], label="Validation Loss")
    plt.title("Loss During Training")
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.legend()

    # Plot metrics
    plt.subplot(1, 2, 2)
    if "val_f1" in history and history["val_f1"]:
        plt.plot(history["val_f1"], label="F1 Score")
    if "val_accuracy" in history and history["val_accuracy"]:
        plt.plot(history["val_accuracy"], label="Accuracy")
    plt.title("Metrics During Training")
    plt.xlabel("Evaluation Step")
    plt.ylabel("Score")
    plt.legend()

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path)
    else:
        plt.show()

    plt.close()

def main():
    """
    Main function to train and evaluate a BERT-based NER model.
    """
    # Set device (GPU if available)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Optimize GPU settings
    if torch.cuda.is_available():
        # Set optimal GPU memory allocation
        torch.cuda.empty_cache()
        torch.cuda.reset_max_memory_allocated()
        torch.cuda.set_device(0)

        # Print GPU info
        print(f"GPU: {torch.cuda.get_device_name(0)}")
        print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

    # Set the number of worker processes based on CPU cores
    num_workers = min(os.cpu_count(), 8) if os.cpu_count() else 4
    print(f"Using {num_workers} dataloader workers")

    # Load preprocessed data
    data_dir = "/content/data"
    transformer_data, tag_mappings = load_preprocessed_data(data_dir)

    # Get tag mappings
    tag_to_id = tag_mappings["tag_to_id"]
    id_to_tag = tag_mappings["id_to_tag"]

    # Extract the datasets - use smaller datasets for faster training
    train_texts = transformer_data["train"]["texts"][:5000]  # Use subset of training data
    train_tags = transformer_data["train"]["tags"][:5000]
    dev_texts = transformer_data["dev"]["texts"]
    dev_tags = transformer_data["dev"]["tags"]
    test_texts = transformer_data["test"]["texts"]
    test_tags = transformer_data["test"]["tags"]

    print(f"Training set: {len(train_texts)} examples")
    print(f"Validation set: {len(dev_texts)} examples")
    print(f"Test set: {len(test_texts)} examples")

    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained("bert-base-cased")

    # Encode datasets
    print("Encoding datasets...")
    max_length = 128  # Adjust based on your dataset

    train_inputs, train_masks, train_labels = encode_dataset(
        train_texts, train_tags, tokenizer, tag_to_id, max_length
    )

    dev_inputs, dev_masks, dev_labels = encode_dataset(
        dev_texts, dev_tags, tokenizer, tag_to_id, max_length
    )

    test_inputs, test_masks, test_labels = encode_dataset(
        test_texts, test_tags, tokenizer, tag_to_id, max_length
    )

    print("Datasets encoded successfully!")

    # Set optimized hyperparameters (based on common best practices)
    best_params = {
        "learning_rate": 3e-5,
        "batch_size": 32,  # Larger batch size for faster training
        "weight_decay": 0.01,
        "gradient_accumulation_steps": 2  # Accumulate gradients for larger effective batch
    }

    print(f"\nUsing hyperparameters: {best_params}")

    # Create data loaders for training
    batch_size = best_params["batch_size"]
    data_loaders = create_data_loaders(
        train_inputs, train_masks, train_labels,
        dev_inputs, dev_masks, dev_labels,
        test_inputs, test_masks, test_labels,
        batch_size=batch_size,
        num_workers=num_workers
    )

    # Initialize model
    print("\nInitializing model...")
    model = BertForTokenClassification.from_pretrained(
        "bert-base-cased",
        num_labels=len(tag_to_id)
    )

    # Set up optimizer and scheduler
    learning_rate = best_params["learning_rate"]
    weight_decay = best_params["weight_decay"]
    gradient_accumulation_steps = best_params["gradient_accumulation_steps"]

    optimizer = AdamW(
        model.parameters(),
        lr=learning_rate,
        weight_decay=weight_decay
    )

    # Use fewer epochs for faster training
    num_epochs = 3
    total_steps = len(data_loaders["train"]) * num_epochs // gradient_accumulation_steps
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=int(0.1 * total_steps),
        num_training_steps=total_steps
    )

    # Train the model
    print("\nTraining model...")
    model, history = train_model(
        model, data_loaders, optimizer, scheduler,
        device, num_epochs=num_epochs,
        evaluation_steps=len(data_loaders["train"]) // 2,  # Evaluate twice per epoch
        id_to_tag=id_to_tag,
        gradient_accumulation_steps=gradient_accumulation_steps
    )

    # Visualize training history
    print("Visualizing training history...")
    os.makedirs(os.path.join(data_dir, "visualizations"), exist_ok=True)
    visualize_training_history(
        history,
        save_path=os.path.join(data_dir, "visualizations", "bert_training_history.png")
    )

    # Evaluate on test set
    print("\nEvaluating on test set...")
    test_metrics = evaluate_model(model, data_loaders["test"], device, id_to_tag)

    print("\nTest Set Metrics:")
    print(f"Loss: {test_metrics['loss']:.4f}")
    print(f"Precision: {test_metrics['precision']:.4f}")
    print(f"Recall: {test_metrics['recall']:.4f}")
    print(f"F1 Score: {test_metrics['f1']:.4f}")
    print(f"Accuracy: {test_metrics['accuracy']:.4f}")

    if "entity_f1" in test_metrics:
        print(f"\nEntity-Level Metrics:")
        print(f"Precision: {test_metrics['entity_precision']:.4f}")
        print(f"Recall: {test_metrics['entity_recall']:.4f}")
        print(f"F1 Score: {test_metrics['entity_f1']:.4f}")

    # Save the model
    print("\nSaving model...")
    output_dir = os.path.join(data_dir, "models")
    os.makedirs(output_dir, exist_ok=True)

    model_path = os.path.join(output_dir, "bert_ner_model")
    model.save_pretrained(model_path)
    tokenizer.save_pretrained(model_path)

    # Save model configuration and mappings
    with open(os.path.join(model_path, "tag_mappings.json"), "w") as f:
        json.dump(tag_mappings, f, indent=2)

    # Save test metrics
    with open(os.path.join(model_path, "test_metrics.json"), "w") as f:
        json.dump({k: float(v) for k, v in test_metrics.items()}, f, indent=2)

    print(f"\nModel saved at {model_path}")
    print("\nDone!")

if __name__ == "__main__":
    main()

Using device: cpu
Using 2 dataloader workers
Training set: 5000 examples
Validation set: 3250 examples
Test set: 3453 examples
Encoding datasets...


Encoding dataset: 100%|██████████| 5000/5000 [00:08<00:00, 599.35it/s]
Encoding dataset: 100%|██████████| 3250/3250 [00:07<00:00, 461.18it/s]
Encoding dataset:  31%|███       | 1060/3453 [00:00<00:01, 1328.80it/s]


KeyboardInterrupt: 