In [7]:
from transformers import pipeline
import pandas as pd
import re

### Modelo

In [5]:
pipe = pipeline(
    task="text-classification",
    model="DATEXIS/CORe-clinical-diagnosis-prediction",
    top_k=None,
    truncation=True,
    function_to_apply="sigmoid",
    return_all_scores=True,
)

Loading weights:   0%|          | 0/201 [00:00<?, ?it/s]

BertForSequenceClassification LOAD REPORT from: DATEXIS/CORe-clinical-diagnosis-prediction
Key                          | Status     |  | 
-----------------------------+------------+--+-
bert.embeddings.position_ids | UNEXPECTED |  | 

Notes:
- UNEXPECTED	:can be ignored when loading from different task/architecture; not ok if you expect identical arch.


In [6]:

notes = [
    "Patient with chest pain radiating to left arm, sweating, nausea.",
    "Fever, cough, shortness of breath, oxygen saturation low.",
    "History of diabetes. Polyuria, polydipsia, very high glucose.",
]

raw = pipe(notes)

def is_icd9_4plus(label: str) -> bool:
    # solo números y longitud >=4 (incluye 4,5,6...)
    return bool(re.fullmatch(r"\d{4,6}", str(label)))

def icd9_with_dot(code_raw: str) -> str:
    s = str(code_raw)
    if len(s) <= 3:
        return s
    return s[:3] + "." + s[3:]  # 4140 -> 414.0 ; 25013 -> 250.13

THRESH = 0.30

for text, scores in zip(notes, raw):
    keep = [(s["label"], s["score"]) for s in scores
            if s["score"] >= THRESH and is_icd9_4plus(s["label"])]
    keep = sorted(keep, key=lambda x: -x[1])[:10]

    print("\nTEXT:", text)
    for lab, sc in keep:
        print(f"  {lab:>6}  ({icd9_with_dot(lab):>7})  {sc:.3f}")


TEXT: Patient with chest pain radiating to left arm, sweating, nausea.
    4140  (  414.0)  1.000
    4111  (  411.1)  0.664

TEXT: Fever, cough, shortness of breath, oxygen saturation low.
    7855  (  785.5)  0.864
    5849  (  584.9)  0.863
    9959  (  995.9)  0.822
    2859  (  285.9)  0.727
    5188  (  518.8)  0.584
    7070  (  707.0)  0.489
    2765  (  276.5)  0.401
    0389  (  038.9)  0.353

TEXT: History of diabetes. Polyuria, polydipsia, very high glucose.
    2501  (  250.1)  0.973
    2765  (  276.5)  0.579
    2505  (  250.5)  0.320


### Data

In [11]:
data = pd.read_csv("./Dataset/icd9dx2015.csv", dtype=str)
data

Unnamed: 0,dgns_cd,longdesc,shortdesc,version,fyear
0,0010,Cholera due to vibrio cholerae,Cholera d/t vib cholerae,32,2015
1,0011,Cholera due to vibrio cholerae el tor,Cholera d/t vib el tor,32,2015
2,0019,"Cholera, unspecified",Cholera NOS,32,2015
3,0020,Typhoid fever,Typhoid fever,32,2015
4,0021,Paratyphoid fever A,Paratyphoid fever a,32,2015
...,...,...,...,...,...
14562,V9129,"Quadruplet gestation, unable to determine numb...",Quad gest-plac/sac undet,32,2015
14563,V9190,"Other specified multiple gestation, unspecifie...",Mult gest-plac/sac NOS,32,2015
14564,V9191,"Other specified multiple gestation, with two o...",Mult gest 2+ monochr NEC,32,2015
14565,V9192,"Other specified multiple gestation, with two o...",Mult gest 2+ monoamn NEC,32,2015


In [17]:
def decimalize(code: str) -> str:
    s = str(code).strip()
    if not s.isdigit():
        return s
    if len(s) <= 3:
        return s.zfill(3)
    return s[:3] + "." + s[3:]

