# ðŸ§  Chain-of-Thought Gradients

This tutorial demonstrates how `af.pullback` provides feedback on each step of a multi-step reasoning pipeline.

## Setup (Colab only)

Uncomment and run the following cell if running in Google Colab:

In [None]:
# !pip install autoform
# import os
# os.environ["OPENAI_API_KEY"] = "your-key-here"

In [None]:
import autoform as af

MODEL = "openai/gpt-4o"  # or "ollama/llama3.2:3b" for local

## 1. The Problem

Complex reasoning often fails silently. When a chain-of-thought answer is wrong, which step caused it? Without visibility into intermediate steps, debugging is guesswork.

## 2. Define Output Structure

In [None]:
class Answer(af.Struct):
    reasoning: str
    answer: str

## 3. Multi-Step Reasoning with Checkpoints

We checkpoint each reasoning step for observability:

In [None]:
def chain_of_thought(question: str) -> Answer:
    """Multi-step reasoning with checkpoints at each step."""

    # Step 1: Break down the problem
    step1_prompt = af.format("Break down this question into sub-problems:\n{}", question)
    msgs1 = [{"role": "user", "content": step1_prompt}]
    step1 = af.lm_call(msgs1, model=MODEL)
    step1 = af.checkpoint(step1, key="breakdown", collection="reasoning")

    # Step 2: Solve each sub-problem
    step2_prompt = af.format("Given these sub-problems:\n{}\n\nSolve each one:", step1)
    msgs2 = [{"role": "user", "content": step2_prompt}]
    step2 = af.lm_call(msgs2, model=MODEL)
    step2 = af.checkpoint(step2, key="solutions", collection="reasoning")

    # Step 3: Synthesize final answer
    step3_prompt = af.format(
        "Sub-problems:\n{}\n\nSolutions:\n{}\n\nProvide a final answer with reasoning:",
        step1,
        step2,
    )
    msgs3 = [{"role": "user", "content": step3_prompt}]
    return af.struct_lm_call(msgs3, model=MODEL, struct=Answer)

## 4. Build the IR

In [None]:
dummy = "..."
ir = af.trace(chain_of_thought)(dummy)
print(ir)

## 5. Run with Collect

Capture intermediate reasoning steps:

In [None]:
result, captured = af.collect(ir, collection="reasoning")("What is 15% of 80?")

print("Answer:", result.answer)
print("\nCaptured steps:", list(captured.keys()))
print("\nBreakdown:", captured["breakdown"])

## 6. Pullback: Get Improvement Hints

Given feedback on the output, pullback suggests how to improve the **input**:

In [None]:
pb_ir = af.pullback(ir)

critique = Answer(
    reasoning="The breakdown was good but solutions were too verbose",
    answer="correct but explained too much",
)

output, gradient = af.call(pb_ir)(("What is 15% of 80?", critique))

print("Answer:", output.answer)
print("\nGradient (how to improve input):")
print(gradient)

## 7. Batched Gradients

Get improvement hints for multiple questions at once:

In [None]:
batched_pb = af.batch(
    pb_ir,
    in_axes=(True, Answer.model_construct(reasoning=True, answer=True)),
)

questions = [
    "What is 15% of 80?",
    "How many days in a leap year?",
    "What is the capital of Japan?",
]

critiques = Answer.model_construct(
    reasoning=["too verbose", "perfect", "needs more context"],
    answer=["correct", "correct", "correct"],
)

outputs, grads = af.call(batched_pb)((questions, critiques))

for i, (q, g) in enumerate(zip(questions, grads)):
    print(f"\nQ{i + 1}: {q}")
    print(f"Hint: {g}")

## Summary

1. **Checkpoint** each reasoning step for visibility
2. **Collect** captures intermediate values for debugging
3. **Pullback** provides improvement hints from output feedback
4. **Batch** processes multiple inputs with their critiques

Use this pattern for any multi-step pipeline where you need to understand _which step_ to improve.