# 02 - Generate Training Traces

This notebook generates training data by having Claude Sonnet play NYT Connections games.

Two types of traces:
1. **Golden traces**: One-shot puzzle solving with step-by-step reasoning
2. **Game traces**: Multi-turn game transcripts with feedback (CORRECT, WRONG, ONE AWAY)

In [None]:
import json
import re
from pathlib import Path
from anthropic import Anthropic

# Setup
ANTHROPIC_API_KEY = "sk-ant-..."  # <-- Your key here
client = Anthropic(api_key=ANTHROPIC_API_KEY)

DATA_DIR = Path("data")
OUTPUT_FILE = DATA_DIR / "game_traces.jsonl"
PROGRESS_FILE = DATA_DIR / "game_traces_progress.json"

SYSTEM_PROMPT = "You are an expert NYT Connections puzzle solver. Find groups of 4 words that share a hidden theme."
MODEL = "claude-sonnet-4-20250514"

In [None]:
# Load training puzzles
train_puzzles = []
with open(DATA_DIR / "train.jsonl", "r") as f:
    for line in f:
        if line.strip():
            train_puzzles.append(json.loads(line))

print(f"Loaded {len(train_puzzles)} training puzzles")

## Helper Functions

In [None]:
def extract_json_from_response(text):
    """Extract JSON from model response."""
    # Try markdown code blocks first
    matches = list(re.finditer(r'```json\s*(\{.*?\})\s*```', text, re.DOTALL))
    if matches:
        try:
            return json.loads(matches[-1].group(1))
        except:
            pass
    # Try raw JSON
    try:
        start = text.rfind('{')
        if start != -1:
            depth = 0
            for i, c in enumerate(text[start:]):
                if c == '{': depth += 1
                elif c == '}': depth -= 1
                if depth == 0:
                    return json.loads(text[start:start + i + 1])
    except:
        pass
    return None


def extract_guess_from_parsed(parsed):
    """Extract 4-word guess from parsed JSON."""
    if not parsed:
        return []
    for key in ["group", "words"]:
        if key in parsed and isinstance(parsed[key], list):
            return parsed[key]
    for value in parsed.values():
        if isinstance(value, list) and len(value) == 4:
            return value
    for value in parsed.values():
        if isinstance(value, list):
            return value
    return []

## Game Loop for Trace Generation

In [None]:
def play_game_with_sonnet(puzzle, max_mistakes=4, max_retries=16, temperature=0.3):
    """
    Play a full game with Sonnet, returning the transcript.
    Only returns winning games (for training data quality).
    """
    remaining_words = set(w.upper() for w in puzzle["words"])
    found_groups = {}
    mistakes = 0
    previous_guesses = []
    tried_combinations = set()
    total_attempts = 0
    
    solution_groups = {
        frozenset(w.upper() for w in members): name 
        for name, members in puzzle["solution"].items()
    }
    
    transcript = []
    
    while len(found_groups) < 4 and mistakes < max_mistakes and total_attempts < max_retries:
        total_attempts += 1
        
        # Auto-complete when 3 groups found
        if len(found_groups) == 3:
            final_set = frozenset(remaining_words)
            if final_set in solution_groups:
                found_groups[solution_groups[final_set]] = list(remaining_words)
                transcript.append({
                    "role": "user",
                    "content": f"CORRECT! Found group: {solution_groups[final_set]}\n\nAll 4 groups found. Puzzle solved!"
                })
                break
        
        # Build context
        context = ""
        if found_groups:
            found_str = ", ".join([f"{name}: {words}" for name, words in found_groups.items()])
            context += f"FOUND GROUPS: {found_str}\n\n"
        
        relevant_wrong = [
            g for g in previous_guesses 
            if set(g['words']).issubset(remaining_words) and 'DUPLICATE' not in g.get('feedback', '')
        ]
        
        if relevant_wrong:
            wrong_str = ", ".join([str(g['words']) for g in relevant_wrong])
            context += f"WRONG GUESSES: {wrong_str}\n\n"
        
        remaining_list = ", ".join(sorted(remaining_words))
        one_away_guesses = [g for g in relevant_wrong if "ONE AWAY" in g.get('feedback', '')]
        
        # Build prompt
        if one_away_guesses:
            last_one_away = one_away_guesses[-1]['words']
            already_tried = [g['words'] for g in relevant_wrong]
            other_words = sorted(remaining_words - set(last_one_away))
            
            prompt = f"""{context}REMAINING: {remaining_list}

Your guess {last_one_away} was ONE AWAY — exactly 3 words are correct, 1 is wrong.
Already tried: {already_tried}

Which word is the impostor? Consider:
- Which word has the weakest connection to the theme?
- Which word might belong to a different group?

Pick ONE word to swap out, replace it with one of: {other_words}
Output: {{"group": ["W1", "W2", "W3", "W4"]}}"""
        else:
            prompt = f"""{context}REMAINING: {remaining_list}

Find 4 words that share a hidden theme.

Think step by step:
1. What patterns do you see? (categories, wordplay, phrases, etc.)
2. Which 4-word group are you MOST confident about?
3. Verify: Are all 4 words in the REMAINING list above?

Output your most confident group:
{{"group": ["W1", "W2", "W3", "W4"]}}"""
        
        transcript.append({"role": "user", "content": prompt})
        
        # Call API
        try:
            response = client.messages.create(
                model=MODEL,
                max_tokens=1500,
                temperature=temperature,
                system=SYSTEM_PROMPT,
                messages=transcript
            )
            assistant_content = response.content[0].text
        except Exception as e:
            print(f"API error: {e}")
            transcript.pop()
            mistakes += 1
            continue
        
        transcript.append({"role": "assistant", "content": assistant_content})
        
        # Parse response
        parsed = extract_json_from_response(assistant_content)
        if not parsed:
            mistakes += 1
            transcript.append({"role": "user", "content": "Could not parse your response. Please output valid JSON."})
            continue
        
        guess = extract_guess_from_parsed(parsed)
        if not guess or len(guess) != 4:
            mistakes += 1
            transcript.append({"role": "user", "content": "Invalid guess - need exactly 4 words. Try again."})
            continue
        
        guess_list = [w.upper() for w in guess]
        guess_set = frozenset(guess_list)
        
        # Check duplicate
        if guess_set in tried_combinations:
            transcript.append({"role": "user", "content": f"DUPLICATE: You already tried {guess_list}. Try a different combination."})
            continue
        
        tried_combinations.add(guess_set)
        
        # Check valid words
        if not guess_set.issubset(remaining_words):
            invalid_words = guess_set - remaining_words
            mistakes += 1
            transcript.append({"role": "user", "content": f"INVALID: Words not in remaining list: {list(invalid_words)}"})
            previous_guesses.append({"words": guess_list, "feedback": "INVALID"})
            continue
        
        # Check if correct
        if guess_set in solution_groups:
            group_name = solution_groups[guess_set]
            found_groups[group_name] = guess_list
            remaining_words -= guess_set
            
            if len(found_groups) < 4 and len(remaining_words) > 0:
                transcript.append({
                    "role": "user",
                    "content": f"CORRECT! Found group: {group_name}: {guess_list}\n\n{4 - len(found_groups)} groups remaining."
                })
            else:
                transcript.append({
                    "role": "user",
                    "content": f"CORRECT! Found group: {group_name}\n\nAll 4 groups found. Puzzle solved!"
                })
        else:
            # Check ONE AWAY
            is_one_away = any(len(guess_set & sol) == 3 for sol in solution_groups.keys())
            mistakes += 1
            
            if is_one_away:
                transcript.append({"role": "user", "content": f"ONE AWAY! {guess_list} has exactly 3 correct words."})
                previous_guesses.append({"words": guess_list, "feedback": "ONE AWAY"})
            else:
                transcript.append({"role": "user", "content": f"WRONG: {guess_list} is not a valid group."})
                previous_guesses.append({"words": guess_list, "feedback": "WRONG"})
    
    # Only return winning games
    won = len(found_groups) == 4
    
    return {
        "won": won,
        "groups_found": len(found_groups),
        "mistakes": mistakes,
        "transcript": transcript if won else None
    }

