In [3]:
# Install necessary packages, including seqeval for evaluation
!pip install torch transformers datasets wandb evaluate seqeval

import torch
from transformers import BertTokenizerFast, BertForTokenClassification, Trainer, TrainingArguments
from datasets import load_dataset
import wandb
from evaluate import load
import os

# Set up your W&B API Key using an input prompt (better security practice)
# You will be prompted to enter the key interactively
os.environ["WANDB_API_KEY"] = input("Enter your W&B API key: ")
wandb.login(key=os.getenv("WANDB_API_KEY"))

# Initialize W&B for tracking
wandb.init(project="ner-bert", reinit=True)  # Ensure 'ner-bert' is your desired project name

# Load the dataset
dataset = load_dataset("conll2003")

# Load pre-trained BERT tokenizer
tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")

# Tokenize the dataset
def tokenize_function(examples):
    tokenized_inputs = tokenizer(examples['tokens'],
                                 padding="max_length",
                                 truncation=True,
                                 is_split_into_words=True)

    # Align the labels with tokenized inputs
    labels = []
    for i, label in enumerate(examples['ner_tags']):
        word_ids = tokenized_inputs.word_ids(i)
        label_ids = [-100 if word_id is None else label[word_id] for word_id in word_ids]
        labels.append(label_ids)

    tokenized_inputs["labels"] = labels
    return tokenized_inputs

# Tokenize the entire dataset and add labels
tokenized_datasets = dataset.map(tokenize_function, batched=True)

# Load the model
model = BertForTokenClassification.from_pretrained("bert-base-uncased", num_labels=9)

# Set the label mappings (important for compute_metrics)
label_list = dataset["train"].features["ner_tags"].feature.names
model.config.id2label = {i: label for i, label in enumerate(label_list)}
model.config.label2id = {label: i for i, label in enumerate(label_list)}

# Load evaluation metric (Seqeval)
metric = load("seqeval")

# Define the compute metrics function
def compute_metrics(p):
    predictions, labels = p
    predictions = predictions.argmax(axis=2)

    true_predictions = []
    true_labels = []

    for prediction, label in zip(predictions, labels):
        pred_labels = []
        true_lbls = []
        for p, l in zip(prediction, label):
            if l != -100:
                pred_labels.append(model.config.id2label[p])
                true_lbls.append(model.config.id2label[l])
        true_predictions.append(pred_labels)
        true_labels.append(true_lbls)

    results = metric.compute(predictions=true_predictions, references=true_labels)
    return results

# Split the dataset
train_dataset = tokenized_datasets["train"]
valid_dataset = tokenized_datasets["validation"]
test_dataset = tokenized_datasets["test"]

# Adjusted Training Arguments for Colab
training_args = TrainingArguments(
    output_dir="./results",              # Output directory
    num_train_epochs=3,                  # Number of training epochs
    per_device_train_batch_size=8,       # Adjusted for Colab's limited GPU memory
    per_device_eval_batch_size=16,       # Adjusted for Colab
    warmup_steps=500,                    # Number of warmup steps for learning rate scheduler
    weight_decay=0.01,                   # Strength of weight decay
    logging_dir="./logs",                # Directory for storing logs
    logging_steps=50,                    # Log every 50 steps
    evaluation_strategy="epoch",         # Evaluate at each epoch
    save_strategy="epoch",               # Save model each epoch
    load_best_model_at_end=True,         # Load the best model at the end of training
    report_to="wandb"                     # Enable W&B tracking
)

# Initialize Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    data_collator=None,                  # Default: padding handled by tokenizer
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=valid_dataset
)

# Train the model
trainer.train()

# Evaluate the model on the validation set
eval_results = trainer.evaluate()

# Print the evaluation results to understand the available keys
print(f"Evaluation Results Keys: {list(eval_results.keys())}")
print(f"Evaluation Results: {eval_results}")

# Log the evaluation results to W&B
# Use the correct key names based on eval_results
# Typically, metrics from compute_metrics are prefixed with 'eval_'
wandb.log({
    "eval_loss": eval_results.get("eval_loss"),
    "eval_f1": eval_results.get("eval_f1")
})

# Evaluate the model on the test set
test_results = trainer.evaluate(test_dataset)

# Print test results
print(f"Test Results: {test_results}")

# Optionally log test results to W&B
wandb.log({
    "test_loss": test_results.get("eval_loss"),
    "test_f1": test_results.get("eval_f1")
})

# Finish W&B logging
wandb.finish()


Enter your W&B API key: 24401e166a0827939798330887e78d1b02e3a816




0,1
eval/loss,▁▁
eval/overall_accuracy,▁▁
eval/overall_f1,▁▁
eval/overall_precision,▁▁
eval/overall_recall,▁▁
eval/runtime,▁█
eval/samples_per_second,█▁
eval/steps_per_second,█▁
train/epoch,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇██████
train/global_step,▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▆▆▆▆▆▇▇▇▇▇██████

