# GoLLIE Experiment Runner for Colab

This notebook provides a complete environment setup and experiment execution for the GoLLIE project. It includes:
1. Environment configuration (Git, Dependencies)
2. Data downloading and preparation
3. Running evaluations across multiple guideline modules

## 1. Project Setup
Detect environment, clone repositories, and install dependencies.

In [None]:
import os
import sys
import subprocess

# 1. Detect if we are in Colab and clone YOUR experiments repo if needed
is_colab = 'google.colab' in sys.modules

if is_colab:
    print("Detected Google Colab environment.")
    # Clone your main experiment repo to get guidelines and requirements
    REPO_URL = "https://github.com/Marc8350/KDAI-Experiments.git"
    REPO_NAME = "KDAI-Experiments"
    if not os.path.exists(REPO_NAME):
        print(f"Cloning {REPO_NAME}...")
        !git clone {REPO_URL}
        %cd {REPO_NAME}
    else:
        %cd {REPO_NAME}
else:
    print("Running in local environment.")

# 3. Install requirements
if os.path.exists("requirements.txt"):
    print("Installing dependencies (this may take a few minutes)...")
    %pip install -r requirements.txt
    %pip install bitsandbytes accelerate # Required for quantization
else:
    print("requirements.txt not found!")

## 2. Data Download

In [None]:
def download_few_nerd():
    expected_test_dir = "few-nerd_test"
    if os.path.isdir(expected_test_dir):
        print("Dataset already prepared.")
        return

    print("Downloading Few-NERD dataset...")
    from datasets import load_dataset
    dataset = load_dataset("DFKI-SLT/few-nerd", name='supervised')

    for split_name, split_dataset in dataset.items():
        save_path = f"few-nerd_{split_name}"
        print(f"Saving {split_name} split...")
        split_dataset.save_to_disk(save_path)
    print("Dataset setup completed.")

download_few_nerd()

## 3. GoLLIE Experiments Configuration

In [None]:
import torch

# --- CONFIGURATION ---
TEST_MODE = True                # Set to False for a full experiment run
TEST_MODULE_IDX = 0            # Select which guideline module to use (0-13)
TEST_SENTENCE_IDX = 0           # Select which specific sentence to process

# Colab Resource Management
USE_4BIT = True                 # HIGHLY RECOMMENDED for free Colab T4 GPUs
USE_FLASH_ATTN = False          # Use only on A100/L4 GPUs. Set False for T4.
DTYPE = "bfloat16" if torch.cuda.is_bf16_supported() else "float16"
# ---------------------

## 4. Experiment Logic

In [None]:
import os
import sys
import json
import re
import inspect
import logging
import black
from datetime import datetime
from typing import Dict, List, Type, Any
from datasets import load_from_disk
from jinja2 import Template

PROJECT_ROOT = os.getcwd()
if PROJECT_ROOT not in sys.path: sys.path.insert(0, PROJECT_ROOT)

GOLLIE_PATH = os.path.join(PROJECT_ROOT, "GoLLIE")
if GOLLIE_PATH not in sys.path: sys.path.append(GOLLIE_PATH)

from src.model.load_model import load_model
from src.tasks.utils_typing import Entity, AnnotationList
from src.tasks.utils_scorer import SpanScorer

from annotation_guidelines import (
    guidelines_coarse_gollie, guidelines_coarse_gollie_detailed_v1, guidelines_coarse_gollie_detailed_v2,
    guidelines_coarse_gollie_detailed_v3, guidelines_coarse_gollie_v1, guidelines_coarse_gollie_v2,
    guidelines_coarse_gollie_v3, guidelines_fine_gollie, guidelines_fine_gollie_detailed_v1,
    guidelines_fine_gollie_detailed_v2, guidelines_fine_gollie_detailed_v3, guidelines_fine_gollie_v1,
    guidelines_fine_gollie_v2, guidelines_fine_gollie_v3
)

logging.basicConfig(level=logging.INFO)
guideline_modules = [
    guidelines_coarse_gollie, guidelines_coarse_gollie_detailed_v1, guidelines_coarse_gollie_detailed_v2,
    guidelines_coarse_gollie_detailed_v3, guidelines_coarse_gollie_v1, guidelines_coarse_gollie_v2,
    guidelines_coarse_gollie_v3, guidelines_fine_gollie, guidelines_fine_gollie_detailed_v1,
    guidelines_fine_gollie_detailed_v2, guidelines_fine_gollie_detailed_v3, guidelines_fine_gollie_v1,
    guidelines_fine_gollie_v2, guidelines_fine_gollie_v3
]

MODEL_LOAD_PARAMS = {
    "inference": True,
    "model_weights_name_or_path": "HiTZ/GoLLIE-7B",
    "quantization": 4 if USE_4BIT else None,
    "use_lora": False,
    "force_auto_device_map": True,
    "use_flash_attention": USE_FLASH_ATTN,
    "torch_dtype": DTYPE
}

