# Run GoLLIE-7B on Google Colab (Free Tier)

This notebook allows you to run the **HiTZ/GoLLIE-7B** model on Google Colab's free tier using 4-bit quantization.

### Note on Flash Attention
You correctly noted that GoLLIE defaults to Flash Attention. However, **Flash Attention 2** requires Ampere GPUs (A100), while Colab Free Tier usually provides T4 GPUs (Turing architecture). 

On a T4, we cannot use `use_flash_attention_2=True`. Instead, this notebook uses the default attention implementation (or `sdpa` - Scaled Dot Product Attention) which works perfectly fine on T4, just slightly less memory-efficient than Flash Attention 2. 4-bit quantization ensures it still fits easily within the 16GB VRAM.

## Instructions
1. **Runtime Type**: Ensure you are using a GPU runtime (Runtime > Change runtime type > T4 GPU).
2. **Files**: Upload your generated `guidelines_coarse_gollie.py` or `guidelines_fine_gollie.py` files to the Colab file explorer (sidebar on the left).
3. **Run All**: Execute the cells below in order.

In [None]:
# 1. Install Dependencies
!pip install -q transformers accelerate bitsandbytes sentencepiece

In [None]:
# 2. Setup Environment and Mock Imports
# Since your guideline files import from 'src.tasks.utils_typing', we mock this module
# so you don't need to clone the entire repository just to define the guidelines.

import sys
from types import ModuleType
from dataclasses import dataclass

# Create a mock module structure
src = ModuleType("src")
src_tasks = ModuleType("src.tasks")
src_tasks_utils_typing = ModuleType("src.tasks.utils_typing")

sys.modules["src"] = src
sys.modules["src.tasks"] = src_tasks
sys.modules["src.tasks.utils_typing"] = src_tasks_utils_typing

# Define the base Entity class and dataclass
@dataclass
class Entity:
    span: str = None

src_tasks_utils_typing.Entity = Entity
src_tasks_utils_typing.dataclass = dataclass

In [None]:
# 3. Load Model with Quantization (Fits in 16GB free Colab GPU)
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig

model_id = "HiTZ/GoLLIE-7B"

# 4-bit quantization configuration
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16  # T4 supports float16
)

print(f"Loading {model_id} in 4-bit mode...")
tokenizer = AutoTokenizer.from_pretrained(model_id)

# We do NOT enforce use_flash_attention_2=True here data to T4 compatibility.
# Transformers will automatically pick the best supported attention implementation.
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)

print("Model loaded successfully!")

In [None]:
# 4. Define Inference Helper Function
def generate_response(prompt, model, tokenizer):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=256,
            do_sample=False, # Deterministic for extraction
            pad_token_id=tokenizer.eos_token_id
        )
    
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

print("Inference function ready.")

In [None]:
# 5. Import your Guidelines
# Make sure you have uploaded 'guidelines_fine_gollie.py' to Colab files!

try:
    import guidelines_fine_gollie as guidelines
    print("Loaded guidelines:", dir(guidelines))
    print("Entities defined:", [x.__name__ for x in guidelines.ENTITY_DEFINITIONS])
except ImportError:
    print("Error: guidelines_fine_gollie.py not found. Please upload it to the Files section.")
except Exception as e:
    print(f"Error loading guidelines: {e}")

In [None]:
# 6. Run Inference Example
# Construct a prompt following GoLLIE format (simplification)
# You typically need to prompt with the class definitions + input.

# Example input text
text_input = "The Shawshank Redemption is a 1994 American drama film written by Frank Darabont."

# Note: GoLLIE follows a specific prompt template. 
# The simplest way is to inspect `guidelines.ENTITY_DEFINITIONS` and construct the instruction.
# Below is a placeholder for testing model liveness.

prompt = f"Extract entities from the following text based on the guidelines: \n\nText: {text_input}\n\nAnswer:"

response = generate_response(prompt, model, tokenizer)
print("OUTPUT:", response)