In [2]:
# Create data/pairs.csv with 450 random (map_id, prompt_id,operator,param) rows.
# - map_id values: subfolder names under data/samples/pairs/
# - prompt_id values: from data/prompts.csv (column 'prompt_id' or 'id')
# - Each prompt_id is used exactly once; map_id is sampled with replacement.

import os
import numpy as np
import pandas as pd

# --- paths (relative to the notebook) ---
ROOT_DIR = "../data/input/samples/pairs"      # where map subfolders live
PROMPTS_CSV = "../data/input/prompts.csv"     # your prompts table (must exist)
META_CSV     = "../data/input/samples/metadata/meta.csv" # metadata with operator + param_value
OUTPUT_DIR = "../data/input"
OUTPUT_FILE = os.path.join(OUTPUT_DIR, "pairs.csv")

N_ROWS = 450        # total pairs
SEED = 42           # set None for non-deterministic

rng = np.random.default_rng(SEED)

# --- 1. Collect map_ids from folder names ---
if not os.path.isdir(ROOT_DIR):
    raise FileNotFoundError(f"Folder not found: {ROOT_DIR}")

map_ids = sorted([d for d in os.listdir(ROOT_DIR) if os.path.isdir(os.path.join(ROOT_DIR, d))])
if not map_ids:
    raise RuntimeError(f"No subfolders found under {ROOT_DIR} (expected map_id folders).")

# Normalize map_ids (zero-padded 4 digits)
map_ids = [str(m).zfill(4) for m in map_ids]

# --- 2. Read prompt_ids from prompts.csv ---
if not os.path.isfile(PROMPTS_CSV):
    raise FileNotFoundError(f"prompts.csv not found at: {PROMPTS_CSV}")

df_prompts = pd.read_csv(PROMPTS_CSV)
if "prompt_id" in df_prompts.columns:
    prompt_ids = df_prompts["prompt_id"].astype(str).str.strip().tolist()
elif "id" in df_prompts.columns:
    prompt_ids = df_prompts["id"].astype(str).str.strip().tolist()
else:
    raise ValueError("prompts.csv must have a 'prompt_id' or 'id' column.")

prompt_ids = [p for p in prompt_ids if p]  # drop blanks
unique_prompts = sorted(set(prompt_ids))
if len(unique_prompts) < N_ROWS:
    raise ValueError(f"Need {N_ROWS} unique prompts; found only {len(unique_prompts)}.")

prompt_ids = rng.choice(unique_prompts, size=N_ROWS, replace=False)
chosen_maps = rng.choice(map_ids, size=N_ROWS, replace=True)

pairs = pd.DataFrame({"map_id": chosen_maps, "prompt_id": prompt_ids})

# --- 3. Load metadata (operator + param_value) ---
if not os.path.isfile(META_CSV):
    raise FileNotFoundError(f"meta.csv not found at: {META_CSV}")

meta = pd.read_csv(META_CSV)

# Derive map_id (4-digit) from sample_id or input paths
if "sample_id" in meta.columns:
    meta["map_id"] = meta["sample_id"].astype(str).str.extract(r"(\d+)")[0].str.zfill(4)
else:
    # fallback: extract from input_geojson path
    meta["map_id"] = meta["input_geojson"].astype(str).str.extract(r"/(\d{4})_input\.geojson$")[0]

# Keep relevant columns
meta_small = meta[["map_id", "operator", "param_value"]].copy()
meta_small = meta_small.rename(columns={"param_value": "param"})

# --- 4. Merge operator + param into pairs ---
pairs = pairs.merge(meta_small, on="map_id", how="left")

missing_ops   = pairs["operator"].isna().sum()
missing_param = pairs["param"].isna().sum()
if missing_ops or missing_param:
    raise ValueError(f"Missing labels for some map_ids: operator={missing_ops}, param={missing_param}")

# --- 5. Save final CSV ---
os.makedirs(OUTPUT_DIR, exist_ok=True)
pairs.to_csv(OUTPUT_FILE, index=False)

print(f"[OK] Saved {len(pairs)} rows to {OUTPUT_FILE}")
print(f"Columns: {list(pairs.columns)}")
print(f"Unique maps used: {pairs['map_id'].nunique()} / {len(map_ids)} available")
print(f"Unique prompts used: {pairs['prompt_id'].nunique()} (should be {N_ROWS})")
print(pairs.head(10))

[OK] Saved 450 rows to ../data/input/pairs.csv
Columns: ['map_id', 'prompt_id', 'operator', 'param']
Unique maps used: 242 / 300 available
Unique prompts used: 450 (should be 450)
  map_id prompt_id   operator    param
0   1203      p386   displace    1.316
1   0080      p203  aggregate    6.482
2   0171      p096     select   80.931
3   0523      p034  aggregate   11.525
4   1579      p031     select  149.262
5   0948      p411   displace    2.785
6   1344      p199   simplify    6.198
7   0867      p073  aggregate    8.208
8   0469      p416   simplify    6.296
9   0804      p413     select   77.560
