In [None]:
import os
import sys
import logging
from dotenv import load_dotenv

import dspy
from dspy.datasets.gsm8k import GSM8K, gsm8k_metric
from dspy.evaluate import Evaluate
from dspy.teleprompt import MIPROv2
import numpy as np
import matplotlib.pyplot as plt

sys.path.append(os.path.dirname(os.getcwd()))
from mcts_llm.mctsr import MCTSr

load_dotenv()

logging.basicConfig(level=logging.WARNING)
logging.getLogger("mcts-llm").setLevel(logging.INFO)

np.random.seed(42)

In [None]:
system_prompt = "The user will provide a problem. Solve the problem. Think step by step."
ollama = dspy.OllamaLocal(
    model="qwen2.5:7b-instruct", 
    model_type="chat",
    temperature=1.0,
    max_tokens=1024,
    num_ctx=1024,
    system=system_prompt,
    timeout_s=600
)
openai = dspy.OpenAI(
    model="deepseek-chat", 
    model_type="chat",
    api_key=os.environ["DEEPSEEK_API_KEY"], 
    base_url=os.environ["DEEPSEEK_BASE_URL"], 
    temperature=1.0,
    max_tokens=4096
)
dspy.settings.configure(lm=ollama, experimental=True)

In [None]:
gsm8k = GSM8K()

In [None]:
gsm8k_trainset = [
    dspy.Example(
        problem=example['question'], 
        gold_reasoning=example['gold_reasoning'],
        answer=example['answer']
    ).with_inputs("problem") for example in gsm8k.train
]
np.random.shuffle(gsm8k_trainset)
gsm8k_trainset[:10]

In [None]:
gsm8k_testset = [
    dspy.Example(
        problem=example['question'], 
        gold_reasoning=example['gold_reasoning'],
        answer=example['answer']
    ).with_inputs("problem") for example in gsm8k.test
]
np.random.shuffle(gsm8k_testset)
gsm8k_testset[:10]

In [None]:
evaluate = Evaluate(
    devset=gsm8k_testset[:20], 
    metric=gsm8k_metric, 
    num_threads=os.cpu_count(), 
    display_progress=True,
    display_table=20,
)

In [None]:
optimizer = MIPROv2(
    prompt_model=openai,
    task_model=ollama,
    metric=gsm8k_metric,
    init_temperature=0.5,
    num_candidates=7,
    num_threads=os.cpu_count(),
    verbose=True
)
miprov2_mctsr = optimizer.compile(
    MCTSr(), 
    trainset=gsm8k_trainset[:50],
    requires_permission_to_run=False,
    num_trials=15,
    max_labeled_demos=0, 
    max_bootstrapped_demos=0
)
miprov2_mctsr.save("miprov2_mctsr.json")

In [None]:
evaluate(miprov2_mctsr)

In [None]:
trial_logs = miprov2_mctsr.trial_logs
trial_numbers = list(trial_logs.keys())
scores = [trial_logs[trial]['score'] for trial in trial_numbers]
pruning_status = [trial_logs[trial]['pruned'] for trial in trial_numbers]

plt.figure(figsize=(5, 3))
for trial_number, score, pruned in zip(trial_numbers, scores, pruning_status):
    if pruned:
        plt.scatter(trial_number, score, color='grey', label='Pruned Batch' if 'Pruned Batch' not in plt.gca().get_legend_handles_labels()[1] else "")
    else:
        plt.scatter(trial_number, score, color='green', label='Successful Batch' if 'Successful Batch' not in plt.gca().get_legend_handles_labels()[1] else "")

plt.xlabel('Batch Number')
plt.ylabel('Score')
plt.title('Batch Scores')
plt.grid(True)
plt.legend()
plt.show()

In [None]:
best_score = 0

def get_signature(predictor):
    if (hasattr(predictor, 'extended_signature')):
        return predictor.extended_signature
    elif (hasattr(predictor, 'signature')):
        return predictor.signature

print(f"Baseline program | Score: {best_score}:")
for i,predictor in enumerate(MCTSr().predictors()):
    print(f"Prompt {i+1} Instruction: {get_signature(predictor).instructions}")
print()

print("----------------")

for trial_num in miprov2_mctsr.trial_logs:
    program_score = miprov2_mctsr.trial_logs[trial_num]["score"]
    program_pruned = miprov2_mctsr.trial_logs[trial_num]["pruned"]
    best_score = program_score
    best_program_so_far = miprov2_mctsr.trial_logs[trial_num]["program"]
    print(f"Best program after {trial_num} batches | Score: {best_score}:")
    for i,predictor in enumerate(best_program_so_far.predictors()):
        print(f"Prompt {i+1} Instruction: {get_signature(predictor).instructions}")
    print()