# Tolerance Model: Neuro Bucket Inference

Uses:
- `../drugs.json` for per-substance metadata
- `baseline.json` for the canonical bucket set
- `inspo.json` as a reference target set to calibrate against

Outputs:
- `outputs/tolerance_neuro_buckets.json` (single JSON document; JSONB-ready)

In [92]:
from __future__ import annotations

import json
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd

from sklearn.feature_extraction import DictVectorizer
from sklearn.multioutput import MultiOutputRegressor
from sklearn.linear_model import Ridge
from sklearn.pipeline import Pipeline

RANDOM_STATE = 42
rng = np.random.default_rng(RANDOM_STATE)

HERE = Path.cwd()
DRUGS_PATH = Path('..') / 'drugs.json'
BASELINE_PATH = Path('baseline.json')
INSPO_PATH = Path('inspo.json')
# Shared canonicalization config (exclude/aliases/groups)
YAML_PATH = Path('..') / 'drug_interaction.yaml'
OUTPUT_DIR = Path('outputs')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_JSON = OUTPUT_DIR / 'tolerance_neuro_buckets.json'

print('DRUGS_PATH:', DRUGS_PATH.resolve())
print('BASELINE_PATH:', BASELINE_PATH.resolve())
print('INSPO_PATH:', INSPO_PATH.resolve())
print('YAML_PATH:', YAML_PATH.resolve())
print('OUTPUT_JSON:', OUTPUT_JSON.resolve())

DRUGS_PATH: C:\Users\USER\dev\code\mobile_drug_use_app\backend\ML\drugs.json
BASELINE_PATH: C:\Users\USER\dev\code\mobile_drug_use_app\backend\ML\drug_tolerance_model\baseline.json
INSPO_PATH: C:\Users\USER\dev\code\mobile_drug_use_app\backend\ML\drug_tolerance_model\inspo.json
YAML_PATH: C:\Users\USER\dev\code\mobile_drug_use_app\backend\ML\drug_interaction.yaml
OUTPUT_JSON: C:\Users\USER\dev\code\mobile_drug_use_app\backend\ML\drug_tolerance_model\outputs\tolerance_neuro_buckets.json


In [93]:
def load_json(path: Path) -> dict:
    with open(path, 'r', encoding='utf-8') as f:
        return json.load(f)

drugs_raw = load_json(DRUGS_PATH)
baseline = load_json(BASELINE_PATH)
inspo = load_json(INSPO_PATH)

BUCKETS: List[str] = list((baseline.get('buckets') or {}).keys())
if not BUCKETS:
    raise ValueError('No buckets found in baseline.json')

print('✓ Loaded drugs:', len(drugs_raw))
print('✓ Loaded baseline buckets:', BUCKETS)
print('✓ Loaded inspo substances:', len((inspo.get('substances') or {})))

✓ Loaded drugs: 555
✓ Loaded baseline buckets: ['stimulant', 'serotonin_release', 'serotonin_psychedelic', 'gaba', 'opioid', 'nmda', 'cannabinoid']
✓ Loaded inspo substances: 123


In [94]:
# Load shared YAML canonicalization (exclude/aliases/groups)
import yaml

def normalize_name(name: str) -> str:
    return (name or '').strip().lower()

def load_yaml(path: Path) -> dict:
    if not path.exists():
        return {}
    with open(path, 'r', encoding='utf-8') as f:
        obj = yaml.safe_load(f)
    return obj if isinstance(obj, dict) else {}

def _yaml_list(cfg: dict, *keys: str) -> List[str]:
    out: List[str] = []
    for k in keys:
        v = cfg.get(k)
        if isinstance(v, list):
            out.extend([x for x in v if isinstance(x, str)])
    return out

TOL_CFG = load_yaml(YAML_PATH)

# YAML schema supports multiple keys (user-editable)
# - exclude from all: excluded everywhere
# - exlude from neuro-bucket: excluded only from tolerance model (typo preserved)
EXCLUDE_ALL = _yaml_list(TOL_CFG, 'exclude', 'exclude from all', 'exclude_from_all')
EXCLUDE_NEURO = _yaml_list(TOL_CFG, 'exclude from neuro-bucket', 'exlude from neuro-bucket', 'exclude_from_neuro_bucket', 'exclude_from_neuro-bucket')

TOL_EXCLUDE_SET = {normalize_name(x) for x in (EXCLUDE_ALL + EXCLUDE_NEURO) if isinstance(x, str)}
TOL_SEPARATE_SET = {normalize_name(x) for x in (TOL_CFG.get('separate') or []) if isinstance(x, str)}
TOL_ALIAS_MAP = {
    normalize_name(k): normalize_name(v)
    for k, v in (TOL_CFG.get('aliases') or {}).items()
    if isinstance(k, str) and isinstance(v, str)
}

