# Hybrid Drug–Drug Interaction Risk Model (Rules + ML)

This notebook builds a safety-critical hybrid system:

1) **Hard rule engine first** (non-negotiable overrides)
2) **ML classifier second** (trained from `combos.json`)

Outputs are **exactly** one of:
- `Low Risk & Synergy`
- `Low Risk & No Synergy`
- `Low Risk & Decrease`
- `Caution`
- `Unsafe`
- `Dangerous`

In [18]:
# Imports
from __future__ import annotations

import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
import yaml

from sklearn.model_selection import train_test_split
from sklearn.feature_extraction import DictVectorizer
from sklearn.linear_model import LogisticRegression
from sklearn.pipeline import Pipeline
from sklearn.metrics import classification_report

RANDOM_STATE = 42
np.random.seed(RANDOM_STATE)
print('✓ Imports loaded')

✓ Imports loaded


In [19]:
# Paths
HERE = Path.cwd()
# This notebook lives at backend/ML/drug_interaction_model/interaction_model.ipynb
# so ../drugs.json and ../combos.json are correct relative paths.
DRUGS_PATH = Path('..') / 'drugs.json'
COMBOS_PATH = Path('..') / 'combos.json'
YAML_PATH = Path('drug_interaction.yaml')

print('DRUGS_PATH:', DRUGS_PATH.resolve())
print('COMBOS_PATH:', COMBOS_PATH.resolve())
print('YAML_PATH:', YAML_PATH.resolve())
print('✓ Paths set')

DRUGS_PATH: C:\Users\USER\dev\code\mobile_drug_use_app\backend\ML\drugs.json
COMBOS_PATH: C:\Users\USER\dev\code\mobile_drug_use_app\backend\ML\combos.json
YAML_PATH: C:\Users\USER\dev\code\mobile_drug_use_app\backend\ML\drug_interaction_model\drug_interaction.yaml
✓ Paths set


In [3]:
# Fixed output classes (enum-like)
RISK_CLASSES = [
    'Low Risk & Synergy',
    'Low Risk & No Synergy',
    'Low Risk & Decrease',
    'Caution',
    'Unsafe',
    'Dangerous',
]
RISK_CLASS_SET = set(RISK_CLASSES)

# For guardrails: ensure we never downgrade rule-based results
RISK_SEVERITY = {
    'Low Risk & Synergy': 0,
    'Low Risk & No Synergy': 0,
    'Low Risk & Decrease': 0,
    'Caution': 1,
    'Unsafe': 2,
    'Dangerous': 3,
}

def max_risk(a: str, b: str) -> str:
    if a not in RISK_SEVERITY or b not in RISK_SEVERITY:
        raise ValueError(f'Unknown risk class: {a} or {b}')
    return a if RISK_SEVERITY[a] >= RISK_SEVERITY[b] else b

print('✓ Risk classes configured')

✓ Risk classes configured


In [20]:
def load_json(path: Path) -> dict:
    with open(path, 'r', encoding='utf-8') as f:
        return json.load(f)

drugs_raw = load_json(DRUGS_PATH)
combos_raw = load_json(COMBOS_PATH)

print(f'✓ Loaded drugs.json entries: {len(drugs_raw)}')
print(f'✓ Loaded combos.json root keys: {len(combos_raw)}')

# quick peek
some_drug = next(iter(drugs_raw.keys()))
some_combo = next(iter(combos_raw.keys()))
print('Example drug key:', some_drug)
print('Example combo key:', some_combo)

✓ Loaded drugs.json entries: 551
✓ Loaded combos.json root keys: 31
Example drug key: 1,4-butanediol
Example combo key: 2c-t-x


In [29]:
def normalize_name(name: str) -> str:
    return (name or '').strip().lower()

def load_yaml(path: Path) -> dict:
    with open(path, 'r', encoding='utf-8') as f:
        obj = yaml.safe_load(f)
    return obj if isinstance(obj, dict) else {}

