<a href="https://colab.research.google.com/github/Raymond-Owen1-137/orco_rta/blob/main/orco_rta.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
#!/usr/bin/env python3
"""
orco.py — opensource residue clasification oracle is a performes Probabilistic Residue Assignment via Gaussian Likelihood and Secondary Structure Filtering

This module performs residue type classification from chemical shift data using a statistically principled
approach combining:
  • Univariate Gaussian likelihood modeling
  • Chi-squared distance thresholding
  • Softmax-based posterior normalization
  • Marginalization over secondary structure states

Designed for NMR-based spin system assignment.
"""

import csv
import math
import sys
from itertools import permutations
from io import StringIO

# Constants
SENTINEL = 1e6
MANDATORY = ("N", "CA", "CO")
OPTIONAL = ("CB", "CG", "CD", "CE")

# Chi-squared critical values at alpha = 0.05 for df = 3–7
CHI2_THRESHOLDS = {
    3: 7.81,
    4: 9.49,
    5: 11.07,
    6: 12.59,
    7: 14.07,
}

# Embedded fallback sidechain statistics (mean ± std from RefDB or literature)
_RAW_STATS = """Residue,Atom,Mean,StdDev
ARG,CD,43.199,2.785
ARG,CG,27.248,2.899
ASN,CG,176.235,8.700
ASP,CG,177.398,17.314
GLN,CD,179.339,7.130
GLN,CG,33.807,2.441
GLU,CD,181.239,13.531
GLU,CG,36.141,2.816
HIS,CG,131.472,8.708
ILE,CD,13.487,3.290
ILE,CG,27.765,3.200
LEU,CG,26.804,1.457
LEU,CD,24.651,2.002
LYS,CG,24.953,2.975
LYS,CD,28.995,2.519
LYS,CE,41.927,2.896
MET,CG,32.064,3.087
MET,CE,17.238,3.991
PHE,CD,131.205,5.750
PHE,CE,130.346,5.724
PRO,CG,27.265,3.532
PRO,CD,50.344,3.052
TRP,CG,110.286,8.701
TYR,CG,128.267,12.396
VAL,CG,21.529,2.328
VAL,CD,21.344,2.424
"""

def load_raw_sidechain_stats():
    """Returns a dict of fallback stats: res -> atom -> (mean, std)"""
    reader = csv.DictReader(StringIO(_RAW_STATS))
    raw = {}
    for row in reader:
        res = row["Residue"].strip().upper()
        atom = row["Atom"].strip().upper()
        if atom in OPTIONAL:
            raw.setdefault(res, {})[atom] = (float(row["Mean"]), float(row["StdDev"]))
    return raw

def load_shift_stats(path):
    """
    Loads backbone and sidechain statistics from a CSV file,
    then supplements missing sidechains with internal _RAW_STATS.

    CSV must contain: Residue,SS,Atom,Mean,StdDev
    """
    stats = {}
    with open(path, newline='') as f:
        reader = csv.DictReader(f)
        for row in reader:
            try:
                res = row["Residue"].strip().upper()
                ss = row["SS"].strip().capitalize()
                atom = row["Atom"].strip().upper()
                μ = float(row["Mean"])
                σ = float(row["StdDev"])
                stats.setdefault((res, ss), {})[atom] = (μ, σ)
            except Exception as e:
                print(f"[WARN] Skipping bad row: {row} ({e})", file=sys.stderr)

    # Supplement missing sidechains with raw fallbacks
    raw = load_raw_sidechain_stats()
    for (res, ss), ref in stats.items():
        for atom, (μ, σ) in raw.get(res, {}).items():
            if atom not in ref:
                ref[atom] = (μ, σ)
    return stats