# Groups: we do NOT merge for tolerance; we use them to share the same bucket *set* across members.
TOL_GROUPS: Dict[str, dict] = {}
TOL_MEMBER_TO_GROUP: Dict[str, str] = {}
for group_name, g in (TOL_CFG.get('groups') or {}).items():
    if not isinstance(g, dict):
        continue
    group_norm = normalize_name(group_name)
    canon = normalize_name(g.get('canonical', group_name))
    members = []
    for m in (g.get('members') or []):
        if isinstance(m, str):
            members.append(normalize_name(m))
    # also treat the group key itself as a member label
    if group_norm not in members:
        members.append(group_norm)
    if canon not in members:
        members.append(canon)

    TOL_GROUPS[group_norm] = {
        'canonical': canon,
        'members': sorted(set(members)),
    }
    for m in TOL_GROUPS[group_norm]['members']:
        # Respect separate: keep separate members un-grouped
        if m in TOL_SEPARATE_SET:
            continue
        TOL_MEMBER_TO_GROUP[m] = group_norm

# Map from normalized drugs.json key -> original key
DRUG_KEY_BY_NORM = {normalize_name(k): k for k in drugs_raw.keys()}

def is_excluded(name: str) -> bool:
    return normalize_name(name) in TOL_EXCLUDE_SET

def apply_alias(name: str) -> str:
    n = normalize_name(name)
    return TOL_ALIAS_MAP.get(n, n)

def resolve_drugs_key(name: str) -> Optional[str]:
    """Resolve name -> drugs.json key after aliasing; None if excluded."""
    n = apply_alias(name)
    if not n or is_excluded(n):
        return None
    return DRUG_KEY_BY_NORM.get(n)

def tolerance_group_id(name: str) -> Optional[str]:
    n = normalize_name(name)
    return TOL_MEMBER_TO_GROUP.get(n)

print('✓ Loaded drug_interaction.yaml for tolerance')
print('  exclude_all:', len(EXCLUDE_ALL), 'exclude_neuro:', len(EXCLUDE_NEURO), 'aliases:', len(TOL_ALIAS_MAP), 'groups:', len(TOL_GROUPS), 'separate:', len(TOL_SEPARATE_SET))

✓ Loaded drug_interaction.yaml for tolerance
  exclude_all: 26 exclude_neuro: 4 aliases: 11 groups: 10 separate: 1


In [95]:
# Optional: Refresh categories from Supabase (drug_profiles table) to补 missing/updated categories in drugs.json
# - Uses SUPABASE_URL and SUPABASE_ANON_KEY from .env at workspace root
# - Does NOT print secrets
import os
import requests

def load_dotenv_simple(dotenv_path: Path) -> None:
    if not dotenv_path.exists():
        return
    for raw in dotenv_path.read_text(encoding='utf-8').splitlines():
        line = raw.strip()
        if not line or line.startswith('#') or '=' not in line:
            continue
        k, v = line.split('=', 1)
        k = k.strip()
        v = v.strip().strip('"').strip("'")
        # don't override existing env vars
        os.environ.setdefault(k, v)

# repo root is two levels up from this notebook folder: backend/ML/drug_tolerance_model
DOTENV_PATH = Path('..') / '..' / '..' / '.env'
DOTENV_PATH = DOTENV_PATH.resolve()
load_dotenv_simple(DOTENV_PATH)

SUPABASE_URL = os.environ.get('SUPABASE_URL')
SUPABASE_ANON_KEY = os.environ.get('SUPABASE_ANON_KEY')

def _postgrest_get(table: str, select: str, limit: int = 10000) -> list:
    if not SUPABASE_URL or not SUPABASE_ANON_KEY:
        return []
    url = SUPABASE_URL.rstrip('/') + f'/rest/v1/{table}'
    headers = {
        'apikey': SUPABASE_ANON_KEY,
        'Authorization': f'Bearer {SUPABASE_ANON_KEY}',
        'Accept': 'application/json',
    }
    params = {
        'select': select,
        'limit': str(limit),
    }
    resp = requests.get(url, headers=headers, params=params, timeout=20)
    if resp.status_code >= 400:
        raise RuntimeError(f'PostgREST error {resp.status_code}: {resp.text[:300]}')
    data = resp.json()
    return data if isinstance(data, list) else []

def _extract_row_key(row: dict) -> Optional[str]:
    # Try a few common column names
    for k in ['substance', 'substance_key', 'drug_key', 'drug', 'name', 'slug', 'id']:
        v = row.get(k)
        if isinstance(v, str) and v.strip():
            return v.strip()
    return None