INTERACTION_CFG = load_yaml(YAML_PATH) if YAML_PATH.exists() else {}

EXCLUDE_SET = {normalize_name(x) for x in (INTERACTION_CFG.get('exclude') or []) if isinstance(x, str)}
SEPARATE_SET = {normalize_name(x) for x in (INTERACTION_CFG.get('separate') or []) if isinstance(x, str)}
ALIASES = {normalize_name(k): normalize_name(v) for k, v in (INTERACTION_CFG.get('aliases') or {}).items() if isinstance(k, str) and isinstance(v, str)}

# Groups: map member -> canonical (unless member is in SEPARATE_SET). Also map the group key name itself -> canonical.
GROUP_MEMBERS: Dict[str, set] = {}
GROUP_CANONICAL_BY_MEMBER: Dict[str, str] = {}
for group_name, g in (INTERACTION_CFG.get('groups') or {}).items():
    if not isinstance(g, dict):
        continue
    group_norm = normalize_name(group_name)
    canon = normalize_name(g.get('canonical', group_name))
    members = set()
    for m in (g.get('members') or []):
        if isinstance(m, str):
            members.add(normalize_name(m))
    members.add(canon)
    members.add(group_norm)
    GROUP_MEMBERS[group_norm] = members
    # group name itself -> canonical
    if group_norm not in SEPARATE_SET:
        GROUP_CANONICAL_BY_MEMBER[group_norm] = canon
    for m in members:
        if m in SEPARATE_SET:
            continue
        GROUP_CANONICAL_BY_MEMBER[m] = canon

def is_excluded(name: str) -> bool:
    return normalize_name(name) in EXCLUDE_SET

def canonicalize_name(name: str) -> str:
    """Apply aliases + group merging; respect exclude + separate."""
    n = normalize_name(name)
    if not n:
        return n
    # First, aliases
    n = ALIASES.get(n, n)
    # Keep 'separate' as-is (even if it appears in groups)
    if n in SEPARATE_SET:
        return n
    # Then group merge
    n = GROUP_CANONICAL_BY_MEMBER.get(n, n)
    return n

def in_group(name: str, group: str) -> bool:
    g = normalize_name(group)
    members = GROUP_MEMBERS.get(g)
    if not members:
        return False
    return canonicalize_name(name) in members

def drug_categories(drug_key: str, drugs: dict) -> List[str]:
    entry = drugs.get(drug_key) or {}
    cats = entry.get('categories') or []
    return [normalize_name(c) for c in cats if isinstance(c, str)]

# Map from normalized drugs.json key -> original key
DRUG_KEY_BY_NORM = {normalize_name(k): k for k in drugs_raw.keys()}

def resolve_drug_key(name: str) -> Optional[str]:
    canon = canonicalize_name(name)
    if is_excluded(canon):
        return None
    return DRUG_KEY_BY_NORM.get(canon) or DRUG_KEY_BY_NORM.get(normalize_name(name))

print('✓ Loaded drug_interaction.yaml config')
print('  exclude:', len(EXCLUDE_SET), 'aliases:', len(ALIASES), 'groups:', len(GROUP_MEMBERS), 'separate:', len(SEPARATE_SET))

# sanity checks for group-ish keys used by combos.json
for k in ['benzodiazepines', 'diazepam', 'opioids', 'alcohol', 'ghb', 'ghb/gbl', 'maois', 'ssris', 'ssri', 'amphetamines', 'amphetamine']:
    print(k, '->', resolve_drug_key(k), 'canon=', canonicalize_name(k))

✓ Loaded drug_interaction.yaml config
  exclude: 11 aliases: 11 groups: 10 separate: 1
