In [None]:
from datasets import load_dataset
import json, os, random
import dotenv
from huggingface_hub import login, hf_hub_download
# OPTIMIZED VERSION - Process subset for faster development
from tqdm import tqdm
import warnings

dotenv.load_dotenv()

HF_TOKEN = os.getenv("HUGGINGFACE_ACCESS_TOKEN")
login(token=HF_TOKEN)

OUTDIR = "reasoning_traces"
os.makedirs(OUTDIR, exist_ok=True)
train_path = f"{OUTDIR}/train.jsonl"
val_path   = f"{OUTDIR}/val.jsonl"


In [None]:
mapping_file = hf_hub_download(
    repo_id="socratesft/SocSci210",
    filename="metadata/participant_mapping.json",
    repo_type="dataset"
)

with open(mapping_file, 'r') as f:
    participant_mapping = json.load(f)


In [None]:
# Suppress the metadata warning (this is harmless - just missing README metadata)
import warnings
warnings.filterwarnings("ignore", message="Repo card metadata block was not found")

dataset = load_dataset("socratesft/SocSci210", token=HF_TOKEN)
# Verify the dataset loaded successfully
print(f"✓ Dataset loaded: {len(dataset['train'])} examples")
print(f"✓ Columns: {list(dataset['train'].features.keys())}") 

Repo card metadata block was not found. Setting CardData to empty.
Using the latest cached version of the dataset since socratesft/SocSci210 couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /Users/amader/.cache/huggingface/datasets/socratesft___soc_sci210/default/0.0.0/048481111a4425ed83dc0eacf15f8431f252b21a (last modified on Fri Sep 26 22:43:56 2025).


ValueError: Column 'train' doesn't exist.

In [35]:
dataset

Dataset({
    features: ['sample_id', 'participant', 'demographic', 'stimuli', 'response', 'condition_num', 'task_num', 'prompt', 'reasoning', 'study_id'],
    num_rows: 2901390
})

In [None]:
def filter_by_demographics(dataset, demographic_filters):
    """    
    Returns:
        Filtered dataset containing only examples that match ALL specified criteria
        and have ALL required demographic keys present
    
    Examples:
        # Filter for females only
        filtered = filter_by_demographics(ds, {"gender": "Female"})
        
        # Filter for married females with college education
        filtered = filter_by_demographics(ds, {
            "gender": "Female", 
            "marital_status": "Married",
            "education": "Post grad study/professional degree"
        })
    """
    def matches_criteria(example):
        # Check if demographic field exists
        if 'demographic' not in example or example['demographic'] is None:
            return False
            
        demographic = example['demographic']
        
        # Check if all required keys exist in demographic data
        for key in demographic_filters.keys():
            if key not in demographic:
                return False
        
        # Check if all criteria match
        for key, required_value in demographic_filters.items():
            actual_value = demographic.get(key)
            if actual_value != required_value:
                return False
                
        return True
    
    # Filter the dataset
    if hasattr(dataset, 'filter'):  # HuggingFace dataset
        return dataset.filter(matches_criteria)
    else:  # Regular list/iterable
        return [ex for ex in dataset if matches_criteria(ex)]

# Test the function with a sample
if 'ds' in globals():
    print(f"\n📊 Dataset info:")
    print(f"Total examples: {len(dataset):,}")
    
    # Test filter for females
    try:
        female_subset = filter_by_demographics(dataset, {"gender": "Female"})
        print(f"Female examples: {len(female_subset):,}")
        
        # Show a sample demographic
        if len(female_subset) > 0:
            sample_demo = female_subset[0]['demographic']
            print(f"Sample female demographic keys: {list(sample_demo.keys())}")
    except Exception as e:
        print(f"Error testing filter: {e}")
else:
    print("\n⚠️  Dataset 'ds' not loaded yet. Run the previous cells first!")

Function defined! Example usage:

# Filter for females only:
filtered_females = filter_by_demographics(ds, {"gender": "Female"})

