# GLiREL RE Evaluation on Common Datasets

## NYT

### Data Preprocessing

In [None]:
import json
import random

In [None]:
# Step 1: Load JSONL
def load_nyt_dataset(path):
    with open(path, 'r', encoding='utf-8') as f:
        return [json.loads(line) for line in f]

In [None]:
nyt_data = load_nyt_dataset("data/nyt_data.jsonl")

# Inspect first example
example = nyt_data[0]
print(example)

{'ner': [[11, 12, 'LOCATION', 'Annandale-on-Hudson'], [8, 10, 'ORGANIZATION', 'Bard College']], 'relations': [{'head': {'mention': 'Annandale-on-Hudson', 'position': [11, 12], 'type': 'LOCATION'}, 'tail': {'mention': 'Bard College', 'position': [8, 10], 'type': 'ORGANIZATION'}, 'relation_text': 'contains'}], 'tokenized_text': ['Massachusetts', 'ASTON', 'MAGNA', 'Great', 'Barrington', ';', 'also', 'at', 'Bard', 'College', ',', 'Annandale-on-Hudson', ',', 'N.Y.', ',', 'July', '1-Aug', '.']}


In [None]:
# Extract unique labels across dataset
def get_all_relation_labels(dataset):
    labels = set()
    for item in dataset:
        for rel in item["relations"]:
            labels.add(rel["relation_text"])
    return sorted(labels)

In [None]:
relation_labels = get_all_relation_labels(nyt_data)
print("Relation Labels:", relation_labels)
print("Number of Relation Labels:", len(relation_labels))

Relation Labels: ['administrative_divisions', 'advisors', 'capital', 'children', 'company', 'contains', 'country', 'ethnicity', 'founders', 'geographic_distribution', 'industry', 'location', 'major_shareholder_of', 'major_shareholders', 'nationality', 'neighborhood_of', 'people', 'place_founded', 'place_lived', 'place_of_birth', 'place_of_death', 'profession', 'religion', 'teams']
Number of Relation Labels: 24


In [None]:
def save_jsonl(data, path):
    with open(path, 'w', encoding='utf-8') as f:
        for item in data:
            f.write(json.dumps(item) + '\n')

In [None]:
# Shuffle for randomness
random.seed(42)
random.shuffle(nyt_data)

# Take 10% as test data
test_size = int(0.10 * len(nyt_data))
test_data = nyt_data[:test_size]

In [None]:
# Save to new file
save_jsonl(test_data, "data/nyt_test_10pct.jsonl")

print(f"Saved {len(test_data)} samples to 'data/nyt_test_10pct.jsonl'")

Saved 5619 samples to 'data/nyt_test_10pct.jsonl'


In [None]:
def prepare_glirel_input(ex):
    tokens = ex["tokenized_text"]
    
    # 1. Inclusive NER
    ner = [[span[0], span[1]-1, span[2], span[3]] for span in ex["ner"]]

    # 2. Gold relations as set
    gold_set = set()
    for rel in ex['relations']:
        head = rel["head"]["position"]
        tail = rel["tail"]["position"]
        relation = rel["relation_text"]

        head_span = (head[0], head[1])
        tail_span = (tail[0], tail[1])

        gold_set.add((head_span, tail_span, relation))

    return {
        "tokens": tokens,
        "ner": ner,
        "gold_relations": gold_set
    }

In [None]:
# Apply to dataset
prepared_dataset = [prepare_glirel_input(ex) for ex in test_data]

In [None]:
# Inspect one example
from pprint import pprint

In [None]:
example = prepared_dataset[4]

pprint(example)

