In [3]:
import json
import os
import glob
import networkx as nx
import matplotlib.pyplot as plt

# Create viz directory if it doesn't exist
viz_dir = 'gsm8k-mcts-qwen-32b/viz'
os.makedirs(viz_dir, exist_ok=True)

# Load all json files from the directory
json_files = glob.glob('gsm8k-mcts-qwen-32b/jsons/*.json')

for json_file in json_files:
    with open(json_file) as f:
        data = json.load(f)
    
    # Create directed graph
    G = nx.DiGraph()
    
    # Add nodes and their rewards
    for answer in data['to_explore']:
        # Get reward for this answer
        rewards = data['to_explore_reward'].get(answer, [0])
        avg_reward = sum(rewards) / len(rewards) if rewards else 0
        
        # Determine color based on correct/exclude arrays
        color = 'red'
        if answer in data['correct_answers']:
            color = 'lightgreen'
            
        # Add node with attributes
        G.add_node(answer[:50] + '...', # Truncate long answers
                   reward=f'{avg_reward:.1f}',
                   color=color)
    
    # Add edges based on fathers/children relationships
    for child, parent in data['fathers'].items():
        if parent is not None:
            G.add_edge(parent[:50] + '...', child[:50] + '...')
            
    # Draw the graph
    plt.figure(figsize=(12,8))
    pos = nx.spring_layout(G)
    
    # Draw nodes
    node_colors = [G.nodes[node]['color'] for node in G.nodes()]
    nx.draw_networkx_nodes(G, pos, node_color=node_colors, node_size=2000)
    
    # Draw edges
    nx.draw_networkx_edges(G, pos, edge_color='gray', arrows=True)
    
    # Add only reward labels
    labels = {node: G.nodes[node]['reward'] for node in G.nodes()}
    nx.draw_networkx_labels(G, pos, labels, font_size=8)
    
    plt.title(f'Answer Graph for {os.path.basename(json_file)}')
    plt.axis('off')
    
    # Save figure instead of showing it
    output_file = os.path.join(viz_dir, os.path.basename(json_file).replace('.json', '.png'))
    plt.savefig(output_file, bbox_inches='tight')
    plt.close()


In [9]:
# Initialize list to store all correct answers data
correct_answers_data = []

# Process all json files to analyze correct answers
for json_file in json_files:
    with open(json_file) as f:
        data = json.load(f)
        
    # Load correct answers and their data
    for answer in data['correct_answers']:
        rewards = data['to_explore_reward'].get(answer, [0])
        avg_reward = sum(rewards) / len(rewards) if rewards else 0
        # Get the first analysis from the reward_analysis list if it exists
        analysis = data['reward_analysis'].get(answer, ['No analysis available'])[0]
        
        correct_answers_data.append({
            'question': data['query'],
            'answer': answer,
            'reward': avg_reward,
            'analysis': analysis,
            'file': os.path.basename(json_file)
        })

# Sort all answers by reward ascending
correct_answers_data.sort(key=lambda x: -x['reward'])

# Display results for all files
print("\nCorrect Answers Analysis Across All Files:")
print("=" * 100)
for item in correct_answers_data:
    print(f"\nFile: {item['file']}")
    print(f"Reward Score: {item['reward']:.1f}")
    print("-" * 50)
    print("Question:")
    print(item['question'])
    print("-" * 50)
    print("Answer:")
    print(item['answer'])
    print("-" * 50) 
    print("Analysis:")
    print(item['analysis'])
    print("=" * 100)



Correct Answers Analysis Across All Files:

File: a831d905a7705d9cbf6693060db48679.json
Reward Score: 85.0
--------------------------------------------------
Question:
Candice put 80 post-it notes in her purse before she headed out to her job at the coffee shop.  On her way, she stopped off at the store and purchased a package of Post-it notes;  At work, she placed a single Post-it note on each of 220 different cups of coffee.  If she had 23 post-it notes remaining overall, how many Post-it notes were in the package that she purchased?
--------------------------------------------------
Answer:
[reasoning process] We start by defining the known quantities and how they interrelate in the problem. Candice initially had 80 post-it notes, she purchased some additional ones, used 220 post-it notes on cups of coffee, and ended up with 23 left over.

Let's denote:
- Initial post-it notes: \(80\)
- Additional post-it notes in the purchased package: \(x\)
- Post-it notes used on cups: \(220\)
-

In [37]:
# Initialize list to store all incorrect answers data
incorrect_answers_data = []
json_files = glob.glob('gsm8k-mcts-qwen-32b/jsons/*.json')

