# 01 · Seed cohort generation
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ORG/Fallrisk-gait/blob/main/datasets/fallrisk/notebooks/01_seed.ipynb)

This notebook fabricates a 2k-row seed cohort with correlated demographics, health indicators, and mobility metrics. It also applies the `multi_feature_v1` label policy and saves the result to `data/seed_fallrisk.csv`.


In [None]:
from pathlib import Path
import csv
import math
import random
from statistics import mean
from typing import Dict, Iterable, List


def locate_repo_root(max_depth: int = 6) -> Path:
    here = Path.cwd()
    for _ in range(max_depth):
        if (here / 'datasets').exists() and (here / 'data').exists():
            return here
        if here.parent == here:
            break
        here = here.parent
    return Path.cwd()


ROOT = locate_repo_root()
DATA_DIR = ROOT / 'data'
DATA_DIR.mkdir(exist_ok=True)
random.seed(7)


In [None]:
def clamp(value: float, lower: float, upper: float) -> float:
    return max(lower, min(upper, value))


POLICY_B_CONFIG: Dict[str, Dict[str, float | str]] = {
    'gait_speed_mps': {'direction': 'low', 'moderate_pct': 0.25, 'high_pct': 0.10},
    'stride_length_m': {'direction': 'low', 'moderate_pct': 0.25, 'high_pct': 0.10},
    'cadence_spm': {'direction': 'low', 'moderate_pct': 0.25, 'high_pct': 0.10},
    'stride_time_var': {'direction': 'high', 'moderate_pct': 0.75, 'high_pct': 0.90},
    'double_support_pct': {'direction': 'high', 'moderate_pct': 0.75, 'high_pct': 0.90},
    'symmetry_index': {'direction': 'high', 'moderate_pct': 0.75, 'high_pct': 0.90},
    'turn_time_s': {'direction': 'high', 'moderate_pct': 0.75, 'high_pct': 0.90},
    'sit_to_stand_s': {'direction': 'high', 'moderate_pct': 0.75, 'high_pct': 0.90},
    'stand_to_sit_s': {'direction': 'high', 'moderate_pct': 0.75, 'high_pct': 0.90},
}


RISK_ORDER = {'low': 0, 'moderate': 1, 'high': 2}


def percentile(sorted_values: List[float], pct: float) -> float:
    if not sorted_values:
        return float('nan')
    if pct <= 0:
        return sorted_values[0]
    if pct >= 1:
        return sorted_values[-1]
    k = (len(sorted_values) - 1) * pct
    lower_idx = math.floor(k)
    upper_idx = math.ceil(k)
    if lower_idx == upper_idx:
        return sorted_values[int(k)]
    lower_val = sorted_values[lower_idx]
    upper_val = sorted_values[upper_idx]
    return lower_val + (upper_val - lower_val) * (k - lower_idx)


