# Train Review: Claim Extractor Model

This notebook fine-tunes a T5 model to extract atomic factual claims from text. 
It validates the setup, prepares a synthetic dataset (50+ examples), and trains the model.

## 1. Setup

In [None]:
!pip install transformers datasets evaluate torch sentencepiece

In [None]:
import json
import torch
from datasets import Dataset
from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq
)

# Check device
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
print(f"Using device: {device}")

## 2. Dataset Preparation (50 Synthetic Examples)
We create a diverse dataset covering Science, History, Tech, and Geography to make the model robust.

In [None]:
data = [
    # --- Science ---
    {"text": "The Earth orbits the Sun every 365.25 days. The Moon is Earth's only natural satellite.", "claims": [{"claim": "The Earth orbits the Sun every 365.25 days.", "type": "factual"}, {"claim": "The Moon is Earth's only natural satellite.", "type": "factual"}]},
    {"text": "Water boils at 100 degrees Celsius at sea level. Pure water has a neutral pH of 7.", "claims": [{"claim": "Water boils at 100 degrees Celsius at sea level.", "type": "factual"}, {"claim": "Pure water has a neutral pH of 7.", "type": "factual"}]},
    {"text": "Photosynthesis converts light energy into chemical energy. Plants need sunlight to grow.", "claims": [{"claim": "Photosynthesis converts light energy into chemical energy.", "type": "factual"}, {"claim": "Plants need sunlight to grow.", "type": "factual"}]},
    {"text": "The human skeleton consists of 206 bones. The femur is the longest bone.", "claims": [{"claim": "The human skeleton consists of 206 bones.", "type": "factual"}, {"claim": "The femur is the longest bone in the human body.", "type": "factual"}]},
    {"text": "Light travels at approximately 299,792 kilometers per second.", "claims": [{"claim": "Light travels at approximately 299,792 kilometers per second.", "type": "factual"}]},
    {"text": "DNA contains the genetic instructions for development. It is shaped like a double helix.", "claims": [{"claim": "DNA contains genetic instructions for development.", "type": "factual"}, {"claim": "DNA is shaped like a double helix.", "type": "factual"}]},
    {"text": "Electrons have a negative charge, while protons are positive.", "claims": [{"claim": "Electrons have a negative charge.", "type": "factual"}, {"claim": "Protons have a positive charge.", "type": "factual"}]},
    {"text": "The Big Bang theory explains the origin of the universe.", "claims": [{"claim": "The Big Bang theory explains the origin of the universe.", "type": "factual"}]},
    {"text": "Honey never spoils. Archaeologists have found edible honey in ancient tombs.", "claims": [{"claim": "Honey never spoils.", "type": "factual"}, {"claim": "Archaeologists have found edible honey in ancient tombs.", "type": "factual"}]},
    {"text": "Venus is the hottest planet in our solar system.", "claims": [{"claim": "Venus is the hottest planet in our solar system.", "type": "factual"}]},

    # --- Tech ---
    {"text": "Python was created by Guido van Rossum. It was released in 1991.", "claims": [{"claim": "Python was created by Guido van Rossum.", "type": "factual"}, {"claim": "Python was released in 1991.", "type": "factual"}]},
    {"text": "Linux is an open-source operating system kernel. Linus Torvalds started it.", "claims": [{"claim": "Linux is an open-source operating system kernel.", "type": "factual"}, {"claim": "Linus Torvalds started Linux.", "type": "factual"}]},
    {"text": "HTML stands for HyperText Markup Language.", "claims": [{"claim": "HTML stands for HyperText Markup Language.", "type": "factual"}]},
    {"text": "The first iPhone was announced by Steve Jobs in 2007.", "claims": [{"claim": "The first iPhone was announced by Steve Jobs in 2007.", "type": "factual"}]},
    {"text": "Google was founded by Larry Page and Sergey Brin.", "claims": [{"claim": "Google was founded by Larry Page.", "type": "factual"}, {"claim": "Google was founded by Sergey Brin.", "type": "factual"}]},
    {"text": "RAM stands for Random Access Memory.", "claims": [{"claim": "RAM stands for Random Access Memory.", "type": "factual"}]},
    {"text": "The CPU is considered the brain of the computer.", "claims": [{"claim": "The CPU is considered the brain of the computer.", "type": "factual"}]},
    {"text": "Bluetooth is a wireless technology standard.", "claims": [{"claim": "Bluetooth is a wireless technology standard.", "type": "factual"}]},
    {"text": "JavaScript is a programming language widely used for web development.", "claims": [{"claim": "JavaScript is a programming language used for web development.", "type": "factual"}]},
    {"text": "Amazon Web Services provides cloud computing platforms.", "claims": [{"claim": "Amazon Web Services provides cloud computing platforms.", "type": "factual"}]},

    # --- History ---
    {"text": "World War II ended in 1945. The United Nations was established the same year.", "claims": [{"claim": "World War II ended in 1945.", "type": "factual"}, {"claim": "The United Nations was established in 1945.", "type": "factual"}]},
    {"text": "The Titanic sank in 1912 after hitting an iceberg.", "claims": [{"claim": "The Titanic sank in 1912.", "type": "factual"}, {"claim": "The Titanic hit an iceberg.", "type": "factual"}]},
    {"text": "The Declaration of Independence was signed in 1776.", "claims": [{"claim": "The Declaration of Independence was signed in 1776.", "type": "factual"}]},
    {"text": "Julius Caesar was a Roman general and statesman.", "claims": [{"claim": "Julius Caesar was a Roman general.", "type": "factual"}, {"claim": "Julius Caesar was a statesman.", "type": "factual"}]},
    {"text": "The Berlin Wall fell in 1989.", "claims": [{"claim": "The Berlin Wall fell in 1989.", "type": "factual"}]},
    {"text": "Cleopatra was the last active ruler of the Ptolemaic Kingdom of Egypt.", "claims": [{"claim": "Cleopatra was the last active ruler of the Ptolemaic Kingdom of Egypt.", "type": "factual"}]},
    {"text": "The Industrial Revolution began in Great Britain.", "claims": [{"claim": "The Industrial Revolution began in Great Britain.", "type": "factual"}]},
    {"text": "Neil Armstrong walked on the moon in 1969.", "claims": [{"claim": "Neil Armstrong walked on the moon in 1969.", "type": "factual"}]},
    {"text": "Mahatma Gandhi led the non-violent independence movement in India.", "claims": [{"claim": "Mahatma Gandhi led the non-violent independence movement in India.", "type": "factual"}]},
    {"text": "The Great Wall of China is a series of fortifications.", "claims": [{"claim": "The Great Wall of China is a series of fortifications.", "type": "factual"}]},

    # --- Geography ---
    {"text": "Tokyo is the capital of Japan. It is the most populous metropolitan area.", "claims": [{"claim": "Tokyo is the capital of Japan.", "type": "factual"}, {"claim": "Tokyo is the most populous metropolitan area.", "type": "factual"}]},
    {"text": "Mount Everest is the highest mountain above sea level.", "claims": [{"claim": "Mount Everest is the highest mountain above sea level.", "type": "factual"}]},
    {"text": "The Nile is traditionally considered the longest river in the world.", "claims": [{"claim": "The Nile is traditionally considered the longest river in the world.", "type": "factual"}]},
    {"text": "Australia is both a country and a continent.", "claims": [{"claim": "Australia is a country.", "type": "factual"}, {"claim": "Australia is a continent.", "type": "factual"}]},
    {"text": "Brazil is the largest country in South America.", "claims": [{"claim": "Brazil is the largest country in South America.", "type": "factual"}]},
    {"text": "The Sahara is the largest hot desert in the world.", "claims": [{"claim": "The Sahara is the largest hot desert in the world.", "type": "factual"}]},
    {"text": "Canada has the longest coastline in the world.", "claims": [{"claim": "Canada has the longest coastline in the world.", "type": "factual"}]},
    {"text": "The Pacific Ocean is the largest ocean on Earth.", "claims": [{"claim": "The Pacific Ocean is the largest ocean on Earth.", "type": "factual"}]},
    {"text": "Antarctica is the coldest continent.", "claims": [{"claim": "Antarctica is the coldest continent.", "type": "factual"}]},
    {"text": "London is located on the River Thames.", "claims": [{"claim": "London is located on the River Thames.", "type": "factual"}]},

    # --- General/Complex ---
    {"text": "A leap year has 366 days instead of 365.", "claims": [{"claim": "A leap year has 366 days.", "type": "factual"}]},
    {"text": "Gold is a chemical element with symbol Au.", "claims": [{"claim": "Gold is a chemical element.", "type": "factual"}, {"claim": "The symbol for Gold is Au.", "type": "factual"}]},
    {"text": "Shakespeare wrote Romeo and Juliet.", "claims": [{"claim": "Shakespeare wrote Romeo and Juliet.", "type": "factual"}]},
    {"text": "The Mona Lisa was painted by Leonardo da Vinci.", "claims": [{"claim": "The Mona Lisa was painted by Leonardo da Vinci.", "type": "factual"}]},
    {"text": "E=mc^2 is a formula derived by Albert Einstein.", "claims": [{"claim": "E=mc^2 is a formula derived by Albert Einstein.", "type": "factual"}]},
    {"text": "Penguins are flightless birds found in the Southern Hemisphere.", "claims": [{"claim": "Penguins are flightless birds.", "type": "factual"}, {"claim": "Penguins are found in the Southern Hemisphere.", "type": "factual"}]},
    {"text": "Diamond is the hardest natural substance.", "claims": [{"claim": "Diamond is the hardest natural substance.", "type": "factual"}]},
    {"text": "The human heart pumps blood through the circulatory system.", "claims": [{"claim": "The human heart pumps blood through the circulatory system.", "type": "factual"}]},
    {"text": "Oxygen is essential for human respiration.", "claims": [{"claim": "Oxygen is essential for human respiration.", "type": "factual"}]},
    {"text": "The Euro is the official currency of the Eurozone.", "claims": [{"claim": "The Euro is the official currency of the Eurozone.", "type": "factual"}]}
]

