In [None]:
!pip install transformers accelerate --quite
!pip install networkx --quite
import networkx as nx 
import matplotlib.pyplot as pyplot

!apt install graphviz graphviz-dev -y 

!pip install pygraphviz

In [None]:
from transformers import AutoModelForCausalLM, AutoTokenizer,AutoModelForSequenceClassification
import torch
import torch.nn.functional as from

reasoning_model_id = "HuggingFaceH4/zephyr-7b-beta"

from transformers import AutoModelForCausalLM,AutoTokenizer

reasoning_tokenizer = AutoTokenizer.from_pretrained(reasoning_model_id,use_fast=True)
reasoning_model = AutoModelForCausalLM.from_pretrained(
    reasoning_model_id,
    torch_dtype = torch.float16,
    device_map = "auto"
)

reasoning_model.eval()

In [None]:
# Load reward model (PRM approximation)
reward_model_id = "OpenAssistant/reward-model-deberta-v3-large"
reward_tokenizer = AutoTokenizer.from_pretrained(reward_model_id)
reward_model = AutoModelForSequenceClassification.from_pretrained(reward_model_id).to("cuda")
reward_model.eval()

In [None]:
# Simulate stepwise reward using a final answer reward model

def stepwise_prm_score(prompt , trace, reward_model,tokenizer,device = "cuda"):
    steps = trace.split(". ")
    cumulative_score = 0.0
    for i in range(1,len(steps)+1):
        partial = prompt + "\n" + ". ".join(steps[:i])
        inputs = tokenizer(partial,return_tensors="pt",truncation=True).to(device)
        with torch.no_grad():
            score = reward_model(**inputs).logits[0].item()
        cumulative_score += score
    return cumulative_score / len(steps) if steps else 0.0

In [None]:
def beam_search_with_prm(prompt, reasoning_model, reasoning_tokenizer,
                         reward_model, reward_tokenizer,
                         N=4, M=2, max_steps=3):
    assert N % M == 0, "N must be divisible by M"
    device = reasoning_model.device

    def format_zephyr_prompt(user_prompt: str) -> str:
        return f"<|system|>\nYou are a helpful assistant.\n<|user|>\n{user_prompt}\n<|assistant|>\n"

    formatted_prompt = format_zephyr_prompt(prompt)

    graph = nx.DiGraph()
    node_counter = 0
    beams = []

    # Step 0: Initial N completions
    input_ids = reasoning_tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(device)
    outputs = reasoning_model.generate(
        input_ids=input_ids,
        max_new_tokens=64,
        do_sample=False,
        num_beams=N,
        temperature=0.9,
        num_return_sequences=N,
        pad_token_id=reasoning_tokenizer.eos_token_id
    )

    for i in range(N):
        gen_text = reasoning_tokenizer.decode(outputs[i], skip_special_tokens=True)
        completion = gen_text.replace(formatted_prompt, "").strip()
        score = stepwise_prm_score(prompt, completion, reward_model, reward_tokenizer)
        node_id = f"0-{i}"
        graph.add_node(node_id, label=completion[:40]+"...", score=score)
        beams.append((gen_text, score, node_id))

    # Steps 1 to max_steps
    for step in range(1, max_steps + 1):
        beams = sorted(beams, key=lambda x: x[1], reverse=True)[:M]
        candidates = []

        for parent_text, _, parent_id in beams:
            input_ids = reasoning_tokenizer(parent_text, return_tensors="pt").input_ids.to(device)
            children = reasoning_model.generate(
                input_ids=input_ids,
                max_new_tokens=64,
                do_sample=False,
                num_beams=(N // M),
                num_return_sequences=(N // M),
                pad_token_id=reasoning_tokenizer.eos_token_id
            )

            for i in range(N // M):
                child_text = reasoning_tokenizer.decode(children[i], skip_special_tokens=True)
                continuation = child_text.replace(parent_text, "").strip()
                score = stepwise_prm_score(prompt, continuation, reward_model, reward_tokenizer)
                node_id = f"{step}-{i}-{parent_id}"
                graph.add_node(node_id, label=continuation[:40]+"...", score=score)
                graph.add_edge(parent_id, node_id)
                candidates.append((child_text, score, node_id))

        beams = sorted(candidates, key=lambda x: x[1], reverse=True)[:N]


    return beams, graph

In [None]:
def plot_trace_graph_tree_clean(graph, figsize=(14, 8), title="Beam Search Tree (PRM-Guided)"):
    import matplotlib.pyplot as plt
    import matplotlib.patches as mpatches
    import networkx as nx
    import textwrap

    try:
        pos = nx.nx_agraph.graphviz_layout(graph, prog='dot')
    except:
        pos = nx.spring_layout(graph, seed=42)

    scores = nx.get_node_attributes(graph, 'score')
    labels = nx.get_node_attributes(graph, 'label')
    node_colors = [scores[n] for n in graph.nodes()]
    node_order = list(graph.nodes())

    fig, ax = plt.subplots(figsize=figsize)

    # Draw edges
    nx.draw_networkx_edges(graph, pos, ax=ax, alpha=0.3)

    # Draw nodes with color based on PRM score (no edge color)
    for node in node_order:
        x, y = pos[node]
        score = scores.get(node, 0)
        fill_color = plt.cm.viridis((score - min(node_colors)) / (max(node_colors) - min(node_colors)))

        ax.scatter(x, y, s=800, c=[fill_color], edgecolors='black', linewidths=1, zorder=5)

        # Shortened label below node
        text = textwrap.shorten(labels.get(node, ""), width=50, placeholder="...")
        ax.text(x, y - 30, text, ha='center', va='top', fontsize=8,
                bbox=dict(boxstyle="round,pad=0.2", fc="white", ec="black", lw=0.4), zorder=10)

    # Colorbar
    sm = plt.cm.ScalarMappable(cmap=plt.cm.viridis,
                               norm=plt.Normalize(vmin=min(node_colors), vmax=max(node_colors)))
    sm.set_array(node_colors)
    fig.colorbar(sm, ax=ax, label="PRM Score")

    plt.title(title)
    plt.axis('off')
    plt.tight_layout()
    plt.show()


In [None]:
prompt = "Roger has 5 tennis balls. He buys 2 cans of 3 tennis balls each. How many tennis balls does he have now?"
beams, graph = beam_search_with_prm(prompt, reasoning_model, reasoning_tokenizer, reward_model, reward_tokenizer, N=4, M=2, max_steps=3)

for i, (text, score, _) in enumerate(beams):
    print(f"--- Final Beam {i+1} ---\nScore: {score:.2f}\n{text}\n")