0,1
eval/loss,0.05827
eval/overall_accuracy,0.98472
eval/overall_f1,0.94608
eval/overall_precision,0.9425
eval/overall_recall,0.94967
eval/runtime,110.4578
eval/samples_per_second,29.423
eval/steps_per_second,1.847
total_flos,3669099951393792.0
train/epoch,1.0




Map:   0%|          | 0/14041 [00:00<?, ? examples/s]

Map:   0%|          | 0/3250 [00:00<?, ? examples/s]

Map:   0%|          | 0/3453 [00:00<?, ? examples/s]

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Loc,Misc,Org,Per,Overall Precision,Overall Recall,Overall F1,Overall Accuracy
1,0.0786,0.070253,"{'precision': 0.9309320460453027, 'recall': 0.9576012223071046, 'f1': 0.9440783279984938, 'number': 2618}","{'precision': 0.8603202846975089, 'recall': 0.7855402112103981, 'f1': 0.8212314225053078, 'number': 1231}","{'precision': 0.8175909878682842, 'recall': 0.9178015564202334, 'f1': 0.8648029330889092, 'number': 2056}","{'precision': 0.9838199085473092, 'recall': 0.9218852999340804, 'f1': 0.9518461800238216, 'number': 3034}",0.909679,0.91263,0.911152,0.980777
2,0.0209,0.059609,"{'precision': 0.9438285291943829, 'recall': 0.975553857906799, 'f1': 0.9594290007513148, 'number': 2618}","{'precision': 0.8733552631578947, 'recall': 0.8627132412672623, 'f1': 0.868001634654679, 'number': 1231}","{'precision': 0.920371275036639, 'recall': 0.9163424124513618, 'f1': 0.9183524250548378, 'number': 2056}","{'precision': 0.9747871643745907, 'recall': 0.9812129202373104, 'f1': 0.9779894875164258, 'number': 3034}",0.939488,0.948316,0.943882,0.986385
3,0.022,0.060626,"{'precision': 0.9616552771450265, 'recall': 0.9675324675324676, 'f1': 0.9645849200304646, 'number': 2618}","{'precision': 0.8778501628664495, 'recall': 0.875710804224208, 'f1': 0.8767791785278568, 'number': 1231}","{'precision': 0.9089193015573384, 'recall': 0.9367704280155642, 'f1': 0.9226347305389222, 'number': 2056}","{'precision': 0.9781890284203569, 'recall': 0.975609756097561, 'f1': 0.9768976897689768, 'number': 3034}",0.943377,0.950554,0.946952,0.9871


Evaluation Results Keys: ['eval_loss', 'eval_LOC', 'eval_MISC', 'eval_ORG', 'eval_PER', 'eval_overall_precision', 'eval_overall_recall', 'eval_overall_f1', 'eval_overall_accuracy', 'eval_runtime', 'eval_samples_per_second', 'eval_steps_per_second', 'epoch']
Evaluation Results: {'eval_loss': 0.059609074145555496, 'eval_LOC': {'precision': 0.9438285291943829, 'recall': 0.975553857906799, 'f1': 0.9594290007513148, 'number': 2618}, 'eval_MISC': {'precision': 0.8733552631578947, 'recall': 0.8627132412672623, 'f1': 0.868001634654679, 'number': 1231}, 'eval_ORG': {'precision': 0.920371275036639, 'recall': 0.9163424124513618, 'f1': 0.9183524250548378, 'number': 2056}, 'eval_PER': {'precision': 0.9747871643745907, 'recall': 0.9812129202373104, 'f1': 0.9779894875164258, 'number': 3034}, 'eval_overall_precision': 0.939487975174554, 'eval_overall_recall': 0.9483163664839468, 'eval_overall_f1': 0.9438815276695246, 'eval_overall_accuracy': 0.986385371820738, 'eval_runtime': 108.6411, 'eval_samples_p

0,1
eval/loss,▂▁▁▁█
eval/overall_accuracy,▄███▁
eval/overall_f1,▃███▁
eval/overall_precision,▃▇█▇▁
eval/overall_recall,▂███▁
eval/runtime,▁▁▃▂█
eval/samples_per_second,▆▆▁▄█
eval/steps_per_second,▇█▁▄█
eval_loss,▁
test_loss,▁

0,1
eval/loss,0.13535
eval/overall_accuracy,0.97623
eval/overall_f1,0.90064
eval/overall_precision,0.89709
eval/overall_recall,0.90423
eval/runtime,114.1929
eval/samples_per_second,30.238
eval/steps_per_second,1.892
eval_loss,0.05961
test_loss,0.13535