benzodiazepines -> None canon= benzodiazepines
diazepam -> diazepam canon= benzodiazepines
opioids -> None canon= opioids
alcohol -> alcohol canon= alcohol
ghb -> ghb canon= ghb
ghb/gbl -> None canon= ghb/gbl
maois -> None canon= maois
ssris -> None canon= ssri
ssri -> None canon= ssri
amphetamines -> amphetamine canon= amphetamine
amphetamine -> amphetamine canon= amphetamine


## Hard Rule Engine (runs first)

Non-negotiable guardrails:
- Benzodiazepines + opioids → **Dangerous**
- Two CNS depressants (benzos, opioids, alcohol, GHB/GBL) → **minimum Unsafe**
- MAOI + serotonergic drugs → **Dangerous**
- Opioid + opioid → **minimum Unsafe**
- GABAergic + GABAergic → **minimum Unsafe**

If any rule triggers, we return immediately and skip ML.

In [22]:
@dataclass(frozen=True)
class RuleResult:
    risk: str
    reason: str

def has_any(categories: List[str], needles: List[str]) -> bool:
    s = set(categories)
    return any(n in s for n in needles)

def _identity_text(drug_key: str, drugs: dict) -> str:
    """Conservative: only use identity + declared categories (avoid scanning free-text that may mention other drugs)."""
    entry = drugs.get(drug_key) or {}
    cats = drug_categories(drug_key, drugs)
    parts = [normalize_name(drug_key), normalize_name(entry.get('name', '')), normalize_name(entry.get('pretty_name', ''))]
    parts.extend(cats)
    return ' '.join([p for p in parts if p])

def _pweffects_keys_text(drug_key: str, drugs: dict) -> str:
    entry = drugs.get(drug_key) or {}
    pwe = entry.get('pweffects') or {}
    if isinstance(pwe, dict):
        return ' '.join([normalize_name(k) for k in pwe.keys()])
    return ''

def _has_token(hay: str, tokens: List[str]) -> bool:
    hay = hay.lower()
    return any(normalize_name(t) in hay for t in tokens)

