In [1]:
import torch
import torch.nn as nn
from transformers import AutoModel,AutoTokenizer
from torch.utils.data import DataLoader, Dataset
import json
import re
import ast
from tqdm import tqdm
import torch.optim as optim
from torch.nn.utils.rnn import pad_sequence
from typing import List, Tuple
import torch.nn.functional as F
import random
from sklearn.metrics import precision_score, recall_score, f1_score

**Process Pile-NER dataset**

In [2]:
#taken from GLiNER source code: https://github.com/urchade/GLiNER/blob/main/data/process_pilener.py

def load_data(filepath):
    """Loads data from a JSON file."""
    with open(filepath, 'r') as f:
        data = json.load(f)
    return data

def tokenize_text(text):
    """Tokenizes the input text into a list of tokens."""
    return re.findall(r'\w+(?:[-_]\w+)*|\S', text)

def extract_entity_spans(entry):
    """Extracts entity spans from an entry."""
    len_start = len("What describes ")
    len_end = len(" in the text?")
    entity_types, entity_texts, negative = [], [], []

    for c in entry['conversations']:
        if c['from'] == 'human' and c['value'].startswith('Text: '):
            text = c['value'][len('Text: '):]
            tokenized_text = tokenize_text(text)
        elif c['from'] == 'human' and c['value'].startswith('What describes '):
            entity_type = c['value'][len_start:-len_end]
            entity_types.append(entity_type)
        elif c['from'] == 'gpt' and c['value'].startswith('['):
            if c['value'] == '[]':
                negative.append(entity_types.pop())
                continue
            texts_ents = ast.literal_eval(c['value'])
            entity_texts.extend(texts_ents)
            num_repeat = len(texts_ents) - 1
            entity_types.extend([entity_types[-1]] * num_repeat)

    entity_spans = []
    for j, entity_text in enumerate(entity_texts):
        entity_tokens = tokenize_text(entity_text)
        matches = []
        for i in range(len(tokenized_text) - len(entity_tokens) + 1):
            if " ".join(tokenized_text[i:i + len(entity_tokens)]).lower() == " ".join(entity_tokens).lower():
                matches.append((i, i + len(entity_tokens) - 1, entity_types[j]))
        if matches:
            entity_spans.extend(matches)

    return {"tokenized_text": tokenized_text, "ner": entity_spans, "negative": negative}

def process_data(data):
    """Processes a list of data entries to extract entity spans."""
    all_data = [extract_entity_spans(entry) for entry in tqdm(data)]
    return all_data

def save_data_to_file(data, filepath):
    """Saves the processed data to a JSON file."""
    with open(filepath, 'w') as f:
        json.dump(data, f)

**Process CrossNER datasets**

In [3]:
#taken from https://github.com/zliucr/CrossNER/tree/main/src
def process_conll_data(file_path):
    """
    Processes a CONLL-format file into the processed_data format with dynamic entity type extraction.

    Args:
        file_path: Path to the CONLL-format file.

    Returns:
        List of processed examples in the desired format.
    """
    data = []
    sentence_tokens = []
    sentence_labels = []
    all_entity_types = set()

    with open(file_path, "r") as file:
        for line in file:
            line = line.strip()
            if not line:
                if sentence_tokens:
                    ner_entities = []
                    for idx, label in enumerate(sentence_labels):
                        if label != "O":
                            entity_type = label[2:]
                            all_entity_types.add(entity_type)
                            if label.startswith("B-"):
                                ner_entities.append((idx, idx, entity_type))
                            elif label.startswith("I-") and ner_entities:
                                ner_entities[-1] = (
                                    ner_entities[-1][0],
                                    idx,
                                    ner_entities[-1][2],
                                )

                    sentence_entity_types = {ent[2] for ent in ner_entities}

                    negative_types = list(all_entity_types - sentence_entity_types)

                    data.append({
                        "tokenized_text": sentence_tokens,
                        "ner": ner_entities,
                        "negative": negative_types,
                    })

                sentence_tokens = []
                sentence_labels = []
            else:
                token, label = line.split()
                sentence_tokens.append(token)
                sentence_labels.append(label)

    for example in data:
        example["negative"] = list(all_entity_types - {ent[2] for ent in example["ner"]})

    return data


