<a href="https://colab.research.google.com/github/aruntakhur/LLMs/blob/main/Fine_Tune_CoT_T5_ChilleD_SVAMP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 🧠 Fine-Tune T5 with Chain-of-Thought (CoT) Reasoning
This Colab notebook fine-tunes `flan-t5-small` on Chain-of-Thought reasoning using a subset of the **GSM8K** dataset from Hugging Face Datasets.

In [None]:

# ✅ Install required libraries
!pip install transformers datasets peft accelerate --quiet


In [None]:

# ✅ Import libraries
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Trainer, TrainingArguments
from datasets import load_dataset
from peft import LoraConfig, get_peft_model


In [None]:

# ✅ Load tokenizer and model (Flan-T5)
model_name = "google/flan-t5-small"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)


In [None]:

# ✅ Apply LoRA configuration
peft_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=["q", "v"],
    lora_dropout=0.05,
    bias="none",
    task_type="SEQ_2_SEQ_LM"
)
model = get_peft_model(model, peft_config)


In [None]:
!pip install -U datasets fsspec

Collecting fsspec
  Using cached fsspec-2025.5.1-py3-none-any.whl.metadata (11 kB)


In [14]:
from datasets import load_dataset

# Load directly from Hugging Face repo
# dataset = load_dataset("gsm8k", "main", trust_remote_code=True)
# train_ds = dataset["train"].select(range(100))  # small subset for demo

# dataset = load_dataset("svamp")
# train_ds = dataset["train"].select(range(100))
# ✅ Load the SVAMP dataset for CoT training
dataset = load_dataset("ChilleD/SVAMP")
train_ds = dataset["train"].select(range(200))  # small subset for demo



README.md:   0%|          | 0.00/675 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/111k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/54.8k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/700 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/300 [00:00<?, ? examples/s]

In [17]:
print(dataset["train"][0])  # Show the first sample to inspect keys

{'ID': 'chal-777', 'Body': "There are 87 oranges and 290 bananas in Philip's collection. If the bananas are organized into 2 groups and oranges are organized into 93 groups", 'Question': 'How big is each group of bananas?', 'Equation': '( 290.0 / 2.0 )', 'Answer': '145', 'Type': 'Common-Division', 'question_concat': "There are 87 oranges and 290 bananas in Philip's collection. If the bananas are organized into 2 groups and oranges are organized into 93 groups How big is each group of bananas?"}


In [18]:
def format_example(ex):
    question = ex["question_concat"].strip()
    equation = ex.get("Equation", "").strip()
    answer = ex.get("Answer", "").strip()

    # Try to extract numbers and operators from the equation to build a rationale
    try:
        numbers = [float(tok) for tok in equation.replace("(", "").replace(")", "").split() if tok.replace('.', '', 1).isdigit()]
        if "/" in equation:
            rationale = f"There are {int(numbers[0])} items divided into {int(numbers[1])} groups. So {int(numbers[0])} ÷ {int(numbers[1])} = {answer}."
        elif "*" in equation:
            rationale = f"{int(numbers[0])} items each repeated {int(numbers[1])} times gives {int(numbers[0])} × {int(numbers[1])} = {answer}."
        elif "+" in equation:
            rationale = f"Adding the numbers: {int(numbers[0])} + {int(numbers[1])} = {answer}."
        elif "-" in equation:
            rationale = f"Subtracting the numbers: {int(numbers[0])} - {int(numbers[1])} = {answer}."
        else:
            rationale = f"We solve the equation {equation} to get the answer {answer}."
    except:
        rationale = f"We solve the equation {equation} to get the answer {answer}."

    return {
        "input": f"Q: {question}\nA: Let's think step by step.",
        "target": f"{rationale} Therefore, the answer is {answer}."
    }


In [19]:
train_ds = dataset["train"].map(format_example)

Map:   0%|          | 0/700 [00:00<?, ? examples/s]

In [20]:

# ✅ Tokenize the dataset
def tokenize(batch):
    input_encodings = tokenizer(batch["input"], truncation=True, padding="max_length", max_length=256)
    target_encodings = tokenizer(batch["target"], truncation=True, padding="max_length", max_length=256)
    input_encodings["labels"] = target_encodings["input_ids"]
    return input_encodings

train_ds = train_ds.map(tokenize, batched=True)


Map:   0%|          | 0/700 [00:00<?, ? examples/s]

In [21]:

# ✅ Training configuration
training_args = TrainingArguments(
    output_dir="./cot-t5-gsm8k",
    per_device_train_batch_size=4,
    num_train_epochs=1,
    logging_steps=5,
    save_steps=20,
    save_total_limit=2,
    fp16=True,
    report_to="none"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    tokenizer=tokenizer
)

trainer.train()


  trainer = Trainer(
No label_names provided for model class `PeftModelForSeq2SeqLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.
Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Step,Training Loss
5,51.3023
10,51.0256
15,50.7498
20,50.4219
25,49.1372
30,49.5471
35,49.2718
40,48.5114
45,49.2714
50,47.7721


TrainOutput(global_step=175, training_loss=45.869085693359374, metrics={'train_runtime': 1215.465, 'train_samples_per_second': 0.576, 'train_steps_per_second': 0.144, 'total_flos': 65431614259200.0, 'train_loss': 45.869085693359374, 'epoch': 1.0})

In [30]:

# ✅ Inference (test on new question)
input_text = "Q: If you have 10 candies and eat 4, how many are left?\nA: Let's think step by step."
inputs = tokenizer(input_text, return_tensors="pt").to(model.device)

output = model.generate(**inputs, max_length=100)
print(tokenizer.decode(output[0], skip_special_tokens=True))


10 candies and eat 4 are left. So, the total number of candies left is 10 + 4 = 20. The answer: 20.
