# GLiREL RE Evaluation on Common Datasets

In [31]:
from pprint import pprint
from glirel import GLiREL
import spacy
import torch
from data_processing.common import load_jsonl, save_jsonl, run_inference

In [30]:
import importlib
import data_processing.common

importlib.reload(data_processing.common)

<module 'data_processing.common' from '/home/bt19d200/Ayaan/DDP-Baseline/data_processing/common.py'>

Load models

In [3]:
model = GLiREL.from_pretrained("jackboyla/glirel-large-v0", use_fast=False)
nlp = spacy.load("en_core_web_sm")



Check model device and move to gpu if not already using it

In [7]:
print(f"Model device: {model.device}")

Model device: cpu


In [8]:
if torch.cuda.is_available():
    model.to('cuda')
    print(f"Model moved to GPU. New device: {model.device}")

Model moved to GPU. New device: cuda


## NYT

#### Create test dataset

In [None]:
# import random

# full_data = load_jsonl("data/nyt_data.jsonl")

# # Shuffle for randomness
# random.seed(42)
# random.shuffle(full_data)

# # Take 10% as test data
# test_size = int(0.10 * len(full_data))
# test_data = full_data[:test_size]

# # Save to new file
# save_jsonl(test_data, "data/nyt_test_10pct.jsonl")

# print(f"Saved {len(test_data)} samples to 'data/nyt_test_10pct.jsonl'")

### Data Preprocessing

In [None]:
from data_processing.nyt import get_nyt_labels, create_nyt_input, evaluate_nyt

Load data

In [9]:
nyt_data = load_jsonl("data/nyt_test_10pct.jsonl")

Check data format

In [11]:
example = nyt_data[0]

pprint(example)

