<a href="https://colab.research.google.com/github/Zain-Haider-ML/Fine-Tune-LukeForEntitySpanClassification/blob/main/LukeForEntitySpanClassification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install datasets
!pip install seqeval git+https://github.com/huggingface/transformers.git

In [19]:
import pandas as pd
import spacy
from tqdm import tqdm, trange
import numpy as np
import copy
import torch
import seqeval.metrics
from transformers import AutoTokenizer, LukeForEntitySpanClassification, Trainer, TrainingArguments, LukeConfig
from datasets import Dataset

In [2]:
ds = pd.read_json('/content/train_data__news.jsonl', lines=True)
ds.head()

Unnamed: 0,text,spans,date,annotator,batch,ind
0,"HanAra Software , a subsidiary of South Korean...","[{'start': 0, 'end': 15, 'label': 'ORG'}, {'st...",2024-10-16,alex,ner-news-3,1168
1,STMI will provide technical assistance and tra...,"[{'start': 0, 'end': 4, 'label': 'ORG'}, {'sta...",2024-10-29,alex-feher,ner-news-5,2470
2,’s infrastructure business and working togethe...,"[{'start': 51, 'end': 71, 'label': 'ORG'}]",2024-10-25,alex,ner-news-5,1969
3,"In this partnership, AkinovA and Benchmark Lab...","[{'start': 21, 'end': 28, 'label': 'ORG'}, {'s...",2024-10-29,alex-feher,ner-news-5,1862
4,About Sarnova and Bound Tree Medical Sarnova a...,"[{'start': 6, 'end': 13, 'label': 'ORG'}, {'st...",2024-10-18,feher,ner-news-3,1075


In [3]:
# Define the label-to-index mapping

label_to_index = {
    'NIL':0,
    "LOC": 1,
    "JOB": 2,
    "MONEY": 3,
    "ORG": 4,
    "PERSON": 5,
    'DATE':6,
}

index_to_label = {
    0: "NIL",
    1: "LOC",
    2: "JOB",
    3: "MONEY",
    4: "ORG",
    5: "PERSON",
    6: "DATE",
}





def extract_entity_spans_and_labels(ds, label_to_index, num_samples=25):
    """
    Extracts start and end positions of entities and their corresponding labels from the dataset.

    Args:
        ds (dict): A dataset containing spans, where each span is a list of entities with 'start', 'end', and 'label' keys.
        label_to_index (dict): A dictionary mapping label names to their respective indices.
        num_samples (int): The number of spans to process from the dataset.

    Returns:
        entity_spans_all (list): A list containing lists of (start, end) pairs for each span.
        all_labels (list): A list containing lists of label indices for each span.
    """
    entity_spans_all = []  # List to store spans for each span in the dataset
    all_labels = []  # List to hold the labels for each span

    # Iterate over the specified number of spans in the dataset
    for span in ds['spans'][:num_samples]:
        entity_spans = []  # Temporary list to hold (start, end) pairs for the current span
        labels = []  # Temporary list to hold labels for the current span

        # Extract the start, end positions and labels from the current span
        for i in span:
            # Extract start and end positions
            start_pos, end_pos = i['start'], i['end']
            entity_spans.append((start_pos, end_pos))

            # Extract and map the label to its index
            labels.append(label_to_index[i['label']])

        # Append the list of (start, end) pairs for this span to entity_spans_all
        entity_spans_all.append(entity_spans)

        # Append the list of label indices for this span to all_labels
        all_labels.append(labels)

    return entity_spans_all, all_labels


In [5]:
num_samples = 500
val_num_samples = 110

entity_spans_all, all_labels  = extract_entity_spans_and_labels(ds, label_to_index, num_samples + val_num_samples)
print(entity_spans_all[:2], all_labels[:2])

[[(0, 15), (53, 67), (82, 88)], [(0, 4), (73, 113), (122, 127), (139, 154), (158, 164)]] [[4, 4, 1], [4, 4, 1, 1, 1]]


In [6]:
texts = list(ds['text'][:num_samples + val_num_samples])

len(texts), len(entity_spans_all), len(all_labels)

(610, 610, 610)

In [7]:
texts[0], entity_spans_all[0], all_labels[0], len(texts[0])

