In [None]:
import json
import re
from pathlib import Path
from typing import List, Dict, Any
from vllm import LLM, SamplingParams

# Configuration
BASE_PATH = Path("/specific/scratches/parallel/evyataroren-2025-12-31/inter_2025")
MODEL_PATH = BASE_PATH / "models" / "DS-qwen-7B" / "DeepSeek-R1-Distill-Qwen-7B"
DATASET_PATH = BASE_PATH / "datasets" / "datasets" / "normalized_datasets.json"
OUTPUT_DIR = BASE_PATH / "prompts" / "play" / "injections" / "irrelevant"
PROMPT_TEMPLATE = BASE_PATH / "prompts" / "play" / "question_formats" / "chain_of_thought.md"

# Create output directory if it doesn't exist
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Number of questions to process
N_QUESTIONS = 200

# Sampling parameters
SAMPLING_PARAMS = SamplingParams(
    temperature=0.7,
    top_p=0.9,
    max_tokens=2048,
    stop=["<|endoftext|>", "\n\n\n"],
    skip_special_tokens=False  # Important: don't skip special tokens so we can extract thinking
)

print("Configuration loaded")

: 

In [None]:
# Load the model
print("Loading model...")
llm = LLM(
    model=str(MODEL_PATH),
    dtype="float16",
    trust_remote_code=True,
    gpu_memory_utilization=0.9,
    skip_tokenizer_init=False,
)
print("Model loaded successfully!")


In [None]:
# Load questions from MATH-500 dataset
print(f"Loading {N_QUESTIONS} questions from MATH-500...")
with open(DATASET_PATH, 'r', encoding='utf-8') as f:
    data = json.load(f)

math500_questions = data.get("MATH-500", [])
if len(math500_questions) < N_QUESTIONS:
    print(f"Warning: Only {len(math500_questions)} questions available, using all.")
    questions = math500_questions
else:
    questions = math500_questions[:N_QUESTIONS]

print(f"Loaded {len(questions)} questions")


In [None]:
# Load prompt template
with open(PROMPT_TEMPLATE, 'r', encoding='utf-8') as f:
    prompt_template = f.read()

print("Prompt template loaded")


In [None]:
def extract_thinking_process(text: str) -> str:
    """
    Extract the thinking process from model output.
    Looks for content between <think> and </think>
    """
    # Pattern to match thinking process (redacted_reasoning tags)
    pattern = r'<think>(.*?)</think>'
    match = re.search(pattern, text, re.DOTALL)
    
    if match:
        return match.group(1).strip()
    
    # Fallback: if no closing tag, extract everything after opening tag
    pattern_start = r'<think>(.*)'
    match_start = re.search(pattern_start, text, re.DOTALL)
    if match_start:
        return match_start.group(1).strip()
    
    return ""

print("Helper function defined")


In [None]:
# Prepare prompts for all questions
print("Preparing prompts...")
batch_prompts = []
question_metadata = []

for i, question_data in enumerate(questions):
    question_text = question_data["question"]
    formatted_prompt = prompt_template.format(question=question_text)
    batch_prompts.append(formatted_prompt)
    question_metadata.append({
        "index": i,
        "question": question_text,
        "answer": question_data.get("answer", ""),
        "originalattr": question_data.get("originalattr", {})
    })

print(f"Prepared {len(batch_prompts)} prompts")


In [None]:
# Run forward pass and generate responses
print("Running forward pass on all questions...")
outputs = llm.generate(
    batch_prompts,
    SAMPLING_PARAMS,
    use_tqdm=True
)

print(f"Generated {len(outputs)} responses")


In [None]:
# Extract thinking processes and save them
print("Extracting thinking processes and saving...")
all_generated_tokens = []

for i, (metadata, output) in enumerate(zip(question_metadata, outputs)):
    # Get full generated text (all tokens)
    full_generation = output.outputs[0].text
    all_generated_tokens.append(full_generation)
    
    # Extract thinking process
    thinking_process = extract_thinking_process(full_generation)
    
    if thinking_process:
        # Save thinking process to file
        output_file = OUTPUT_DIR / f"thought_{i:04d}.txt"
        with open(output_file, 'w', encoding='utf-8') as f:
            f.write(thinking_process)
        
        if (i + 1) % 20 == 0:
            print(f"  Processed {i + 1}/{len(outputs)} questions...")
    else:
        print(f"  Warning: No thinking process found for question {i}")

print(f"\nSaved {len([t for t in all_generated_tokens if extract_thinking_process(t)])} thinking processes to {OUTPUT_DIR}")


In [None]:
# Save all generated tokens to a JSON file for reference
output_summary_file = OUTPUT_DIR / "generated_tokens_summary.json"
summary_data = {
    "total_questions": len(questions),
    "total_generated": len(all_generated_tokens),
    "questions_with_thinking": len([t for t in all_generated_tokens if extract_thinking_process(t)]),
    "generations": [
        {
            "index": i,
            "full_generation": gen,
            "thinking_process": extract_thinking_process(gen),
            "question": meta["question"],
            "answer": meta["answer"]
        }
        for i, (gen, meta) in enumerate(zip(all_generated_tokens, question_metadata))
    ]
}

with open(output_summary_file, 'w', encoding='utf-8') as f:
    json.dump(summary_data, f, indent=2, ensure_ascii=False)

print(f"Saved summary with all generated tokens to {output_summary_file}")
print("\nDone!")