**GLiNER Model**

In [54]:
class GLiNER(nn.Module):
    def __init__(self, pretrained_model_name, tokenizer, max_span_length=12, hidden_size=768):
        super(GLiNER, self).__init__()

        #tokenizer
        self.tokenizer = tokenizer

        #encoder
        self.encoder = AutoModel.from_pretrained(pretrained_model_name)
        self.encoder.resize_token_embeddings(len(tokenizer))

        #hyperparams
        self.hidden_size = hidden_size
        self.max_span_length = max_span_length

        #FFN for span rep
        self.span_ffn = nn.Sequential(
            nn.Linear(self.hidden_size * 2, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size),
        )
        #FFN for entity rep
        self.entity_ffn = nn.Sequential(
            nn.Linear(self.hidden_size, self.hidden_size),
            nn.ReLU(),
            nn.Linear(self.hidden_size, self.hidden_size),
        )

    def forward(self, input_ids, attention_mask):

        # encoder layer
        encoder_output = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        token_embeddings = encoder_output.last_hidden_state  # Shape: (batch_size, seq_len, hidden_size)

        # entity embeddings
        ent_token_id = self.tokenizer.convert_tokens_to_ids("[ENT]")
        ent_positions = (input_ids == ent_token_id).nonzero(as_tuple=True)
        entity_embeddings = token_embeddings[ent_positions[0], ent_positions[1], :]  # (num_entities, hidden_size)
        refined_entity_embeddings = self.entity_ffn(entity_embeddings)

        # span embeddings
        spans, span_embeddings = self.create_span_embeddings(token_embeddings, attention_mask, input_ids)

        # span-entity matching
        span_scores = self.compute_span_scores(span_embeddings, refined_entity_embeddings)

        return spans, span_scores

    def create_span_embeddings(self, token_embeddings, attention_mask, input_ids):

        batch_size, seq_len, _ = token_embeddings.size()
        spans = []
        span_embeddings = []

        special_token_ids = [self.tokenizer.convert_tokens_to_ids("[ENT]"), self.tokenizer.sep_token_id]

        #ignore special tokens for embeddings
        for batch_idx in range(batch_size):
            for start in range(seq_len):
                for end in range(start, min(start + self.max_span_length, seq_len)):
                    if (
                        attention_mask[batch_idx, start] == 0
                        or attention_mask[batch_idx, end] == 0
                        or input_ids[batch_idx, start] in special_token_ids
                        or input_ids[batch_idx, end] in special_token_ids
                    ):
                        continue

                    # generate span embeddings
                    spans.append((batch_idx, start, end))
                    start_embedding = token_embeddings[batch_idx, start, :]
                    end_embedding = token_embeddings[batch_idx, end, :]
                    span_embedding = torch.cat((start_embedding, end_embedding), dim=-1)
                    span_embeddings.append(self.span_ffn(span_embedding))

        span_embeddings = torch.stack(span_embeddings) if span_embeddings else torch.empty(0, self.hidden_size)
        return spans, span_embeddings

    def compute_span_scores(self, span_embeddings, entity_embeddings):

        # dot product
        scores = torch.matmul(span_embeddings, entity_embeddings.T)

        # apply sigmoid to normalize between 0-1
        span_scores = torch.sigmoid(scores)

        return span_scores



**GLiNER Dataset class**

**Collate data function**

In [39]:
def collate_func(batch):

    # pad input_ids and attention_mask to the same length
    input_ids = torch.nn.utils.rnn.pad_sequence(
        [item["input_ids"] for item in batch], batch_first=True, padding_value=0
    )
    attention_mask = torch.nn.utils.rnn.pad_sequence(
        [item["attention_mask"] for item in batch], batch_first=True, padding_value=0
    )

    spans = [item["spans"] for item in batch]
    labels = [item["labels"] for item in batch]

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "spans": spans,
        "labels": labels,
    }