def apply_hard_rules(drug_a: str, drug_b: str, drugs: dict) -> Optional[RuleResult]:
    # Canonicalize first (aliases + group merging); keep 'separate' distinct
    a_c = canonicalize_name(drug_a)
    b_c = canonicalize_name(drug_b)

    if is_excluded(a_c) or is_excluded(b_c):
        return None

    # Use drugs.json key if present; otherwise keep canonical string key
    a_key = resolve_drug_key(a_c) or a_c
    b_key = resolve_drug_key(b_c) or b_c

    cats_a = drug_categories(a_key, drugs)
    cats_b = drug_categories(b_key, drugs)

    id_a = _identity_text(a_key, drugs)
    id_b = _identity_text(b_key, drugs)
    pwe_a = _pweffects_keys_text(a_key, drugs)
    pwe_b = _pweffects_keys_text(b_key, drugs)

    # Prefer YAML groups (more reliable than free-text/categories when present)
    is_benzo_a = in_group(a_key, 'benzodiazepines') or has_any(cats_a, ['benzodiazepine', 'benzodiazepines']) or _has_token(id_a, ['benzodiazepine', 'benzodiazepines'])
    is_benzo_b = in_group(b_key, 'benzodiazepines') or has_any(cats_b, ['benzodiazepine', 'benzodiazepines']) or _has_token(id_b, ['benzodiazepine', 'benzodiazepines'])
    is_opioid_a = in_group(a_key, 'opioids') or has_any(cats_a, ['opioid', 'opioids']) or _has_token(id_a, ['opioid', 'opioids', 'opiate'])
    is_opioid_b = in_group(b_key, 'opioids') or has_any(cats_b, ['opioid', 'opioids']) or _has_token(id_b, ['opioid', 'opioids', 'opiate'])
    is_ghb_a = in_group(a_key, 'ghb') or normalize_name(a_key) in ['ghb', 'ghb/gbl', 'gbl']
    is_ghb_b = in_group(b_key, 'ghb') or normalize_name(b_key) in ['ghb', 'ghb/gbl', 'gbl']
    is_alcohol_a = normalize_name(a_key) == 'alcohol'
    is_alcohol_b = normalize_name(b_key) == 'alcohol'

    # MAOI is not in the YAML yet; keep conservative detection
    is_maoi_a = normalize_name(a_key) in ['maois', 'maoi'] or has_any(cats_a, ['maoi', 'maois']) or _has_token(id_a, ['maoi', 'maois'])
    is_maoi_b = normalize_name(b_key) in ['maois', 'maoi'] or has_any(cats_b, ['maoi', 'maois']) or _has_token(id_b, ['maoi', 'maois'])

    # Serotonergic heuristic: SSRI group OR explicit serotonin mechanisms OR empathogen
    is_ssri_a = in_group(a_key, 'ssris') or normalize_name(a_key) in ['ssri', 'ssris']
    is_ssri_b = in_group(b_key, 'ssris') or normalize_name(b_key) in ['ssri', 'ssris']
    is_serotonergic_a = (is_ssri_a
                         or has_any(cats_a, ['serotonergic', 'empathogen', 'entactogen'])
                         or _has_token(pwe_a, ['serotonin', '5-ht', '5ht']))
    is_serotonergic_b = (is_ssri_b
                         or has_any(cats_b, ['serotonergic', 'empathogen', 'entactogen'])
                         or _has_token(pwe_b, ['serotonin', '5-ht', '5ht']))

    # GABAergic heuristic: benzos/alcohol/GHB or explicit GABA mechanisms
    is_gabaergic_a = is_benzo_a or is_alcohol_a or is_ghb_a or has_any(cats_a, ['gaba', 'gabaergic', 'depressant']) or _has_token(pwe_a, ['gaba'])
    is_gabaergic_b = is_benzo_b or is_alcohol_b or is_ghb_b or has_any(cats_b, ['gaba', 'gabaergic', 'depressant']) or _has_token(pwe_b, ['gaba'])

    # CNS depressants: benzos/opioids/alcohol/GHB or declared depressant
    is_cns_dep_a = is_benzo_a or is_opioid_a or is_alcohol_a or is_ghb_a or has_any(cats_a, ['depressant'])
    is_cns_dep_b = is_benzo_b or is_opioid_b or is_alcohol_b or is_ghb_b or has_any(cats_b, ['depressant'])

    # 1) Benzodiazepines + opioids → always Dangerous
    if (is_benzo_a and is_opioid_b) or (is_benzo_b and is_opioid_a):
        return RuleResult('Dangerous', 'Hard rule: Benzodiazepines + opioids')

    # 2) MAOI + serotonergic → always Dangerous
    if (is_maoi_a and is_serotonergic_b) or (is_maoi_b and is_serotonergic_a):
        return RuleResult('Dangerous', 'Hard rule: MAOI + serotonergic')

    # 3) Opioid + opioid → minimum Unsafe
    if is_opioid_a and is_opioid_b:
        return RuleResult('Unsafe', 'Hard rule: Opioid + opioid (min Unsafe)')

    # 4) GABAergic + GABAergic → minimum Unsafe
    if is_gabaergic_a and is_gabaergic_b:
        return RuleResult('Unsafe', 'Hard rule: GABAergic + GABAergic (min Unsafe)')

    # 5) Two CNS depressants → minimum Unsafe
    if is_cns_dep_a and is_cns_dep_b:
        return RuleResult('Unsafe', 'Hard rule: CNS depressant stacking (min Unsafe)')

    return None

print('✓ Rule engine ready')

✓ Rule engine ready


## Training Schema from combos.json

`combos.json` is a nested mapping: `combos[a][b].status`.
We flatten it into a dataset of `(drug_a, drug_b, label)` and derive features using `drugs.json` categories.

