# üî• DimABSA Subtask 3: AOC Extraction con LoRA + Cross-Entropy Loss

## üéØ Obiettivo
Fine-tuning di **LLAMA 3.2-3B** con **LoRA** per estrarre triplet **(Aspect, Opinion, Category)** usando **cross-entropy loss** token-level.

## üìä Pipeline
```
Input Text ‚Üí LLAMA + LoRA ‚Üí Token Generation + Logits ‚Üí Parse JSON ‚Üí Cross-Entropy Loss ‚Üí Backprop
```

---

## üì¶ 1. Setup e Installazione Pacchetti

In [None]:
# Installazione pacchetti necessari
import sys
import subprocess

packages = [
    'transformers>=4.30.0',
    'datasets',
    'torch',
    'accelerate',
    'sentencepiece',
    'huggingface-hub',
    'peft',  # ‚Üê LoRA library
    'tqdm',
    'scikit-learn'
]

print("üì¶ Installazione pacchetti...")
for package in packages:
    try:
        subprocess.check_call([sys.executable, '-m', 'pip', 'install', '-q', package])
    except:
        print(f"‚ö† Errore nell'installazione di {package}")

print("‚úÖ Pacchetti installati con successo!")

In [None]:
# Import librerie
import json
import torch
import numpy as np
from pathlib import Path
from typing import List, Dict, Tuple, Optional
from tqdm.auto import tqdm
import warnings
warnings.filterwarnings('ignore')

# Transformers & PEFT
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    get_linear_schedule_with_warmup
)
from peft import LoraConfig, get_peft_model, PeftModel, prepare_model_for_kbit_training

# PyTorch
from torch.utils.data import Dataset, DataLoader
from torch import nn
from torch.optim import AdamW

# Sklearn
from sklearn.model_selection import train_test_split

# Set seed
SEED = 42
torch.manual_seed(SEED)
np.random.seed(SEED)