def _normalize_categories(value) -> List[str]:
    if value is None:
        return []
    if isinstance(value, list):
        return [normalize_name(x) for x in value if isinstance(x, str) and x.strip()]
    if isinstance(value, str):
        # allow comma-separated strings
        parts = [p.strip() for p in value.split(',')]
        return [normalize_name(p) for p in parts if p]
    return []

DB_CATEGORIES_BY_NORM: Dict[str, List[str]] = {}
DB_FETCH_STATUS = 'skipped'
try:
    if SUPABASE_URL and SUPABASE_ANON_KEY:
        # Try a conservative select first; if schema differs, fall back to '*'
        rows = []
        try:
            rows = _postgrest_get('drug_profiles', 'substance,categories')
        except Exception:
            rows = _postgrest_get('drug_profiles', '*')
        for row in rows:
            if not isinstance(row, dict):
                continue
            key = _extract_row_key(row)
            if not key:
                continue
            cats = _normalize_categories(row.get('categories'))
            if not cats:
                continue
            n = apply_alias(key)
            if is_excluded(n):
                continue
            DB_CATEGORIES_BY_NORM[normalize_name(n)] = cats
        DB_FETCH_STATUS = f'ok ({len(DB_CATEGORIES_BY_NORM)} rows)'
    else:
        DB_FETCH_STATUS = 'missing SUPABASE_URL/SUPABASE_ANON_KEY'
except Exception as e:
    DB_FETCH_STATUS = f'error: {type(e).__name__}: {e}'
    DB_CATEGORIES_BY_NORM = {}

def get_categories_for(drugs_key: str, entry: dict) -> List[str]:
    """Return categories using DB override when available; otherwise fall back to drugs.json."""
    n = normalize_name(drugs_key)
    if n in DB_CATEGORIES_BY_NORM:
        return DB_CATEGORIES_BY_NORM[n]
    cats = entry.get('categories') or []
    return [normalize_name(c) for c in cats if isinstance(c, str)]

print('✓ Supabase category refresh:', DB_FETCH_STATUS)

✓ Supabase category refresh: ok (522 rows)


In [96]:
def norm(s: str) -> str:
    return (s or '').strip().lower()

def _iter_str_list(x) -> List[str]:
    if isinstance(x, list):
        return [norm(v) for v in x if isinstance(v, str)]
    return []

def extract_features(substance: str, entry: dict) -> Dict[str, float]:
    feats: Dict[str, float] = {}

    # Categories are the most stable structured signal we have (DB-refreshed when available).
    for c in get_categories_for(substance, entry):
        feats[f'cat:{c}'] = 1.0

    # PsychonautWiki effect keys (structured-ish, useful for mechanism hints)
    pwe = entry.get('pweffects') or {}
    if isinstance(pwe, dict):
        for k in pwe.keys():
            if isinstance(k, str):
                feats[f'pwe:{norm(k)}'] = 1.0

    # Formatted effects list (e.g., Sedative, Stimulation)
    for fx in _iter_str_list(entry.get('formatted_effects')):
        feats[f'fx:{fx}'] = 1.0

    # Minimal parsing from properties.avoid (warning-like field)
    props = entry.get('properties') or {}
    if isinstance(props, dict):
        avoid = props.get('avoid')
        if isinstance(avoid, str):
            a = norm(avoid)
            if 'cns depressant' in a:
                feats['warn:cns_depressant'] = 1.0
            if 'serotonergic' in a:
                feats['warn:serotonergic'] = 1.0
            if 'maoi' in a:
                feats['warn:maoi'] = 1.0

    return feats

print('✓ Feature extractor ready')

✓ Feature extractor ready


In [97]:
# Heuristic prior: fast rule-based mapping from categories/effects -> bucket weights
# These priors do most of the work; ML learns residual corrections against inspo.json.
PRIOR_KEYWORDS = {
    'stimulant': [
        'stimulant', 'dopamine', 'norepinephrine', 'adrenergic', 'amphetamine', 'cathinone',
        'pwe:stimulation', 'fx:stimulation',
    ],
    'serotonin_release': [
        'entactogen', 'empathogen', 'serotonin', 'mdma', 'mda',
        'pwe:empathy, love, and sociability enhancement',
        'fx:empathy',
        'warn:serotonergic',
    ],
    'serotonin_psychedelic': [
        'psychedelic', 'tryptamine', 'lysergamide', 'phenethylamine', '5-ht2a',
        'pwe:hallucinations', 'fx:hallucinations',
    ],
    'gaba': [
        'benzodiazepine', 'benzo', 'z-drug', 'depressant', 'sedative', 'gaba',
        'pwe:sedation', 'fx:sedative', 'fx:hypnotic',
        'warn:cns_depressant',
    ],
    'opioid': [
        'opioid', 'opiate', 'pwe:respiratory depression',
    ],
    'nmda': [
        'dissociative', 'nmda', 'ketamine',
        'pwe:dissociation',
    ],
    'cannabinoid': [
        'cannabinoid', 'thc', 'cannabis',
    ],
}

