XLLM-ACL Named Entity Recognition Challenge Project




In [2]:
import json
import re
from pathlib import Path
from collections import defaultdict, Counter

import torch
from torch.utils.data import Dataset

from transformers import (
    AutoTokenizer,
    AutoModelForTokenClassification,
    TrainingArguments,
    Trainer,
    DataCollatorForTokenClassification,
    EarlyStoppingCallback,
)


SEED = 42
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ========= SET THESE PATHS =========

PROJECT_ROOT = Path(".")  # current folder

TRAIN_DIR = PROJECT_ROOT / "train"
DEV_DIR   = PROJECT_ROOT / "dev"
TEST_DIR  = PROJECT_ROOT / "test"  

# ========= MODEL / TRAINING CONFIG =========
BASE_NER_MODEL = "allenai/longformer-base-4096"  
MAX_LEN = 2048   
BATCH_SIZE = 1   
NER_EPOCHS = 5   
LR = 2e-5        


Using device: cuda


In [3]:
# Recursively load all *.json files under dir_path. Each file can be a list[dict] or a single dict. 
def load_docs_from_dir(dir_path: Path):    
    docs = []
    for jp in sorted(dir_path.rglob("*.json")):
        with open(jp, "r", encoding="utf-8") as f:
            data = json.load(f)
        if isinstance(data, list):
            docs.extend(data)
        elif isinstance(data, dict):
            docs.append(data)
        else:
            raise ValueError(f"Unexpected JSON structure in {jp}")
    return docs

train_docs = load_docs_from_dir(TRAIN_DIR)
dev_docs   = load_docs_from_dir(DEV_DIR)

print(f"Loaded {len(train_docs)} train docs")
print(f"Loaded {len(dev_docs)} dev docs")
print("Example train doc keys:", train_docs[0].keys())


Loaded 51 train docs
Loaded 23 dev docs
Example train doc keys: dict_keys(['domain', 'title', 'doc', 'entities', 'triples', 'label_set', 'entity_label_set'])


In [4]:
# Collect entity types from TRAIN + DEV data
entity_types = set()

for d in train_docs + dev_docs:
    if "entity_label_set" in d:
        entity_types.update(d["entity_label_set"])
    for ent in d.get("entities", []):
        t = ent.get("type")
        if t:
            entity_types.add(t)

entity_types = sorted(entity_types)
print("Entity types:", entity_types)

ner_labels = ["O"]
for t in entity_types:
    ner_labels.append(f"B-{t}")
    ner_labels.append(f"I-{t}")

label2id_ner = {lbl: i for i, lbl in enumerate(ner_labels)}
id2label_ner = {i: lbl for lbl, i in label2id_ner.items()}

print("Num NER labels:", len(ner_labels))
print("Sample labels:", ner_labels[:20])


Entity types: ['CARDINAL', 'DATE', 'EVENT', 'FAC', 'GPE', 'LANGUAGE', 'LAW', 'LOC', 'MISC', 'MONEY', 'NORP', 'ORDINAL', 'ORG', 'PERCENT', 'PERSON', 'PRODUCT', 'QUANTITY', 'TIME', 'WORK_OF_ART']
Num NER labels: 39
Sample labels: ['O', 'B-CARDINAL', 'I-CARDINAL', 'B-DATE', 'I-DATE', 'B-EVENT', 'I-EVENT', 'B-FAC', 'I-FAC', 'B-GPE', 'I-GPE', 'B-LANGUAGE', 'I-LANGUAGE', 'B-LAW', 'I-LAW', 'B-LOC', 'I-LOC', 'B-MISC', 'I-MISC', 'B-MONEY']


In [5]:
mention_type_counts = defaultdict(Counter)

for doc in train_docs:
    for ent in doc.get("entities", []):
        t = ent.get("type")
        for m in ent.get("mentions", []):
            if not m:
                continue
            mention_type_counts[m.lower()][t] += 1

lexicon = {
    m: cnt.most_common(1)[0][0]
    for m, cnt in mention_type_counts.items()
}

print("Lexicon size:", len(lexicon))
sample_lex = list(lexicon.items())[:10]
print("Lexicon sample:", sample_lex)


Lexicon size: 3445
Lexicon sample: [('william thomson, 1st baron kelvin', 'PERSON'), ('william thomson', 'PERSON'), ('1st baron kelvin', 'PERSON'), ('kelvin', 'PERSON'), ('baron kelvin', 'PERSON'), ('1st', 'ORDINAL'), ('cable signals', 'MISC'), ('wheatstone transmitter', 'MISC'), ('land line', 'MISC'), ('long cable', 'MISC')]