# Device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"üñ•Ô∏è  Using device: {device}")
if torch.cuda.is_available():
    print(f"üéÆ GPU: {torch.cuda.get_device_name(0)}")
    print(f"üíæ GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## üîë 2. Autenticazione HuggingFace

In [None]:
from huggingface_hub import login
import os

# Ottieni token da variabile d'ambiente
HF_TOKEN = os.getenv('HF_TOKEN')

if HF_TOKEN:
    login(token=HF_TOKEN)
    print("‚úÖ Autenticato con HuggingFace!")
else:
    print("‚ö†Ô∏è  HF_TOKEN non trovato. Alcuni modelli potrebbero non essere accessibili.")
    print("   Imposta con: $env:HF_TOKEN='your_token' (Windows PowerShell)")

## üìÅ 3. Caricamento Dataset

In [None]:
# Configurazione dataset
TRACK = "track_a"
SUBTASK = "3"
LANGUAGE = "eng"
DOMAIN = "restaurant"

BASE_URL = "https://raw.githubusercontent.com/DimABSA/DimABSA2026/main/task-dataset"
TRAIN_URL = f"{BASE_URL}/{TRACK}/subtask_{SUBTASK}/{LANGUAGE}/{LANGUAGE}_{DOMAIN}_train_alltasks.jsonl"

print("="*80)
print("üìä CONFIGURAZIONE DATASET")
print("="*80)
print(f"Track:    {TRACK}")
print(f"Subtask:  {SUBTASK} (AOC Extraction)")
print(f"Language: {LANGUAGE}")
print(f"Domain:   {DOMAIN}")
print(f"URL:      {TRAIN_URL}")
print("="*80)

In [None]:
def load_jsonl_from_url(url: str) -> List[Dict]:
    """Carica dati JSONL da URL GitHub."""
    import urllib.request
    
    print(f"‚è≥ Caricamento da {url.split('/')[-1]}...")
    with urllib.request.urlopen(url) as response:
        data = response.read().decode('utf-8')
        items = [json.loads(line) for line in data.strip().split('\n') if line.strip()]
    print(f"‚úì Caricati {len(items)} esempi")
    return items

# Carica dataset
train_data = load_jsonl_from_url(TRAIN_URL)

print(f"\n{'='*80}")
print(f"üìä DATASET CARICATO")
print(f"{'='*80}")
print(f"Numero totale di esempi: {len(train_data)}")
print(f"\nEsempio di record:")
print(json.dumps(train_data[0], indent=2))
print(f"{'='*80}")

## ‚úÇÔ∏è 4. Split Train/Val/Test

In [None]:
def create_dataset_splits(data: List[Dict], 
                         train_ratio: float = 0.7,
                         val_ratio: float = 0.15, 
                         test_ratio: float = 0.15,
                         seed: int = 42):
    """Crea split train/val/test."""
    np.random.seed(seed)
    indices = np.random.permutation(len(data))
    
    n_train = int(len(data) * train_ratio)
    n_val = int(len(data) * val_ratio)
    
    train_indices = indices[:n_train]
    val_indices = indices[n_train:n_train + n_val]
    test_indices = indices[n_train + n_val:]
    
    train_split = [data[i] for i in train_indices]
    val_split = [data[i] for i in val_indices]
    test_split = [data[i] for i in test_indices]
    
    return train_split, val_split, test_split

# Crea split
train_split, val_split, test_split = create_dataset_splits(train_data, seed=SEED)

print("\n" + "="*80)
print("‚úÇÔ∏è  SPLIT DATASET")
print("="*80)
print(f"Train: {len(train_split)} esempi ({len(train_split)/len(train_data)*100:.1f}%)")
print(f"Val:   {len(val_split)} esempi ({len(val_split)/len(train_data)*100:.1f}%)")
print(f"Test:  {len(test_split)} esempi ({len(test_split)/len(train_data)*100:.1f}%)")
print("="*80)

## üìù 5. Prompt Engineering

In [None]:
def create_prompt(text: str) -> str:
    """Crea prompt few-shot per estrazione AOC."""
    prompt = f"""### TASK
You are an aspect-category-opinion extraction system for restaurant reviews. Extract triplets and return ONLY a valid JSON array.

### RULES:
1. ASPECT = target entity/attribute. Use "NULL" if implicit.
2. CATEGORY = Entity#Attribute in UPPERCASE (e.g., "FOOD#QUALITY"). NEVER NULL.
3. OPINION = sentiment word/phrase from text. Use "NULL" if implicit.
4. Return ONLY the JSON array, nothing else.

VALID CATEGORIES:
- Entities: RESTAURANT, FOOD, DRINKS, AMBIENCE, SERVICE, LOCATION
- Attributes: GENERAL, PRICES, QUALITY, STYLE_OPTIONS, MISCELLANEOUS

### EXAMPLES:

Text: "the spicy tuna roll was unusually good and the rock shrimp tempura was awesome."
[{{"aspect": "spicy tuna roll", "category": "FOOD#QUALITY", "opinion": "unusually good"}}, {{"aspect": "rock shrimp tempura", "category": "FOOD#QUALITY", "opinion": "awesome"}}]

Text: "we love the pink pony."
[{{"aspect": "pink pony", "category": "RESTAURANT#GENERAL", "opinion": "love"}}]

Text: "the food here is rather good, but only if you like to wait for it."
[{{"aspect": "food", "category": "FOOD#QUALITY", "opinion": "rather good"}}, {{"aspect": "NULL", "category": "SERVICE#GENERAL", "opinion": "NULL"}}]

### INPUT
Text: "{text}"

### OUTPUT
"""
    return prompt

# Test prompt
test_text = "The food was amazing but the service was slow."
print("\nüìù Esempio di prompt:")
print("="*80)
print(create_prompt(test_text))
print("="*80)

## üîß 6. Parsing JSON (_parse_json_response)

In [None]:
def _parse_json_response(response: str) -> List[Dict[str, str]]:
    """
    Rimuove token non-JSON dall'output e parsa il risultato.
    
    Args:
        response: Testo generato dal modello
        
    Returns:
        Lista di dict con Aspect, Category, Opinion
    """
    import re
    
    try:
        # Step 1: Rimuovi markdown
        response = response.replace('```json', '').replace('```', '').strip()
        
        # Step 2: Trova tutti gli array JSON
        candidates = re.findall(r'\[[\s\S]*?\]', response)
        
        for json_str in candidates:
            try:
                result = json.loads(json_str)
                
                if isinstance(result, list) and len(result) > 0:
                    normalized = []
                    seen = set()
                    
                    for item in result:
                        if not isinstance(item, dict):
                            continue
                        
                        aspect = str(item.get('aspect', item.get('Aspect', 'NULL'))).strip()
                        category = str(item.get('category', item.get('Category', 'RESTAURANT#GENERAL'))).strip().upper()
                        opinion = str(item.get('opinion', item.get('Opinion', 'NULL'))).strip()
                        
                        if not aspect:
                            aspect = 'NULL'
                        if not opinion:
                            opinion = 'NULL'
                        if not category or category == 'NULL':
                            category = 'RESTAURANT#GENERAL'
                        
                        # Rimuovi duplicati
                        triplet_key = (aspect.lower(), category, opinion.lower())
                        if triplet_key not in seen:
                            seen.add(triplet_key)
                            normalized.append({
                                'Aspect': aspect,
                                'Category': category,
                                'Opinion': opinion
                            })
                    
                    if normalized:
                        return normalized
            
            except json.JSONDecodeError:
                continue
        
        # Fallback
        return [{'Aspect': 'NULL', 'Category': 'RESTAURANT#GENERAL', 'Opinion': 'NULL'}]
    
    except Exception:
        return [{'Aspect': 'NULL', 'Category': 'RESTAURANT#GENERAL', 'Opinion': 'NULL'}]

# Test parsing
test_response = '[{"aspect": "food", "category": "FOOD#QUALITY", "opinion": "great"}]'
print("\nüîß Test parsing:")
print(f"Input:  {test_response}")
print(f"Output: {_parse_json_response(test_response)}")

## ü§ñ 7. Caricamento Modello con LoRA

In [None]:
MODEL_NAME = "meta-llama/Llama-3.2-3B-Instruct"

print("\n" + "="*80)
print("ü§ñ CARICAMENTO MODELLO CON LoRA")
print("="*80)

# Carica tokenizer
print("‚è≥ Caricamento tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.padding_side = 'left'
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
print("‚úì Tokenizer caricato")

# Carica modello base
print("\n‚è≥ Caricamento modello base...")
base_model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True
)
print("‚úì Modello base caricato")

# Configura LoRA
print("\n‚öôÔ∏è  Configurazione LoRA...")
lora_config = LoraConfig(
    r=16,                    # Rank delle matrici LoRA
    lora_alpha=32,           # Scaling factor
    target_modules=["q_proj", "v_proj", "k_proj", "o_proj"],  # Attention layers
    lora_dropout=0.1,
    bias="none",
    task_type="CAUSAL_LM"
)

# Applica LoRA al modello
model = get_peft_model(base_model, lora_config)
model.print_trainable_parameters()

print("\n‚úÖ Modello LoRA pronto per il training!")
print("="*80)

## üìä 8. Dataset Class per Training

In [None]:
class AOCDataset(Dataset):
    """Dataset per training AOC extraction con LoRA."""
    
    def __init__(self, data: List[Dict], tokenizer, max_length: int = 512):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Input: prompt con il testo
        prompt = create_prompt(item['Text'])
        
        # Target: JSON output con AOC
        target_aocs = [
            {
                "aspect": q['Aspect'],
                "category": q['Category'],
                "opinion": q['Opinion']
            }
            for q in item['Quadruplet']
        ]
        target_json = json.dumps(target_aocs)
        
        # Testo completo: prompt + target
        full_text = prompt + target_json
        
        # Tokenizza
        encoding = self.tokenizer(
            full_text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )
        
        # Labels: copia di input_ids, ma maschera il prompt
        labels = encoding['input_ids'].clone()
        
        # Calcola lunghezza del prompt (da mascherare)
        prompt_encoding = self.tokenizer(
            prompt,
            truncation=True,
            return_tensors='pt'
        )
        prompt_length = prompt_encoding['input_ids'].shape[1]
        
        # Maschera il prompt (loss calcolata solo sul target JSON)
        labels[:, :prompt_length] = -100
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'labels': labels.squeeze()
        }