def heuristic_prior(substance: str, entry: dict) -> np.ndarray:
    feats = extract_features(substance, entry)
    # Token set: feature keys + raw cat values + raw name string
    tokens = set(feats.keys())
    tokens |= {k.split(':', 1)[1] for k in feats.keys() if k.startswith('cat:')}
    tokens.add(norm(substance))
    w = np.zeros(len(BUCKETS), dtype=float)
    for i, b in enumerate(BUCKETS):
        kws = PRIOR_KEYWORDS.get(b, [])
        score = 0
        for kw in kws:
            if kw in tokens:
                score += 2 if (kw.startswith('pwe:') or kw.startswith('fx:') or kw.startswith('warn:')) else 1
        # convert score -> weight
        if score >= 3:
            w[i] = 1.0
        elif score == 2:
            w[i] = 0.7
        elif score == 1:
            w[i] = 0.4
        else:
            w[i] = 0.0
    return w

print('✓ Heuristic prior ready')

✓ Heuristic prior ready


In [98]:
# Build supervised training set from inspo.json
inspo_substances = inspo.get('substances') or {}
train_names = sorted([k for k in inspo_substances.keys() if isinstance(k, str)])

# Optional manual mapping for inspo naming quirks -> drugs.json naming
INSPO_TO_DRUGS_ALIASES = {
    'dxm': 'dextromethorphan',
    'psilocybin': 'psilocin',
    'thc': 'cannabis',
}

def resolve_inspo_name_to_drugs_key(name: str) -> Optional[str]:
    n = normalize_name(name)
    n = INSPO_TO_DRUGS_ALIASES.get(n, n)
    n = apply_alias(n)
    if is_excluded(n):
        return None
    # direct hit
    direct = resolve_drugs_key(n)
    if direct is not None:
        return direct
    # group fallback: use canonical or any member that exists in drugs.json
    gid = tolerance_group_id(n)
    if gid is not None:
        g = TOL_GROUPS.get(gid) or {}
        canon = g.get('canonical')
        if isinstance(canon, str):
            hit = resolve_drugs_key(canon)
            if hit is not None:
                return hit
        for m in (g.get('members') or []):
            if isinstance(m, str):
                hit = resolve_drugs_key(m)
                if hit is not None:
                    return hit
    return None

def target_vector_from_inspo(substance: str) -> np.ndarray:
    obj = inspo_substances.get(substance) or {}
    nb = obj.get('neuro_buckets') or {}
    y = np.zeros(len(BUCKETS), dtype=float)
    if isinstance(nb, dict):
        for i, b in enumerate(BUCKETS):
            if b in nb and isinstance(nb[b], dict):
                w = nb[b].get('weight', 0.0)
                if isinstance(w, (int, float)):
                    y[i] = float(w)
    return np.clip(y, 0.0, 1.0)

X = []
y = []
prior = []
missing_from_drugs = []
excluded_from_training = []
for s in train_names:
    s_key = resolve_inspo_name_to_drugs_key(s)
    if s_key is None:
        # either excluded or not found
        if is_excluded(s):
            excluded_from_training.append(s)
        else:
            missing_from_drugs.append(s)
        continue
    entry = drugs_raw[s_key]
    X.append(extract_features(s_key, entry))
    y.append(target_vector_from_inspo(s))
    prior.append(heuristic_prior(s_key, entry))

X = list(X)
y = np.vstack(y) if y else np.zeros((0, len(BUCKETS)))
prior = np.vstack(prior) if prior else np.zeros((0, len(BUCKETS)))
print('✓ Training rows:', len(X))
print('Missing inspo entries in drugs.json:', missing_from_drugs)
print('Excluded inspo entries (yaml exclude):', excluded_from_training)
print('Target buckets:', BUCKETS)

✓ Training rows: 117
Missing inspo entries in drugs.json: ['4-pmc', 'alpha-pbp']
Excluded inspo entries (yaml exclude): ['2-nmc', 'apap', 'db-mdbp', 'oxiracetam']
Target buckets: ['stimulant', 'serotonin_release', 'serotonin_psychedelic', 'gaba', 'opioid', 'nmda', 'cannabinoid']


In [99]:
# Train residual model: predicts (target - heuristic_prior) from drugs.json features
# We disable intercept so unknown/unseen feature rows don't get a drifting baseline residual.
residual = y - prior