('HanAra Software , a subsidiary of South Korean-based BNF Technology, has selected Austin as its North American headquarters and plans to triple in size.',
 [(0, 15), (53, 67), (82, 88)],
 [4, 4, 1],
 152)

In [8]:
nlp = spacy.load("en_core_web_sm")

# Process each text
all_token_labels = []  # Store token labels for each text
all_token_spans = []   # Store token spans for each text

for text, entity_spans, labels in zip(texts, entity_spans_all, all_labels):
    doc = nlp(text)  # Tokenize the text
    token_labels = [label_to_index['NIL']] * len(doc)  # Initialize NIL labels
    token_spans = [(token.idx, token.idx + len(token)) for token in doc]  # Store spans

    # Assign entity labels to corresponding tokens
    for (start, end), label in zip(entity_spans, labels):
        for token in doc:
            if token.idx >= start and token.idx < end:
                token_labels[token.i] = label  # Assign correct label

    all_token_labels.append(token_labels)  # Save labels
    all_token_spans.append(token_spans)    # Save spans

In [9]:
len(all_token_labels), len(all_token_spans), len(texts)

(610, 610, 610)

In [10]:
data = {
    "text": texts[:num_samples],
    "entity_spans": all_token_spans[:num_samples],
    "labels": all_token_labels[:num_samples],
}

len(data['text']), len(data['entity_spans']), len(data['labels'])

(500, 500, 500)

In [11]:
val_data = {
    "text": texts[num_samples:],
    "entity_spans": all_token_spans[num_samples:],
    "labels": all_token_labels[num_samples:],
}

len(val_data['text']), len(val_data['entity_spans']), len(val_data['labels'])

(110, 110, 110)

In [12]:
for l, i in zip(data['labels'], data['entity_spans']):
  # print(len(l), len(i))
  if len(l) != len(i):
    print(len(l), len(i))

for l, i in zip(val_data['labels'], val_data['entity_spans']):
  # print(len(l), len(i))
  if len(l) != len(i):
    print(len(l), len(i))

In [13]:
def pad_labels(labels, max_len, pad_value=-100):
    return [lbl + [pad_value] * (max_len - len(lbl)) for lbl in labels]

In [14]:
max_entities = max(len(spans) for spans in data["entity_spans"])
val_max_entities = max(len(spans) for spans in val_data["entity_spans"])

print(max_entities, val_max_entities)

261 229


In [15]:
num_labels = len(label_to_index)
print('num_labels = ', num_labels)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device = ', device)

num_labels =  7
device =  cuda


In [16]:
config = LukeConfig.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003",
                                    label2id=label_to_index, id2label=index_to_label,
                                    num_labels=num_labels)

tokenizer = AutoTokenizer.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003")
model = LukeForEntitySpanClassification.from_pretrained("studio-ousia/luke-large-finetuned-conll-2003",
                                                        config=config, ignore_mismatched_sizes=True, ).to(device)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/877 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/1.70k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

entity_vocab.json:   0%|          | 0.00/15.3M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/33.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/1.00k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.24G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.24G [00:00<?, ?B/s]

Some weights of the model checkpoint at studio-ousia/luke-large-finetuned-conll-2003 were not used when initializing LukeForEntitySpanClassification: ['luke.embeddings.position_ids']
- This IS expected if you are initializing LukeForEntitySpanClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LukeForEntitySpanClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LukeForEntitySpanClassification were not initialized from the model checkpoint at studio-ousia/luke-large-finetuned-conll-2003 and are newly initialized because the shapes did not match:
- classifier.weight: found shape torch.Size([5, 3072]) in the checkpoint and torch.Size([7, 3072]) in the model insta

In [17]:
tokenized_inputs = tokenizer(
    data['text'], entity_spans=data['entity_spans'], padding=True, truncation=True,
    return_tensors="pt", max_length=512, max_entity_length = 512).to(device)

val_tokenized_inputs = tokenizer(
    val_data['text'], entity_spans=val_data['entity_spans'], padding=True, truncation=True,
    return_tensors="pt", max_length=512, max_entity_length = 512).to(device)



