# Train Fact Verifier (NLI Model)

This notebook demonstrates how to fine-tune a Natural Language Inference (NLI) model for the `FactVerifier` component. 
The goal is to classify the relationship between a **Premise** (Evidence) and a **Hypothesis** (Claim) as:
- `ENTAILMENT` (Supports)
- `CONTRADICTION` (Refutes)
- `NEUTRAL` (Not enough info)

## 1. Setup

In [None]:
!pip install transformers datasets evaluate torch scikit-learn

In [None]:
import os
import sys

# Add project root to path to import DatasetLoader
sys.path.append(os.path.abspath('../../'))

from keystone.data.dataset_loader import DatasetLoader
import torch
from datasets import Dataset, DatasetDict
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
    DataCollatorWithPadding
)
import evaluate
import numpy as np

## 2. Load Dataset (FEVER)
We use the `DatasetLoader` to fetch FEVER data. FEVER is a standard large-scale dataset for Fact Extraction and VERification.

In [None]:
loader = DatasetLoader()
# Load a small subset for demonstration. Increase max_samples for real training.
raw_data = loader.load_fever(split="train", max_samples=1000)

print(f"Loaded {len(raw_data)} examples.")
print("Sample:", raw_data[0])

## 3. Preprocess Data
Convert the loader's format to HuggingFace Dataset format suitable for NLI training.
Mapping:
- `faithful` -> `ENTAILMENT` (Label 0)
- `unverifiable` -> `NEUTRAL` (Label 1)
- `hallucinated` -> `CONTRADICTION` (Label 2) (In our context)

In [None]:
label2id = {"faithful": 0, "unverifiable": 1, "hallucinated": 2}
id2label = {0: "faithful", 1: "unverifiable", 2: "hallucinated"}

def prepare_nli_dataset(data):
    hf_data = []
    for item in data:
        if item['label'] not in label2id: 
            continue
            
        # FEVER dataset 'source_document' might be empty in our loader if we didn't resolve Wiki pages.
        # For this demo, we assume 'source_document' contains the evidence text.
        # If the loader returned empty evidence, we skip or mock it.
        premise = item.get('source_document') or "Evidence text missing in demo loader"
        hypothesis = item.get('generated_text') # This is the claim
        
        hf_data.append({
            "text": premise,
            "text_pair": hypothesis,
            "label": label2id[item['label']]
        })
    return hf_data

formatted_data = prepare_nli_dataset(raw_data)
dataset = Dataset.from_list(formatted_data)
dataset = dataset.train_test_split(test_size=0.2)

## 4. Tokenization

In [None]:
model_name = "cross-encoder/nli-deberta-v3-xsmall" # Efficient base model
tokenizer = AutoTokenizer.from_pretrained(model_name)

def preprocess_function(examples):
    return tokenizer(
        examples["text"], 
        examples["text_pair"], 
        truncation=True, 
        max_length=512
    )

tokenized_input = dataset.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

## 5. Model Training

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    model_name, 
    num_labels=3, 
    id2label=id2label, 
    label2id=label2id,
    ignore_mismatched_sizes=True # Only if resizing heads
)

training_args = TrainingArguments(
    output_dir="./nli_model_output",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_input["train"],
    eval_dataset=tokenized_input["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()

## 6. Save Model

In [None]:
save_path = "./saved_fact_verifier_model"
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print(f"Model saved to {save_path}")