model: Pipeline = Pipeline([
    ('vec', DictVectorizer(sparse=True)),
    ('reg', MultiOutputRegressor(Ridge(alpha=2.0, fit_intercept=False, random_state=RANDOM_STATE))),
])
model.fit(X, residual)
print('✓ Trained residual model')

BUCKET_INDEX = {b: i for i, b in enumerate(BUCKETS)}

def is_benzodiazepine_category(drugs_key: str) -> bool:
    entry = drugs_raw.get(drugs_key) or {}
    cats = set(get_categories_for(drugs_key, entry))
    return ('benzodiazepine' in cats) or ('benzodiazepines' in cats)

def predict_weights(drugs_key: str) -> np.ndarray:
    k = normalize_name(drugs_key)
    if is_excluded(k):
        return np.zeros(len(BUCKETS), dtype=float)
    entry = drugs_raw.get(drugs_key) or {}
    p = heuristic_prior(drugs_key, entry)
    r = model.predict([extract_features(drugs_key, entry)])[0]
    w = np.clip(p + r, 0.0, 1.0)

    # Rule: if category includes Benzodiazepine, only allow GABA bucket
    if is_benzodiazepine_category(drugs_key) and 'gaba' in BUCKET_INDEX:
        gaba_i = BUCKET_INDEX['gaba']
        gaba_w = float(w[gaba_i])
        w2 = np.zeros(len(BUCKETS), dtype=float)
        w2[gaba_i] = np.clip(gaba_w if gaba_w > 0 else 0.7, 0.0, 1.0)
        w = w2

    if float(w.max()) <= 0.0:
        idx = int(np.argmax(p)) if float(p.max()) > 0 else 0
        w[idx] = 0.1
    return w

def eval_against_inspo() -> pd.DataFrame:
    rows = []
    for s in train_names:
        s_key = resolve_inspo_name_to_drugs_key(s)
        if s_key is None:
            continue
        y_true = target_vector_from_inspo(s)
        y_pred = predict_weights(s_key)
        row = {'substance': s, 'drugs_key': s_key, 'mae': float(np.mean(np.abs(y_true - y_pred)))}
        for i, b in enumerate(BUCKETS):
            row[f'true:{b}'] = float(y_true[i])
            row[f'pred:{b}'] = float(y_pred[i])
        rows.append(row)
    return pd.DataFrame(rows).sort_values('mae', ascending=False)

df_eval = eval_against_inspo()
print('Mean MAE vs inspo:', float(df_eval['mae'].mean()) if len(df_eval) else None)
df_eval.head(25)

✓ Trained residual model
Mean MAE vs inspo: 0.04932683479552811


Unnamed: 0,substance,drugs_key,mae,true:stimulant,pred:stimulant,true:serotonin_release,pred:serotonin_release,true:serotonin_psychedelic,pred:serotonin_psychedelic,true:gaba,pred:gaba,true:opioid,pred:opioid,true:nmda,pred:nmda,true:cannabinoid,pred:cannabinoid
65,dimemebfe,dimemebfe,0.28904,1.0,0.039195,0.0,0.395491,0.0,0.585536,0.0,0.0,0.0,0.021844,0.0,0.00055,0.0,0.059056
17,2-mppp,2-mppp,0.259301,0.0,0.888969,0.0,0.0,0.0,0.02919,0.0,0.006289,1.0,0.135831,0.0,0.003864,0.0,0.022629
91,marinol,marinol,0.222411,0.0,0.068499,0.0,0.0,0.0,0.095186,0.0,0.449037,0.0,0.046271,0.0,0.0,1.0,0.102116
92,mbdb,mbdb,0.180361,0.1,0.775769,0.9,0.429801,0.0,0.0,0.0,0.0,0.0,0.027276,0.0,0.014134,0.0,0.075151
43,5f-pb-22,5f-pb-22,0.175855,0.0,0.216125,0.0,0.009705,0.0,0.099742,0.0,0.0,0.0,0.052735,0.0,0.058429,1.0,0.205753
46,ab-chminaca,ab-chminaca,0.175855,0.0,0.216125,0.0,0.009705,0.0,0.099742,0.0,0.0,0.0,0.052735,0.0,0.058429,1.0,0.205753
93,mdai,mdai,0.17041,0.12,0.573522,1.0,0.547516,0.0,0.286863,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
76,ethylcathinone,ethylcathinone,0.149224,1.0,0.621955,0.0,0.153957,0.0,0.512564,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
77,fluorophenibut,fluorophenibut,0.137283,0.0,0.099635,0.0,0.0,0.0,0.022639,1.0,0.455411,0.0,0.058888,0.0,0.044605,0.0,0.190624
54,butyrfentanyl,butyrfentanyl,0.128726,0.0,0.06652,0.0,0.0,0.0,0.02318,0.0,0.528264,1.0,0.785041,0.0,0.0,0.0,0.068158