def compute_assignment_probabilities(spin, stats):
    """
    Given a spin system and stats, compute:
      1. Mahalanobis-filtered likelihoods
      2. Softmax-normalized posterior P(res, ss | spin)
      3. Marginalized P(res | spin), with SS breakdown
    """
    cx_vals = [x for x in spin["CX"] if x != SENTINEL]
    m = len(cx_vals)
    if any(spin[d] == SENTINEL for d in MANDATORY):
        raise ValueError("Missing mandatory atoms: N, CA, CO must be present.")
    if not (3 <= 3 + m <= 7):
        raise ValueError(f"Total atom dimensions must be 3–7; got {3 + m}")
    χ2 = CHI2_THRESHOLDS[3 + m]

    # Filter candidates with complete backbone stats
    candidates = {
        (res, ss): ref for (res, ss), ref in stats.items()
        if all(atom in ref for atom in MANDATORY)
    }

    # Build sidechain lookup: res -> atom -> (μ, σ)
    sidechains = {}
    for (res, _), ref in stats.items():
        for atom in OPTIONAL:
            if atom in ref:
                sidechains.setdefault(res, {})[atom] = ref[atom]

    results = []
    for (res, ss), ref in candidates.items():
        sc = sidechains.get(res, {})
        available = [a for a in OPTIONAL if a in sc]
        if len(available) < m:
            continue
        best_ll = None
        for perm in permutations(available, m):
            ll = d2 = 0.0
            for dim in MANDATORY:
                x, (μ, σ) = spin[dim], ref[dim]
                z = (x - μ) / σ
                ll += -0.5*z*z - math.log(σ) - 0.5*math.log(2*math.pi)
                d2 += z*z
            for i, dim in enumerate(perm):
                x, (μ, σ) = cx_vals[i], sc[dim]
                z = (x - μ) / σ
                ll += -0.5*z*z - math.log(σ) - 0.5*math.log(2*math.pi)
                d2 += z*z
            if d2 <= χ2:
                best_ll = ll if best_ll is None else max(best_ll, ll)
        if best_ll is not None:
            results.append(((res, ss), best_ll))

    if not results:
        return None

    # Softmax normalization
    max_ll = max(ll for (_, ll) in results)
    scores = {k: math.exp(ll - max_ll) for k, ll in results}
    Z = sum(scores.values())
    posteriors = {k: v/Z for k, v in scores.items()}

    # Marginalize
    residue_totals = {}
    ss_breakdown = {}
    for (res, ss), p in posteriors.items():
        residue_totals[res] = residue_totals.get(res, 0.0) + p
        ss_breakdown.setdefault(res, {})
        ss_breakdown[res][ss] = ss_breakdown[res].get(ss, 0.0) + p

    return residue_totals, ss_breakdown, posteriors

def parse_spin_input(line):
    """Parses a space-separated string into a spin dict."""
    toks = line.strip().split()
    vals = []
    for tok in toks[:7]:
        try:
            vals.append(float(tok))
        except:
            vals.append(SENTINEL)
    vals = (vals + [SENTINEL]*7)[:7]
    return {
        "N": vals[0],
        "CA": vals[1],
        "CO": vals[2],
        "CX": vals[3:7],
    }

def print_results(res_totals, ss_breakdown):
    print(f"{'Res':4} {'Total':>6} {'Coil':>6} {'Helix':>6} {'Sheet':>6}")
    for res, tot in sorted(res_totals.items(), key=lambda kv: -kv[1]):
        c = ss_breakdown[res].get("Coil",  0.0)
        h = ss_breakdown[res].get("Helix", 0.0)
        s = ss_breakdown[res].get("Sheet", 0.0)
        print(f"{res:4} {tot:6.4f} {c:6.4f} {h:6.4f} {s:6.4f}")

def main():
    stats = load_shift_stats("stats_refdb.csv")
    while True:
        line = input("Enter N CA CO CX1 CX2 CX3 CX4 (blank to quit): ").strip()
        if not line:
            break
        try:
            spin = parse_spin_input(line)
            res_totals, ss_breakdown, _ = compute_assignment_probabilities(spin, stats)
            print_results(res_totals, ss_breakdown)
        except Exception as e:
            print(f"[ERROR] {e}")
        again = input("Assign another? [Y/n]: ").strip().lower()
        if again and not again.startswith("y"):
            break

if __name__ == "__main__":
    main()


Enter N CA CO CX1 CX2 CX3 CX4 (blank to quit): 118.6 57.92 1000000.0  40.21  1000000.0  1000000.0  1000000.0
[ERROR] Missing mandatory atoms: N, CA, CO must be present.
Assign another? [Y/n]: y
Enter N CA CO CX1 CX2 CX3 CX4 (blank to quit): 118.414 51.897 174.555  18.24  1000000.0  1000000.0  1000000.0 
Res   Total   Coil  Helix  Sheet
ALA  0.6001 0.1030 0.0000 0.4971
MET  0.3999 0.1113 0.0000 0.2886
