In [1]:
import pandas as pd

from data_utils import load_data, GLiNERDataset

from pathlib import Path
from config import TrainingConfig
from gliner import GLiNER
from gliner.data_processing.collator import DataCollator

In [None]:
config = TrainingConfig()

# Create output directory
output_dir = Path(config.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)

# Load and prepare data
train_data = load_data(config.train_file)
eval_data = load_data(config.eval_file) if Path(config.eval_file).exists() else None

data/training_data/preprocessed_data/training_data_dev.json
Loading pre-processed data
data/training_data/preprocessed_data/training_data_test.json
Loading pre-processed data


In [11]:
model = GLiNER.from_pretrained(config.model_name)

# Initialize data collator
data_collator = DataCollator(
    model.config, 
    data_processor=model.data_processor, 
    prepare_labels=True
)

# Create datasets
train_dataset = GLiNERDataset(train_data, model.data_processor.transformer_tokenizer)

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]



In [13]:
train_dataset[0]

{'tokenized_text': '207 sloane street, london (sw1x 9qx)',
 'ner': [[4, 17, 'STREET_NAME'],
  [19, 25, 'CITY'],
  [0, 3, 'STREET_NUMBER'],
  [27, 35, 'POSTCODE']]}

In [5]:
def analyze_dataset_entities(dataset):
    all_entities = set()
    entity_counts = {}
    
    for i, item in enumerate(dataset):
        sample_entities = set()
        for span in item.get('spans', []):
            entity = span['label'].upper()
            all_entities.add(entity)
            sample_entities.add(entity)
            entity_counts[entity] = entity_counts.get(entity, 0) + 1
        
        if len(sample_entities) == 7:
            print(f"Sample {i} has 7 entities: {sample_entities}")
        elif len(sample_entities) == 8:
            print(f"Sample {i} has 8 entities: {sample_entities}")
    
    print(f"\nTotal unique entities in dataset: {len(all_entities)}")
    print(f"All entities: {sorted(all_entities)}")
    print(f"Config entities: {sorted(config.entity_types)}")
    print(f"Entity counts: {entity_counts}")

# Run this on your raw data before GLiNERDataset conversion
analyze_dataset_entities(train_data)

Sample 2 has 7 entities: {'CITY', 'BUILDING_NAME', 'UNIT_ID', 'STREET_NUMBER', 'STREET_NAME', 'POSTCODE', 'UNIT_TYPE'}
Sample 118 has 7 entities: {'CITY', 'BUILDING_NAME', 'UNIT_ID', 'STREET_NUMBER', 'STREET_NAME', 'POSTCODE', 'UNIT_TYPE'}
Sample 133 has 7 entities: {'CITY', 'BUILDING_NAME', 'UNIT_ID', 'STREET_NUMBER', 'STREET_NAME', 'POSTCODE', 'UNIT_TYPE'}
Sample 158 has 7 entities: {'CITY', 'BUILDING_NAME', 'UNIT_ID', 'STREET_NUMBER', 'STREET_NAME', 'POSTCODE', 'UNIT_TYPE'}
Sample 199 has 7 entities: {'CITY', 'BUILDING_NAME', 'UNIT_ID', 'STREET_NUMBER', 'STREET_NAME', 'POSTCODE', 'UNIT_TYPE'}
Sample 212 has 7 entities: {'CITY', 'BUILDING_NAME', 'UNIT_ID', 'STREET_NUMBER', 'STREET_NAME', 'POSTCODE', 'UNIT_TYPE'}
Sample 238 has 7 entities: {'CITY', 'BUILDING_NAME', 'UNIT_ID', 'STREET_NUMBER', 'STREET_NAME', 'POSTCODE', 'UNIT_TYPE'}
Sample 241 has 7 entities: {'CITY', 'BUILDING_NAME', 'UNIT_ID', 'STREET_NUMBER', 'STREET_NAME', 'POSTCODE', 'UNIT_TYPE'}
Sample 250 has 7 entities: {'CITY'

In [9]:
model = GLiNER.from_pretrained("urchade/gliner_base")

# Check available attributes/methods
print(dir(model))
print(f"Model config: {model.config}")
print(f"Model config attributes: {dir(model.config)}")

# Look for entity-related attributes
for attr in dir(model.config):
    if 'entity' in attr.lower() or 'label' in attr.lower():
        print(f"Config attribute: {attr} = {getattr(model.config, attr)}")

Fetching 4 files:   0%|          | 0/4 [00:00<?, ?it/s]

README.md:   0%|          | 0.00/4.78k [00:00<?, ?B/s]

.gitattributes:   0%|          | 0.00/1.57k [00:00<?, ?B/s]

gliner_config.json:   0%|          | 0.00/732 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/792M [00:00<?, ?B/s]



['T_destination', '__annotations__', '__call__', '__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getstate__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__setstate__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__', '_apply', '_backward_hooks', '_backward_pre_hooks', '_buffers', '_call_impl', '_compiled_call_impl', '_decode_arg', '_encode_arg', '_forward_hooks', '_forward_hooks_always_called', '_forward_hooks_with_kwargs', '_forward_pre_hooks', '_forward_pre_hooks_with_kwargs', '_from_pretrained', '_get_backward_hooks', '_get_backward_pre_hooks', '_get_name', '_hub_mixin_coders', '_hub_mixin_config', '_hub_mixin_config', '_hub_mixin_info', '_hub_mixin_init_parameters', '_hub_mixin_inject_config', '_hub_mixin_jsonable_custom_types', '_hub_mixin_jsonable_default_value