# Crea dataset
print("\nüìä Creazione dataset per training...")
train_dataset = AOCDataset(train_split, tokenizer, max_length=512)
val_dataset = AOCDataset(val_split, tokenizer, max_length=512)

print(f"‚úì Train dataset: {len(train_dataset)} esempi")
print(f"‚úì Val dataset:   {len(val_dataset)} esempi")

# Test dataset
print("\nüîç Test dataset:")
sample = train_dataset[0]
print(f"Input IDs shape:      {sample['input_ids'].shape}")
print(f"Attention Mask shape: {sample['attention_mask'].shape}")
print(f"Labels shape:         {sample['labels'].shape}")
print(f"\nPrimi 10 label: {sample['labels'][:10].tolist()}")
print(f"(Note: -100 = prompt mascherato, loss non calcolata)")

## üéì 9. Training Configuration

In [None]:
# Hyperparameters
BATCH_SIZE = 4
GRADIENT_ACCUMULATION_STEPS = 4  # Effective batch size = 4 * 4 = 16
LEARNING_RATE = 1e-4
NUM_EPOCHS = 3
WARMUP_STEPS = 100
LOGGING_STEPS = 50
EVAL_STEPS = 200
SAVE_STEPS = 500
OUTPUT_DIR = "./lora_checkpoints"

