# Tolerance Model: Neuro Bucket Inference

Uses:
- `../drugs.json` for per-substance metadata
- `baseline.json` for the canonical bucket set
- `inspo.json` as a reference target set to calibrate against

Outputs:
- `outputs/tolerance_neuro_buckets.json` (single JSON document; JSONB-ready)

In [23]:
from __future__ import annotations

import json
from pathlib import Path
from typing import Dict, List, Tuple, Optional

import numpy as np
import pandas as pd

from sklearn.feature_extraction import DictVectorizer
from sklearn.multioutput import MultiOutputRegressor
from sklearn.linear_model import Ridge
from sklearn.pipeline import Pipeline

RANDOM_STATE = 42
rng = np.random.default_rng(RANDOM_STATE)

HERE = Path.cwd()
DRUGS_PATH = Path('..') / 'drugs.json'
BASELINE_PATH = Path('baseline.json')
INSPO_PATH = Path('inspo.json')
# Shared canonicalization config (exclude/aliases/groups)
YAML_PATH = Path('..') / 'drug_interaction.yaml'
OUTPUT_DIR = Path('outputs')
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_JSON = OUTPUT_DIR / 'tolerance_neuro_buckets.json'

print('DRUGS_PATH:', DRUGS_PATH.resolve())
print('BASELINE_PATH:', BASELINE_PATH.resolve())
print('INSPO_PATH:', INSPO_PATH.resolve())
print('YAML_PATH:', YAML_PATH.resolve())
print('OUTPUT_JSON:', OUTPUT_JSON.resolve())

DRUGS_PATH: C:\Users\USER\dev\code\mobile_drug_use_app\backend\ML\drugs.json
BASELINE_PATH: C:\Users\USER\dev\code\mobile_drug_use_app\backend\ML\drug_tolerance_model\baseline.json
INSPO_PATH: C:\Users\USER\dev\code\mobile_drug_use_app\backend\ML\drug_tolerance_model\inspo.json
YAML_PATH: C:\Users\USER\dev\code\mobile_drug_use_app\backend\ML\drug_interaction.yaml
OUTPUT_JSON: C:\Users\USER\dev\code\mobile_drug_use_app\backend\ML\drug_tolerance_model\outputs\tolerance_neuro_buckets.json


In [24]:
def load_json(path: Path) -> dict:
    with open(path, 'r', encoding='utf-8') as f:
        return json.load(f)

drugs_raw = load_json(DRUGS_PATH)
baseline = load_json(BASELINE_PATH)
inspo = load_json(INSPO_PATH)

BUCKETS: List[str] = list((baseline.get('buckets') or {}).keys())
if not BUCKETS:
    raise ValueError('No buckets found in baseline.json')

print('✓ Loaded drugs:', len(drugs_raw))
print('✓ Loaded baseline buckets:', BUCKETS)
print('✓ Loaded inspo substances:', len((inspo.get('substances') or {})))

✓ Loaded drugs: 551
✓ Loaded baseline buckets: ['stimulant', 'serotonin_release', 'serotonin_psychedelic', 'gaba', 'opioid', 'nmda', 'cannabinoid']
✓ Loaded inspo substances: 19


In [25]:
# Load shared YAML canonicalization (exclude/aliases/groups)
import yaml

def normalize_name(name: str) -> str:
    return (name or '').strip().lower()

def load_yaml(path: Path) -> dict:
    if not path.exists():
        return {}
    with open(path, 'r', encoding='utf-8') as f:
        obj = yaml.safe_load(f)
    return obj if isinstance(obj, dict) else {}

TOL_CFG = load_yaml(YAML_PATH)
TOL_EXCLUDE_SET = {normalize_name(x) for x in (TOL_CFG.get('exclude') or []) if isinstance(x, str)}
TOL_SEPARATE_SET = {normalize_name(x) for x in (TOL_CFG.get('separate') or []) if isinstance(x, str)}
TOL_ALIAS_MAP = {
    normalize_name(k): normalize_name(v)
    for k, v in (TOL_CFG.get('aliases') or {}).items()
    if isinstance(k, str) and isinstance(v, str)
}