# Process all json files to analyze incorrect answers
for json_file in json_files:
    with open(json_file) as f:
        data = json.load(f)
        
    # Get all answers from answers_list and filter out excluded/correct ones
    correct_answers = set(data['correct_answers'])
    ground_truth = data['ground_truth']
    excluded_answers = set(data.get('exclude', []))
    all_answers = data['answers_list']
    incorrect_answers = [ans for ans in all_answers 
                        if ans not in correct_answers 
                        and ans not in excluded_answers]
    
    # Load incorrect answers and their data
    for answer in incorrect_answers:
        rewards = data['to_explore_reward'].get(answer, [0])
        avg_reward = sum(rewards) / len(rewards) if rewards else 0
        # Get the first analysis from the reward_analysis list if it exists
        analysis = data['reward_analysis'].get(answer, ['No analysis available'])
        
        incorrect_answers_data.append({
            'question': data['query'],
            'answer': answer,
            'ground_truth': ground_truth,
            'rewards': rewards,
            'avg_reward': avg_reward,
            'analysis': analysis,
            'file': os.path.basename(json_file)
        })

# Sort all answers by average reward
incorrect_answers_data.sort(key=lambda x: -x['avg_reward'])

# Display results for all files
print("\nIncorrect Answers Analysis Across All Files:")
print("=" * 100)
for item in incorrect_answers_data:
    print(f"\nFile: {item['file']}")
    print("Reward Scores:")
    for i, reward in enumerate(item['rewards'], 1):
        print(f"iteration {i}: {reward:.1f}")
    print("-" * 50)
    print("Question:")
    print(item['question'])
    print("-" * 50)
    print("Ground Truth:")
    print(item['ground_truth'])
    print("-" * 50)
    print("Answer:")
    print(item['answer'])
    print("-" * 50) 
    for i, analysis in enumerate(item['analysis'], 1):
        print(f"iteration {i}:\n {analysis}\n")
    print("=" * 100)



Incorrect Answers Analysis Across All Files:

File: fc3dc9eceeea5b2867ef12a018aa79e6.json
Reward Scores:
iteration 1: -3.0
iteration 2: -15.0
iteration 3: -20.0
--------------------------------------------------
Question:
Adrien's total salary was 30 percent higher than Lylah's. Four years later, his salary had increased, and he was earning 40% more than what he was making four years ago. If Adrien's and Lylah's salary increased simultaneously, and Adrien earned $40000 four years ago, calculate the total salary the two were receiving four years later?
--------------------------------------------------
Ground Truth:
Since Adrien was earning $40000 four years ago and received a raise that makes him earn 40% more, he received a 40/100*$40000 = $<<40/100*40000=16000>>16000 raise.
In total, four years later, Adrien's salary is $40000+$16000 = $56000
If four years ago Adrien was earning $40000, and Lylah's salary was 30% less, then Lylah's salary was 30/100*$40000= $12000 less than Adrien's

In [35]:
# Initialize list to store all answers data
all_answers_data = []
json_files = glob.glob('gsm8k-mcts-gpt-4o-mini-modified/jsons/*.json')
#json_files = glob.glob('gsm8k-mcts-qwen-32b/jsons/*.json')

# Process all json files to analyze answers
for json_file in json_files:
    with open(json_file) as f:
        data = json.load(f)
        
    # Get all answers from answers_list
    ground_truth = data['ground_truth']
    all_answers = data['answers_list']
    correct_answers = set(data['correct_answers'])
    excluded_answers = set(data.get('exclude', []))
    
    # Load all answers and their data
    for answer in all_answers:
        rewards = data['to_explore_reward'].get(answer, [0])
        avg_reward = sum(rewards) / len(rewards) if rewards else 0
        # Get the first analysis from the reward_analysis list if it exists
        
        is_correct = answer in correct_answers
        is_excluded = answer in excluded_answers
        
        all_answers_data.append({
            'question': data['query'],
            'answer': answer,
            'ground_truth': ground_truth,
            'rewards': rewards,
            'avg_reward': avg_reward,
            'is_correct': is_correct,
            'is_excluded': is_excluded,
            'file': os.path.basename(json_file)
        })

# Group by question and analyze answers
question_stats = {}
for item in all_answers_data:
    question = item['question']
    if question not in question_stats:
        question_stats[question] = {'total': 0, 'correct': 0, 'incorrect': 0}
    
    question_stats[question]['total'] += 1
    if item['is_correct']:
        question_stats[question]['correct'] += 1
    else:
        question_stats[question]['incorrect'] += 1

# Count questions with 3+ answers and show their correct/incorrect breakdown
questions_with_3plus_nodes = 0
for question, stats in question_stats.items():
    if stats['total'] >= 3:
        questions_with_3plus_nodes += 1

print(f"\nNumber of trees with 3 or more nodes: {questions_with_3plus_nodes}")
print("\nBreakdown of answers for trees with 3+ nodes:")
for question, stats in question_stats.items():
    if stats['total'] >= 3:
        print(f"Question has {stats['correct']} correct and {stats['incorrect']} incorrect answers")


Number of trees with 3 or more nodes: 9

Breakdown of answers for trees with 3+ nodes:
Question has 1 correct and 16 incorrect answers
Question has 0 correct and 18 incorrect answers
Question has 1 correct and 5 incorrect answers
Question has 0 correct and 18 incorrect answers
Question has 1 correct and 3 incorrect answers
Question has 1 correct and 3 incorrect answers
Question has 0 correct and 18 incorrect answers
Question has 1 correct and 5 incorrect answers
Question has 1 correct and 3 incorrect answers