print("\n" + "="*80)
print("üéì CONFIGURAZIONE TRAINING")
print("="*80)
print(f"Batch size:                {BATCH_SIZE}")
print(f"Gradient accumulation:     {GRADIENT_ACCUMULATION_STEPS}")
print(f"Effective batch size:      {BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"Learning rate:             {LEARNING_RATE}")
print(f"Num epochs:                {NUM_EPOCHS}")
print(f"Warmup steps:              {WARMUP_STEPS}")
print(f"Total training steps:      {len(train_dataset) // (BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS) * NUM_EPOCHS}")
print(f"Output directory:          {OUTPUT_DIR}")
print("="*80)

## üöÄ 10. Training Loop con Cross-Entropy Loss

In [None]:
# DataLoaders
train_loader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=0  # Set to 0 for Windows compatibility
)

val_loader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=0
)

# Optimizer & Scheduler
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)

total_steps = len(train_loader) * NUM_EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=WARMUP_STEPS,
    num_training_steps=total_steps
)

# Training tracking
training_stats = {
    'train_loss': [],
    'val_loss': [],
    'learning_rates': []
}

print("\n" + "="*80)
print("üöÄ INIZIO TRAINING")
print("="*80)

# Training loop
model.train()
global_step = 0

for epoch in range(NUM_EPOCHS):
    print(f"\nüìÖ Epoch {epoch + 1}/{NUM_EPOCHS}")
    print("-" * 80)
    
    epoch_loss = 0
    progress_bar = tqdm(train_loader, desc=f"Training Epoch {epoch + 1}")
    
    for step, batch in enumerate(progress_bar):
        # Move to device
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
        
        # Forward pass
        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        
        loss = outputs.loss / GRADIENT_ACCUMULATION_STEPS
        
        # Backward pass
        loss.backward()
        
        epoch_loss += loss.item() * GRADIENT_ACCUMULATION_STEPS
        
        # Update weights ogni GRADIENT_ACCUMULATION_STEPS
        if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
            optimizer.step()
            scheduler.step()
            optimizer.zero_grad()
            global_step += 1
            
            # Logging
            if global_step % LOGGING_STEPS == 0:
                avg_loss = epoch_loss / (step + 1)
                lr = scheduler.get_last_lr()[0]
                progress_bar.set_postfix({
                    'loss': f'{avg_loss:.4f}',
                    'lr': f'{lr:.2e}'
                })
                training_stats['train_loss'].append(avg_loss)
                training_stats['learning_rates'].append(lr)
            
            # Validation
            if global_step % EVAL_STEPS == 0:
                print(f"\n\nüìä Validation at step {global_step}")
                model.eval()
                val_loss = 0
                
                with torch.no_grad():
                    for val_batch in tqdm(val_loader, desc="Validation"):
                        val_input_ids = val_batch['input_ids'].to(device)
                        val_attention_mask = val_batch['attention_mask'].to(device)
                        val_labels = val_batch['labels'].to(device)
                        
                        val_outputs = model(
                            input_ids=val_input_ids,
                            attention_mask=val_attention_mask,
                            labels=val_labels
                        )
                        val_loss += val_outputs.loss.item()
                
                avg_val_loss = val_loss / len(val_loader)
                training_stats['val_loss'].append(avg_val_loss)
                
                print(f"Validation Loss: {avg_val_loss:.4f}\n")
                model.train()
            
            # Save checkpoint
            if global_step % SAVE_STEPS == 0:
                checkpoint_dir = f"{OUTPUT_DIR}/checkpoint-{global_step}"
                print(f"\nüíæ Saving checkpoint to {checkpoint_dir}")
                model.save_pretrained(checkpoint_dir)
                tokenizer.save_pretrained(checkpoint_dir)
    
    # End of epoch
    avg_epoch_loss = epoch_loss / len(train_loader)
    print(f"\nüìä Epoch {epoch + 1} completed. Average Loss: {avg_epoch_loss:.4f}")

