In [2]:
from datasets import load_dataset
import json, os, random
import dotenv
from huggingface_hub import login, hf_hub_download
# OPTIMIZED VERSION - Process subset for faster development
from tqdm import tqdm
import warnings

dotenv.load_dotenv()

HF_TOKEN = os.getenv("HUGGINGFACE_ACCESS_TOKEN")
login(token=HF_TOKEN)

OUTDIR = "direct_response"
os.makedirs(OUTDIR, exist_ok=True)
train_path = f"{OUTDIR}/train.jsonl"
val_path   = f"{OUTDIR}/val.jsonl"


In [3]:
mapping_file = hf_hub_download(
    repo_id="socratesft/SocSci210",
    filename="metadata/participant_mapping.json",
    repo_type="dataset"
)

with open(mapping_file, 'r') as f:
    participant_mapping = json.load(f)


In [4]:
# Suppress the metadata warning (this is harmless - just missing README metadata)
warnings.filterwarnings("ignore", message="Repo card metadata block was not found")

dataset = load_dataset("socratesft/SocSci210", token=HF_TOKEN)
# Verify the dataset loaded successfully
print(f"✓ Dataset loaded: {len(dataset['train'])} examples")
print(f"✓ Columns: {list(dataset['train'].features.keys())}") 

Repo card metadata block was not found. Setting CardData to empty.
Downloading data: 100%|██████████| 17/17 [00:27<00:00,  1.61s/files]
Generating train split: 100%|██████████| 2901390/2901390 [00:06<00:00, 429863.84 examples/s]

✓ Dataset loaded: 2901390 examples
✓ Columns: ['sample_id', 'participant', 'demographic', 'stimuli', 'response', 'condition_num', 'task_num', 'prompt', 'reasoning', 'study_id']





In [5]:
dataset = dataset["train"]

In [6]:
def filter_by_demographics(dataset, demographic_filters):
    """    
    Returns:
        Filtered dataset containing only examples that match ALL specified criteria
        and have ALL required demographic keys present
    
    Examples:
        # Filter for females only
        filtered = filter_by_demographics(ds, {"gender": "Female"})
        
        # Filter for married females with college education
        filtered = filter_by_demographics(ds, {
            "gender": "Female", 
            "marital_status": "Married",
            "education": "Post grad study/professional degree"
        })
    """
    def matches_criteria(example):
        # Check if demographic field exists
        if 'demographic' not in example or example['demographic'] is None:
            return False
            
        demographic = example['demographic']
        
        # Check if all required keys exist in demographic data
        for key in demographic_filters.keys():
            if key not in demographic:
                return False
        
        # Check if all criteria match
        for key, required_value in demographic_filters.items():
            actual_value = demographic.get(key)
            if actual_value != required_value:
                return False
                
        return True
    
    # Filter the dataset
    if hasattr(dataset, 'filter'):  # HuggingFace dataset
        return dataset.filter(matches_criteria)
    else:  # Regular list/iterable
        return [ex for ex in dataset if matches_criteria(ex)]

# Test the function with a sample
if 'dataset' in globals():
    print(f"\n📊 Dataset info:")
    print(f"Total examples: {len(dataset):,}")
    
    # Test filter for females
    try:
        filtered_dataset = filter_by_demographics(dataset, {"gender": "Male"})
    
        print(f"Filtered dataset size: {len(filtered_dataset):,}")
        
    except Exception as e:
        print(f"Error testing filter: {e}")
else:
    print("\ndataset' not loaded yet. Run the previous cells first!")


📊 Dataset info:
Total examples: 2,901,390


Filter: 100%|██████████| 2901390/2901390 [01:08<00:00, 42245.71 examples/s]

Filtered dataset size: 791,278





In [7]:

# USE FILTERED DATASET FROM ABOVE HERE
subset = filtered_dataset
print(f"Full subset size: {len(subset):,} examples")


# Optional single system instruction
SYSTEM_TXT =  "You are simulating a survey respondent. Answer exactly as instructed, following the specified response format without additional commentary."


# Create sets for faster lookup
seen_participants = set(participant_mapping['seen'])
unseen_participants = set(participant_mapping['unseen'])

train_rows = []
val_rows = []
processed = 0
skipped = 0

for ex in tqdm(subset, desc="Converting to chat format"):
    prompt = ex.get("prompt") .strip()
    response = ex.get("response")
    participant = ex.get("study_id")
    
    if not prompt or response is None:
        skipped += 1
        continue

    assistant = str(response).strip()
    message_obj = {
        "messages": [
            {"role": "system", "content": SYSTEM_TXT},
            {"role": "user", "content": prompt},
            {"role": "assistant", "content": assistant}
        ]
    }
    
    # Split based on participant mapping
    if participant in seen_participants:
        train_rows.append(message_obj)
    elif participant in unseen_participants:
        val_rows.append(message_obj)
    else:
        # Skip examples from participants not in either set
        skipped += 1
        continue
    
    processed += 1

print(f"✓ Processed: {processed:,} examples, Skipped: {skipped:,} examples")
print(f"Split based on participant_mapping: {len(train_rows):,} train (seen), {len(val_rows):,} validation (unseen)")

print(f"Writing files: {len(train_rows):,} train, {len(val_rows):,} validation...")

# Write files
with open(train_path, "w") as f:
    for r in train_rows: 
        f.write(json.dumps(r, ensure_ascii=False) + "\n")

with open(val_path, "w") as f:
    for r in val_rows:   
        f.write(json.dumps(r, ensure_ascii=False) + "\n")

print(f"✓ Successfully wrote:")
print(f"  - {train_path}: {len(train_rows):,} examples") 
print(f"  - {val_path}: {len(val_rows):,} examples")


Full subset size: 791,278 examples


Converting to chat format: 100%|██████████| 791278/791278 [01:31<00:00, 8654.94it/s]


✓ Processed: 791,278 examples, Skipped: 0 examples
Split based on participant_mapping: 678,797 train (seen), 112,481 validation (unseen)
Writing files: 678,797 train, 112,481 validation...
✓ Successfully wrote:
  - direct_response/train.jsonl: 678,797 examples
  - direct_response/val.jsonl: 112,481 examples