GENERATE_PARAMS = {
    "max_new_tokens": 128,
    "do_sample": False,
    "min_new_tokens": 0,
    "num_beams": 1,
    "num_return_sequences": 1,
}

class MyEntityScorer(SpanScorer):
    valid_types: List[Type] = []
    def __call__(self, reference: List[List[Entity]], predictions: List[List[Entity]]) -> Dict[str, Any]:
        output = super().__call__(reference, predictions)
        return {"entities": output["spans"]}

def label_to_classname(label):
    if label == "O": return None
    parts = re.split(r'[-/]', label)
    return "".join(p.capitalize() for p in parts)

def run_experiment(test_mode=False, test_m_idx=0, test_s_idx=0):
    RESULTS_DIR = "GOLLIE-results"
    os.makedirs(RESULTS_DIR, exist_ok=True)
    
    ds = load_from_disk("./few-nerd_test")
    coarse_names = ds.features["ner_tags"].feature.names
    fine_names = ds.features["fine_ner_tags"].feature.names

    print("Loading model with params:", MODEL_LOAD_PARAMS)
    model, tokenizer = load_model(**MODEL_LOAD_PARAMS)

    template_path = os.path.join(GOLLIE_PATH, "templates", "prompt.txt")
    with open(template_path, "rt") as f:
        template = Template(f.read())

    active_modules = [guideline_modules[test_m_idx]] if test_mode else guideline_modules

    for module in active_modules:
        module_name = module.__name__
        timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
        log_filename = os.path.join(RESULTS_DIR, f"{module_name}_{timestamp}.json")
        
        logging.info(f"Processing module: {module_name}")
        tag_key = "ner_tags" if "coarse" in module_name else "fine_ner_tags"
        names_ref = coarse_names if "coarse" in module_name else fine_names
        
        scorer = MyEntityScorer()
        scorer.valid_types = module.ENTITY_DEFINITIONS
        gold_per_module, predictions_per_module, sentence_results = [], [], []

        if test_mode:
            sentences_to_process = [(test_s_idx, ds[test_s_idx])]
            print(f"TEST MODE: Running index {test_s_idx}")
        else:
            sentences_to_process = enumerate(ds)

        for i, sentence in sentences_to_process:
            tokens = sentence["tokens"]
            text = " ".join(tokens)
            tags = sentence[tag_key]
            gold = []
            for token, tag_id in zip(tokens, tags):
                class_name = label_to_classname(names_ref[tag_id])
                if class_name:
                    entity_class = getattr(module, class_name, None)
                    if entity_class: gold.append(entity_class(span=token))
            
            formatted_text = template.render(
                guidelines=[inspect.getsource(def_obj) for def_obj in module.ENTITY_DEFINITIONS],
                text=text, annotations=gold, gold=gold
            )
            
            try: formatted_text = black.format_str(formatted_text, mode=black.Mode())
            except: pass

            prompt = formatted_text.split("result =")[0] + "result ="
            inputs = tokenizer(prompt, add_special_tokens=True, return_tensors="pt")
            inputs = {k: v[:, :-1].to(model.device) for k, v in inputs.items()}
            
            out = model.generate(**inputs, **GENERATE_PARAMS)
            res_str = tokenizer.decode(out[0], skip_special_tokens=True).split("result =")[-1]
            
            try: prediction = AnnotationList.from_output(res_str, task_module=module_name)
            except: prediction = []

            score = scorer(reference=[gold], predictions=[prediction])
            gold_per_module.append(gold)
            predictions_per_module.append(prediction)
            
            sentence_results.append({
                "index": i, "timestamp": datetime.now().isoformat(), "text": text,
                "gold": [str(g) for g in gold], "prediction": [str(p) for p in prediction], "score": score
            })
            
            # 7. Intermediate Saving (Avoid data loss on long runs)
            current_overall_score = scorer(reference=gold_per_module, predictions=predictions_per_module)
            final_results = {
                "module": module_name, "timestamp": timestamp, "model_load_params": MODEL_LOAD_PARAMS,
                "generate_params": GENERATE_PARAMS, "overall_score": current_overall_score,
                "processed_count": len(sentence_results), "sentences": sentence_results
            }
            
            with open(log_filename, "w") as f: json.dump(final_results, f, indent=4)
            if i % 10 == 0: logging.info(f"[{module_name}] Progress: {i} sentences saved.")

    logging.info("Execution finished.")

## 5. Run Experiment

In [None]:
run_experiment(
    test_mode=TEST_MODE, 
    test_m_idx=TEST_MODULE_IDX, 
    test_s_idx=TEST_SENTENCE_IDX
)