print("\n" + "="*80)
print("‚úÖ TRAINING COMPLETATO!")
print("="*80)

# Save final model
final_dir = f"{OUTPUT_DIR}/final"
print(f"\nüíæ Saving final model to {final_dir}")
model.save_pretrained(final_dir)
tokenizer.save_pretrained(final_dir)

## üìà 11. Visualizzazione Training

In [None]:
import matplotlib.pyplot as plt

# Plot training loss
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Training loss
axes[0].plot(training_stats['train_loss'], label='Train Loss', color='blue', alpha=0.7)
axes[0].set_xlabel('Step (x50)')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training Loss', fontweight='bold')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Learning rate
axes[1].plot(training_stats['learning_rates'], label='Learning Rate', color='green', alpha=0.7)
axes[1].set_xlabel('Step (x50)')
axes[1].set_ylabel('Learning Rate')
axes[1].set_title('Learning Rate Schedule', fontweight='bold')
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print(f"\nüìä Training Statistics:")
print(f"Final Train Loss: {training_stats['train_loss'][-1]:.4f}")
if training_stats['val_loss']:
    print(f"Final Val Loss:   {training_stats['val_loss'][-1]:.4f}")

## üéØ 12. Inference e Evaluation

In [None]:
def generate_aoc(text: str, model, tokenizer, max_new_tokens: int = 150):
    """Genera AOC per un testo usando il modello fine-tunato."""
    prompt = create_prompt(text)
    
    inputs = tokenizer(prompt, return_tensors='pt', truncation=True, max_length=4096)
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    
    model.eval()
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            pad_token_id=tokenizer.pad_token_id,
            eos_token_id=tokenizer.eos_token_id
        )
    
    # Decodifica solo la parte generata
    input_length = inputs['attention_mask'][0].sum().item()
    generated_ids = outputs[0][input_length:]
    response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
    
    # Parsa JSON
    aocs = _parse_json_response(response)
    
    return aocs, response

# Test su esempi
print("\n" + "="*80)
print("üéØ TEST INFERENCE")
print("="*80)

test_texts = [
    "The food was amazing but the service was slow.",
    "I love this restaurant! Great atmosphere and delicious pizza.",
    "The prices are too high for the quality you get."
]

for i, text in enumerate(test_texts, 1):
    print(f"\n--- Test {i} ---")
    print(f"Text: {text}")
    
    aocs, raw_response = generate_aoc(text, model, tokenizer)
    
    print(f"\nRaw response: {raw_response}")
    print(f"\nParsed AOCs:")
    for aoc in aocs:
        print(f"  - Aspect: '{aoc['Aspect']}', Category: '{aoc['Category']}', Opinion: '{aoc['Opinion']}'")

print("\n" + "="*80)

## üìä 13. Evaluation su Validation Set