tokenized_inputs['labels'] = torch.tensor(pad_labels(data['labels'], max_entities), dtype=torch.long, device=device)
tokenized_inputs = {k: v.to(device) for k, v in tokenized_inputs.items()}

val_tokenized_inputs['labels'] = torch.tensor(pad_labels(val_data['labels'], val_max_entities), dtype=torch.long, device=device)
val_tokenized_inputs = {k: v.to(device) for k, v in val_tokenized_inputs.items()}


In [18]:
tokenized_inputs['entity_ids'].shape, tokenized_inputs['entity_position_ids'].shape, tokenized_inputs['attention_mask'].shape, tokenized_inputs['entity_attention_mask'].shape, tokenized_inputs['labels'].shape

(torch.Size([500, 261]),
 torch.Size([500, 261, 30]),
 torch.Size([500, 405]),
 torch.Size([500, 261]),
 torch.Size([500, 261]))

In [19]:
val_tokenized_inputs['entity_ids'].shape, val_tokenized_inputs['entity_position_ids'].shape, val_tokenized_inputs['attention_mask'].shape, val_tokenized_inputs['labels'].shape

(torch.Size([110, 229]),
 torch.Size([110, 229, 30]),
 torch.Size([110, 268]),
 torch.Size([110, 229]))

In [20]:
# Convert to Hugging Face Dataset
dataset = Dataset.from_dict(tokenized_inputs)
val_dataset = Dataset.from_dict(val_tokenized_inputs)

In [21]:
# Define training arguments
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=6e-6,  # Reduce learning rate for stability
    per_device_train_batch_size=2,  # Increase batch size
    per_device_eval_batch_size=2,  # Increase batch size
    gradient_accumulation_steps=2,  # Reduce accumulation to adjust for batch size increase
    num_train_epochs=7,  # Increase epochs for better learning
    weight_decay=0.06,  # Increase weight decay to prevent overfitting
    logging_dir="./logs",
    logging_steps=10,  # Reduce logging frequency to avoid overhead
    report_to="none",
    save_total_limit=2,  # Prevent excessive checkpoints
    load_best_model_at_end=True,  # Ensure best model is kept
)

# torch.cuda.empty_cache()
# Define trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    eval_dataset=val_dataset,
)

# Train the model
trainer.train()



Epoch,Training Loss,Validation Loss
1,0.1476,0.086899
2,0.059,0.066651
3,0.0581,0.069455
4,0.0237,0.076363
5,0.0173,0.078133
6,0.0243,0.082096
7,0.0102,0.083028


TrainOutput(global_step=875, training_loss=0.08628214403986931, metrics={'train_runtime': 2509.237, 'train_samples_per_second': 1.395, 'train_steps_per_second': 0.349, 'total_flos': 3225248469135000.0, 'train_loss': 0.08628214403986931, 'epoch': 7.0})

In [22]:
model.save_pretrained("/content/fine-tuned_model/")
tokenizer.save_pretrained("/content/fine-tuned_model/")

('/content/fine-tuned_model/tokenizer_config.json',
 '/content/fine-tuned_model/special_tokens_map.json',
 '/content/fine-tuned_model/vocab.json',
 '/content/fine-tuned_model/merges.txt',
 '/content/fine-tuned_model/entity_vocab.json',
 '/content/fine-tuned_model/added_tokens.json')

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print('device = ', device)
# Load model and tokenizer
model_path = "/content/fine-tuned_model"
ft_model = LukeForEntitySpanClassification.from_pretrained(model_path, ignore_mismatched_sizes=True).to(device)

ft_tokenizer = AutoTokenizer.from_pretrained(model_path)

device =  cuda


In [5]:
ft_model.config.id2label

{0: 'NIL', 1: 'LOC', 2: 'JOB', 3: 'MONEY', 4: 'ORG', 5: 'PERSON', 6: 'DATE'}

### Recognizing named entities in a text:
Finally, we extract named entities from a text using the fine-tuned model. The input text is tokenized using SpaCy.

It extracts entity spans, tokenizes text, sends it to a GPU-based model, and constructs an IOB2 label sequence.


In [6]:


# Load text and model
text = ds['text'][801]  # Get the second-last text sample
nlp = spacy.load("en_core_web_sm")
doc = nlp(text)