In [30]:
def flatten_combos(combos: dict) -> List[Tuple[str, str, str]]:
    rows: List[Tuple[str, str, str]] = []
    for a, inner in combos.items():
        if not isinstance(inner, dict):
            continue
        for b, obj in inner.items():
            if not isinstance(obj, dict):
                continue
            status = obj.get('status')
            if not isinstance(status, str):
                continue
            a_c = canonicalize_name(a)
            b_c = canonicalize_name(b)
            if not a_c or not b_c:
                continue
            if is_excluded(a_c) or is_excluded(b_c):
                continue
            if a_c == b_c:
                continue
            rows.append((a_c, b_c, status.strip()))
    return rows

pairs = flatten_combos(combos_raw)
print('✓ Flattened labeled pairs (after canonicalization/exclude):', len(pairs))

# keep only the fixed classes
pairs = [(a, b, y) for (a, b, y) in pairs if y in RISK_CLASS_SET]
print('✓ Kept pairs with valid labels:', len(pairs))

# Merge any conflicts introduced by canonicalization by keeping the most severe label per undirected pair
merged: Dict[Tuple[str, str], str] = {}
conflicts = 0
for a, b, y in pairs:
    k = tuple(sorted([a, b]))
    if k not in merged:
        merged[k] = y
        continue
    if merged[k] != y:
        conflicts += 1
        merged[k] = max_risk(merged[k], y)
pairs = [(a, b, y) for (a, b), y in merged.items()]
print('✓ Undirected merged pairs:', len(pairs), '| conflicts resolved:', conflicts)

df_pairs = pd.DataFrame(pairs, columns=['drug_a', 'drug_b', 'label'])
print(df_pairs['label'].value_counts())
df_pairs.head()

✓ Flattened labeled pairs (after canonicalization/exclude): 839
✓ Kept pairs with valid labels: 839
✓ Undirected merged pairs: 392 | conflicts resolved: 34
label
Caution                  102
Low Risk & Synergy        97
Dangerous                 74
Low Risk & Decrease       47
Unsafe                    44
Low Risk & No Synergy     28
Name: count, dtype: int64


Unnamed: 0,drug_a,drug_b,label
0,2c-t-x,2c-x,Caution
1,2c-t-x,5-meo-xxt,Caution
2,2c-t-x,alcohol,Low Risk & Decrease
3,2c-t-x,amphetamine,Unsafe
4,2c-t-x,amt,Dangerous


In [31]:
def make_features(drug_a: str, drug_b: str, drugs: dict) -> Dict[str, float]:
    a_c = canonicalize_name(drug_a)
    b_c = canonicalize_name(drug_b)

    a_key = resolve_drug_key(a_c) or a_c
    b_key = resolve_drug_key(b_c) or b_c

    cats_a = set(drug_categories(a_key, drugs))
    cats_b = set(drug_categories(b_key, drugs))

    feats: Dict[str, float] = {}
    # category indicator features
    for c in cats_a:
        feats[f'a_cat:{c}'] = 1.0
    for c in cats_b:
        feats[f'b_cat:{c}'] = 1.0
    for c in (cats_a & cats_b):
        feats[f'both_cat:{c}'] = 1.0

    # pair identity (learn frequent exact pairs; use canonicalized names)
    feats[f'pair:{a_c}|{b_c}'] = 1.0
    feats[f'pair:{b_c}|{a_c}'] = 1.0

    # archetype flags (rules still override; these just help ML when rules don't trigger)
    feats['is_benzo_opioid'] = float(
        (in_group(a_key, 'benzodiazepines') and in_group(b_key, 'opioids'))
        or (in_group(b_key, 'benzodiazepines') and in_group(a_key, 'opioids'))
    )
    feats['is_cns_dep_stack'] = float(
        (in_group(a_key, 'benzodiazepines') or in_group(a_key, 'opioids') or normalize_name(a_key) == 'alcohol' or in_group(a_key, 'ghb'))
        and (in_group(b_key, 'benzodiazepines') or in_group(b_key, 'opioids') or normalize_name(b_key) == 'alcohol' or in_group(b_key, 'ghb'))
    )

    return feats