# Groups: we do NOT merge for tolerance; we use them to share the same bucket *set* across members.
TOL_GROUPS: Dict[str, dict] = {}
TOL_MEMBER_TO_GROUP: Dict[str, str] = {}
for group_name, g in (TOL_CFG.get('groups') or {}).items():
    if not isinstance(g, dict):
        continue
    group_norm = normalize_name(group_name)
    canon = normalize_name(g.get('canonical', group_name))
    members = []
    for m in (g.get('members') or []):
        if isinstance(m, str):
            members.append(normalize_name(m))
    # also treat the group key itself as a member label
    if group_norm not in members:
        members.append(group_norm)
    if canon not in members:
        members.append(canon)

    TOL_GROUPS[group_norm] = {
        'canonical': canon,
        'members': sorted(set(members)),
    }
    for m in TOL_GROUPS[group_norm]['members']:
        # Respect separate: keep separate members un-grouped
        if m in TOL_SEPARATE_SET:
            continue
        TOL_MEMBER_TO_GROUP[m] = group_norm

# Map from normalized drugs.json key -> original key
DRUG_KEY_BY_NORM = {normalize_name(k): k for k in drugs_raw.keys()}

def is_excluded(name: str) -> bool:
    return normalize_name(name) in TOL_EXCLUDE_SET

def apply_alias(name: str) -> str:
    n = normalize_name(name)
    return TOL_ALIAS_MAP.get(n, n)

def resolve_drugs_key(name: str) -> Optional[str]:
    """Resolve name -> drugs.json key after aliasing; None if excluded."""
    n = apply_alias(name)
    if not n or is_excluded(n):
        return None
    return DRUG_KEY_BY_NORM.get(n)

def tolerance_group_id(name: str) -> Optional[str]:
    n = normalize_name(name)
    return TOL_MEMBER_TO_GROUP.get(n)

print('✓ Loaded drug_interaction.yaml for tolerance')
print('  exclude:', len(TOL_EXCLUDE_SET), 'aliases:', len(TOL_ALIAS_MAP), 'groups:', len(TOL_GROUPS), 'separate:', len(TOL_SEPARATE_SET))

✓ Loaded drug_interaction.yaml for tolerance
  exclude: 11 aliases: 11 groups: 10 separate: 1


In [26]:
def norm(s: str) -> str:
    return (s or '').strip().lower()

def _iter_str_list(x) -> List[str]:
    if isinstance(x, list):
        return [norm(v) for v in x if isinstance(v, str)]
    return []

def extract_features(substance: str, entry: dict) -> Dict[str, float]:
    feats: Dict[str, float] = {}

    # Categories are the most stable structured signal we have.
    for c in _iter_str_list(entry.get('categories')):
        feats[f'cat:{c}'] = 1.0

    # PsychonautWiki effect keys (structured-ish, useful for mechanism hints)
    pwe = entry.get('pweffects') or {}
    if isinstance(pwe, dict):
        for k in pwe.keys():
            if isinstance(k, str):
                feats[f'pwe:{norm(k)}'] = 1.0

    # Formatted effects list (e.g., Sedative, Stimulation)
    for fx in _iter_str_list(entry.get('formatted_effects')):
        feats[f'fx:{fx}'] = 1.0

    # Minimal parsing from properties.avoid (warning-like field)
    props = entry.get('properties') or {}
    if isinstance(props, dict):
        avoid = props.get('avoid')
        if isinstance(avoid, str):
            a = norm(avoid)
            if 'cns depressant' in a:
                feats['warn:cns_depressant'] = 1.0
            if 'serotonergic' in a:
                feats['warn:serotonergic'] = 1.0
            if 'maoi' in a:
                feats['warn:maoi'] = 1.0

    return feats

print('✓ Feature extractor ready')

✓ Feature extractor ready


