# GoLLIE Experiment Runner (Single-Test Version)

This notebook is configured for a **quick test run** (1 module, 1 sentence) to verify the Colab environment.

## 1. Project Setup

In [None]:
import os
import sys
import logging

if not os.path.exists("GoLLIE"): !git clone https://github.com/hitz-zentroa/GoLLIE.git
if os.path.exists("requirements.txt"): %pip install -r requirements.txt

In [None]:
def download_few_nerd():
    if os.path.isdir("few-nerd_test"): return
    from datasets import load_dataset
    ds = load_dataset("DFKI-SLT/few-nerd", name='supervised', split='test')
    ds.save_to_disk("few-nerd_test")
download_few_nerd()

In [None]:
import os, sys, json, re, inspect, logging, black
from datetime import datetime
from typing import Dict, List, Type, Any
from datasets import load_from_disk
from jinja2 import Template

PROJECT_ROOT = os.getcwd()
if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT)
GOLLIE_PATH = os.path.join(PROJECT_ROOT, "GoLLIE")
if GOLLIE_PATH not in sys.path: sys.path.append(GOLLIE_PATH)

from src.model.load_model import load_model
from src.tasks.utils_typing import Entity, AnnotationList
from src.tasks.utils_scorer import SpanScorer
from annotation_guidelines import guidelines_coarse_gollie

logging.basicConfig(level=logging.INFO)

MODEL_LOAD_PARAMS = {
    "inference": True, "model_weights_name_or_path": "HiTZ/GoLLIE-7B",
    "quantization": None, "use_lora": False, "force_auto_device_map": True,
    "use_flash_attention": True, "torch_dtype": "bfloat16"
}

def label_to_classname(label):
    if label == "O": return None
    return "".join(p.capitalize() for p in re.split(r'[-/]', label))

def run_quick_test():
    os.makedirs("GOLLIE-results", exist_ok=True)
    ds = load_from_disk("./few-nerd_test")
    model, tokenizer = load_model(**MODEL_LOAD_PARAMS)
    
    with open(os.path.join(GOLLIE_PATH, "templates/prompt.txt"), "rt") as f:
        template = Template(f.read())

    module = guidelines_coarse_gollie
    sentence = ds[0]
    text = " ".join(sentence["tokens"])
    
    gold = []
    names = ds.features["ner_tags"].feature.names
    for token, tag_id in zip(sentence["tokens"], sentence["ner_tags"]):
        class_name = label_to_classname(names[tag_id])
        if class_name:
            cls = getattr(module, class_name, None)
            if cls: gold.append(cls(span=token))

    prompt = template.render(
        guidelines=[inspect.getsource(d) for d in module.ENTITY_DEFINITIONS],
        text=text, annotations=gold, gold=gold
    )
    try: prompt = black.format_str(prompt, mode=black.Mode())
    except: pass
    
    prompt = prompt.split("result =")[0] + "result ="
    inputs = {k: v[:, :-1].to(model.device) for k, v in tokenizer(prompt, return_tensors="pt").items()}
    out = model.generate(**inputs, max_new_tokens=128, do_sample=False)
    res = tokenizer.decode(out[0], skip_special_tokens=True).split("result =")[-1]
    
    print(f"--- TEST RESULT ---\nText: {text}\nGold: {gold}\nPred: {res}")
run_quick_test()