In [40]:
class GlinerDataset(Dataset):
    def __init__(self, data, tokenizer, max_seq_length=512, max_span_length=12,max_entity_types=25):

        self.data = data
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.max_span_length = max_span_length
        self.max_entity_types = max_entity_types

    def __len__(self):

        return len(self.data)

    def __getitem__(self, idx):

        item = self.data[idx]

        #limit to 25 entities per sentence
        all_entity_types = list(set([ent[2] for ent in item["ner"]] + item["negative"]))
        if len(all_entity_types) > self.max_entity_types:
            all_entity_types = random.sample(all_entity_types, self.max_entity_types)

        # format input
        entity_type_str = " ".join([f"[ENT] {etype}" for etype in all_entity_types])

        text = " ".join(item["tokenized_text"])
        formatted_input = f"{entity_type_str} [SEP] {text}"

        # tokenize
        tokenized = self.tokenizer(
            formatted_input, padding="max_length", truncation=True, max_length=self.max_seq_length, return_tensors="pt"
        )
        input_ids = tokenized["input_ids"].squeeze(0)
        attention_mask = tokenized["attention_mask"].squeeze(0)

        # generate spans
        positive_spans = item["ner"]
        positive_labels = {(start, end): 1 for start, end, _ in positive_spans}
        spans = []
        labels = []

        #ignore special tokens for spans and labels
        special_token_ids = [self.tokenizer.convert_tokens_to_ids("[ENT]"), self.tokenizer.sep_token_id]
        for start in range(len(input_ids)):
            for end in range(start, min(start + self.max_span_length, len(input_ids))):
                if (
                    input_ids[start].item() in special_token_ids
                    or input_ids[end].item() in special_token_ids
                    or attention_mask[start].item() == 0
                    or attention_mask[end].item() == 0
                ):
                    continue
                spans.append((start, end))
                labels.append(positive_labels.get((start, end), 0))

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "spans": spans,
            "labels": torch.tensor(labels, dtype=torch.float),
        }


**Training Function**

In [41]:
def train_gliner_model(model, dataloader, optimizer, num_epochs, device, threshold=0.5):

    model = model.to(device)

    model.train()
    for epoch in range(num_epochs):
        total_loss, total_correct, total_spans = 0.0, 0, 0

        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            spans = batch["spans"]
            labels = [label.to(device) for label in batch["labels"]]

            optimizer.zero_grad()
            _, span_scores = model(input_ids=input_ids, attention_mask=attention_mask)

            loss = 0.0
            score_idx = 0

            for item_idx, (item_spans, item_labels) in enumerate(zip(spans, labels)):
                # get score for item
                item_scores = span_scores[score_idx:score_idx + len(item_spans)]

                # skip if no valid spans
                if item_scores.size(0) == 0:
                    continue

                item_labels = item_labels.to(device)

                # expand labels to match num entity types
                if item_labels.dim() == 1:
                    expanded_labels = torch.zeros(
                        item_labels.size(0), item_scores.size(1), device=device
                    )
                    expanded_labels.scatter_(1, item_labels.long().unsqueeze(1), 1)
                    item_labels = expanded_labels

                # compute binary cross-entropy for this item
                loss += F.binary_cross_entropy(item_scores, item_labels)

                # predictions for accuracy
                predictions = (item_scores > threshold).long()
                total_correct += (predictions == item_labels).sum().item()
                total_spans += len(item_labels)

                # next
                score_idx += len(item_spans)

            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        accuracy = total_correct / total_spans if total_spans > 0 else 0
        print(f"Epoch {epoch + 1}, Loss: {total_loss:.4f}, Accuracy: {accuracy:.4f}")

    return model


**Initialize and Train**

In [18]:
#path_pile_ner = 'train.json'
#data = load_data(path_pile_ner)
#processed_data = process_data(data)
#save_data_to_file(processed_data, 'pilener_train.json')
processed_data = load_data('pilener_train.json')
print("dataset size:", len(processed_data))

dataset size: 45889


In [55]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#load tokenizer and add special tokens
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-small")
special_tokens = {"additional_special_tokens": ["[ENT]", "[SEP]"]}
tokenizer.add_special_tokens(special_tokens)