# Generate entity spans
entity_spans = []
original_word_spans = []
for token_start in doc:
    for token_end in doc[token_start.i:]:
        entity_spans.append((token_start.idx, token_end.idx + len(token_end)))
        original_word_spans.append((token_start.i, token_end.i + 1))

# Tokenize and send to GPU
inputs = ft_tokenizer(text, entity_spans=entity_spans, return_tensors="pt", padding=True).to("cuda")

# Get model predictions
with torch.no_grad():
    outputs = ft_model(**inputs)

logits = outputs.logits
max_logits, max_indices = logits[0].max(dim=1)

# Process predictions
predictions = []
for logit, index, span in zip(max_logits, max_indices, original_word_spans):
    if index != 0:  # The span is not NIL
        predictions.append((logit, span, ft_model.config.id2label[int(index)]))

# Construct an IOB2 label sequence
predicted_sequence = ["O"] * len(doc)

for _, span, label in sorted(predictions, key=lambda o: o[0], reverse=True):
    # Check if any token within the span is already labeled
    if all(token == "O" for token in predicted_sequence[span[0] : span[1]]):
        predicted_sequence[span[0]] = "B-" + label  # First token gets "B-"

        # If the span covers multiple tokens, assign "I-" to the rest
        for i in range(span[0] + 1, span[1]):
            predicted_sequence[i] = "I-" + label

# Print results
for token, label in zip(doc, predicted_sequence):
    print(token, label)


Our O
trading O
subsidiaries O
in O
Preston B-LOC
, O
Durham B-LOC
and O
Wimbledon B-LOC
will O
now O
operate O
under O
a O
single O
national O
entity O
, O
BAKO B-ORG
Limited B-ORG
. O


###Measuring performance

"Due to limited GPU computational power, I am using only 20 examples for inference to measure performance."

In [7]:
test_examples_data = ds[1000:1020].reset_index(drop=True)

test_texts = test_examples_data['text']
_, all_labels_for_test  = extract_entity_spans_and_labels(test_examples_data, label_to_index, 100)

print(all_labels_for_test[:2])

[[5, 5], [4, 1, 6, 4]]


In [8]:
test_examples = {
    "text": test_texts,
    "labels": all_labels_for_test,
}

len(test_examples['text']),  len(test_examples['labels'])

(20, 20)

In [9]:
# Convert the 'test_examples' dictionary into a list of dictionaries
test_examples_list = []
for i in range(len(test_examples['text'])):
    example = {}
    for key in test_examples:  # Iterate through keys ('text', 'labels')
        example[key] = test_examples[key][i]  # Copy values for the current index
    test_examples_list.append(example)

In [10]:
test_examples_list[:2]

[{'text': 'Chen and Merk were initially apprehensive about hosting a virtual fair of this scale.',
  'labels': [5, 5]},
 {'text': 'Airbus turns to Germany to put its pions in cybersecurity. The European aviation giant announced Monday 25. March the acquisition of Infodas, a German company specializing in cybersecurity, in order to “fortify its portfolio” in this field.',
  'labels': [4, 1, 6, 4]}]

In [11]:

# Load the SpaCy model for tokenization
nlp = spacy.load("en_core_web_sm")

# Initialize lists to store outputs for all text samples
all_token_labels = []           # Stores token labels for each text
all_token_spans = []            # Stores entity spans for each text
all_logits = []                 # Stores model logits for predictions
all_original_word_spans = []    # Stores original word spans for reference
all_length = []                 # Stores text lengths

