# Train Correction Engine

This notebook trains a Seq2Seq model (T5) to **rewrite** incorrect claims based on contradictory evidence.
Input: "Claim: <wrong claim> Evidence: <true evidence>"
Target: "<corrected claim>"


In [None]:
!pip install transformers datasets evaluate torch sentencepiece

In [None]:
from datasets import Dataset
from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
    Seq2SeqTrainingArguments,
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq
)
import json

## 1. Prepare Synthetic Dataset

In [None]:
# Example pairs of (Wrong Claim + Evidence) -> Corrected Output
data = [
    {
        "claim": "The Earth is flat.",
        "evidence": "The Earth is an oblate spheroid, meaning it is mostly spherical but slightly flattened at the poles.",
        "correction": "The Earth is an oblate spheroid."
    },
    {
        "claim": "Python was released in 2005.",
        "evidence": "Python was first released by Guido van Rossum in 1991.",
        "correction": "Python was released in 1991."
    },
    {
        "claim": "Water boils at 50 degrees Celsius.",
        "evidence": "The boiling point of water is 100 degrees Celsius at standard atmospheric pressure.",
        "correction": "Water boils at 100 degrees Celsius."
    },
    {
        "claim": "Humans have 4 hearts.",
        "evidence": "The human heart has 4 chambers, but humans only have one heart.",
        "correction": "Humans have one heart with four chambers."
    }
]

def format_data(data):
    formatted = []
    for item in data:
        input_str = f"correct claim: {item['claim']} evidence: {item['evidence']}"
        formatted.append({
            "input_text": input_str,
            "target_text": item['correction']
        })
    return formatted

dataset = Dataset.from_list(format_data(data))
dataset = dataset.train_test_split(test_size=0.1)

## 2. Model & Tokenizer

In [None]:
model_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

def preprocess_function(examples):
    inputs = examples["input_text"]
    targets = examples["target_text"]
    model_inputs = tokenizer(inputs, max_length=512, truncation=True)
    with tokenizer.as_target_tokenizer():
        labels = tokenizer(targets, max_length=512, truncation=True)
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

tokenized_datasets = dataset.map(preprocess_function, batched=True)

## 3. Training

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="./correction_model_output",
    evaluation_strategy="epoch",
    learning_rate=2e-4,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=15,
    predict_with_generate=True,
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_datasets["train"],
    eval_dataset=tokenized_datasets["test"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
)

trainer.train()

## 4. Inference Test

In [None]:
def correct_claim(claim, evidence):
    input_text = f"correct claim: {claim} evidence: {evidence}"
    inputs = tokenizer(input_text, return_tensors="pt").input_ids.to(model.device)
    outputs = model.generate(inputs, max_length=128)
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

print(correct_claim("Sharks are mammals.", "Sharks are a group of elasmobranch fish characterized by a cartilaginous skeleton."))

In [None]:
save_path = "./saved_correction_model"
model.save_pretrained(save_path)
tokenizer.save_pretrained(save_path)
print(f"Model saved to {save_path}")