In [100]:
# Export: neuro_buckets for all substances (JSONB-ready)
import re
import yaml

DEFAULT_TOLERANCE_PARAMS = {
    'half_life_hours': 12.0,
    'active_threshold': 0.05,
    'standard_unit': {'value': 10.0, 'unit': 'mg'},
    'potency_multiplier': 1.0,
    'duration_multiplier': 1.2,
    'tolerance_gain_rate': 0.25,
    'tolerance_decay_days': 5.0,
}

def parse_half_life_hours_from_drugs(entry: dict) -> Optional[float]:
    """Best-effort parse of drugs.json properties['half-life'] into hours."""
    props = entry.get('properties') or {}
    if not isinstance(props, dict):
        return None
    hl = props.get('half-life') or props.get('half_life')
    if not isinstance(hl, str):
        return None
    s = hl.strip().lower()
    if not s:
        return None
    # Extract 1-2 numbers; if range, average
    nums = [float(x) for x in re.findall(r'\d+(?:\.\d+)?', s)[:2]]
    if not nums:
        return None
    value = nums[0] if len(nums) == 1 else (nums[0] + nums[1]) / 2.0
    # Unit handling
    if 'minute' in s or 'min' in s:
        return value / 60.0
    if 'day' in s:
        return value * 24.0
    if 'hour' in s or 'hr' in s:
        return value
    # Unknown unit: assume hours only if the string mentions half-life in hours-like context
    return None

def parse_standard_unit_from_drugs(entry: dict) -> Optional[Dict[str, Any]]:
    """
    Parse formatted_dose to find a 'Common' or 'Light' dose to use as standard unit.
    Returns {'value': float, 'unit': str} or None.
    """
    formatted = entry.get('formatted_dose') or {}
    if not isinstance(formatted, dict):
        return None
    
    # Priority: Oral -> Insufflated -> First Available
    roa_data = formatted.get('Oral')
    if not roa_data:
        roa_data = formatted.get('Insufflated')
    if not roa_data and formatted:
        roa_data = formatted[list(formatted.keys())[0]]
        
    if not isinstance(roa_data, dict):
        return None
        
    # Priority: Common -> Light -> Strong -> Threshold
    dose_str = roa_data.get('Common') or roa_data.get('Light') or roa_data.get('Strong') or roa_data.get('Threshold')
    if not isinstance(dose_str, str):
        return None
        
    # Regex to extract range and unit
    # Matches: "50-100ug", "10mg", "1.5ml", "1 - 2 g"
    # Group 1: Min, Group 2: Max (optional), Group 3: Unit
    match = re.search(r'([\d\.]+)(?:\s*-\s*([\d\.]+))?\s*([a-zA-Zµ]+)', dose_str)
    if not match:
        return None
        
    val_min = float(match.group(1))
    val_max = float(match.group(2)) if match.group(2) else val_min
    unit = match.group(3).lower()
    
    # Normalize Unit
    if unit in ['ug', 'µg', 'mcg']:
        unit = 'mcg'
    elif unit in ['g', 'gram', 'grams']:
        unit = 'g'
    elif unit in ['mg', 'milligram', 'milligrams']:
        unit = 'mg'
    elif unit in ['ml', 'milliliter', 'milliliters']:
        unit = 'ml'
    elif unit in ['oz']:
        unit = 'oz' # Keep as is, or convert? standard_unit usually implies metric. 
                     # But let's keep what's in drugs.json for now to allow user interpretation.
    
    avg_val = (val_min + val_max) / 2.0
    return {'value': round(avg_val, 4), 'unit': unit}

def build_export_substances() -> List[str]:
    """
    Build export list from drugs.json, applying YAML excludes and alias de-duplication.
    """
    out: List[str] = []
    seen = set()
    for k in sorted(drugs_raw.keys()):
        n = normalize_name(k)
        if not n or is_excluded(n):
            continue
        # alias de-dup: if this key is an alias and its target exists, skip this key
        if n in TOL_ALIAS_MAP:
            target = TOL_ALIAS_MAP[n]
            if target in DRUG_KEY_BY_NORM and not is_excluded(target):
                continue
        if n in seen:
            continue
        seen.add(n)
        out.append(DRUG_KEY_BY_NORM.get(n, k))
    return out

BUCKET_INDEX = {b: i for i, b in enumerate(BUCKETS)}