def format_dataset(data):
    formatted_data = []
    for entry in data:
        formatted_data.append({
            "input_text": f"extract claims: {entry['text']}",
            "target_text": json.dumps(entry['claims'])
        })
    return formatted_data

dataset = Dataset.from_list(format_dataset(data))
dataset = dataset.train_test_split(test_size=0.1)
print(f"Training examples: {len(dataset['train'])}")
print(f"Test examples: {len(dataset['test'])}")

## 3. Preprocessing & Tokenization

In [None]:
model_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)
model = model.to(device)

def preprocess_function(examples):
    inputs = examples["input_text"]
    targets = examples["target_text"]
    
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)

    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=512, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = dataset.map(preprocess_function, batched=True)

## 4. Training

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./model_output",
    evaluation_strategy="epoch",
    learning_rate=3e-4,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    weight_decay=0.01,
    save_total_limit=2,
    num_train_epochs=15,
    predict_with_generate=True,
    logging_steps=10,
    report_to="none"
)

data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()

## 5. Evaluation & Inference

In [None]:
def extract_claims(text):
    input_text = f"extract claims: {text}"
    inputs = tokenizer(input_text, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model.generate(
            inputs.input_ids, 
            max_length=512, 
            num_beams=4,
            length_penalty=1.0
        )
        
    claim_json = tokenizer.decode(outputs[0], skip_special_tokens=True)
    try:
        return json.loads(claim_json)
    except json.JSONDecodeError:
        return claim_json

# Test on new text
test_text = "Mars is the fourth planet from the Sun. It is known as the Red Planet."
print("Input:", test_text)
print("Output:", json.dumps(extract_claims(test_text), indent=2))

In [None]:
save_path = "./saved_claim_extractor_model"
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print(f"Model saved to {save_path}")