def lookup_icd9(code_raw: str, code_to_long: dict, code_to_short: dict, all_codes: list[str], max_children: int = 10):
    code_raw = str(code_raw).strip()

    if not re.fullmatch(r"\d{3,6}", code_raw):
        return {"query": code_raw, "found": False, "reason": "non-numeric"}

    if code_raw in code_to_long:
        return {
            "query": code_raw,
            "found": True,
            "code_csv": code_raw,
            "code_decimal": decimalize(code_raw),
            "longdesc": code_to_long[code_raw],
            "shortdesc": code_to_short[code_raw],
            "note": "exact match",
        }
    if len(code_raw) == 4:
        children = [c for c in all_codes if c.startswith(code_raw) and len(c) >= 5]
        children = sorted(children)[:max_children]

        padded0 = code_raw + "0"
        best_guess = None
        if padded0 in code_to_long:
            best_guess = {
                "code_csv": padded0,
                "code_decimal": decimalize(padded0),
                "longdesc": code_to_long[padded0],
                "shortdesc": code_to_short[padded0],
                "note": "best_guess: padded 0 (often NOS)",
            }

        if children:
            return {
                "query": code_raw,
                "found": False,
                "reason": "no exact; returning child candidates",
                "query_decimal": decimalize(code_raw),
                "best_guess": best_guess,
                "children": [
                    {
                        "code_csv": c,
                        "code_decimal": decimalize(c),
                        "longdesc": code_to_long.get(c),
                        "shortdesc": code_to_short.get(c),
                    }
                    for c in children
                ],
            }

        if padded0 in code_to_long:
            return {
                "query": code_raw,
                "found": True,
                "code_csv": padded0,
                "code_decimal": decimalize(padded0),
                "longdesc": code_to_long[padded0],
                "shortdesc": code_to_short[padded0],
                "note": "matched by padding 0 (no children found)",
            }

    children = [c for c in all_codes if c.startswith(code_raw)]
    children = sorted(children)[:max_children]
    if children:
        return {
            "query": code_raw,
            "found": False,
            "reason": "no exact; returning prefix candidates",
            "query_decimal": decimalize(code_raw),
            "children": [
                {
                    "code_csv": c,
                    "code_decimal": decimalize(c),
                    "longdesc": code_to_long.get(c),
                    "shortdesc": code_to_short.get(c),
                }
                for c in children
            ],
        }

    return {"query": code_raw, "found": False, "reason": "no match"}

In [18]:
all_codes = data["dgns_cd"].tolist()
code_to_long = dict(zip(data["dgns_cd"], data["longdesc"]))
code_to_short = dict(zip(data["dgns_cd"], data["shortdesc"]))

for c in ["4140", "0389", "2501", "2505", "5849"]:
    print("\n", c, "->", lookup_icd9(c, code_to_long, code_to_short, all_codes))


 4140 -> {'query': '4140', 'found': False, 'reason': 'no exact; returning child candidates', 'query_decimal': '414.0', 'best_guess': {'code_csv': '41400', 'code_decimal': '414.00', 'longdesc': 'Coronary atherosclerosis of unspecified type of vessel, native or graft', 'shortdesc': 'Cor ath unsp vsl ntv/gft', 'note': 'best_guess: padded 0 (often NOS)'}, 'children': [{'code_csv': '41400', 'code_decimal': '414.00', 'longdesc': 'Coronary atherosclerosis of unspecified type of vessel, native or graft', 'shortdesc': 'Cor ath unsp vsl ntv/gft'}, {'code_csv': '41401', 'code_decimal': '414.01', 'longdesc': 'Coronary atherosclerosis of native coronary artery', 'shortdesc': 'Crnry athrscl natve vssl'}, {'code_csv': '41402', 'code_decimal': '414.02', 'longdesc': 'Coronary atherosclerosis of autologous vein bypass graft', 'shortdesc': 'Crn ath atlg vn bps grft'}, {'code_csv': '41403', 'code_decimal': '414.03', 'longdesc': 'Coronary atherosclerosis of nonautologous biological bypass graft', 'shortde