In [6]:
# From entity mentions, build a list of (start_char, end_char, type). Uses regex with word-boundary heuristics to avoid mid-word matches.
def build_char_spans(doc_text: str, entities: list):    
    spans = []

    for ent in entities:
        ent_type = ent.get("type")
        if not ent_type:
            continue

        for m in ent.get("mentions", []):
            if not m:
                continue

            pattern = re.escape(m)
            
            if m[0].isalnum() and m[-1].isalnum():
                pattern = r"\b" + pattern + r"\b"

            for match in re.finditer(pattern, doc_text):
                start, end = match.span()
                spans.append((start, end, ent_type))

    spans.sort(key=lambda x: (x[0], x[1]))
    return spans

# quick sanity on first train doc
sample_spans = build_char_spans(train_docs[0]["doc"], train_docs[0].get("entities", []))
print("Example spans:", sample_spans[:10])


Example spans: [(5, 26, 'MISC'), (41, 54, 'MISC'), (68, 83, 'PERSON'), (68, 101, 'PERSON'), (85, 88, 'ORDINAL'), (85, 101, 'PERSON'), (89, 101, 'PERSON'), (95, 101, 'PERSON'), (153, 158, 'MISC'), (178, 200, 'MISC')]


In [7]:
ner_tokenizer = AutoTokenizer.from_pretrained(BASE_NER_MODEL)

# Convert docs to tokenized inputs + BIO label IDs. One Longformer pass per document (up to max_length tokens).

def encode_ner_docs(docs, tokenizer, max_length, label2id):    
    all_input_ids = []
    all_attention_masks = []
    all_label_ids = []

    for d in docs:
        text = d["doc"]
        if not isinstance(text, str):
            if isinstance(text, (list, tuple)):
                text = " ".join(str(x) for x in text)
            else:
                text = str(text)

        entities = d.get("entities", [])
        spans = build_char_spans(text, entities)

        enc = tokenizer(
            text,
            truncation=True,
            max_length=max_length,
            return_offsets_mapping=True,
        )

        input_ids = enc["input_ids"]
        attention_mask = enc["attention_mask"]
        offsets = enc["offset_mapping"]

        labels = ["O"] * len(input_ids)

        
        for (span_start, span_end, span_type) in spans:
            inside = False
            for i, (tok_start, tok_end) in enumerate(offsets):
                if tok_end <= tok_start:  
                    continue

                if tok_start >= span_end:
                    break  

                if tok_start >= span_start and tok_end <= span_end:
                    if not inside:
                        labels[i] = f"B-{span_type}"
                        inside = True
                    else:
                        labels[i] = f"I-{span_type}"

        # Convert to IDs; ignore special tokens with -100
        label_ids = []
        for lab, (tok_start, tok_end) in zip(labels, offsets):
            if tok_end <= tok_start:
                label_ids.append(-100)
            else:
                label_ids.append(label2id.get(lab, label2id["O"]))

        all_input_ids.append(input_ids)
        all_attention_masks.append(attention_mask)
        all_label_ids.append(label_ids)

    return {
        "input_ids": all_input_ids,
        "attention_mask": all_attention_masks,
        "labels": all_label_ids,
    }

train_encodings = encode_ner_docs(train_docs, ner_tokenizer, MAX_LEN, label2id_ner)
dev_encodings   = encode_ner_docs(dev_docs,   ner_tokenizer, MAX_LEN, label2id_ner)

print("Train examples:", len(train_encodings["input_ids"]))
print("Dev examples:", len(dev_encodings["input_ids"]))


Train examples: 51
Dev examples: 23


In [8]:
class NERDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __len__(self):
        return len(self.encodings["input_ids"])

    def __getitem__(self, idx):
        # Return lists; DataCollator pads & tensorizes
        return {
            "input_ids": self.encodings["input_ids"][idx],
            "attention_mask": self.encodings["attention_mask"][idx],
            "labels": self.encodings["labels"][idx],
        }

train_dataset = NERDataset(train_encodings)
dev_dataset   = NERDataset(dev_encodings)

print("Train dataset size:", len(train_dataset))
print("Dev dataset size:", len(dev_dataset))

data_collator = DataCollatorForTokenClassification(
    tokenizer=ner_tokenizer,
    padding=True,
    max_length=MAX_LEN,
    return_tensors="pt",
)


Train dataset size: 51
Dev dataset size: 23


In [9]:
from collections import Counter

# Flatten labels (ignore -100)
all_labels_flat = [
    lab
    for seq in train_encodings["labels"]
    for lab in seq
    if lab != -100
]

label_counts = Counter(all_labels_flat)
print("Label counts (id -> count):", label_counts)

num_labels = len(ner_labels)
class_weights = torch.ones(num_labels, dtype=torch.float)

total = sum(label_counts.values())
for lab_id, count in label_counts.items():
    class_weights[lab_id] = total / (num_labels * count)

class_weights = class_weights.to(device)
print("Class weights:", class_weights)


