In [1]:
import spacy
import json
import pandas as pd
from collections import defaultdict


In [2]:

nlp = spacy.load("en_core_web_sm")


logical_connectors = {
    "because", "so", "therefore", "however", "thus", "but", "although",
    "though", "moreover", "meanwhile", "consequently", "nevertheless",
    "since", "as", "nonetheless", "then", "hence","while" #too small
}
symbol_tokens = {"Step", ":", ".", ",", "(", ")", "[", "]", "{", "}", "<", ">", "!", "?"}
symbol_tokens.update({str(i) for i in range(10)})


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
file_path = "/content/drive/MyDrive/Cluster-proj/output/deepseek7b-math-0-11_with_all_steps.json"

In [5]:

with open(file_path, "r") as f:
    data = json.load(f)

In [6]:

def analyze_token_clusters(step_data):
    sentence = " ".join([t["token"] for t in step_data])
    # print(sentence)
    doc = nlp(sentence)

    cluster_map = defaultdict(list)
    original_tokens = [t["token"] for t in step_data]
    original_probs = [t["prob"] for t in step_data]

    token_idx = 0
    for token in doc:
        if token_idx >= len(original_tokens):
            break
        if token.text != original_tokens[token_idx]:
            token_idx += 1
            continue
        prob = original_probs[token_idx]
        token_lower = token.text.lower()
        if token.text in symbol_tokens:
            cluster = "Symbol"
        elif token_lower in logical_connectors:
            cluster = "LogicalConnector"
        elif token.ent_type_:
            cluster = f"Entity:{token.ent_type_}"
        elif token.dep_:
            cluster = f"Syntactic:{token.dep_}"
        else:
            cluster = "Other"
        cluster_map[cluster].append({"token": token.text, "prob": prob})
        token_idx += 1

    cluster_confidence = {
        cluster: {
            "total_prob": sum(t["prob"] for t in tokens),
            "avg_prob": sum(t["prob"] for t in tokens) / len(tokens),
            "tokens": tokens,
            "top_token": max(tokens, key=lambda x: x["prob"])
        }
        for cluster, tokens in cluster_map.items()
    }
    return cluster_confidence



In [7]:
def analyze_all_steps_cluster(data, example_id, mode, target_step=None):
    all_step_results = []

    step_items = data[example_id][mode].items()
    if target_step is not None:
        # 只保留指定的 step
        step_items = [(target_step, data[example_id][mode][target_step])] if target_step in data[example_id][mode] else []

    for step_id, step_data in step_items:
        cluster_result = analyze_token_clusters(step_data)
        for cluster, info in cluster_result.items():
            all_step_results.append({
                "Step": step_id,
                "Cluster": cluster,
                "Total_Prob": info["total_prob"],
                "Avg_Prob": info["avg_prob"],
                "Top_Token": info["top_token"]["token"],
                "Top_Prob": info["top_token"]["prob"],
                "Token_Count": len(info["tokens"])
            })

    return pd.DataFrame(all_step_results)


In [10]:

df = analyze_all_steps_cluster(data, id, 'sampling0_step_token_probs')
df_sorted = df.sort_values(by=["Step", "Total_Prob"], ascending=[True, False])



In [11]:
df_sorted

Unnamed: 0,Step,Cluster,Total_Prob,Avg_Prob,Top_Token,Top_Prob,Token_Count
0,1,Symbol,5.06128,0.843547,:,1.0,6
2,1,Syntactic:det,2.474346,0.824782,the,1.0,3
4,1,Syntactic:prep,2.0,1.0,of,1.0,2
7,1,Entity:QUANTITY,2.0,1.0,pounds,1.0,2
13,1,Syntactic:dobj,2.0,1.0,beans,1.0,2
3,1,Syntactic:nsubj,1.0,1.0,weight,1.0,1
5,1,Syntactic:pobj,1.0,1.0,box,1.0,1
12,1,Syntactic:compound,1.0,1.0,jelly,1.0,1
14,1,Syntactic:aux,1.0,1.0,to,1.0,1
11,1,Syntactic:amod,0.970814,0.970814,enough,0.970814,1


In [12]:
df_sorted[df_sorted["Cluster"].str.startswith("Syntactic")]


Unnamed: 0,Step,Cluster,Total_Prob,Avg_Prob,Top_Token,Top_Prob,Token_Count
2,1,Syntactic:det,2.474346,0.824782,the,1.0,3
4,1,Syntactic:prep,2.0,1.0,of,1.0,2
13,1,Syntactic:dobj,2.0,1.0,beans,1.0,2
3,1,Syntactic:nsubj,1.0,1.0,weight,1.0,1
5,1,Syntactic:pobj,1.0,1.0,box,1.0,1
12,1,Syntactic:compound,1.0,1.0,jelly,1.0,1
14,1,Syntactic:aux,1.0,1.0,to,1.0,1
11,1,Syntactic:amod,0.970814,0.970814,enough,0.970814,1
6,1,Syntactic:ROOT,0.872099,0.872099,was,0.872099,1
10,1,Syntactic:advcl,0.638482,0.638482,added,0.638482,1
