
# DSPy + GSM8K "Under the Hood" Demo (Ollama, Real Dataset)

This notebook is a **larger scale, but still demo-friendly**, version of the earlier toy example.

What it does:

* Uses the real **GSM8K** dataset (via Hugging Face `datasets`).
* Uses a **smaller Ollama model** so you can actually see errors and improvements.
* Shows **raw baseline behavior** of the DSPy program on many GSM8K problems.
* Shows **manual "backtracking" style traces** where we keep or reject traces based on correctness.
* Uses a real DSPy teleprompter (`BootstrapFewShotWithRandomSearch`) to optimize on a GSM8K subset.
* Compares **before vs after** accuracy and prints **example traces** of successes and failures.



## 1. Install dependencies

You need:

* `dspy-ai`
* `datasets` (Hugging Face)

Uncomment and run the cell below if you have not installed them.


In [None]:

# !pip install -qU dspy-ai datasets


In [None]:

import random
from pprint import pprint

import dspy
from datasets import load_dataset



## 2. Configure DSPy with a smaller Ollama model

To make improvements visible and keep things reasonably fast, we will use a **smaller model** than `qwen3:30b`.

You could try for example:

* `phi3:mini`
* or `qwen2.5:3b-instruct`

This cell assumes you have already done, in a terminal:

```bash
ollama pull phi3:mini
```

and that Ollama is running on `localhost:11434`.


In [None]:

# Configure DSPy to use Ollama with a smaller model (e.g., phi3:mini)
# You can swap 'phi3:mini' to any other local Ollama model tag you like.
ollama_model = dspy.LM(
    model='ollama/phi3:mini',  # change to 'ollama/qwen2.5:3b-instruct' etc. if you prefer
    api_base='http://localhost:11434',
    api_key=''  # Ollama does not require an API key
)

# Slightly non-zero temperature so we can see some variation / errors.
dspy.configure(lm=ollama_model, temperature=0.6)



## 3. Load GSM8K from Hugging Face

We use the official `openai/gsm8k` dataset. This cell:

* downloads the dataset (first run only)
* creates small **train** and **dev** subsets for this demo


In [None]:

gsm8k = load_dataset("openai/gsm8k", "main")

print(gsm8k)
print("Train size:", len(gsm8k["train"]))
print("Test size: ", len(gsm8k["test"]))


In [None]:

# Build DSPy-style examples

def make_example(row):
    # GSM8K has 'question' and 'answer' where 'answer' includes a rationale and '#### final_number'
    return dspy.Example(
        question=row["question"],
        answer=row["answer"],
    ).with_inputs("question")

# For the demo, we use a subset so that you can run live without waiting forever.
# You can scale these up later.
random.seed(0)

full_train = gsm8k["train"]
full_test = gsm8k["test"]

# Choose a subset for training / compilation
train_subset_size = 200
dev_subset_size = 100

train_indices = random.sample(range(len(full_train)), train_subset_size)
dev_indices = random.sample(range(len(full_test)), dev_subset_size)

train_examples = [make_example(full_train[i]) for i in train_indices]
dev_examples = [make_example(full_test[i]) for i in dev_indices]

len(train_examples), len(dev_examples)



## 4. Define a GSM8K-style metric

GSM8K answers are natural language rationales that **end with** a line like:

```text
#### 42
```

We will extract the last integer and compare it to the gold final integer.


In [None]:

import re

def extract_final_int_from_gsm8k_answer(text: str):
    """Extract the final integer after '####' or the last integer in the string."""
    if text is None:
        return None
    # Look for '#### number'
    m = re.search(r"####\s*(-?\d+)", text)
    if m:
        return int(m.group(1))
    # Fallback: last integer anywhere
    ints = re.findall(r"-?\d+", text)
    return int(ints[-1]) if ints else None

def gsm8k_metric(example, prediction, trace=None):
    gold = extract_final_int_from_gsm8k_answer(example.answer)
    pred = extract_final_int_from_gsm8k_answer(getattr(prediction, "answer", ""))
    return int(gold is not None and pred is not None and gold == pred)



## 5. Define a GSM8K DSPy program

We use a simple **Chain-of-Thought** module that maps:

```text
question -> answer
```

This keeps the program easy to reason about while still allowing multi step reasoning in the LM.