Label counts (id -> count): Counter({0: 42757, 18: 2622, 17: 1703, 26: 1210, 30: 1179, 3: 676, 25: 642, 29: 629, 4: 377, 38: 353, 9: 333, 22: 326, 10: 316, 1: 260, 21: 226, 32: 158, 37: 150, 2: 136, 31: 131, 6: 118, 8: 94, 23: 85, 34: 82, 16: 70, 7: 69, 20: 55, 27: 48, 15: 47, 28: 43, 5: 42, 33: 41, 24: 41, 14: 41, 19: 22, 13: 13, 12: 11, 11: 10, 35: 7, 36: 7})
Class weights: tensor([3.3061e-02, 5.4369e+00, 1.0394e+01, 2.0911e+00, 3.7496e+00, 3.3657e+01,
        1.1980e+01, 2.0487e+01, 1.5038e+01, 4.2450e+00, 4.4734e+00, 1.4136e+02,
        1.2851e+02, 1.0874e+02, 3.4478e+01, 3.0076e+01, 2.0194e+01, 8.3006e-01,
        5.3913e-01, 6.4254e+01, 2.5702e+01, 6.2548e+00, 4.3362e+00, 1.6630e+01,
        3.4478e+01, 2.2019e+00, 1.1683e+00, 2.9450e+01, 3.2874e+01, 2.2474e+00,
        1.1990e+00, 1.0791e+01, 8.9468e+00, 3.4478e+01, 1.7239e+01, 2.0194e+02,
        2.0194e+02, 9.4239e+00, 4.0045e+00], device='cuda:0')


In [10]:
ner_model = AutoModelForTokenClassification.from_pretrained(
    BASE_NER_MODEL,
    num_labels=len(ner_labels),
    id2label=id2label_ner,
    label2id=label2id_ner,
    ignore_mismatched_sizes=True, 
)

ner_model.to(device)


Some weights of LongformerForTokenClassification were not initialized from the model checkpoint at allenai/longformer-base-4096 and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