In [27]:
# Heuristic prior: fast rule-based mapping from categories/effects -> bucket weights
# These priors do most of the work; ML learns residual corrections against inspo.json.
PRIOR_KEYWORDS = {
    'stimulant': [
        'stimulant', 'dopamine', 'norepinephrine', 'adrenergic', 'amphetamine', 'cathinone',
        'pwe:stimulation', 'fx:stimulation',
    ],
    'serotonin_release': [
        'entactogen', 'empathogen', 'serotonin', 'mdma', 'mda',
        'pwe:empathy, love, and sociability enhancement',
        'fx:empathy',
        'warn:serotonergic',
    ],
    'serotonin_psychedelic': [
        'psychedelic', 'tryptamine', 'lysergamide', 'phenethylamine', '5-ht2a',
        'pwe:hallucinations', 'fx:hallucinations',
    ],
    'gaba': [
        'benzodiazepine', 'benzo', 'z-drug', 'depressant', 'sedative', 'gaba',
        'pwe:sedation', 'fx:sedative', 'fx:hypnotic',
        'warn:cns_depressant',
    ],
    'opioid': [
        'opioid', 'opiate', 'pwe:respiratory depression',
    ],
    'nmda': [
        'dissociative', 'nmda', 'ketamine',
        'pwe:dissociation',
    ],
    'cannabinoid': [
        'cannabinoid', 'thc', 'cannabis',
    ],
}

def heuristic_prior(substance: str, entry: dict) -> np.ndarray:
    feats = extract_features(substance, entry)
    # Token set: feature keys + raw cat values + raw name string
    tokens = set(feats.keys())
    tokens |= {k.split(':', 1)[1] for k in feats.keys() if k.startswith('cat:')}
    tokens.add(norm(substance))
    w = np.zeros(len(BUCKETS), dtype=float)
    for i, b in enumerate(BUCKETS):
        kws = PRIOR_KEYWORDS.get(b, [])
        score = 0
        for kw in kws:
            if kw in tokens:
                score += 2 if (kw.startswith('pwe:') or kw.startswith('fx:') or kw.startswith('warn:')) else 1
        # convert score -> weight
        if score >= 3:
            w[i] = 1.0
        elif score == 2:
            w[i] = 0.7
        elif score == 1:
            w[i] = 0.4
        else:
            w[i] = 0.0
    return w

print('✓ Heuristic prior ready')

✓ Heuristic prior ready


In [28]:
# Build supervised training set from inspo.json
inspo_substances = inspo.get('substances') or {}
train_names = sorted([k for k in inspo_substances.keys() if isinstance(k, str)])

# Optional manual mapping for inspo naming quirks -> drugs.json naming
INSPO_TO_DRUGS_ALIASES = {
    'dxm': 'dextromethorphan',
    'psilocybin': 'psilocin',
    'thc': 'cannabis',
}

def resolve_inspo_name_to_drugs_key(name: str) -> Optional[str]:
    n = normalize_name(name)
    n = INSPO_TO_DRUGS_ALIASES.get(n, n)
    n = apply_alias(n)
    if is_excluded(n):
        return None
    # direct hit
    direct = resolve_drugs_key(n)
    if direct is not None:
        return direct
    # group fallback: use canonical or any member that exists in drugs.json
    gid = tolerance_group_id(n)
    if gid is not None:
        g = TOL_GROUPS.get(gid) or {}
        canon = g.get('canonical')
        if isinstance(canon, str):
            hit = resolve_drugs_key(canon)
            if hit is not None:
                return hit
        for m in (g.get('members') or []):
            if isinstance(m, str):
                hit = resolve_drugs_key(m)
                if hit is not None:
                    return hit
    return None

def target_vector_from_inspo(substance: str) -> np.ndarray:
    obj = inspo_substances.get(substance) or {}
    nb = obj.get('neuro_buckets') or {}
    y = np.zeros(len(BUCKETS), dtype=float)
    if isinstance(nb, dict):
        for i, b in enumerate(BUCKETS):
            if b in nb and isinstance(nb[b], dict):
                w = nb[b].get('weight', 0.0)
                if isinstance(w, (int, float)):
                    y[i] = float(w)
    return np.clip(y, 0.0, 1.0)

X = []
y = []
prior = []
missing_from_drugs = []
excluded_from_training = []
for s in train_names:
    s_key = resolve_inspo_name_to_drugs_key(s)
    if s_key is None:
        # either excluded or not found
        if is_excluded(s):
            excluded_from_training.append(s)
        else:
            missing_from_drugs.append(s)
        continue
    entry = drugs_raw[s_key]
    X.append(extract_features(s_key, entry))
    y.append(target_vector_from_inspo(s))
    prior.append(heuristic_prior(s_key, entry))