# Iterate through the test examples
for example in test_examples_list:
    torch.cuda.empty_cache()
    text, labels = example['text'], example['labels']

    # Tokenize the text using SpaCy
    doc = nlp(text)

    # Generate entity spans from tokenized text
    entity_spans = []
    original_word_spans = []
    for token_start in doc:
        for token_end in doc[token_start.i:]:
            entity_spans.append((token_start.idx, token_end.idx + len(token_end)))
            original_word_spans.append((token_start.i, token_end.i + 1))

    # Initialize all tokens with 'NIL' labels (no entity)
    token_labels = [label_to_index['NIL']] * len(doc)

    # Assign correct entity labels to tokens based on provided labels
    for (start, end), label in zip(entity_spans, labels):
        for token in doc:
            if start <= token.idx < end:  # Check if token falls within entity span
                token_labels[token.i] = label  # Assign label

    # Store results for the current text sample
    all_token_labels.append(token_labels)
    all_token_spans.append(entity_spans)
    all_original_word_spans.append(original_word_spans)
    all_length.append(len(doc))

    # Tokenize text and move tensors to the device (GPU/CPU)
    inputs = ft_tokenizer(text, entity_spans=entity_spans, return_tensors="pt", padding=True).to(device)

    # Perform inference with the model (disable gradient calculation for efficiency)
    with torch.no_grad():
        outputs = ft_model(**inputs)

    # Store model logits for later analysis
    all_logits.extend(outputs.logits.tolist())

In [12]:
len(all_token_labels), len(all_token_spans), len(all_logits), len(all_original_word_spans)

(20, 20, 20, 20)

In [13]:
# Extract ground truth labels from the test dataset
final_labels = [label for label in test_examples['labels']]

# Initialize a list to store final predictions
final_predictions = []

# Iterate over each example in the dataset
for example_index, example in enumerate(all_original_word_spans):

    # Retrieve the model's logits (confidence scores) for the current example
    logits = all_logits[example_index]

    # Get the maximum logit value and its corresponding index for each span
    max_logits = np.max(logits, axis=1)   # Highest confidence score per span
    max_indices = np.argmax(logits, axis=1)  # Predicted label index per span

    # Retrieve original spans for the current example
    original_spans = example
    predictions = []

    # Process predictions by filtering out "NIL" spans
    for logit, index, span in zip(max_logits, max_indices, original_spans):
        if index != 0:  # The span is not NIL (i.e., it has a valid label)
            predictions.append((logit, span, ft_model.config.id2label[index]))  # Store label with its confidence score

    # Initialize an IOB2 label sequence with "O" (outside entity) for all tokens
    predicted_sequence = ["O"] * all_length[example_index]

    # Sort predictions by confidence (logit score) in descending order
    for _, span, label in sorted(predictions, key=lambda o: o[0], reverse=True):

        # Ensure no overlapping labels are assigned
        if all([o == "O" for o in predicted_sequence[span[0] : span[1]]]):
            predicted_sequence[span[0]] = "B-" + label  # Mark beginning of entity
            if span[1] - span[0] > 1:
                predicted_sequence[span[0] + 1 : span[1]] = ["I-" + label] * (span[1] - span[0] - 1)  # Mark inside tokens

    # Store the final predicted label sequence for the example
    final_predictions.append(predicted_sequence)

In [14]:
print(final_predictions[0])

['B-PERSON', 'O', 'B-PERSON', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']


**Now, I am going to convert the true labels (`final_labels`) into the IOB2 format. Currently, it is a list of lists containing label indices instead of the label strings needed for inference.**

In [15]:
final_labels[:5]

[[5, 5], [4, 1, 6, 4], [4, 1, 4, 4], [4, 1], [1, 1, 1, 1]]

In [16]:
def convert_references_to_iob(labels):
    """Converts a list of label indices to IOB format, preserving all labels including 'O'."""
    iob_labels = []

    for label_seq in labels:
        iob_seq = []
        prev_label = None  # Keep track of the previous label

        for i, label_index in enumerate(label_seq):
            label = index_to_label[label_index]

            if label == "NIL":
                iob_seq.append("O")  # Preserve 'O' labels for non-entity tokens
                prev_label = None  # Reset previous label
            else:
                if prev_label is None or prev_label != label:
                    iob_seq.append("B-" + label)  # Start of a new entity
                else:
                    iob_seq.append("I-" + label)  # Continuation of the same entity
                prev_label = label  # Update the previous label

        iob_labels.append(iob_seq)  # Append sequence correctly

    return iob_labels

In [17]:
final_labels_iob = convert_references_to_iob(all_token_labels)

In [18]:
len(final_labels_iob), len(final_predictions)

(20, 20)

In [None]:
print(seqeval.metrics.classification_report(final_labels_iob, final_predictions, digits=4))