In [None]:
def evaluate_model(model, tokenizer, dataset: List[Dict], num_samples: int = 100):
    """
    Valuta il modello su un dataset.
    Calcola Precision, Recall, F1 per triplet matching.
    """
    print(f"\n{'='*80}")
    print(f"üìä EVALUATION")
    print(f"{'='*80}")
    print(f"Num samples: {num_samples}")
    
    sample_data = dataset[:num_samples]
    
    total_tp = 0
    total_pred = 0
    total_gold = 0
    
    for item in tqdm(sample_data, desc="Evaluating"):
        text = item['Text']
        gold_quadruplets = item['Quadruplet']
        
        # Predizione
        pred_aocs, _ = generate_aoc(text, model, tokenizer)
        
        # Normalizza gold
        gold_triplets = [
            {
                'aspect': q['Aspect'].strip().lower(),
                'category': q['Category'].strip().upper(),
                'opinion': q['Opinion'].strip().lower()
            }
            for q in gold_quadruplets
        ]
        
        # Normalizza pred
        pred_triplets = [
            {
                'aspect': p['Aspect'].strip().lower(),
                'category': p['Category'].strip().upper(),
                'opinion': p['Opinion'].strip().lower()
            }
            for p in pred_aocs
        ]
        
        total_gold += len(gold_triplets)
        total_pred += len(pred_triplets)
        
        # Match
        matched_gold = set()
        for pred in pred_triplets:
            for j, gold in enumerate(gold_triplets):
                if j in matched_gold:
                    continue
                if (pred['aspect'] == gold['aspect'] and
                    pred['category'] == gold['category'] and
                    pred['opinion'] == gold['opinion']):
                    total_tp += 1
                    matched_gold.add(j)
                    break
    
    # Calcola metriche
    precision = total_tp / total_pred if total_pred > 0 else 0
    recall = total_tp / total_gold if total_gold > 0 else 0
    f1 = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    
    print(f"\n{'='*80}")
    print(f"üìà RISULTATI")
    print(f"{'='*80}")
    print(f"Precision: {precision:.4f} ({precision*100:.2f}%)")
    print(f"Recall:    {recall:.4f} ({recall*100:.2f}%)")
    print(f"F1 Score:  {f1:.4f} ({f1*100:.2f}%)")
    print(f"\nTrue Positives:  {total_tp}")
    print(f"Total Predicted: {total_pred}")
    print(f"Total Gold:      {total_gold}")
    print(f"{'='*80}")
    
    return {
        'precision': precision,
        'recall': recall,
        'f1': f1,
        'tp': total_tp,
        'total_pred': total_pred,
        'total_gold': total_gold
    }

# Evalua su validation set
val_results = evaluate_model(model, tokenizer, val_split, num_samples=100)

## üíæ 14. Load LoRA Checkpoint

In [None]:
# Per caricare un checkpoint salvato in futuro
def load_lora_model(checkpoint_path: str, model_name: str = MODEL_NAME):
    """Carica un modello LoRA da checkpoint."""
    print(f"\n‚è≥ Caricamento LoRA checkpoint da {checkpoint_path}...")
    
    # Carica tokenizer
    tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)
    
    # Carica modello base
    base_model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.float16,
        device_map="auto"
    )
    
    # Carica adapter LoRA
    model = PeftModel.from_pretrained(base_model, checkpoint_path)
    
    print("‚úÖ Modello LoRA caricato con successo!")
    return model, tokenizer

# Esempio uso:
# model, tokenizer = load_lora_model("./lora_checkpoints/checkpoint-1000")

## üéâ Fine!

### üìù Summary

Questo notebook implementa:

1. ‚úÖ **LoRA fine-tuning** di LLAMA 3.2-3B
2. ‚úÖ **Cross-entropy loss** token-level
3. ‚úÖ **Masking del prompt** (loss solo sul target JSON)
4. ‚úÖ **Training loop** custom con validation
5. ‚úÖ **Checkpointing** per salvare/caricare modelli
6. ‚úÖ **Evaluation** con Precision/Recall/F1

### üöÄ Next Steps

- Tunare hyperparameters (learning rate, batch size, LoRA rank)
- Aumentare epochs se necessario
- Confrontare con baseline (LLAMA zero-shot)
- Test su test set finale
- Considerare QLoRA (4-bit quantization) per memoria ridotta

---