X = list(X)
y = np.vstack(y) if y else np.zeros((0, len(BUCKETS)))
prior = np.vstack(prior) if prior else np.zeros((0, len(BUCKETS)))
print('✓ Training rows:', len(X))
print('Missing inspo entries in drugs.json:', missing_from_drugs)
print('Excluded inspo entries (yaml exclude):', excluded_from_training)
print('Target buckets:', BUCKETS)

✓ Training rows: 18
Missing inspo entries in drugs.json: []
Excluded inspo entries (yaml exclude): ['melatonin']
Target buckets: ['stimulant', 'serotonin_release', 'serotonin_psychedelic', 'gaba', 'opioid', 'nmda', 'cannabinoid']


In [29]:
# Train residual model: predicts (target - heuristic_prior) from drugs.json features
# We disable intercept so unknown/unseen feature rows don't get a drifting baseline residual.
residual = y - prior

model: Pipeline = Pipeline([
    ('vec', DictVectorizer(sparse=True)),
    ('reg', MultiOutputRegressor(Ridge(alpha=2.0, fit_intercept=False, random_state=RANDOM_STATE))),
])
model.fit(X, residual)
print('✓ Trained residual model')

def predict_weights(drugs_key: str) -> np.ndarray:
    k = normalize_name(drugs_key)
    if is_excluded(k):
        return np.zeros(len(BUCKETS), dtype=float)
    entry = drugs_raw.get(drugs_key) or {}
    p = heuristic_prior(drugs_key, entry)
    r = model.predict([extract_features(drugs_key, entry)])[0]
    w = np.clip(p + r, 0.0, 1.0)
    if float(w.max()) <= 0.0:
        idx = int(np.argmax(p)) if float(p.max()) > 0 else 0
        w[idx] = 0.1
    return w

def eval_against_inspo() -> pd.DataFrame:
    rows = []
    for s in train_names:
        s_key = resolve_inspo_name_to_drugs_key(s)
        if s_key is None:
            continue
        y_true = target_vector_from_inspo(s)
        y_pred = predict_weights(s_key)
        row = {'substance': s, 'drugs_key': s_key, 'mae': float(np.mean(np.abs(y_true - y_pred)))}
        for i, b in enumerate(BUCKETS):
            row[f'true:{b}'] = float(y_true[i])
            row[f'pred:{b}'] = float(y_pred[i])
        rows.append(row)
    return pd.DataFrame(rows).sort_values('mae', ascending=False)

df_eval = eval_against_inspo()
print('Mean MAE vs inspo:', float(df_eval['mae'].mean()) if len(df_eval) else None)
df_eval.head(25)

✓ Trained residual model
Mean MAE vs inspo: 0.026984240605400736


Unnamed: 0,substance,drugs_key,mae,true:stimulant,pred:stimulant,true:serotonin_release,pred:serotonin_release,true:serotonin_psychedelic,pred:serotonin_psychedelic,true:gaba,pred:gaba,true:opioid,pred:opioid,true:nmda,pred:nmda,true:cannabinoid,pred:cannabinoid
11,mdma,mdma,0.09333,0.35,0.46855,1.0,0.786762,0.0,0.298442,0.0,0.017865,0.0,0.0,0.0,0.0,0.0,0.005214
4,dexedrine,dexedrine,0.082608,1.0,0.563269,0.0,0.101049,0.0,0.0,0.0,0.036802,0.0,0.0,0.0,0.000822,0.0,0.002854
16,psilocybin,psilocin,0.081186,0.0,0.0,0.0,0.225552,1.0,0.665767,0.0,0.0,0.0,0.005366,0.0,0.00087,0.0,0.002281
2,bupropion,bupropion,0.047384,0.2,0.459196,0.0,0.063022,0.0,0.0,0.0,0.0,0.0,0.005655,0.0,0.000134,0.0,0.00368
8,ghb,ghb,0.044425,0.0,0.046636,0.0,0.02377,0.0,0.0,1.0,0.763992,0.0,0.0,0.0,0.00167,0.0,0.002894
14,morphine,morphine,0.020259,0.0,0.0147,0.0,0.0,0.0,0.0,0.0,0.099216,1.0,0.986169,0.0,0.013488,0.0,0.000582
1,bromazolam,bromazolam,0.020016,0.0,0.036402,0.0,0.0,0.0,0.103709,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
0,alcohol,alcohol,0.016669,0.2,0.185954,0.0,0.038951,0.0,0.001126,0.9,0.884178,0.2,0.214352,0.6,0.568439,0.0,0.000828
12,mdpv,mdpv,0.014582,1.0,0.907213,0.0,0.0,0.0,0.008167,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.001119
17,thc,cannabis,0.014257,0.0,0.013833,0.0,0.000104,0.0,0.018256,0.0,0.041037,0.0,0.0,0.0,0.003043,1.0,0.976471


