# 01 · Seed cohort generation
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ORG/Fallrisk-gait/blob/main/datasets/fallrisk/notebooks/01_seed.ipynb)

This notebook fabricates a 2k-row seed cohort with correlated demographics, health indicators, and mobility metrics. It also applies the `multi_feature_v1` label policy and saves the result to `data/seed_fallrisk.csv`.


In [None]:
from pathlib import Path
import csv
import math
import random
from statistics import mean

def locate_repo_root(max_depth: int = 6) -> Path:
    here = Path.cwd()
    for _ in range(max_depth):
        if (here / 'datasets').exists() and (here / 'data').exists():
            return here
        if here.parent == here:
            break
        here = here.parent
    return Path.cwd()

ROOT = locate_repo_root()
DATA_DIR = ROOT / 'data'
DATA_DIR.mkdir(exist_ok=True)
random.seed(7)


In [None]:
def generate_seed_row(idx: int) -> dict:
    base_age = random.gauss(72, 7)
    age = max(55, min(95, base_age))
    sex = 'Female' if random.random() < 0.62 else 'Male'
    bmi_latent = random.gauss(27, 3)
    bmi = max(18, min(42, bmi_latent + 0.5 if sex == 'Female' else bmi_latent - 0.3))
    chronic = max(0, int(random.gauss(2.2, 1.4)))
    med = max(0, int(random.gauss(3.5 + 0.6 * chronic, 1.2)))
    muscle = max(20, min(90, random.gauss(70 - 0.4 * (age - 70) - 1.5 * chronic, 6)))
    gait_base = random.gauss(1.05, 0.12)
    gait = max(0.4, min(1.6, gait_base - 0.004 * (age - 70) + 0.002 * (muscle - 60) - 0.015 * chronic))
    stride = max(60, min(140, random.gauss(120, 8) + 5 * (gait - 1.0)))
    sway = max(0.5, min(6.0, random.gauss(2.2 + 0.02 * (age - 70), 0.6)))
    falls = max(0, min(4, int(random.random() < 0.28 + 0.12 * max(gait < 0.8, 0) + 0.07 * max(chronic - 2, 0))))
    device_prob = 0.2 + 0.15 * (gait < 0.85) + 0.1 * (falls > 0)
    if random.random() < device_prob:
        assistive_device = 'Walker' if random.random() < 0.45 else 'Cane'
    else:
        assistive_device = 'None'
    dual_task = max(0.0, min(45.0, random.gauss(12 + 6 * (gait < 0.9) + 0.15 * med, 5)))
    fear_score = max(0, min(28, int(random.gauss(11 + 4 * (falls > 0) + 0.12 * dual_task, 4))))
    reaction_time = max(300, min(900, random.gauss(620 + 2 * (age - 70) + 15 * (falls > 0), 60)))
    systolic = max(95, min(190, random.gauss(128 + 0.6 * (age - 65) + 1.2 * chronic, 12)))
    tug = 8.5 + 0.065 * (age - 65) + 0.09 * (bmi - 26) - 2.1 * (gait - 1.0) + 0.18 * max(0, sway - 2.5)
    tug += 0.07 * dual_task + 0.04 * fear_score + 0.45 * falls + (2.5 if assistive_device != 'None' else 0)
    tug += random.gauss(0, 0.9)
    tug = max(6.5, min(35, tug))
    high_risk = int(tug >= 13.5 or gait < 0.8 or falls >= 1 or dual_task >= 22 or assistive_device != 'None')
    if high_risk:
        risk_level = 'high'
    else:
        moderate = tug >= 11.2 or fear_score >= 16 or med >= 6 or chronic >= 3
        risk_level = 'moderate' if moderate else 'low'
    return {
        'participant_id': f'SEED_{idx:05d}',
        'age_years': round(age, 1),
        'sex': sex,
        'bmi': round(bmi, 1),
        'systolic_bp': round(systolic, 1),
        'gait_speed_m_s': round(gait, 3),
        'stride_length_cm': round(stride, 1),
        'postural_sway_cm': round(sway, 3),
        'medication_count': med,
        'chronic_conditions': chronic,
        'past_falls_6mo': falls,
        'assistive_device': assistive_device,
        'dual_task_cost_percent': round(dual_task, 2),
        'fear_of_falling_score': fear_score,
        'muscle_strength_score': round(muscle, 1),
        'reaction_time_ms': round(reaction_time, 1),
        'tug_seconds': round(tug, 3),
        'label_high_fall_risk': high_risk,
        'label_risk_level': risk_level
    }


In [None]:
seed_rows = [generate_seed_row(i + 1) for i in range(2000)]
fieldnames = list(seed_rows[0].keys())
output_path = DATA_DIR / 'seed_fallrisk.csv'
with output_path.open('w', newline='') as f:
    writer = csv.DictWriter(f, fieldnames=fieldnames)
    writer.writeheader()
    writer.writerows(seed_rows)

high_rate = mean(r['label_high_fall_risk'] for r in seed_rows)
print(f'Seed cohort saved to {output_path.resolve()}')
print(f'Total rows: {len(seed_rows)} | High-risk prevalence: {high_rate:.3f}')