LongformerForTokenClassification(
  (longformer): LongformerModel(
    (embeddings): LongformerEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (position_embeddings): Embedding(4098, 768, padding_idx=1)
    )
    (encoder): LongformerEncoder(
      (layer): ModuleList(
        (0-11): 12 x LongformerLayer(
          (attention): LongformerAttention(
            (self): LongformerSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (query_global): Linear(in_features=768, out_features=768, bias=True)
              (key_global): Linear(in_features=768, out_features=768, bias=True)
             

In [11]:
from transformers import Trainer

class WeightedNERTrainer(Trainer):
    def __init__(self, class_weights=None, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.class_weights = class_weights

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # accept extra kwargs like num_items_in_batch
        labels = inputs.pop("labels")
        outputs = model(**inputs)
        logits = outputs.logits  

        loss_fct = torch.nn.CrossEntropyLoss(
            weight=self.class_weights,
            ignore_index=-100,
        )
        loss = loss_fct(
            logits.view(-1, logits.size(-1)),
            labels.view(-1),
        )
        return (loss, outputs) if return_outputs else loss



In [12]:
ner_training_args = TrainingArguments(
    output_dir="outputs/ner_longformer",
    learning_rate=LR,
    per_device_train_batch_size=BATCH_SIZE,
    num_train_epochs=NER_EPOCHS,
    weight_decay=0.01,
    logging_steps=20,
    save_steps=200,           
    save_total_limit=1,
    fp16=torch.cuda.is_available(),
)



In [13]:
ner_trainer = WeightedNERTrainer(
    model=ner_model,
    args=ner_training_args,
    train_dataset=train_dataset,
    eval_dataset=dev_dataset,   
    data_collator=data_collator,
    class_weights=class_weights,
)


ner_trainer.train()

metrics = ner_trainer.evaluate()
print("Dev metrics:", metrics)

save_dir = "outputs/ner_longformer_model"
ner_trainer.save_model(save_dir)
ner_tokenizer.save_pretrained(save_dir)
print("Saved fine-tuned model to:", save_dir)


Input ids are automatically padded to be a multiple of `config.attention_window`: 512


Step,Training Loss
20,3.6814
40,3.5584
60,3.2554
80,2.7916
100,2.9534
120,2.3
140,2.249
160,2.3562
180,1.885
200,1.8442




Dev metrics: {'eval_loss': 2.366665840148926, 'eval_runtime': 50.0017, 'eval_samples_per_second': 0.46, 'eval_steps_per_second': 0.06, 'epoch': 5.0}
Saved fine-tuned model to: outputs/ner_longformer_model


In [None]:
# Heuristic filters to remove obvious junk spans.
def clean_predicted_entities(entities):    
    cleaned = []
    for ent in entities:
        ent_type = ent["type"]
        new_mentions = []
        for m in ent["mentions"]:
            m_strip = m.strip()
            if len(m_strip) < 3:
                continue
            if all(ch in ",.;:-_()[]{}\"'" for ch in m_strip):
                continue
            if " " not in m_strip and m_strip.islower() and len(m_strip) <= 4:
                continue
            new_mentions.append(m_strip)
        if new_mentions:
            cleaned.append({
                "type": ent_type,
                "mentions": sorted(set(new_mentions)),
            })
    return cleaned

# Predict DocIE-style entities for one document, BIO decode spans, group by (mention, type), optional lexicon-based type correction
def ner_predict_entities(
    doc_text,
    tokenizer,
    model,
    id2label,
    max_length=MAX_LEN,
    lexicon=None,
):
    if not isinstance(doc_text, str):
        if isinstance(doc_text, (list, tuple)):
            doc_text = " ".join(str(x) for x in doc_text)
        else:
            doc_text = str(doc_text)

    model.eval()
    all_spans = []

    with torch.no_grad():
        enc = tokenizer(
            doc_text,
            truncation=True,
            max_length=max_length,
            return_offsets_mapping=True,
            return_tensors="pt",
        ).to(device)

        offsets = enc.pop("offset_mapping")[0].cpu().tolist()
        outputs = model(**enc)
        pred_ids = outputs.logits.argmax(-1)[0].cpu().tolist()

    current_type = None
    current_start = None
    current_end = None

    for label_id, (start_char, end_char) in zip(pred_ids, offsets):
        if end_char <= start_char:
            continue  

        tag = id2label[label_id]
        if tag == "O":
            if current_type is not None:
                all_spans.append((current_start, current_end, current_type))
                current_type = None
                current_start = None
                current_end = None
            continue

        prefix, ent_type = tag.split("-", 1)

        if prefix == "B":
            if current_type is not None:
                all_spans.append((current_start, current_end, current_type))
            current_type = ent_type
            current_start = start_char
            current_end = end_char
        elif prefix == "I":
            if current_type == ent_type:
                current_end = end_char
            else:
                if current_type is not None:
                    all_spans.append((current_start, current_end, current_type))
                current_type = ent_type
                current_start = start_char
                current_end = end_char

    if current_type is not None:
        all_spans.append((current_start, current_end, current_type))

    # Deduplicate spans
    span_set = set(all_spans)

    grouped = defaultdict(set)
    for (s, e, t) in span_set:
        mention = doc_text[s:e].strip()
        if mention:
            grouped[(mention, t)].add(mention)

    entities = []
    for (mention, ent_type), mentions in grouped.items():
        entities.append({
            "mentions": list(mentions),
            "type": ent_type,
        })

    entities = clean_predicted_entities(entities)

    # Lexicon-based type override
    if lexicon is not None:
        for ent in entities:
            votes = Counter()
            for m in ent["mentions"]:
                key = m.lower()
                if key in lexicon:
                    votes[lexicon[key]] += 1
            if votes:
                ent["type"] = votes.most_common(1)[0][0]

    return entities


In [15]:
# Create IDs for dev docs for later analysis
dev_ids = [f"{doc['domain']}_{i}" for i, doc in enumerate(dev_docs)]
len(dev_ids), dev_ids[:5]


(23,
 ['Human_behavior_0',
  'Human_behavior_1',
  'Human_behavior_2',
  'Human_behavior_3',
  'Human_behavior_4'])

In [16]:
# Build reference.json (GROUND TRUTH) for dev
reference = {}

for doc_id, doc in zip(dev_ids, dev_docs):
    
    reference[doc_id] = doc

ref_out_dir = Path("input") / "ref"
ref_out_dir.mkdir(parents=True, exist_ok=True)
ref_file = ref_out_dir / "reference.json"

with open(ref_file, "w", encoding="utf-8") as f:
    json.dump(reference, f, indent=2, ensure_ascii=False)

print("Wrote dev reference to:", ref_file)


Wrote dev reference to: input\ref\reference.json


In [17]:
# Build results.json (PREDICTIONS) for dev 
results = {}

for doc_id, doc in zip(dev_ids, dev_docs):
    text = doc["doc"]
    title = doc.get("title", "")
    ents = ner_predict_entities(text, ner_tokenizer, ner_model, id2label_ner, max_length=MAX_LEN, lexicon=lexicon)

    results[doc_id] = {
        "title":   title,
        "entities": ents,
        "triples": [],   
    }

res_out_dir = Path("input") / "res"
res_out_dir.mkdir(parents=True, exist_ok=True)
res_file = res_out_dir / "results.json"

with open(res_file, "w", encoding="utf-8") as f:
    json.dump(results, f, indent=2, ensure_ascii=False)

print("Wrote dev predictions to:", res_file)


Wrote dev predictions to: input\res\results.json