# initialize model
model = GLiNER(pretrained_model_name="microsoft/deberta-v3-small", tokenizer=tokenizer, max_span_length=12)

#resize embeddings to accomodate new tokens
model.encoder.resize_token_embeddings(len(tokenizer))




Embedding(128002, 768, padding_idx=0)

In [56]:
#just trying a subset of data...
data = processed_data[:200]

#prepare dataset and dataloader with custom collate function
dataset = GlinerDataset(data, tokenizer, max_span_length=12)

In [60]:
dataloader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_func,)

#optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

#epochs
num_epochs = 3

#total steps
total_steps = len(dataloader) * num_epochs

In [61]:
trained_model = train_gliner_model(model, dataloader, optimizer, num_epochs=num_epochs, device=device)

Epoch 1, Loss: 15.5279, Accuracy: 53.8741
Epoch 2, Loss: 13.2279, Accuracy: 53.6064
Epoch 3, Loss: 13.9754, Accuracy: 54.1247


**Greedy Decoding algorithm (Flat NER)**

In [67]:
#Flat NER Greedy Decoding algorithm
def greedy_decode(spans, span_scores, threshold=0.1):

    scored_spans = [(span, score.item()) for span, score in zip(spans, span_scores) if score > threshold]
    scored_spans = sorted(scored_spans, key=lambda x: x[1], reverse=True)

    selected_spans = []
    for span, score in scored_spans:
        if all(span[1] < other[0] or span[0] > other[1] for other, _ in selected_spans):
            selected_spans.append((span, score))

    return selected_spans

**Evaluation Function**

In [75]:
def evaluate_gliner_model(model, dataloader, device, threshold=0.1):
    """
    Evaluate the GLiNER model on a validation/test dataset.

    Args:
        model: The trained GLiNER model.
        dataloader: DataLoader providing the evaluation data.
        device: Device to run the evaluation on (e.g., "cuda" or "cpu").
        threshold: Classification threshold for span predictions.

    Returns:
        metrics: A dictionary containing precision, recall, and F1-score.
    """
    model.eval()
    model.to(device)

    all_true_spans = []
    all_predicted_spans = []

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            spans = batch["spans"]
            labels = [label.cpu().numpy() for label in batch["labels"]]

            predicted_spans, span_scores = model(input_ids=input_ids, attention_mask=attention_mask)

            decoded_spans = greedy_decode(predicted_spans, span_scores.flatten(), threshold)

            for item_spans, item_labels, item_decoded in zip(spans, labels, decoded_spans):
                true_spans = [span for span, label in zip(item_spans, item_labels) if label == 1]
                all_true_spans.append(true_spans)

                predicted_spans = [span for span, _ in item_decoded]
                all_predicted_spans.append(predicted_spans)

    flat_true = [span for spans in all_true_spans for span in spans]
    flat_pred = [span for spans in all_predicted_spans for span in spans]

    precision = precision_score(flat_true, flat_pred, average="micro", zero_division=0)
    recall = recall_score(flat_true, flat_pred, average="micro", zero_division=0)
    f1 = f1_score(flat_true, flat_pred, average="micro", zero_division=0)

    metrics = {
        "precision": precision,
        "recall": recall,
        "f1_score": f1
    }

    print(f"Evaluation Metrics: Precision={precision:.4f}, Recall={recall:.4f}, F1-Score={f1:.4f}")

    return metrics

**Evaluate**

In [50]:
crossner_ai_test = process_conll_data("test.txt")

In [72]:
validation_data = crossner_ai_test[:20]
validation_dataset = GlinerDataset(validation_data, tokenizer, max_span_length=12)
validation_dataloader = DataLoader(
    validation_dataset, batch_size=8, shuffle=False, collate_fn=collate_func
)

In [76]:
metrics = evaluate_gliner_model(trained_model, validation_dataloader, device=device)

Evaluation Metrics: Precision=0.0000, Recall=0.0000, F1-Score=0.0000


In [77]:
metrics

{'precision': 0.0, 'recall': 0.0, 'f1_score': 0.0}