def generate_seed_row(idx: int) -> dict:
    base_age = random.gauss(72, 7)
    age = clamp(base_age, 55, 95)
    sex = 'Female' if random.random() < 0.62 else 'Male'

    strength = clamp(random.gauss(65 - 0.35 * (age - 70), 7.5), 30, 88)
    agility_shift = -0.02 if sex == 'Female' else 0.0
    gait_baseline = random.gauss(1.05 + agility_shift, 0.12)
    gait_speed = clamp(gait_baseline - 0.0035 * (age - 70) + 0.0025 * (strength - 60), 0.45, 1.55)

    stride_length = clamp(random.gauss(1.22, 0.08) + 0.06 * (gait_speed - 1.0), 0.75, 1.60)
    cadence = clamp(random.gauss(108, 6) + 14 * (gait_speed - 1.0), 70, 150)
    stride_time = 120.0 / cadence
    stride_time_var = clamp(
        random.gauss(0.012 + 0.06 * max(0.0, stride_time - 1.0) + 0.015 * max(0.0, 0.95 - gait_speed) + 0.0004 * (age - 70), 0.006),
        0.008,
        0.100,
    )
    double_support = clamp(random.gauss(28 + 18 * (0.95 - gait_speed), 4.5), 12, 55)
    symmetry_index = clamp(random.gauss(0.07 + 0.025 * (0.9 - gait_speed), 0.025), 0.0, 0.30)
    turn_time = clamp(random.gauss(2.4 + 0.30 * (0.95 - gait_speed) + 0.010 * (age - 70), 0.32), 1.2, 7.0)
    sit_to_stand = clamp(random.gauss(1.7 + 0.38 * (0.92 - gait_speed) + 0.009 * (age - 70), 0.30), 1.2, 5.0)
    stand_to_sit = clamp(random.gauss(1.6 + 0.32 * (0.92 - gait_speed) + 0.008 * (age - 70), 0.30), 1.0, 4.5)

    walk_distance = 3.0
    walk_time = walk_distance / max(gait_speed, 0.35)
    tug = sit_to_stand + walk_time + turn_time + stand_to_sit + random.gauss(0, 0.4)
    risk_factor = max(0.0, 0.9 - gait_speed)
    tug += risk_factor * (6.0 + random.uniform(0.0, 2.0))
    tug += max(0.0, double_support - 32) * 0.12
    tug += max(0.0, symmetry_index - 0.12) * 18
    tug = clamp(tug, 6.0, 28.0)

    return {
        'participant_id': f'SEED_{idx:05d}',
        'age_years': round(age, 1),
        'sex': sex,
        'gait_speed_mps': round(gait_speed, 3),
        'stride_length_m': round(stride_length, 3),
        'cadence_spm': round(cadence, 1),
        'stride_time_var': round(stride_time_var, 4),
        'double_support_pct': round(double_support, 2),
        'symmetry_index': round(symmetry_index, 3),
        'turn_time_s': round(turn_time, 3),
        'sit_to_stand_s': round(sit_to_stand, 3),
        'stand_to_sit_s': round(stand_to_sit, 3),
        'tug_seconds': round(tug, 3),
    }


def policy_a_tug(tug_seconds: float, thresholds: Dict[str, float] | None = None) -> tuple[str, Dict[str, object]]:
    thresholds = thresholds or {'moderate': 11.0, 'high': 13.5}
    if tug_seconds >= thresholds['high']:
        risk = 'high'
    elif tug_seconds >= thresholds['moderate']:
        risk = 'moderate'
    else:
        risk = 'low'
    return risk, {
        'policy': 'A',
        'trigger': risk != 'low',
        'tug_seconds': round(tug_seconds, 3),
        'thresholds': thresholds,
    }


def compute_policy_b_percentiles(
    rows: List[dict], config: Dict[str, Dict[str, float | str]] = POLICY_B_CONFIG
) -> Dict[str, Dict[str, float]]:
    percentiles = {}
    for feature, settings in config.items():
        values = sorted(row[feature] for row in rows)
        percentiles[feature] = {
            'moderate': percentile(values, settings['moderate_pct']),
            'high': percentile(values, settings['high_pct']),
            'moderate_pct': settings['moderate_pct'],
            'high_pct': settings['high_pct'],
        }
    return percentiles


def policy_b_multi_feature(
    row: dict,
    percentiles: Dict[str, Dict[str, float]],
    config: Dict[str, Dict[str, float | str]] = POLICY_B_CONFIG,
) -> tuple[str, Dict[str, object]]:
    high_hits: List[str] = []
    moderate_hits: List[str] = []

    for feature, settings in config.items():
        value = row[feature]
        cuts = percentiles[feature]
        direction = settings['direction']
        if direction == 'low':
            if value <= cuts['high']:
                high_hits.append(feature)
            elif value <= cuts['moderate']:
                moderate_hits.append(feature)
        else:
            if value >= cuts['high']:
                high_hits.append(feature)
            elif value >= cuts['moderate']:
                moderate_hits.append(feature)

    score = 2 * len(high_hits) + len(moderate_hits)
    if score >= 4:
        risk = 'high'
    elif score >= 2:
        risk = 'moderate'
    else:
        risk = 'low'

    return risk, {
        'policy': 'B',
        'trigger': risk != 'low',
        'high_hits': high_hits,
        'moderate_hits': moderate_hits,
        'score': score,
    }