# Filter for married people in metro areas:
filtered_married_metro = filter_by_demographics(ds, {
    "marital_status": "Married",
    "metro_status": "Metro Area"
})

# Filter by multiple criteria:
filtered_complex = filter_by_demographics(ds, {
    "gender": "Female",
    "education": "Post grad study/professional degree",
    "employment": "Employed as paid employee"
})

📊 Dataset info:
Total examples: 2,901,390


Filter: 100%|██████████| 2901390/2901390 [01:07<00:00, 43118.19 examples/s]

Female examples: 791,278
Sample female demographic keys: ['age', 'education', 'employment', 'ethnicity', 'gender', 'household_size', 'housing_ownership', 'housing_type', 'ideology', 'income', 'internet_access', 'location', 'marital_status', 'metro_status', 'party_id', 'phone_service']





In [None]:

# USE FILTERED DATASET FROM ABOVE HERE
ds = female_subset
print(f"Full dataset size: {len(ds):,} examples")


# Optional single system instruction
SYSTEM_TXT = (
    "You are simulating a survey respondent. You are to answer exactly as instructed, "
    "but also include your reasoning (5 sentences or less) before you output your answer. Please follow the exact output format below.\n"
    "### Output format\n<trace>\n...your step-by-step reasoning here...\n</trace>\n"
    "PREDICTION: <verbatim answer>  (conclude with predicted answer, use exactly the option label/number with no extra commentary)"
)

def build_assistant_text(reasoning, response):
    r = (reasoning or "").strip()
    y = str(response).strip()
    if r:
        return f"<trace>{r}</trace>\nPREDICTION: {y}"
    return f"PREDICTION: {y}"

# Process SUBSET for faster development - change this number as needed
# max_examples = 5000  # Start with 5k examples instead of 2.9M!
# print(f"Processing first {max_examples:,} examples for development...")

# Select a subset for processing
# subset = ds.select(range(min(max_examples, len(ds))))

subset = ds

# Create sets for faster lookup
seen_participants = set(participant_mapping['seen'])
unseen_participants = set(participant_mapping['unseen'])

train_rows = []
val_rows = []
processed = 0
skipped = 0

for ex in tqdm(subset, desc="Converting to chat format"):
    prompt = ex.get("prompt") .strip()
    response = ex.get("response")
    reasoning = (ex.get("reasoning") or "").strip()
    study_id = ex.get("study_id")
    
    if not prompt or response is None:
        skipped += 1
        continue

    assistant = build_assistant_text(reasoning, response)
    message_obj = {
        "messages": [
            {"role": "system", "content": SYSTEM_TXT},
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": assistant}
        ]
    }
    
    # Split based on participant mapping
    if study_id in seen_participants:
        train_rows.append(message_obj)
    elif study_id in unseen_participants:
        val_rows.append(message_obj)
    else:
        # Skip examples from participants not in either set
        skipped += 1
        continue
    
    processed += 1

print(f"✓ Processed: {processed:,} examples, Skipped: {skipped:,} examples")
print(f"Split based on participant_mapping: {len(train_rows):,} train (seen), {len(val_rows):,} validation (unseen)")

print(f"Writing files: {len(train_rows):,} train, {len(val_rows):,} validation...")

# Write files
with open(train_path, "w") as f:
    for r in train_rows: 
        f.write(json.dumps(r, ensure_ascii=False) + "\n")

with open(val_path, "w") as f:
    for r in val_rows:   
        f.write(json.dumps(r, ensure_ascii=False) + "\n")

print(f"✓ Successfully wrote:")
print(f"  - {train_path}: {len(train_rows):,} examples") 
print(f"  - {val_path}: {len(val_rows):,} examples")


Full dataset size: 791,278 examples


Converting to chat format: 100%|██████████| 791278/791278 [01:03<00:00, 12510.42it/s]


✓ Processed: 791,278 examples, Skipped: 0 examples
Writing files: 775,453 train, 15,825 validation...
✓ Successfully wrote:
  - ax_data/train.jsonl: 775,453 examples
  - ax_data/val.jsonl: 15,825 examples