{'gold_relations': {((50, 51), (43, 45), 'contains')},
 'ner': [[50, 50, 'LOCATION', 'Iraq'], [43, 44, 'PERSON', 'Abu Ghraib']],
 'tokens': ['Pitiless',
            'himself',
            ',',
            'he',
            'sent',
            'hundreds',
            'of',
            'thousands',
            'of',
            'his',
            'countrymen',
            'to',
            'miserable',
            'deaths',
            ',',
            'in',
            'the',
            'wars',
            'he',
            'started',
            'against',
            'Iran',
            'and',
            'Kuwait',
            ',',
            'in',
            'the',
            'torture',
            'chambers',
            'of',
            'his',
            'secret',
            'police',
            ',',
            'or',
            'on',
            'the',
            'gallows',
            'that',
            'became',
            'an',
            'industry',
            'a

### Example Testing

In [None]:
tokens = example["tokens"]
ner_spans = example["ner"]
labels = relation_labels

In [None]:
# Ready to predict with GLiREL:
relations = model.predict_relations(tokens, labels, threshold=0.0, ner=ner_spans, top_k=24)
print(relations)

[{'head_pos': [50, 51], 'tail_pos': [43, 45], 'head_text': ['Iraq'], 'tail_text': ['Abu', 'Ghraib'], 'label': 'place_of_death', 'score': 0.3135254979133606}, {'head_pos': [43, 45], 'tail_pos': [50, 51], 'head_text': ['Abu', 'Ghraib'], 'tail_text': ['Iraq'], 'label': 'location', 'score': 0.2854765057563782}, {'head_pos': [50, 51], 'tail_pos': [43, 45], 'head_text': ['Iraq'], 'tail_text': ['Abu', 'Ghraib'], 'label': 'location', 'score': 0.28398367762565613}, {'head_pos': [50, 51], 'tail_pos': [43, 45], 'head_text': ['Iraq'], 'tail_text': ['Abu', 'Ghraib'], 'label': 'industry', 'score': 0.27281203866004944}, {'head_pos': [43, 45], 'tail_pos': [50, 51], 'head_text': ['Abu', 'Ghraib'], 'tail_text': ['Iraq'], 'label': 'place_of_death', 'score': 0.26273059844970703}, {'head_pos': [43, 45], 'tail_pos': [50, 51], 'head_text': ['Abu', 'Ghraib'], 'tail_text': ['Iraq'], 'label': 'industry', 'score': 0.18127655982971191}, {'head_pos': [43, 45], 'tail_pos': [50, 51], 'head_text': ['Abu', 'Ghraib'], 

### Model Inference

In [None]:
predictions = []

for ex in prepared_dataset:
    tokens = ex["tokens"]
    labels = relation_labels
    ner_spans = ex["ner"]
    
    # Run GLiREL inference
    preds = model.predict_relations(tokens, labels, threshold=0.0, ner=ner_spans, top_k=24)
    
    predictions.append(preds)

In [None]:
print(predictions[0])

[{'head_pos': [0, 1], 'tail_pos': [7, 9], 'head_text': ['Englewood'], 'tail_text': ['New', 'Jersey'], 'label': 'location', 'score': 0.5626252889633179}, {'head_pos': [7, 9], 'tail_pos': [0, 1], 'head_text': ['New', 'Jersey'], 'tail_text': ['Englewood'], 'label': 'location', 'score': 0.43038100004196167}, {'head_pos': [0, 1], 'tail_pos': [7, 9], 'head_text': ['Englewood'], 'tail_text': ['New', 'Jersey'], 'label': 'contains', 'score': 0.09896647930145264}, {'head_pos': [7, 9], 'tail_pos': [0, 1], 'head_text': ['New', 'Jersey'], 'tail_text': ['Englewood'], 'label': 'contains', 'score': 0.07002299278974533}, {'head_pos': [7, 9], 'tail_pos': [0, 1], 'head_text': ['New', 'Jersey'], 'tail_text': ['Englewood'], 'label': 'place_of_birth', 'score': 0.0444660447537899}, {'head_pos': [0, 1], 'tail_pos': [7, 9], 'head_text': ['Englewood'], 'tail_text': ['New', 'Jersey'], 'label': 'place_of_birth', 'score': 0.03361504152417183}, {'head_pos': [0, 1], 'tail_pos': [7, 9], 'head_text': ['Englewood'], 't

### Performance Evaluation

In [None]:
def compute_metrics(dataset, predictions, threshold=0.5):
    assert len(dataset) == len(predictions)

    total_tp = 0
    total_fp = 0
    total_fn = 0

    for ex, preds in zip(dataset, predictions):
        gold = set(ex["gold_relations"])
        pred = set()
        for rel in preds:
            if rel["score"] < threshold:
                continue
            h_span = (rel["head_pos"][0], rel["head_pos"][1])
            t_span = (rel["tail_pos"][0], rel["tail_pos"][1])
            pred.add((h_span, t_span, rel["label"]))
        
        for rel in gold:
            rel = (rel[0], rel[1], rel[2].lower().strip())
        for rel in pred:
            rel = (rel[0], rel[1], rel[2].lower().strip())

        tp = len(gold & pred)
        fp = len(pred - gold)
        fn = len(gold - pred)
        
        total_tp += tp
        total_fp += fp
        total_fn += fn

    precision = total_tp / (total_tp + total_fp) if (total_tp + total_fp) > 0 else 0.0
    recall = total_tp / (total_tp + total_fn) if (total_tp + total_fn) > 0 else 0.0
    f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0.0

    return {
        "precision": precision,
        "recall": recall,
        "f1": f1,
        "true_positives": total_tp,
        "false_positives": total_fp,
        "false_negatives": total_fn
    }

In [None]:
metrics = compute_metrics(prepared_dataset, predictions, threshold=0.0)

# Print nicely
for key, value in metrics.items():
    print(f"{key}: {value:.4f}" if isinstance(value, float) else f"{key}: {value}")

precision: 0.0221
recall: 0.7355
f1: 0.0429
true_positives: 6538
false_positives: 289238
false_negatives: 2351


## CoNLL 2004

In [None]:
from datasets import load_dataset

In [None]:
ds = load_dataset("DFKI-SLT/conll04")

README.md: 0.00B [00:00, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/118k [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/40.5k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/46.6k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/922 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/231 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/288 [00:00<?, ? examples/s]

In [None]:
test_ds = ds["test"]

In [None]:
print(test_ds)

Dataset({
    features: ['entities', 'tokens', 'relations', 'orig_id'],
    num_rows: 288
})


In [None]:
print(test_ds[0])

{'entities': [{'end': 7, 'start': 5, 'type': 'Org'}, {'end': 9, 'start': 8, 'type': 'Other'}, {'end': 11, 'start': 10, 'type': 'Loc'}, {'end': 18, 'start': 17, 'type': 'Other'}], 'tokens': ['An', 'art', 'exhibit', 'at', 'the', 'Hakawati', 'Theatre', 'in', 'Arab', 'east', 'Jerusalem', 'was', 'a', 'series', 'of', 'portraits', 'of', 'Palestinians', 'killed', 'in', 'the', 'rebellion', '.'], 'relations': [{'head': 0, 'tail': 2, 'type': 'OrgBased_In'}], 'orig_id': 17}


### Pre-processing

In [None]:
def create_conll_input(ex):
    tokens = ex["tokens"]
    
    ner = []
    for ent in ex['entities']:
        start = ent['start']
        end = ent['end'] - 1  # make inclusive
        label = ent['type']
        text = ' '.join(tokens[start:end + 1])
        ner.append([start, end, label, text])

    # 2. Gold relations as set
    gold_set = set()
    for rel in ex['relations']:
        head = [ex['entities'][rel["head"]]['start'], ex['entities'][rel["head"]]['end']]
        tail = [ex['entities'][rel["tail"]]['start'], ex['entities'][rel["tail"]]['end']]
        relation = rel['type']

        head_span = (head[0], head[1])
        tail_span = (tail[0], tail[1])

        gold_set.add((head_span, tail_span, relation))

    return {
        "tokens": tokens,
        "ner": ner,
        "gold_relations": gold_set
    }

In [None]:
conll_input = [create_conll_input(example) for example in test_ds]

In [None]:
# Extract unique labels across dataset
def get_conll_labels(dataset):
    labels = set()
    for item in dataset:
        for rel in item["relations"]:
            labels.add(rel["relation_text"])
    return sorted(labels)

In [None]:
conll_labels = get_conll_labels(conll_data)
print("Relation Labels:", conll_labels)
print("Number of Relation Labels:", len(conll_labels))

Relation Labels: ['Kill', 'Live_In', 'Located_In', 'OrgBased_In', 'Work_For']
Number of Relation Labels: 5


### Example Test

In [None]:
conll_example = conll_input[0]
print(conll_example)

{'tokens': ['An', 'art', 'exhibit', 'at', 'the', 'Hakawati', 'Theatre', 'in', 'Arab', 'east', 'Jerusalem', 'was', 'a', 'series', 'of', 'portraits', 'of', 'Palestinians', 'killed', 'in', 'the', 'rebellion', '.'], 'ner': [[5, 6, 'Org', 'Hakawati Theatre'], [8, 8, 'Other', 'Arab'], [10, 10, 'Loc', 'Jerusalem'], [17, 17, 'Other', 'Palestinians']], 'gold_relations': {((5, 7), (10, 11), 'OrgBased_In')}}


In [None]:
conll_ex_tokens = conll_example["tokens"]
conll_ex_ner = conll_example["ner"]

In [None]:
conll_ex_prediction = model.predict_relations(conll_ex_tokens, labels=conll_labels, threshold=0.0, ner=conll_ex_ner, top_k=5)

In [None]:
print('Number of relations:', len(conll_ex_prediction))

sorted_conll_preds = sorted(conll_ex_prediction, key=lambda x: x['score'], reverse=True)
print("\nDescending Order by Score:")
for item in sorted_conll_preds:
    print(f"{item['head_text']} --> {item['label']} --> {item['tail_text']} | score: {item['score']}")

Number of relations: 60

Descending Order by Score:
['Hakawati', 'Theatre'] --> Located_In --> ['Jerusalem'] | score: 0.5305671095848083
['Jerusalem'] --> Located_In --> ['Hakawati', 'Theatre'] | score: 0.3967788517475128
['Arab'] --> Located_In --> ['Jerusalem'] | score: 0.39375779032707214
['Jerusalem'] --> Located_In --> ['Arab'] | score: 0.32883286476135254
['Hakawati', 'Theatre'] --> Located_In --> ['Arab'] | score: 0.320782870054245
['Palestinians'] --> Live_In --> ['Jerusalem'] | score: 0.31211742758750916
['Arab'] --> Located_In --> ['Hakawati', 'Theatre'] | score: 0.22228744626045227
['Palestinians'] --> Located_In --> ['Jerusalem'] | score: 0.21948927640914917
['Hakawati', 'Theatre'] --> Kill --> ['Palestinians'] | score: 0.21816113591194153
['Jerusalem'] --> Kill --> ['Palestinians'] | score: 0.1833038032054901
['Arab'] --> Live_In --> ['Jerusalem'] | score: 0.1621631383895874
['Hakawati', 'Theatre'] --> Live_In --> ['Jerusalem'] | score: 0.126860573887825
['Jerusalem'] --> 

### Inference

In [None]:
conll_predictions = []

for example in conll_input:
    tokens = example["tokens"]
    labels = conll_labels
    ner = example["ner"]

    # Run GLiREL inference
    preds = model.predict_relations(tokens, labels=conll_labels, threshold=0.0, ner=ner, top_k=5)
    
    conll_predictions.append(preds)

In [None]:
print(conll_predictions[0])

[{'head_pos': [5, 7], 'tail_pos': [10, 11], 'head_text': ['Hakawati', 'Theatre'], 'tail_text': ['Jerusalem'], 'label': 'Located_In', 'score': 0.5305671095848083}, {'head_pos': [10, 11], 'tail_pos': [5, 7], 'head_text': ['Jerusalem'], 'tail_text': ['Hakawati', 'Theatre'], 'label': 'Located_In', 'score': 0.3967788517475128}, {'head_pos': [8, 9], 'tail_pos': [10, 11], 'head_text': ['Arab'], 'tail_text': ['Jerusalem'], 'label': 'Located_In', 'score': 0.39375779032707214}, {'head_pos': [10, 11], 'tail_pos': [8, 9], 'head_text': ['Jerusalem'], 'tail_text': ['Arab'], 'label': 'Located_In', 'score': 0.32883286476135254}, {'head_pos': [5, 7], 'tail_pos': [8, 9], 'head_text': ['Hakawati', 'Theatre'], 'tail_text': ['Arab'], 'label': 'Located_In', 'score': 0.320782870054245}, {'head_pos': [17, 18], 'tail_pos': [10, 11], 'head_text': ['Palestinians'], 'tail_text': ['Jerusalem'], 'label': 'Live_In', 'score': 0.31211742758750916}, {'head_pos': [8, 9], 'tail_pos': [5, 7], 'head_text': ['Arab'], 'tail_

### Evaluation

In [None]:
def evaluate_conll(dataset, predictions, threshold=0.5):
    assert len(dataset) == len(predictions)

    tp = fp = fn = 0

    for example, preds in zip(dataset, predictions):
        gold = example["gold_relations"]

        pred = set()
        for rel in preds:
            if rel["score"] < threshold:
                continue
            pred.add((tuple(rel["head_pos"]), tuple(rel["tail_pos"]), rel["label"]))

        tp += len(pred & gold)
        fp += len(pred - gold)
        fn += len(gold - pred)

    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    f1 = 2 * precision * recall / (precision + recall + 1e-8)

    return {
        "precision": round(precision, 4),
        "recall": round(recall, 4),
        "f1": round(f1, 4),
        "TP": tp,
        "FP": fp,
        "FN": fn
    }

In [None]:
results_conll = evaluate_conll(conll_input, conll_predictions, threshold=0.3)

print("Conll Evaluation Results:")
for key, value in results_conll.items():
    print(f"{key}: {value:.4f}" if isinstance(value, float) else f"{key}: {value}")

Conll Evaluation Results:
precision: 0.1524
recall: 0.2441
f1: 0.1876
TP: 103
FP: 573
FN: 319


### Fuzzy Evaluation

In [None]:
def evaluate_conll_fuzzy(dataset, predictions, threshold=0.5):
    assert len(dataset) == len(predictions)

    tp = fp = fn = 0

    for example, preds in zip(dataset, predictions):
        gold = example["gold_relations"]
        gold_set = set()
        for rel in gold:
            label = 'Located_In' if rel[2] == 'OrgBased_In' else rel[2]
            gold_set.add((tuple(rel[0]), tuple(rel[1]), label))

        pred_set = set()
        for rel in preds:
            if rel["score"] < threshold:
                continue
            label = 'Located_In' if rel["label"] == 'OrgBased_In' else rel["label"]
            pred_set.add((tuple(rel["head_pos"]), tuple(rel["tail_pos"]), label))

        tp += len(pred_set & gold_set)
        fp += len(pred_set - gold_set)
        fn += len(gold_set - pred_set)

    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    f1 = 2 * precision * recall / (precision + recall + 1e-8)

    return {
        "precision": round(precision, 4),
        "recall": round(recall, 4),
        "f1": round(f1, 4),
        "TP": tp,
        "FP": fp,
        "FN": fn
    }

In [None]:
conll_results_fuzzy = evaluate_conll_fuzzy(conll_input, conll_predictions, threshold=0.3)

print("Conll Fuzzy Evaluation Results:")
for key, value in conll_results_fuzzy.items():
    print(f"{key}: {value:.4f}" if isinstance(value, float) else f"{key}: {value}")

Conll Fuzzy Evaluation Results:
precision: 0.2751
recall: 0.4408
f1: 0.3388
TP: 186
FP: 490
FN: 236