# Build training matrices
X_dicts = [make_features(a, b, drugs_raw) for (a, b, _) in pairs]
y = [label for (_, _, label) in pairs]

print('✓ Feature dicts:', len(X_dicts))
print('✓ Example feature keys:', list(X_dicts[0].keys())[:10])

✓ Feature dicts: 392
✓ Example feature keys: ['pair:2c-t-x|2c-x', 'pair:2c-x|2c-t-x', 'is_benzo_opioid', 'is_cns_dep_stack']


## ML Classifier (runs only if rules don’t trigger)

We use a deterministic multiclass `LogisticRegression` (good interpretability + fast).

In [32]:
# Train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X_dicts, y, test_size=0.2, random_state=RANDOM_STATE, stratify=y
)

clf: Pipeline = Pipeline([
    ('vec', DictVectorizer(sparse=True)),
    ('lr', LogisticRegression(
        solver='lbfgs',
        max_iter=500,
        random_state=RANDOM_STATE,
        class_weight='balanced',
    )),
])

clf.fit(X_train, y_train)
pred = clf.predict(X_test)
print('✓ Trained ML classifier')
print(classification_report(y_test, pred, labels=RISK_CLASSES, zero_division=0))

✓ Trained ML classifier
                       precision    recall  f1-score   support

   Low Risk & Synergy       0.75      0.32      0.44        19
Low Risk & No Synergy       0.40      0.33      0.36         6
  Low Risk & Decrease       0.25      0.78      0.38         9
              Caution       0.50      0.19      0.28        21
               Unsafe       0.09      0.22      0.13         9
            Dangerous       0.25      0.13      0.17        15

             accuracy                           0.29        79
            macro avg       0.37      0.33      0.29        79
         weighted avg       0.43      0.29      0.30        79



## Resolver: Rules + ML (safe composition)

- If a hard rule triggers: return immediately.
- Otherwise: use ML prediction.
- Always return one of the fixed strings.
- Provide an explanation: rule reason OR top contributing features (log-reg coefficients × feature values).

In [26]:
def explain_logreg_top_features(pipeline: Pipeline, x_dict: Dict[str, float], top_k: int = 8) -> List[Tuple[str, float]]:
    vec: DictVectorizer = pipeline.named_steps['vec']
    lr: LogisticRegression = pipeline.named_steps['lr']

    X = vec.transform([x_dict])  # sparse row
    feature_names = np.array(vec.get_feature_names_out())

    proba = lr.predict_proba(X)[0]
    class_idx = int(np.argmax(proba))
    class_name = lr.classes_[class_idx]

    coef = lr.coef_[class_idx]
    X_dense = X.toarray()[0]
    contrib = coef * X_dense
    top_idx = np.argsort(np.abs(contrib))[::-1][:top_k]

    items = [(feature_names[i], float(contrib[i])) for i in top_idx if X_dense[i] != 0]
    return [(f'{class_name} :: {k}', v) for (k, v) in items]

def predict_hybrid(drug_a: str, drug_b: str, drugs: dict, model: Pipeline) -> dict:
    a_c = canonicalize_name(drug_a)
    b_c = canonicalize_name(drug_b)
    if is_excluded(a_c) or is_excluded(b_c):
        raise ValueError(f'Excluded substance in prediction: {drug_a} or {drug_b}')

    rr = apply_hard_rules(a_c, b_c, drugs)
    if rr is not None:
        if rr.risk not in RISK_CLASS_SET:
            raise ValueError(f'Rule produced invalid risk: {rr.risk}')
        return {
            'drug_a': a_c,
            'drug_b': b_c,
            'risk': rr.risk,
            'source': 'rules',
            'reason': rr.reason,
            'explanation': [rr.reason],
        }

    x = make_features(a_c, b_c, drugs)
    pred = model.predict([x])[0]
    if pred not in RISK_CLASS_SET:
        raise ValueError(f'Model produced invalid risk: {pred}')

    proba = model.predict_proba([x])[0]
    classes = list(model.named_steps['lr'].classes_)
    conf = float(np.max(proba))

    top = explain_logreg_top_features(model, x, top_k=10)
    return {
        'drug_a': a_c,
        'drug_b': b_c,
        'risk': pred,
        'source': 'ml',
        'confidence': conf,
        'class_probs': {classes[i]: float(proba[i]) for i in range(len(classes))},
        'explanation': [f'{k}: {v:+.4f}' for (k, v) in top],
    }

