In [None]:
import os
import sys
import logging
from dotenv import load_dotenv

import dspy
from dspy.datasets.gsm8k import GSM8K, gsm8k_metric
from dspy.evaluate import Evaluate
import numpy as np

sys.path.append(os.path.dirname(os.getcwd()))
from mcts_llm.mctsr import MCTSr, ZeroShotCoT, MultipleTurnSelfRefine, Policy

load_dotenv()

logging.basicConfig(level=logging.WARNING)
logging.getLogger("mcts-llm").setLevel(logging.INFO)

np.random.seed(42)

In [None]:
ollama = dspy.OllamaLocal(
    model="qwen2.5:7b-instruct", 
    model_type="chat",
    temperature=1.0,
    max_tokens=1024,
    num_ctx=1024,
    timeout_s=600,
    cache=False
)
dspy.settings.configure(lm=ollama, experimental=True)

In [None]:
gsm8k = GSM8K()

In [None]:
gsm8k_trainset = [
    dspy.Example(
        problem=example['question'], 
        gold_reasoning=example['gold_reasoning'],
        answer=example['answer']
    ).with_inputs("problem") for example in gsm8k.train
]
np.random.shuffle(gsm8k_trainset)
gsm8k_trainset[:10]

In [None]:
gsm8k_testset = [
    dspy.Example(
        problem=example['question'], 
        gold_reasoning=example['gold_reasoning'],
        answer=example['answer']
    ).with_inputs("problem") for example in gsm8k.test
]
np.random.shuffle(gsm8k_testset)
gsm8k_testset[:10]

In [None]:
evaluate = Evaluate(
    devset=gsm8k_testset[:20], 
    metric=gsm8k_metric, 
    num_threads=os.cpu_count(), 
    display_progress=True,
    display_table=10,
)

In [None]:
evaluate(ZeroShotCoT())

In [None]:
evaluate(MultipleTurnSelfRefine(num_turns=1))

In [None]:
evaluate(MCTSr())

In [None]:
evaluate(MCTSr(policy=Policy.IMPORTANCE_SAMPLING))

In [None]:
evaluate(MCTSr(max_rollouts=8))

In [None]:
evaluate(MCTSr(max_rollouts=8, policy=Policy.IMPORTANCE_SAMPLING))

In [None]:
evaluate(MCTSr(max_rollouts=16))

In [None]:
evaluate(MCTSr(max_rollouts=16, policy=Policy.IMPORTANCE_SAMPLING))