{'ner': [[7, 9, 'LOCATION', 'New Jersey'], [0, 1, 'LOCATION', 'Englewood']],
 'relations': [{'head': {'mention': 'New Jersey',
                         'position': [7, 9],
                         'type': 'LOCATION'},
                'relation_text': 'contains',
                'tail': {'mention': 'Englewood',
                         'position': [0, 1],
                         'type': 'LOCATION'}}],
 'tokenized_text': ['Englewood',
                    'is',
                    'one',
                    'of',
                    'nine',
                    'hospitals',
                    'in',
                    'New',
                    'Jersey',
                    'where',
                    'the',
                    'nurses',
                    'voted',
                    'last',
                    'month',
                    'to',
                    'authorize',
                    'a',
                    'strike',
                    'if',
                    'the',


Get and view relation labels in the nyt dataset:

In [14]:
nyt_labels = get_nyt_labels(nyt_data)

print("Relation Labels:", nyt_labels)
print("Number of Relation Labels:", len(nyt_labels))

Relation Labels: ['administrative_divisions', 'advisors', 'capital', 'children', 'company', 'contains', 'country', 'ethnicity', 'founders', 'geographic_distribution', 'location', 'major_shareholder_of', 'major_shareholders', 'nationality', 'neighborhood_of', 'people', 'place_founded', 'place_lived', 'place_of_birth', 'place_of_death', 'religion', 'teams']
Number of Relation Labels: 22


Convert nyt data to glirel input format and verify the format

In [16]:
nyt_input = [create_nyt_input(example) for example in nyt_data]

nyt_example = nyt_input[0]

pprint(nyt_example)

{'ner': [[7, 8, 'LOCATION', 'New Jersey'], [0, 0, 'LOCATION', 'Englewood']],
 'tokens': ['Englewood',
            'is',
            'one',
            'of',
            'nine',
            'hospitals',
            'in',
            'New',
            'Jersey',
            'where',
            'the',
            'nurses',
            'voted',
            'last',
            'month',
            'to',
            'authorize',
            'a',
            'strike',
            'if',
            'the',
            'contract',
            'dispute',
            'was',
            'not',
            'settled',
            'by',
            'June',
            '1',
            ',',
            'but',
            'the',
            'other',
            'hospitals',
            'either',
            'reached',
            'agreements',
            'or',
            'are',
            'still',
            'negotiating',
            '.']}


### Example Testing

In [17]:
nyt_tokens = nyt_example["tokens"]
nyt_ner = nyt_example["ner"]

In [18]:
nyt_relations = model.predict_relations(nyt_tokens, nyt_labels, threshold=0.0, ner=nyt_ner, top_k=len(nyt_labels))

pprint(nyt_relations)

[{'head_pos': [0, 1],
  'head_text': ['Englewood'],
  'label': 'location',
  'score': 0.5665032267570496,
  'tail_pos': [7, 9],
  'tail_text': ['New', 'Jersey']},
 {'head_pos': [7, 9],
  'head_text': ['New', 'Jersey'],
  'label': 'location',
  'score': 0.43044036626815796,
  'tail_pos': [0, 1],
  'tail_text': ['Englewood']},
 {'head_pos': [0, 1],
  'head_text': ['Englewood'],
  'label': 'contains',
  'score': 0.10872403532266617,
  'tail_pos': [7, 9],
  'tail_text': ['New', 'Jersey']},
 {'head_pos': [7, 9],
  'head_text': ['New', 'Jersey'],
  'label': 'contains',
  'score': 0.0758998841047287,
  'tail_pos': [0, 1],
  'tail_text': ['Englewood']},
 {'head_pos': [7, 9],
  'head_text': ['New', 'Jersey'],
  'label': 'place_of_birth',
  'score': 0.04050888493657112,
  'tail_pos': [0, 1],
  'tail_text': ['Englewood']},
 {'head_pos': [7, 9],
  'head_text': ['New', 'Jersey'],
  'label': 'neighborhood_of',
  'score': 0.03361096605658531,
  'tail_pos': [0, 1],
  'tail_text': ['Englewood']},
 {'he

### Inference

Run inference and collect results

In [29]:
nyt_predictions = run_inference(model, nyt_input, nyt_labels, threshold=0.0, top_k=len(nyt_labels))

Save results

In [33]:
save_jsonl(nyt_predictions, "data_predictions/nyt_predictions.jsonl")

OSError: [Errno 28] No space left on device

### Evaluation

Check results (can vary threshold)

In [34]:
results = evaluate_nyt(nyt_data, nyt_predictions, threshold=0.5)

# Print nicely
for key, value in results.items():
    print(f"{key}: {value:.4f}" if isinstance(value, float) else f"{key}: {value}")

The history saving thread hit an unexpected error (OperationalError('database or disk is full')).History will not be written to the database.
precision: 0.0951
recall: 0.0244
f1: 0.0389
true_positives: 217
false_positives: 2065
false_negatives: 8672


## CoNLL 2004

In [None]:
from datasets import load_dataset

In [57]:
import data_processing.conll04
importlib.reload(data_processing.conll04)

from data_processing.conll04 import get_conll04_labels, create_conll04_input, evaluate_conll04, fuzzy_evaluate_conll04

Load dataset and check data format

In [39]:
conll04_data = load_dataset("DFKI-SLT/conll04", split="test")

print(conll04_data)
pprint(conll04_data[0])

Dataset({
    features: ['entities', 'tokens', 'relations', 'orig_id'],
    num_rows: 288
})
{'entities': [{'end': 7, 'start': 5, 'type': 'Org'},
              {'end': 9, 'start': 8, 'type': 'Other'},
              {'end': 11, 'start': 10, 'type': 'Loc'},
              {'end': 18, 'start': 17, 'type': 'Other'}],
 'orig_id': 17,
 'relations': [{'head': 0, 'tail': 2, 'type': 'OrgBased_In'}],
 'tokens': ['An',
            'art',
            'exhibit',
            'at',
            'the',
            'Hakawati',
            'Theatre',
            'in',
            'Arab',
            'east',
            'Jerusalem',
            'was',
            'a',
            'series',
            'of',
            'portraits',
            'of',
            'Palestinians',
            'killed',
            'in',
            'the',
            'rebellion',
            '.']}


### Data Pre-processing

Get conll04 relation labels

In [48]:
conll04_labels = get_conll04_labels(conll04_data)
print("Relation Labels:", conll04_labels)
print("Number of Relation Labels:", len(conll04_labels))

Relation Labels: ['Kill', 'Live_In', 'Located_In', 'OrgBased_In', 'Work_For']
Number of Relation Labels: 5


Convert conll04 data to glirel input format and verify the format

In [47]:
conll04_input = [create_conll04_input(example) for example in conll04_data]

conll04_example = conll04_input[0]
pprint(conll04_example)

{'ner': [[5, 6, 'Org', 'Hakawati Theatre'],
         [8, 8, 'Other', 'Arab'],
         [10, 10, 'Loc', 'Jerusalem'],
         [17, 17, 'Other', 'Palestinians']],
 'tokens': ['An',
            'art',
            'exhibit',
            'at',
            'the',
            'Hakawati',
            'Theatre',
            'in',
            'Arab',
            'east',
            'Jerusalem',
            'was',
            'a',
            'series',
            'of',
            'portraits',
            'of',
            'Palestinians',
            'killed',
            'in',
            'the',
            'rebellion',
            '.']}


### Example Test

In [49]:
conll04_tokens = conll04_example["tokens"]
conll04_ner = conll04_example["ner"]

In [51]:
conll04_prediction = model.predict_relations(conll04_tokens, labels=conll04_labels, threshold=0.0, ner=conll04_ner, top_k=len(conll04_labels))

print('Number of relations:', len(conll04_prediction))

sorted_conll_preds = sorted(conll04_prediction, key=lambda x: x['score'], reverse=True)
print("\nDescending Order by Score:")
for item in sorted_conll_preds:
    print(f"{item['head_text']} --> {item['label']} --> {item['tail_text']} | score: {item['score']}")

Number of relations: 60

Descending Order by Score:
['Hakawati', 'Theatre'] --> Located_In --> ['Jerusalem'] | score: 0.5305655598640442
['Jerusalem'] --> Located_In --> ['Hakawati', 'Theatre'] | score: 0.39677685499191284
['Arab'] --> Located_In --> ['Jerusalem'] | score: 0.3937556743621826
['Jerusalem'] --> Located_In --> ['Arab'] | score: 0.3288310468196869
['Hakawati', 'Theatre'] --> Located_In --> ['Arab'] | score: 0.3207818269729614
['Palestinians'] --> Live_In --> ['Jerusalem'] | score: 0.31211718916893005
['Arab'] --> Located_In --> ['Hakawati', 'Theatre'] | score: 0.2222864329814911
['Palestinians'] --> Located_In --> ['Jerusalem'] | score: 0.21948803961277008
['Hakawati', 'Theatre'] --> Kill --> ['Palestinians'] | score: 0.21816208958625793
['Jerusalem'] --> Kill --> ['Palestinians'] | score: 0.18330393731594086
['Arab'] --> Live_In --> ['Jerusalem'] | score: 0.16216304898262024
['Hakawati', 'Theatre'] --> Live_In --> ['Jerusalem'] | score: 0.12686099112033844
['Jerusalem'] -

### Inference

Run inference on the entire dataset and check output format

In [52]:
conll04_predictions = run_inference(model, conll04_input, conll04_labels, threshold=0.0, top_k=len(conll04_labels))

pprint(conll04_predictions[0])

Running inference on device: cuda
[{'head_pos': [5, 7],
  'head_text': ['Hakawati', 'Theatre'],
  'label': 'Located_In',
  'score': 0.5305655598640442,
  'tail_pos': [10, 11],
  'tail_text': ['Jerusalem']},
 {'head_pos': [10, 11],
  'head_text': ['Jerusalem'],
  'label': 'Located_In',
  'score': 0.39677685499191284,
  'tail_pos': [5, 7],
  'tail_text': ['Hakawati', 'Theatre']},
 {'head_pos': [8, 9],
  'head_text': ['Arab'],
  'label': 'Located_In',
  'score': 0.3937556743621826,
  'tail_pos': [10, 11],
  'tail_text': ['Jerusalem']},
 {'head_pos': [10, 11],
  'head_text': ['Jerusalem'],
  'label': 'Located_In',
  'score': 0.3288310468196869,
  'tail_pos': [8, 9],
  'tail_text': ['Arab']},
 {'head_pos': [5, 7],
  'head_text': ['Hakawati', 'Theatre'],
  'label': 'Located_In',
  'score': 0.3207818269729614,
  'tail_pos': [8, 9],
  'tail_text': ['Arab']},
 {'head_pos': [17, 18],
  'head_text': ['Palestinians'],
  'label': 'Live_In',
  'score': 0.31211718916893005,
  'tail_pos': [10, 11],
  

### Evaluation

In [59]:
results_conll04 = evaluate_conll04(conll04_data, conll04_predictions, threshold=0.5)

print("Conll04 Evaluation Results:")
for key, value in results_conll04.items():
    print(f"{key}: {value:.4f}" if isinstance(value, float) else f"{key}: {value}")

Conll04 Evaluation Results:
precision: 0.1286
recall: 0.0213
f1: 0.0366
TP: 9
FP: 61
FN: 413


### Fuzzy Evaluation

In [60]:
conll_results_fuzzy = evaluate_conll_fuzzy(conll_input, conll_predictions, threshold=0.3)

print("Conll Fuzzy Evaluation Results:")
for key, value in conll_results_fuzzy.items():
    print(f"{key}: {value:.4f}" if isinstance(value, float) else f"{key}: {value}")

Conll04 Fuzzy Evaluation Results:
precision: 0.4714
recall: 0.0782
f1: 0.1341
TP: 33
FP: 37
FN: 389