In [None]:

class GSM8KCoT(dspy.Module):
    def __init__(self):
        super().__init__()
        self.cot = dspy.ChainOfThought("question -> answer")

    def forward(self, question: str):
        return self.cot(question=question)

base_program = GSM8KCoT()
base_program



## 6. Baseline evaluation on a GSM8K subset

This helper runs the program on a dataset and computes average accuracy according to our metric.
It also prints a few example predictions.


In [None]:

def evaluate_program(program, dataset, metric_fn, max_print=5, label=""):
    scores = []
    for i, ex in enumerate(dataset):
        with dspy.context(trace=[]):
            pred = program(question=ex.question)
            trace = dspy.settings.trace.copy()

        score = metric_fn(ex, pred, trace)
        scores.append(score)

        if i < max_print:
            print(f"Example {i} {label}:")
            print("Q:", ex.question)
            print("Predicted answer raw:")
            print(getattr(pred, "answer", ""))
            print("Gold answer raw:")
            print(ex.answer)
            print("Metric score:", score)
            print("-" * 80)

    avg = sum(scores) / max(len(scores), 1)
    print(f"Average metric on {len(dataset)} examples{label}: {avg:.3f}")
    return avg


In [None]:

print("Baseline on train subset:")
base_train_acc = evaluate_program(base_program, train_examples, gsm8k_metric, max_print=3, label="(train)")

print("\nBaseline on dev subset:")
base_dev_acc = evaluate_program(base_program, dev_examples, gsm8k_metric, max_print=5, label="(dev)")



## 7. Manual "backtracking" style tracing on real GSM8K examples

Here we:

* pick a few random GSM8K training problems
* for each, run the program multiple times (with sampling)
* record the trace for each attempt
* accept a trace only if the final answer is correct

This imitates a simplified version of what DSPy does inside `BootstrapFewShot`.


In [None]:

def run_with_trace_multiple_attempts(program, example, metric_fn, max_attempts=3):
    accepted_trace = None
    print("=" * 100)
    print("Question:")
    print(example.question)
    print("Gold answer (raw):")
    print(example.answer)
    print()

    for attempt in range(1, max_attempts + 1):
        with dspy.context(trace=[]):
            pred = program(question=example.question)
            trace = dspy.settings.trace.copy()

        score = metric_fn(example, pred, trace)

        print(f"Attempt {attempt}:")
        print("Predicted answer (raw):")
        print(getattr(pred, "answer", ""))
        print("Metric score:", score)
        print("Trace:")
        for j, (mod, inputs, outputs) in enumerate(trace):
            print(f"  Step {j}: module={type(mod).__name__}")
            print(f"    inputs:  {inputs}")
            print(f"    outputs keys: {list(outputs.keys())}")
        print()

        if score == 1:
            print("  -> Accepted trace (correct answer).")
            accepted_trace = trace
            break
        else:
            print("  -> Rejected trace (wrong answer).")
            print()

    return accepted_trace

# Pick a few random training examples and run the backtracking demo
demo_indices = random.sample(range(len(train_examples)), 3)
for idx in demo_indices:
    _ = run_with_trace_multiple_attempts(base_program, train_examples[idx], gsm8k_metric, max_attempts=3)



## 8. Real DSPy optimization with `BootstrapFewShotWithRandomSearch`

Now we switch from the manual loop to a real DSPy teleprompter.

We will:

* use `BootstrapFewShotWithRandomSearch` on the GSM8K train subset
* use the GSM8K dev subset as a validation set
* get a **compiled** version of our program
* compare before vs after accuracy on the dev subset


In [None]:

from dspy.teleprompt import BootstrapFewShotWithRandomSearch

optimizer = BootstrapFewShotWithRandomSearch(
    metric=gsm8k_metric,
    max_bootstrapped_demos=4,      # at most 4 demos per module
    num_candidate_programs=4,      # how many random demo combinations to try
    num_threads=1,                 # keep things deterministic and simple
)

compiled_program = optimizer.compile(
    student=GSM8KCoT(),
    trainset=train_examples,
    valset=dev_examples,
)

compiled_program



## 9. Before vs after: dev subset accuracy


In [None]:

print("Dev subset BEFORE optimization:")
base_dev_acc = evaluate_program(base_program, dev_examples, gsm8k_metric, max_print=3, label="(dev, base)")