def weights_to_neuro_buckets(
    w: np.ndarray,
    *,
    threshold: float = 0.05,
    allowed_buckets: Optional[set] = None,
 ) -> Dict[str, dict]:
    """
    Convert weight vector -> neuro_buckets map.
    """
    out: Dict[str, dict] = {}
    eps = 0.001
    if allowed_buckets is None:
        idx_sorted = list(np.argsort(w)[::-1])
        for i in idx_sorted:
            wi = float(w[i])
            if wi >= threshold:
                b = BUCKETS[i]
                out[b] = {'weight': round(wi, 3), 'tolerance_type': b}
    else:
        # deterministic order: use BUCKETS order
        for b in BUCKETS:
            if b not in allowed_buckets:
                continue
            wi = float(w[BUCKET_INDEX[b]])
            if wi <= 0.0:
                wi = eps
            # keep present even if below threshold
            out[b] = {'weight': round(wi, 3), 'tolerance_type': b}

    if not out:
        # fallback: keep at least one bucket
        i = int(np.argmax(w))
        b = BUCKETS[i]
        out[b] = {'weight': round(max(float(w[i]), 0.1), 3), 'tolerance_type': b}
    return out

export_keys = build_export_substances()
print('✓ Export substances (after exclude + alias de-dup):', len(export_keys))

# Map Inspo keys to Drugs keys to help with overrides
DRUGS_KEY_TO_INSPO = {}
for inspo_name, inspo_data in (inspo.get('substances') or {}).items():
    d_key = resolve_inspo_name_to_drugs_key(inspo_name)
    if d_key:
        DRUGS_KEY_TO_INSPO[d_key] = inspo_data

# For each YAML group, choose a reference substance and compute its bucket *set*;
GROUP_ALLOWED_BUCKETS: Dict[str, set] = {}
for gid, g in TOL_GROUPS.items():
    members = [m for m in (g.get('members') or []) if isinstance(m, str)]
    canon = g.get('canonical') if isinstance(g.get('canonical'), str) else None
    ref = None
    if canon is not None and resolve_drugs_key(canon) in drugs_raw:
        ref = resolve_drugs_key(canon)
    if ref is None:
        for m in members:
            hit = resolve_drugs_key(m)
            if hit is not None:
                ref = hit
                break
    if ref is None:
        continue
    w_ref = predict_weights(ref)
    allowed = {BUCKETS[i] for i in range(len(BUCKETS)) if float(w_ref[i]) >= 0.05}
    if not allowed:
        allowed = {BUCKETS[int(np.argmax(w_ref))]}
    GROUP_ALLOWED_BUCKETS[gid] = allowed

def is_same_unit(u1, u2):
    return normalize_name(u1) == normalize_name(u2)

# ---- ANCHORING LOGIC START ----
RULES_FILE = Path('substances_rules.yaml')
DEFAULT_ANCHOR_WEIGHT = 1.0
if RULES_FILE.exists():
    with open(RULES_FILE, 'r', encoding='utf-8') as f:
        rules_data = yaml.safe_load(f) or {}
    DEFAULT_ANCHOR_WEIGHT = float(rules_data.get('inspo_anchor', {}).get('weight', 1.0))

print('Default anchor weight from rules:', DEFAULT_ANCHOR_WEIGHT)

# Explicitly running for these 3 weights as requested
ANCHOR_WEIGHTS_TO_RUN = [0.5, 1.0, 2.0]

def get_inspo_vector(inspo_data: dict) -> np.ndarray:
    nb = inspo_data.get('neuro_buckets') or {}
    y = np.zeros(len(BUCKETS), dtype=float)
    if isinstance(nb, dict):
        for i, b in enumerate(BUCKETS):
            if b in nb and isinstance(nb[b], dict):
                w = nb[b].get('weight', 0.0)
                if isinstance(w, (int, float)):
                    y[i] = float(w)
    return np.clip(y, 0.0, 1.0)
# ---- ANCHORING LOGIC END ----