## Generate Game Traces

In [None]:
# Load progress
def load_progress():
    try:
        with open(PROGRESS_FILE, "r") as f:
            return json.load(f)
    except FileNotFoundError:
        return {"completed_ids": [], "win_count": 0, "total_count": 0}

def save_progress(progress):
    with open(PROGRESS_FILE, "w") as f:
        json.dump(progress, f)

progress = load_progress()
completed_ids = set(progress["completed_ids"])

print(f"Progress: {progress['win_count']} wins / {progress['total_count']} total")
print(f"Remaining puzzles: {len(train_puzzles) - len(completed_ids)}")

In [None]:
# Generate traces
TARGET_WINS = 200  # Stop after this many winning games

for puzzle in train_puzzles:
    if puzzle["game_id"] in completed_ids:
        continue
    
    if progress["win_count"] >= TARGET_WINS:
        print(f"\nReached target of {TARGET_WINS} winning traces!")
        break
    
    result = play_game_with_sonnet(puzzle)
    
    # Save winning trace
    if result["won"]:
        trace = {
            "messages": [{"role": "system", "content": SYSTEM_PROMPT}] + result["transcript"],
            "metadata": {
                "puzzle_id": puzzle["game_id"],
                "mistakes": result["mistakes"],
                "source": "sonnet_game_trace"
            }
        }
        
        with open(OUTPUT_FILE, "a") as f:
            f.write(json.dumps(trace) + "\n")
        
        progress["win_count"] += 1
        status = "✅"
    else:
        status = "❌"
    
    # Update progress
    completed_ids.add(puzzle["game_id"])
    progress["completed_ids"] = list(completed_ids)
    progress["total_count"] += 1
    save_progress(progress)
    
    win_rate = progress["win_count"] / progress["total_count"] * 100
    print(f"{status} Puzzle {puzzle['game_id']}: {result['groups_found']}/4 groups, {result['mistakes']} mistakes "
          f"| Total: {progress['win_count']}/{progress['total_count']} ({win_rate:.1f}%)")

## Analyze Generated Traces

In [None]:
# Load and analyze traces
traces = []
with open(OUTPUT_FILE, "r") as f:
    for line in f:
        if line.strip():
            traces.append(json.loads(line))

print(f"Total traces: {len(traces)}")

# Group by mistakes
by_mistakes = {}
for t in traces:
    m = t["metadata"]["mistakes"]
    by_mistakes.setdefault(m, []).append(t)

print("\nDistribution by mistakes:")
for m in sorted(by_mistakes.keys()):
    print(f"  {m} mistakes: {len(by_mistakes[m])} traces")

In [None]:
# Count traces with ONE AWAY situations
one_away_count = 0
for t in traces:
    has_one_away = any("ONE AWAY" in msg["content"] for msg in t["messages"] if msg["role"] == "user")
    if has_one_away:
        one_away_count += 1

print(f"\nTraces with ONE AWAY situations: {one_away_count}/{len(traces)} ({one_away_count/len(traces)*100:.1f}%)")