def combine_risk_levels(levels: Iterable[str]) -> str:
    return max(levels, key=lambda lvl: RISK_ORDER[lvl])


def assign_fall_risk(rows: List[dict]) -> Dict[str, Dict[str, float]]:
    percentiles = compute_policy_b_percentiles(rows)
    for row in rows:
        policy_a_level, policy_a_details = policy_a_tug(row['tug_seconds'])
        policy_b_level, policy_b_details = policy_b_multi_feature(row, percentiles)
        final_level = combine_risk_levels([policy_a_level, policy_b_level])

        row['policy_a_risk'] = policy_a_level
        row['policy_b_risk'] = policy_b_level
        row['fall_risk'] = final_level
        row['label_high_fall_risk'] = 1 if final_level == 'high' else 0
        row['policy_a_trigger'] = policy_a_details['trigger']
        row['policy_a_threshold_moderate'] = policy_a_details['thresholds']['moderate']
        row['policy_a_threshold_high'] = policy_a_details['thresholds']['high']
        row['policy_b_trigger'] = policy_b_details['trigger']
        row['policy_b_high_feature_hits'] = '|'.join(policy_b_details['high_hits']) if policy_b_details['high_hits'] else 'none'
        row['policy_b_moderate_feature_hits'] = (
            '|'.join(policy_b_details['moderate_hits']) if policy_b_details['moderate_hits'] else 'none'
        )
        row['policy_b_trigger_count'] = len(policy_b_details['high_hits']) + len(policy_b_details['moderate_hits'])
        row['policy_b_score'] = policy_b_details['score']

        for feature, cuts in percentiles.items():
            moderate_pct = int(round(cuts['moderate_pct'] * 100))
            high_pct = int(round(cuts['high_pct'] * 100))
            row[f'policy_b_{feature}_cutoff_p{moderate_pct}'] = round(cuts['moderate'], 4)
            row[f'policy_b_{feature}_cutoff_p{high_pct}'] = round(cuts['high'], 4)

    return percentiles


In [None]:
seed_rows = [generate_seed_row(i + 1) for i in range(2000)]
policy_b_percentiles = assign_fall_risk(seed_rows)

fieldnames = list(seed_rows[0].keys())
output_path = DATA_DIR / 'seed_fallrisk.csv'
with output_path.open('w', newline='') as f:
    writer = csv.DictWriter(f, fieldnames=fieldnames)
    writer.writeheader()
    writer.writerows(seed_rows)

high_rate = mean(1 if r['fall_risk'] == 'high' else 0 for r in seed_rows)
moderate_rate = mean(1 if r['fall_risk'] == 'moderate' else 0 for r in seed_rows)
print(f'Seed cohort saved to {output_path.resolve()}')
print(f'Total rows: {len(seed_rows)} | High-risk prevalence: {high_rate:.3f} | Moderate-risk prevalence: {moderate_rate:.3f}')
print('Policy B percentile metadata:')
for feature, settings in POLICY_B_CONFIG.items():
    cuts = policy_b_percentiles[feature]
    mod_pct = int(round(settings['moderate_pct'] * 100))
    high_pct = int(round(settings['high_pct'] * 100))
    direction = 'lower is higher risk' if settings['direction'] == 'low' else 'higher is higher risk'
    print(f"  {feature}: p{mod_pct}={cuts['moderate']:.3f}, p{high_pct}={cuts['high']:.3f} ({direction})")