for anchor_w in ANCHOR_WEIGHTS_TO_RUN:
    print(f'Generating output for anchor weight: {anchor_w}')
    
    payload = {
        'metadata': {
            'generated_at': pd.Timestamp.utcnow().isoformat(),
            'source_files': {
                'drugs_json': str(DRUGS_PATH),
                'baseline_json': str(BASELINE_PATH),
                'inspo_json': str(INSPO_PATH),
                'drug_interaction_yaml': str(YAML_PATH),
                'substances_rules_yaml': str(RULES_FILE) if RULES_FILE.exists() else None,
            },
            'buckets': BUCKETS,
            'yaml_config': {
                'exclude': sorted(TOL_EXCLUDE_SET),
                'aliases': dict(TOL_ALIAS_MAP),
                'groups': {gid: {'canonical': g.get('canonical'), 'members': g.get('members')} for gid, g in TOL_GROUPS.items()},
                'separate': sorted(TOL_SEPARATE_SET),
            },
            'default_tolerance_params': dict(DEFAULT_TOLERANCE_PARAMS),
            'inspo_anchor': {
                'weight': anchor_w,
                'formula': '(inspo * weight + ml) / (weight + 1)',
            },
            'notes': [
                'Bucket weights are inferred from drugs.json using a heuristic prior + residual ML fit to inspo.json.',
                'Globally anchored to inspo priors using weighted average.',
                'Group members share the same neuro-bucket keys; weights may differ.',
                'Standard units are derived from drugs.json (Common/Light dose) or inspo.json.',
            ],
        },
        'substances': {},
    }

    for s in export_keys:
        gid = tolerance_group_id(s)
        allowed = GROUP_ALLOWED_BUCKETS.get(gid) if gid is not None else None
        entry = drugs_raw.get(s) or {}
        
        # 1. Get Base ML Weight (Prior + Residual)
        w_ml = predict_weights(s)

        # 2. Apply Inspo Anchoring (if inspo has data)
        w_final = w_ml
        if s in DRUGS_KEY_TO_INSPO:
            w_inspo = get_inspo_vector(DRUGS_KEY_TO_INSPO[s])
            w_final = (w_inspo * anchor_w + w_ml) / (anchor_w + 1.0)
        
        # 3. Convert to buckets
        neuro = weights_to_neuro_buckets(w_final, allowed_buckets=allowed)

        params = dict(DEFAULT_TOLERANCE_PARAMS)
        
        # 1. Half-Life
        hl = parse_half_life_hours_from_drugs(entry)
        if isinstance(hl, (int, float)) and float(hl) > 0:
            params['half_life_hours'] = round(float(hl), 3)

        # 2. Standard Unit
        # Try parsing from drugs.json first
        derived_su = parse_standard_unit_from_drugs(entry)
        if derived_su:
            params['standard_unit'] = derived_su
        
        # Check inspo for missing or correction or params
        inspo_su = None
        if s in DRUGS_KEY_TO_INSPO:
            inspo_data = DRUGS_KEY_TO_INSPO[s]
            inspo_su = inspo_data.get('standard_unit')
            
            # Apply potency params from inspo if available
            for p_key in ['potency_multiplier', 'duration_multiplier', 'tolerance_gain_rate', 'tolerance_decay_days']:
                if p_key in inspo_data and isinstance(inspo_data[p_key], (int, float)):
                    params[p_key] = float(inspo_data[p_key])

        # Standard Unit Fallback / Override
        if inspo_su and isinstance(inspo_su, dict):
            if not derived_su:
                params['standard_unit'] = inspo_su
            else:
                # Enforce inspo for specific targets (Bupropion/Caffeine/MDMA/Ketamine)
                s_norm = normalize_name(s)
                if any(x in s_norm for x in ['bupropion', 'caffeine', 'mdma', 'ketamine']):
                    params['standard_unit'] = inspo_su
        
        # 3. Explicit Safety Overrides for High Impact Substances
        # Reduce potency multiplier for weak stimulants that are often used daily
        s_norm = normalize_name(s)
        if 'bupropion' in s_norm:
            # Reduce to 30% impact relative to Dexedrine (1.0)
            params['potency_multiplier'] = 0.3
        elif 'caffeine' in s_norm:
            # Reduce to 15% impact
            params['potency_multiplier'] = 0.15

        payload['substances'][s] = {
            'neuro_buckets': neuro,
            **params,
        }

    # Use unique filename per weight
    file_weight_suffix = str(anchor_w)
    out_file = OUTPUT_DIR / f'tolerance_neuro_buckets_{file_weight_suffix}.json'
    
    with open(out_file, 'w', encoding='utf-8') as f:
        json.dump(payload, f, ensure_ascii=False, indent=2)

    print('✓ Wrote', out_file.resolve())

print('Done generating all requested variations.')

✓ Export substances (after exclude + alias de-dup): 525
Default anchor weight from rules: 1.0
Generating output for anchor weight: 0.5
✓ Wrote C:\Users\USER\dev\code\mobile_drug_use_app\backend\ML\drug_tolerance_model\outputs\tolerance_neuro_buckets_0.5.json
Generating output for anchor weight: 1.0
✓ Wrote C:\Users\USER\dev\code\mobile_drug_use_app\backend\ML\drug_tolerance_model\outputs\tolerance_neuro_buckets_1.0.json
Generating output for anchor weight: 2.0
✓ Wrote C:\Users\USER\dev\code\mobile_drug_use_app\backend\ML\drug_tolerance_model\outputs\tolerance_neuro_buckets_2.0.json
Done generating all requested variations.