print('✓ Resolver ready')

✓ Resolver ready


In [27]:
# Sanity tests
tests = [
    # should trigger hard rules
    ('benzodiazepines', 'opioids'),
    ('alcohol', 'benzodiazepines'),
    ('opioids', 'opioids'),
    ('maois', 'ssris'),
    # should likely fall back to ML
    ('cannabis', 'caffeine'),
    ('lsd', 'cannabis'),
    ('mdma', 'alcohol'),
]

for a, b in tests:
    out = predict_hybrid(a, b, drugs_raw, clf)
    print('\n' + '='*72)
    print(f'{a} + {b} -> {out["risk"]}  (source={out["source"]})')
    if out['source'] == 'rules':
        print('Reason:', out['reason'])
    else:
        print('Confidence:', f"{out['confidence']:.3f}")
        for line in out['explanation'][:6]:
            print('  ', line)

print('\n✓ Sanity tests complete')


benzodiazepines + opioids -> Dangerous  (source=rules)
Reason: Hard rule: Benzodiazepines + opioids

alcohol + benzodiazepines -> Unsafe  (source=rules)
Reason: Hard rule: GABAergic + GABAergic (min Unsafe)

opioids + opioids -> Unsafe  (source=rules)
Reason: Hard rule: Opioid + opioid (min Unsafe)

maois + ssris -> Dangerous  (source=rules)
Reason: Hard rule: MAOI + serotonergic

cannabis + caffeine -> Low Risk & No Synergy  (source=ml)
Confidence: 0.582
   Low Risk & No Synergy :: b_cat:nootropic: +1.2879
   Low Risk & No Synergy :: a_cat:stimulant: +0.9851
   Low Risk & No Synergy :: a_cat:psychedelic: -0.7068
   Low Risk & No Synergy :: pair:cannabis|caffeine: +0.5998
   Low Risk & No Synergy :: pair:caffeine|cannabis: +0.5998
   Low Risk & No Synergy :: both_cat:habit-forming: -0.4804

lsd + cannabis -> Low Risk & Synergy  (source=ml)
Confidence: 0.591
   Low Risk & Synergy :: a_cat:psychedelic: +1.7090
   Low Risk & Synergy :: b_cat:psychedelic: +1.3133
   Low Risk & Synergy :: 

## Export: all pairwise combinations to JSON

This will evaluate **every substance vs every other substance** (undirected pairs) using the hybrid resolver and write a JSON file grouped by risk class.

Note: with ~551 substances, this is ~151k pairs and can take a few minutes to run.

In [33]:
from datetime import datetime, timezone

def list_export_substances(drugs: dict) -> List[str]:
    # Start from drugs.json keys, canonicalize, exclude, and make unique
    keys = set()
    for k in drugs.keys():
        c = canonicalize_name(k)
        if not c or is_excluded(c):
            continue
        keys.add(c)
    # Also include canonical group targets and alias targets (even if not present in drugs.json)
    for c in set(GROUP_CANONICAL_BY_MEMBER.values()):
        if c and not is_excluded(c):
            keys.add(c)
    for v in ALIASES.values():
        c = canonicalize_name(v)
        if c and not is_excluded(c):
            keys.add(c)
    return sorted(keys)