In [31]:
# Export: neuro_buckets for all substances (JSONB-ready)
import re

DEFAULT_TOLERANCE_PARAMS = {
    'half_life_hours': 12.0,
    'active_threshold': 0.05,
    'standard_unit_mg': 10.0,
    'potency_multiplier': 1.0,
    'duration_multiplier': 1.2,
    'tolerance_gain_rate': 0.25,
    'tolerance_decay_days': 5.0,
}

def parse_half_life_hours_from_drugs(entry: dict) -> Optional[float]:
    """Best-effort parse of drugs.json properties['half-life'] into hours."""
    props = entry.get('properties') or {}
    if not isinstance(props, dict):
        return None
    hl = props.get('half-life') or props.get('half_life')
    if not isinstance(hl, str):
        return None
    s = hl.strip().lower()
    if not s:
        return None
    # Extract 1-2 numbers; if range, average
    nums = [float(x) for x in re.findall(r'\d+(?:\.\d+)?', s)[:2]]
    if not nums:
        return None
    value = nums[0] if len(nums) == 1 else (nums[0] + nums[1]) / 2.0
    # Unit handling
    if 'minute' in s or 'min' in s:
        return value / 60.0
    if 'day' in s:
        return value * 24.0
    if 'hour' in s or 'hr' in s:
        return value
    # Unknown unit: assume hours only if the string mentions half-life in hours-like context
    return None

def build_export_substances() -> List[str]:
    """
    Build export list from drugs.json, applying YAML excludes and alias de-duplication.

    - Excluded substances are skipped
    - Alias keys are skipped if their alias target exists in drugs.json
    - Group members are NOT merged; they remain separate entries
    """
    out: List[str] = []
    seen = set()
    for k in sorted(drugs_raw.keys()):
        n = normalize_name(k)
        if not n or is_excluded(n):
            continue
        # alias de-dup: if this key is an alias and its target exists, skip this key
        if n in TOL_ALIAS_MAP:
            target = TOL_ALIAS_MAP[n]
            if target in DRUG_KEY_BY_NORM and not is_excluded(target):
                continue
        if n in seen:
            continue
        seen.add(n)
        out.append(DRUG_KEY_BY_NORM.get(n, k))
    return out

BUCKET_INDEX = {b: i for i, b in enumerate(BUCKETS)}

def weights_to_neuro_buckets(
    w: np.ndarray,
    *,
    threshold: float = 0.05,
    allowed_buckets: Optional[set] = None,
 ) -> Dict[str, dict]:
    """
    Convert weight vector -> neuro_buckets map.

    If allowed_buckets is provided, enforce the same bucket *set* for grouped substances,
    while allowing weights to differ. Buckets in allowed_buckets are included even if below threshold
    (floored slightly to keep the key present).
    """
    out: Dict[str, dict] = {}
    eps = 0.001
    if allowed_buckets is None:
        idx_sorted = list(np.argsort(w)[::-1])
        for i in idx_sorted:
            wi = float(w[i])
            if wi >= threshold:
                b = BUCKETS[i]
                out[b] = {'weight': round(wi, 3), 'tolerance_type': b}
    else:
        # deterministic order: use BUCKETS order
        for b in BUCKETS:
            if b not in allowed_buckets:
                continue
            wi = float(w[BUCKET_INDEX[b]])
            if wi <= 0.0:
                wi = eps
            # keep present even if below threshold
            out[b] = {'weight': round(wi, 3), 'tolerance_type': b}

    if not out:
        # fallback: keep at least one bucket
        i = int(np.argmax(w))
        b = BUCKETS[i]
        out[b] = {'weight': round(max(float(w[i]), 0.1), 3), 'tolerance_type': b}
    return out