print("\nDev subset AFTER optimization:")
compiled_dev_acc = evaluate_program(compiled_program, dev_examples, gsm8k_metric, max_print=5, label="(dev, compiled)")

print("\nImprovement on dev subset:", compiled_dev_acc - base_dev_acc)



## 10. Inspect internal demonstrations in the compiled program

Each `Predict`-like module inside the compiled program now has **demonstrations** that were
selected by the optimizer using GSM8K examples and our metric.


In [None]:

for name, predictor in compiled_program.named_predictors():
    demos = getattr(predictor, "demonstrations", [])
    print(f"Predictor: {name}")
    print(f"  Type: {type(predictor).__name__}")
    print(f"  Number of demos: {len(demos)}")
    if demos:
        first = demos[0]
        print("  First demo inputs keys:", list(first.inputs.keys()))
        print("  First demo outputs keys:", list(first.outputs.keys()))
    print("-" * 80)



## 11. Traces from the compiled program

Now we look at **compiled** traces on some dev examples:

* same `with dspy.context(trace=[])` mechanism
* but now the internal modules are using the learned demonstrations


In [None]:

for i in range(3):
    ex = dev_examples[i]
    print("=" * 100)
    print(f"Dev example {i}:")
    print("Question:")
    print(ex.question)
    print()

    with dspy.context(trace=[]):
        pred = compiled_program(question=ex.question)
        trace = dspy.settings.trace.copy()

    print("Predicted answer (raw):")
    print(getattr(pred, "answer", ""))
    print("Gold answer (raw):")
    print(ex.answer)
    print("Metric score:", gsm8k_metric(ex, pred, trace))
    print()

    print("Trace:")
    for j, (mod, inputs, outputs) in enumerate(trace):
        print(f"  Step {j}: module={type(mod).__name__}")
        print(f"    inputs keys:  {list(inputs.keys())}")
        print(f"    outputs keys: {list(outputs.keys())}")
    print()



## 12. Optional: two step GSM8K program to show multi module traces

This variation:

1. rewrites the question into a simpler question
2. answers the simpler question

It is mainly to highlight how traces look when more than one `Predict` module is involved.


In [None]:

class GSM8KRewriteAndSolve(dspy.Module):
    def __init__(self):
        super().__init__()
        self.rewrite = dspy.ChainOfThought("question -> simpler_question")
        self.solve = dspy.ChainOfThought("simpler_question -> answer")

    def forward(self, question: str):
        rewritten = self.rewrite(question=question)
        simpler_question = rewritten.simpler_question
        solved = self.solve(simpler_question=simpler_question)
        return dspy.Prediction(
            simpler_question=simpler_question,
            answer=solved.answer,
        )

two_step_program = GSM8KRewriteAndSolve()

# Show one trace from this two step program on a GSM8K dev example
ex = dev_examples[0]

with dspy.context(trace=[]):
    pred = two_step_program(question=ex.question)
    trace = dspy.settings.trace.copy()

print("Question:")
print(ex.question)
print()
print("Simpler question (model output):")
print(getattr(pred, "simpler_question", ""))
print()
print("Answer (raw):")
print(getattr(pred, "answer", ""))
print()

print("Two step trace:")
for j, (mod, inputs, outputs) in enumerate(trace):
    print(f"  Step {j}: module={type(mod).__name__}")
    print(f"    inputs keys:  {list(inputs.keys())}")
    print(f"    outputs keys: {list(outputs.keys())}")



## 13. Wrap up

This notebook gives you a **real GSM8K** demo of DSPy on top of a **small Ollama model**, with:

* baseline performance on genuine GSM8K questions
* concrete traces for both correct and incorrect answers
* a manual "backtracking" style loop that accepts or rejects traces
* a real `BootstrapFewShotWithRandomSearch` run over a GSM8K subset
* before vs after accuracy comparison
* introspection of the compiled program's internal demonstrations

You can scale this up by:

* increasing `train_subset_size` and `dev_subset_size`
* increasing `max_bootstrapped_demos` or `num_candidate_programs`
* swapping in a different Ollama model

For a live talk, running exactly this notebook and scrolling through the printed traces works nicely
as an "under the hood" tour of DSPy on a realistic benchmark.