def export_all_pairwise_risks(
    output_path: Path = Path('outputs') / 'pairwise_risks_by_class.json',
    max_substances: Optional[int] = None,
    progress_every: int = 25,
    include_groups_only_if_present_in_drugs_json: bool = False,
 ):
    """
    Export all undirected pairs to JSON grouped by risk class.

    Output format:
    {
      "generated_at": "...",
      "n_substances": 551,
      "n_pairs": 151525,
      "counts": {"Caution": 123, ...},
      "pairs_by_risk": {
        "Dangerous": [["a","b"], ...],
        ...
      }
    }
    """
    keys = list_export_substances(drugs_raw)
    if include_groups_only_if_present_in_drugs_json:
        present = {normalize_name(k) for k in drugs_raw.keys()}
        keys = [k for k in keys if normalize_name(k) in present]
    if max_substances is not None:
        keys = keys[:max_substances]

    n = len(keys)
    total_pairs = n * (n - 1) // 2
    pairs_by_risk = {risk: [] for risk in RISK_CLASSES}

    output_path.parent.mkdir(parents=True, exist_ok=True)

    processed = 0
    started = datetime.now(timezone.utc)

    for i in range(n):
        a = keys[i]
        for j in range(i + 1, n):
            b = keys[j]
            risk = predict_hybrid(a, b, drugs_raw, clf)['risk']
            pairs_by_risk[risk].append([a, b])
            processed += 1

        if progress_every and (i + 1) % progress_every == 0:
            elapsed = (datetime.now(timezone.utc) - started).total_seconds()
            rate = processed / elapsed if elapsed > 0 else 0.0
            print(f'... {i+1}/{n} substances processed | {processed}/{total_pairs} pairs | {rate:,.0f} pairs/s')

    counts = {risk: len(pairs_by_risk[risk]) for risk in RISK_CLASSES}
    payload = {
        'generated_at': datetime.now(timezone.utc).isoformat(),
        'n_substances': n,
        'n_pairs': total_pairs,
        'risk_classes': list(RISK_CLASSES),
        'counts': counts,
        'pairs_by_risk': pairs_by_risk,
        'config': {
            'yaml_path': str(YAML_PATH),
            'exclude': sorted(EXCLUDE_SET),
            'aliases': ALIASES,
            'groups': {k: {'canonical': canonicalize_name(v.get('canonical', k))} for k, v in (INTERACTION_CFG.get('groups') or {}).items() if isinstance(v, dict)},
            'separate': sorted(SEPARATE_SET),
        },
        'notes': 'Pairs are undirected (a,b) checked once; substances are canonicalized via drug_interaction.yaml before prediction.',
    }

    with open(output_path, 'w', encoding='utf-8') as f:
        json.dump(payload, f, ensure_ascii=False, indent=2)

    return output_path, counts


# Quick test export (set to None for full)
out_path, out_counts = export_all_pairwise_risks(max_substances=60, progress_every=10)
print('\n✓ Wrote:', out_path.resolve())
print('Counts:')
for k, v in out_counts.items():
    print(f'  {k:20} {v}')

... 10/60 substances processed | 545/1770 pairs | 1,120 pairs/s
... 20/60 substances processed | 990/1770 pairs | 1,131 pairs/s
... 30/60 substances processed | 1335/1770 pairs | 1,128 pairs/s
... 40/60 substances processed | 1580/1770 pairs | 1,126 pairs/s
... 50/60 substances processed | 1725/1770 pairs | 1,128 pairs/s
... 60/60 substances processed | 1770/1770 pairs | 1,130 pairs/s

✓ Wrote: C:\Users\USER\dev\code\mobile_drug_use_app\backend\ML\drug_interaction_model\outputs\pairwise_risks_by_class.json
Counts:
  Low Risk & Synergy   593
  Low Risk & No Synergy 0
  Low Risk & Decrease  7
  Caution              47
  Unsafe               6
  Dangerous            1117