export_keys = build_export_substances()
print('✓ Export substances (after exclude + alias de-dup):', len(export_keys))

# For each YAML group, choose a reference substance and compute its bucket *set*;
# all members will share this set (weights may differ).
GROUP_ALLOWED_BUCKETS: Dict[str, set] = {}
for gid, g in TOL_GROUPS.items():
    members = [m for m in (g.get('members') or []) if isinstance(m, str)]
    canon = g.get('canonical') if isinstance(g.get('canonical'), str) else None
    # pick reference: canonical if present, else first present member
    ref = None
    if canon is not None and resolve_drugs_key(canon) in drugs_raw:
        ref = resolve_drugs_key(canon)
    if ref is None:
        for m in members:
            hit = resolve_drugs_key(m)
            if hit is not None:
                ref = hit
                break
    if ref is None:
        continue
    w_ref = predict_weights(ref)
    # determine bucket set from reference using normal thresholding
    allowed = {BUCKETS[i] for i in range(len(BUCKETS)) if float(w_ref[i]) >= 0.05}
    if not allowed:
        allowed = {BUCKETS[int(np.argmax(w_ref))]}
    GROUP_ALLOWED_BUCKETS[gid] = allowed

payload = {
    'metadata': {
        'generated_at': pd.Timestamp.utcnow().isoformat(),
        'source_files': {
            'drugs_json': str(DRUGS_PATH),
            'baseline_json': str(BASELINE_PATH),
            'inspo_json': str(INSPO_PATH),
            'drug_interaction_yaml': str(YAML_PATH),
        },
        'buckets': BUCKETS,
        'yaml_config': {
            'exclude': sorted(TOL_EXCLUDE_SET),
            'aliases': dict(TOL_ALIAS_MAP),
            'groups': {gid: {'canonical': g.get('canonical'), 'members': g.get('members')} for gid, g in TOL_GROUPS.items()},
            'separate': sorted(TOL_SEPARATE_SET),
        },
        'default_tolerance_params': dict(DEFAULT_TOLERANCE_PARAMS),
        'notes': [
            'Bucket weights are inferred from drugs.json using a heuristic prior + residual ML fit to inspo.json.',
            'Excluded substances (yaml) are omitted from the export.',
            'Alias keys are de-duplicated (only the alias target is exported when present).',
            'Group members share the same neuro-bucket keys; weights may differ per substance.',
            'Tolerance parameters are exported per substance; half_life_hours is parsed from drugs.json when available, otherwise defaults apply.',
            'This is a single JSON document suitable for Postgres JSONB.',
        ],
    },
    'substances': {},
}

for s in export_keys:
    gid = tolerance_group_id(s)
    allowed = GROUP_ALLOWED_BUCKETS.get(gid) if gid is not None else None
    entry = drugs_raw.get(s) or {}
    w = predict_weights(s)
    neuro = weights_to_neuro_buckets(w, allowed_buckets=allowed)

    params = dict(DEFAULT_TOLERANCE_PARAMS)
    hl = parse_half_life_hours_from_drugs(entry)
    if isinstance(hl, (int, float)) and float(hl) > 0:
        params['half_life_hours'] = round(float(hl), 3)

    payload['substances'][s] = {
        'neuro_buckets': neuro,
        **params,
    }

with open(OUTPUT_JSON, 'w', encoding='utf-8') as f:
    json.dump(payload, f, ensure_ascii=False, indent=2)

print('✓ Wrote', OUTPUT_JSON.resolve())
print('Substances exported:', len(payload['substances']))

missing = [k for k, v in payload['substances'].items() if not (v.get('neuro_buckets') or {})]
print('Missing neuro_buckets:', len(missing))

✓ Export substances (after exclude + alias de-dup): 538
✓ Wrote C:\Users\USER\dev\code\mobile_drug_use_app\backend\ML\drug_tolerance_model\outputs\tolerance_neuro_buckets.json
Substances exported: 538
Missing neuro_buckets: 0
