In [1]:
import os, json, torch
from transformers import AutoConfig, AutoTokenizer, AutoModelForSequenceClassification
from peft import PeftModel, PeftConfig
import torch.nn as nn
BASE_MODEL     = "Qwen/Qwen3-8B"
MODEL_DIR      = "/data/LLM/qwen3_8b_ais_lora_onehot_4digit_mlphead_a128_epoch10"   # ← tokenizer 위치
ADAPTER_DIR    = "/data/LLM/qwen3_8b_ais_lora_onehot_4digit_mlphead_a128_epoch10/cls_lora_adapter"  # ← LoRA 가중치
DEVICE         = "cuda:1"

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# — 1) load code2idx.json to rebuild labels list
with open("code2idx_3lvl.json") as f:
    code2idx = json.load(f)
num_labels = len(code2idx)
labels     = [None] * num_labels
for code, idx in code2idx.items():
    labels[int(idx)] = code

In [3]:
lvl_map = {
    (1, ()): {
        "0": "OTHER TRAUMA",
        "1": "HEAD",
        "2": "FACE",
        "3": "NECK",
        "4": "THORAX",
        "5": "ABDOMEN",
        "6": "SPINE",
        "7": "UPPER_EXT",
        "8": "LOWER_EXT",
        "9": "EXTERNAL",
    },
    (2, ("0",)): {
        "1": "Environmental/ WHOLE-BODY",
        "2": "Asphyxia/Suffocation",
        "4": "Caustic Agents",
        "6": "Drowning",
        "8": "Electrical injury",
    },
    (2, ("1",)): {
        "0": "WHOLE/NFS",
        "1": "WHOLE AREA",
        "2": "VESSELS, INTRACRICAL",
        "3": "NERVES,CRANICAL",
        "4": "INTERNAL ORGANS",
        "5": "SKELETAL",
        "6": "CONCUSSIVE INJURY",
    },
    (2, ("2",)): {
        "0": "WHOLE/NFS",
        "1": "WHOLE AREA",
        "2": "VESSELS",
        "3": "NERVES",
        "4": "INTERNAL ORGANS",
        "5": "SKELETAL",
    },
    (2, ("3",)): {
        "0": "WHOLE/NFS",
        "1": "WHOLE AREA",
        "2": "VESSELS",
        "3": "NERVES",
        "4": "INTERNAL ORGANS",
        "5": "SKELETAL",
    },
    (2, ("4",)): {
        "0": "WHOLE/NFS",
        "1": "WHOLE AREA",
        "2": "VESSELS",
        "3": "NERVES",
        "4": "INTERNAL ORGANS",
        "5": "SKELETAL including thoracic wall involvement",
    },
    (2, ("5",)): {
        "0": "WHOLE/NFS",
        "1": "WHOLE AREA",
        "2": "VESSELS",
        "3": "NERVES",
        "4": "INTERNAL ORGANS",
    },
    (2, ("6",)): {
        "00": "WHOLE_SPINE",
        "09": "WHOLE_SPINE_DIE",
        "02": "CERVICAL",
        "04": "THORACIC",
        "06": "LUMBAR",
    },
    (2, ("7",)): {
        "0": "WHOLE/NFS",
        "1": "WHOLE AREA",
        "2": "VESSELS",
        "3": "NERVES",
        "4": "MUSCLES, TENDONS, LIGAMENTS",
        "5": "SKELETAL",
        "7": "JOINTS",
    },
    (2, ("8",)): {
        "0": "WHOLE/NFS",
        "1": "WHOLE AREA",
        "2": "VESSELS",
        "3": "NERVES",
        "4": "MUSCLES, TENDONS, LIGAMENTS",
        "5": "SKELETAL",
        "7": "JOINTS",
    },
    (2, ("9",)): {"1": "ALL"},
    (3, ("1", "0")): {
        "00": "Injuries to the Head",
        "09": "Injuries to the Head Died of head injury without further substantiation of injuries or no autopsy confirmation of specific injuries. or no autopsy confirmation of specific injuries.",
    },
    (3, ("1", "1")): {
        "00": "Head Injury involving only headache",
        "02": "Scalp abrasion",
        "04": "Scalp contusion; hematoma",
        "06": "Scalp laceration",
        "08": "Scalp avulsion",
        "30": "Crush Injury",
        "60": "Penetrating Injury to Skull",
    },
    (3, ("1", "2")): {
        "00": "Vascular Injury in Head",
        "01": "Artery",
        "02": "Artery Anterior cerebral artery",
        "04": "Artery Basilar artery",
        "06": "Sinus Carotid-cavernous fistula",
        "08": "Sinus Cavernous sinus",
        "10": "Artery Internal carotid artery",
        "14": "Artery Middle cerebral artery",
        "16": "Artery Other artery",
        "18": "Artery Posterior cerebral artery",
        "20": "Sinus Sigmoid sinus",
        "22": "Sinus",
        "23": "Vein",
        "24": "Sinus Superior longitudinal (saggital) sinus",
        "25": "Vein Vein, major",
        "26": "Sinus Transverse sinus",
        "27": "Vein Vein, non-major",
        "28": "Artery Vertebral artery",
        "30": "Sinus Straight sinus",
    },
    (3, ("1", "3")): {
        "02": "Cranial nerve",
        "04": "I (Olfactory nerve, tract)",
        "06": "II (Optic nerve - intracranial and intracanalicular segments includes chiasm and tracts)",
        "08": "III (Oculomotor nerve)",
        "10": "IV (Trochlear nerve)",
        "12": "V (Trigeminal nerve)",
        "14": "VI (Abducens nerve)",
        "16": "VII (Facial nerve)",
        "18": "VIII (Vestibulocochlear nerve includes auditory, acoustic and vestibular nerves)",
        "20": "IX (Glossopharyngeal nerve)",
        "22": "X (Vagus nerve excludes injury in neck, thorax or abdomen)",
        "24": "XI (Spinal accessory nerve)",
        "26": "XII (Hypoglossal nerve)",
    },
    (3, ("1", "4")): {
        "02": "Brain stem (hypothalamus, medulla, midbrain, pons)",
        "04": "Cerebellum",
        "06": "Cerebrum(includes basal ganglia, thalamus, putamen, globus pallidius)",
        "07": "Pituitary injury [AND] Cerebrum hypoxic or ischemic brain damage secondary to systemic hypoxemia, hypotension or shock not directly related to head trauma",
    },
    (3, ("1", "5")): {
        "00": "Skull fracture",
        "02": "Base (basilar) fracture",
        "04": "Vault fracture",
    },
    (3, ("1", "6")): {
        "10": "Cerebral Concussion [AND] Diffuse Axonal Injury (prolonged traumatic coma LOC >6 hours not due to mass lesion)"
    },
    (3, ("2", "0")): {
        "00": "Injuries to the Face",
        "09": "Injuries to the Face Died of facial injury without further substantiation of injuries or no autopsy confirmation of specific injuries",
    },
    (3, ("2", "1")): {
        "00": "Skin/subcutaneous/muscle",
        "02": "Skin/subcutaneous/muscle abrasion",
        "04": "Skin/subcutaneous/muscle contusion; hematoma",
        "06": "Skin/subcutaneous/muscle laceration",
        "08": "Skin/subcutaneous/muscle avulsion",
        "60": "Penetrating injury",
    },
    (3, ("2", "2")): {
        "00": "Vascular injuries in face",
        "02": "External carotid artery branch(es) laceration (includes facial, temporal, and internal maxillary)",
    },
    (3, ("2", "3")): {"02": "Optic Nerve"},
    (3, ("2", "4")): {
        "02": "Ear",
        "04": "Eye",
        "05": "Eye Choroid",
        "06": "Eye Cornea",
        "07": "Eye Injury with retained Intraocular Foreign Body",
        "08": "Eye Lens",
        "09": "Eye Macula",
        "10": "Eye Retina",
        "12": "Eye Sclera (includes globe)",
        "14": "Eye Uvea",
        "16": "Eye Vitreous",
        "30": "Mouth injury",
        "31": "Palate",
        "32": "Gingiva (gum)",
        "34": "Tongue laceration",
    },
    (3, ("2", "5")): {
        "02": "Alveolar ridge fracture with or without injury to teeth",
        "04": "Facial bone(s) fracture",
        "06": "Mandible fracture",
        "08": "Maxilla fracture (including maxillary sinus)",
        "10": "Nose",
        "12": "Orbit",
        "14": "Teeth",
        "16": "Temporomandibular joint",
        "18": "Zygoma (Includes tripod and malar fractures)",
        "19": "Panfacial fracture",
    },
    (3, ("3", "0")): {
        "00": "Injuries to the Neck",
        "09": "Injuries to the Neck Died of neck injury without further substantiation of injuries or no autopsy confirmation of specific injuries",
    },
    (3, ("3", "1")): {
        "00": "Skin/subcutaneous tissue/muscle",
        "02": "Skin/subcutaneous tissue/muscle abrasion",
        "04": "Skin/subcutaneous tissue/muscle contusion; hematoma",
        "06": "Skin/subcutaneous tissue/muscle laceration",
        "08": "Skin/subcutaneous tissue/muscle avulsion",
        "10": "Decapitation",
        "60": "Penetrating injury",
    },
    (3, ("3", "2")): {
        "00": "Vascular Injury in Neck",
        "02": "Carotid artery (common, internal)",
        "04": "Carotid artery (external includes thyroid)",
        "06": "Jugular vein (external)",
        "08": "Jugular vein (internal)",
        "10": "Vertebral artery",
    },
    (3, ("3", "3")): {
        "00": "Nerve Injury",
        "02": "Phrenic nerve injury",
        "04": "Vagus nerve injury",
    },
    (3, ("3", "4")): {
        "01": "Esophagus injury",
        "02": "Larynx (including thyroid and cricoid cartilage)",
        "06": "Pharynx or Retropharyngeal area",
        "10": "Salivary gland",
        "14": "Thyroid gland",
        "16": "Trachea injury",
        "18": "Vocal cord (not due to intubation)",
    },
    (3, ("3", "5")): {"02": "Hyoid fracture"},
    (3, ("4", "0")): {
        "00": "Injuries to the Whole Thorax",
        "09": "Injuries to the Whole Thorax Died of thoracic injury without further substantiation of injuries or no autopsy confirmation of specific injuries",
    },
    (3, ("4", "1")): {
        "00": "Skin/subcutaneous/muscle",
        "01": "Pectoral muscle tear; laceration",
        "02": "Skin/subcutaneous/muscle abrasion",
        "04": "Skin/subcutaneous/muscle contusion; hematoma",
        "06": "Skin/subcutaneous/muscle laceration",
        "08": "Skin/subcutaneous/muscle avulsion",
        "10": "Breast avulsion, female",
        "30": "Crush injury",
        "50": "Open (“sucking”) chest wound",
        "60": "Penetrating injury",
        "92": "Lung inhalation injury (heat, particulate matter, noxious agents)",
    },
    (3, ("4", "2")): {
        "00": "Vascular injuries in thorax",
        "02": "Aorta, thoracic",
        "04": "Brachiocephalic artery",
        "06": "Brachiocephalic vein",
        "08": "Coronary artery laceration or thrombosis to left main, right main or left anterior descending artery; coronary sinus",
        "10": "Pulmonary artery",
        "12": "Pulmonary vein",
        "14": "Subclavian artery",
        "16": "Subclavian vein",
        "18": "Vena Cava, superior and thoracic portion of inferior",
        "20": "Other named arteries (ex- bronchial, esophageal, intercostal, internal mammary)",
        "22": "Other named veins (ex- azygos, bronchial, esophageal, hemiazygos, intercostal, internal jugular, internal mammary)",
    },
    (3, ("4", "3")): {"04": "Vagus nerve injury"},
    (3, ("4", "4")): {
        "00": "Bronchus injury",
        "01": "Bronchus, main stem",
        "02": "Bronchus, distal to main stem",
        "04": "Intracardiac chordae tendineae laceration;rupture",
        "06": "Diaphragm",
        "08": "Esophagus injury in Thorax",
        "10": "Heart (Myocardium) injury",
        "12": "Intracardiac valve laceration; rupture",
        "13": "Interatrial septum laceration; rupture",
        "14": "Lung",
        "16": "Pericardium",
        "18": "Pleura laceration",
        "22": "Thoracic injury",
        "24": "Thoracic duct laceration",
        "25": "Thymus laceration; perforation",
        "26": "Trachea injury in Thorax",
        "29": "Thoracic injury",
    },
    (3, ("4", "5")): {"02": "Rib cage", "08": "Sternum", "10": "Thoracic Wall"},
    (3, ("5", "0")): {
        "00": "Injuries to the Whole Abdomen",
        "09": "Injuries to the Whole Abdomen Died of abdominal injury without further substantiation of injuries or no autopsy confirmation of specific injuries",
    },
    (3, ("5", "1")): {
        "00": "Skin/Subcutaneous/Muscle (except rectus abdominus)",
        "01": "Rectus Abdominus rupture",
        "02": "Skin/Subcutaneous/Muscle (except rectus abdominus) abrasion",
        "04": "Skin/Subcutaneous/Muscle (except rectus abdominus) contusion; hematoma",
        "06": "Skin/Subcutaneous/Muscle (except rectus abdominus) laceration",
        "08": "Skin/Subcutaneous/Muscle (except rectus abdominus) avulsion",
        "10": "Torso transection",
        "60": "Penetrating injury",
    },
    (3, ("5", "2")): {
        "00": "Vascular injury in Abdomen",
        "02": "Aorta, Abdomina",
        "04": "Celiac Artery",
        "06": "Iliac Artery (common, internal, external) and its named branches",
        "08": "Iliac Vein (common)",
        "10": "Iliac Vein (internal, external)",
        "11": "Superior Mesenteric Artery",
        "12": "Vena Cava, inferior",
        "14": "Other named arteries (ex- hepatic, renal, splenic)",
        "16": "Other named veins (ex- portal, renal, splenic, superior mesenteric)",
    },
    (3, ("5", "3")): {"04": "Vagus nerve injury"},
    (3, ("5", "4")): {
        "02": "Adrenal Gland",
        "03": "Appendix laceration; perforation",
        "04": "Anus",
        "06": "Bladder (urinary)",
        "08": "Colon (large bowel)",
        "10": "Duodenum",
        "12": "Gallbladder",
        "14": "Jejunum-Ileum (small bowel)",
        "16": "Kidney",
        "18": "Liver",
        "20": "Mesentery",
        "22": "Omentum",
        "24": "Ovarian (Fallopian) tube",
        "26": "Ovary",
        "28": "Pancreas",
        "30": "Penis",
        "32": "Perineum",
        "35": "Prostate",
        "36": "Rectum",
        "38": "Retroperitoneumhemorrhage or hematoma",
        "40": "Scrotum",
        "42": "Spleen",
        "44": "Stomach",
        "46": "Testes",
        "48": "Ureter",
        "50": "Urethra",
        "52": "Uterus",
        "54": "Vagina",
        "56": "Vulva",
    },
    (3, ("7", "0")): {
        "00": "Injuries to the Whole Upper Extremity",
        "09": "Injuries to the Whole Upper Extremity Died of upper extremity injury without further substantiation of injuries or no autopsy confirmation of specific injuries",
    },
    (3, ("7", "1")): {
        "00": "Skin/subcutaneous/muscle",
        "02": "Skin/subcutaneous/muscle abrasion",
        "04": "Skin/subcutaneous/muscle contusion; hematoma",
        "06": "Skin/subcutaneous/muscle laceration",
        "08": "Skin/subcutaneous/muscle avulsion",
        "10": "Amputation [traumatic], partial or complete between shoulder and hand",
        "20": "Compartment syndrome from trauma to soft tissue only, not involving fracture or massive destruction of bone or other anatomical structures",
        "30": "Crush Injury to limb between shoulder and wrist",
        "40": "Degloving injury",
        "60": "Penetrating injury",
    },
    (3, ("7", "2")): {
        "00": "Vascular Injury in Upper Extremity",
        "02": "Axillary artery",
        "04": "Axillary vein",
        "06": "Brachial artery",
        "08": "Brachial vein",
        "10": "Other named arteries (ex- radial, ulnar)",
        "12": "Other named veins (ex- cephalic, basilic)",
    },
    (3, ("7", "3")): {
        "00": "Nerve injury in upper extremity",
        "02": "Digital nerve",
        "04": "Median nerve",
        "06": "Radial nerve",
        "08": "Ulnar nerve",
    },
    (3, ("7", "4")): {
        "00": "Muscle, tendon or ligament injury",
        "02": "Tendon tear; avulsion",
        "04": "Muscle tear; avulsion",
        "06": "Joint capsule; rupture; tear; avulsion",
    },
    (3, ("7", "5")): {
        "00": "Upper Extremity fracture",
        "05": "Clavicle, Proximal clavicle fracture",
        "06": "Clavicle shaft fracture",
        "07": "Distal (lateral end) clavicle fracture",
        "09": "Scapula fracture",
        "11": "Humerus, Proximal humerus fracture",
        "12": "Humerus shaft fracture",
        "13": "Distal humerus fracture",
        "18": "Arm fracture",
        "19": "Forearm fracture",
        "20": "Hand fracture",
        "21": "Proximal (Radius, Ulna(olecranon)) fracture",
        "22": "(Radius, Ulna) shaft fracture",
        "23": "Distal (Radius, Ulna) fracture",
        "24": "Carpus fracture",
        "25": "Metacarpus fracture",
        "26": "Phalange fracture",
        "28": "Radius fracture",
        "32": "Ulna fracture",
    },
    (3, ("7", "7")): {
        "00": "Upper extremity joint injury",
        "05": "Sternoclavicular joint",
        "07": "Acromioclavicular joint",
        "10": "Shoulder (glenohumeral) joint",
        "20": "Elbow joint",
        "22": "Carpal (wrist) joint dislocation (distal radioulnar)",
        "23": "Carpal (wrist) joint dislocation (radiocarpal)",
        "24": "Carpal (wrist) joint",
        "25": "Metacarpophalangeal or Interphalangeal joint",
    },
    (3, ("8", "0")): {
        "00": "Injuries to the Whole Lower Extremity",
        "09": "Injuries to the Whole Lower Extremity Died of lower extremity injury without further substantiation of injuries or no autopsy confirmation of specific injuries",
    },
    (3, ("8", "1")): {
        "00": "Skin/subcutaneous/muscle",
        "02": "Skin/subcutaneous/muscle abrasion",
        "04": "Skin/subcutaneous/muscle contusion; hematoma",
        "06": "Skin/subcutaneous/muscle laceration",
        "08": "Skin/subcutaneous/muscle avulsion",
        "10": "Amputation [traumatic], partial or complete between hip and foot",
        "20": "Compartment syndrome resulting from trauma to soft tissue only, not involving fracture or massive destruction of bone and other  \n anatomical structures",
        "30": "Crush Injury to limb between hip and foot",
        "40": "Degloving injury",
        "60": "Penetrating injury",
    },
    (3, ("8", "2")): {
        "00": "Vascular injuries in the lower extremity",
        "02": "Femoral artery and its named branches",
        "04": "Femoral vein",
        "06": "Popliteal artery",
        "08": "Popliteal vein",
        "10": "Other named arteries (ex- tibial,peroneal)",
        "12": "Other named veins (ex- saphenous)",
    },
    (3, ("8", "3")): {
        "00": "Nerve injury in the lower extremity",
        "02": "Digital nerve",
        "03": "Femoral nerve",
        "04": "Sciatic nerve",
        "05": "Peroneal nerve",
        "06": "Tibial nerve",
    },
    (3, ("8", "4")): {
        "00": "Muscle, tendon, ligament injury",
        "02": "Achilles tendon tear; avulsion",
        "03": "Meniscus tear; avulsion",
        "04": "Collateral ligament tear; avulsion",
        "05": "Cruciate ligament (anterior or posterior) tear; avulsion",
        "06": "Muscle tear; avulsion",
        "08": "Tendon (other than Achilles or patellar) tear; avulsion",
        "10": "Patellar tendon tear; avulsion",
    },
    (3, ("8", "5")): {
        "00": "Lower Extremity fracture",
        "20": "Leg [AND] Foot fracture",
        "30": "Femur fracture",
        "31": "Proximal Femur fracture",
        "32": "Femur Shaft fracture",
        "33": "Distal Femur fracture",
        "40": "Tibia fracture",
        "41": "Proximal Tibia fracture",
        "42": "Tibia Shaft fracture",
        "43": "Distal Tibia fracture(includes medial malleolus; also pilon fracture)",
        "44": "Fibula (malleoli) fracture",
        "45": "Patella fracture",
        "61": "Pelvic ring fracture",
        "62": "Acetabulum fracture",
        "72": "Talus fracture",
        "73": "Calcaneus fracture",
        "74": "Navicular fracture",
        "75": "Cuneiform fracture",
        "76": "Cuboid fracture",
        "81": "Metatarsal fracture",
        "82": "Phalange fracture",
    },
    (3, ("8", "7")): {
        "00": "Lower extremity joint injury",
        "04": "Forefoot joint",
        "30": "Hip joint",
        "40": "Knee joint",
        "71": "Ankle joint",
        "72": "Subtalar joint",
        "77": "Midtarsal joint",
        "80": "Tarsometatarsal joint",
        "81": "Metatarsophalangeal or interphalangeal joint",
    },
    (3, ("6", "00")): {
        "0": "CERVICAL_WHOLE",
        "2": "THORACIC_WHOLE",
        "3": "LUMBAR_WHOLE",
    },
    (3, ("6", "09")): {
        "0": "CERVICAL_WHOLE",
        "2": "THORACIC_WHOLE",
        "3": "LUMBAR_WHOLE",
    },
    (3, ("6", "02")): {
        "3": "Nerve root, single or multiple [AND] Brachial Plexus injury (includes trunks, divisions or cords)",
        "4": "Cord [AND] Spinous ligament injury [AND] Strain, acute, with no fracture or dislocation",
        "5": "Disc injury [AND] Dislocation (subluxation), no fracture, no cord involvement [AND] Fracture with or without dislocation but no cord involvement",
    },
    (3, ("6", "04")): {
        "3": " Nerve root, single or multiple",
        "4": "Cord [AND] Spinous ligament injury [AND] Strain, acute with no fracture or dislocation",
        "5": "Disc injury [AND] Dislocation (subluxation), no fracture, no cord involvement [AND] Fracture with or without dislocation but no cord involvement",
    },
    (3, ("6", "06")): {
        "3": "Nerve root or sacral plexus, single or multiple [AND] Cauda equina contusion",
        "4": "Cord [AND] Spinous ligament injury [AND] Strain, acute with no fracture or dislocation",
        "5": "Disc injury [AND] Dislocation (subluxation), no fracture, no cord involvement [AND] Fracture with or without dislocation but no cord involvement",
    },
    (3, ("9", "1")): {
        "00": "Soft tissue (skin) injury",
        "02": "Soft tissue (skin) abrasion",
        "04": "Soft tissue (skin) contusion; hematoma",
        "06": "Soft tissue (skin) laceration",
        "08": "Soft tissue (skin) avulsion",
        "40": "Degloving injury",
        "50": "Frostbite",
        "60": "Penetrating injury",
        "20": "Burns",
    },
    (3, ("0", "1")): {
        "00": "Hypothermia",
        "20": "Whole Body (explosion-type) Injury",
    },
    (3, ("0", "2")): {
        "00": "Asphyxia/Suffocation",
    },
    (3, ("0", "4")): {
        "00": "Caustic Agents",
    },
    (3, ("0", "6")): {
        "00": "Drowning",
    },
    (3, ("0", "8")): {
        "00": "Electrical injury",
    },
}

In [4]:
config = AutoConfig.from_pretrained(
    "Qwen/Qwen3-8B",
    num_labels=355,
    problem_type="multi_label_classification",
    id2label={i:lab for i,lab in enumerate(labels)},
    label2id={lab:i for i,lab in enumerate(labels)},
)
base_model = AutoModelForSequenceClassification.from_pretrained(
    "Qwen/Qwen3-8B",
    config=config,
    torch_dtype=torch.bfloat16,
).to(DEVICE)

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00, 91.92it/s]
Some weights of Qwen3ForSequenceClassification were not initialized from the model checkpoint at Qwen/Qwen3-8B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
hidden_mid = 512
base_model.score = nn.Sequential(
    nn.Linear(config.hidden_size, hidden_mid, bias=False),
    nn.ReLU(),
    nn.Dropout(0.1),
    nn.Linear(hidden_mid, config.num_labels, bias=False)
)


In [6]:
# — 4) load your LoRA adapter on top
adapter_path = os.path.join(MODEL_DIR, "cls_lora_adapter")
model = PeftModel.from_pretrained(
    base_model,
    adapter_path,
    torch_dtype=torch.bfloat16,
).to(DEVICE)
model = model.bfloat16() 
model.eval()

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): Qwen3ForSequenceClassification(
      (model): Qwen3Model(
        (embed_tokens): Embedding(151936, 4096)
        (layers): ModuleList(
          (0-35): 36 x Qwen3DecoderLayer(
            (self_attn): Qwen3Attention(
              (q_proj): lora.Linear(
                (base_layer): Linear(in_features=4096, out_features=4096, bias=False)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.1, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): Linear(in_features=4096, out_features=32, bias=False)
                )
                (lora_B): ModuleDict(
                  (default): Linear(in_features=32, out_features=4096, bias=False)
                )
                (lora_embedding_A): ParameterDict()
                (lora_embedding_B): ParameterDict()
                (lora_magnitude_vector): ModuleDict()
              )
        

In [7]:
# — 5) load tokenizer from your OUTPUT_DIR
tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR, use_fast=True)
tokenizer.pad_token    = tokenizer.eos_token
tokenizer.pad_token_id = tokenizer.eos_token_id

In [8]:
def build_prompt(report: str) -> str:
    """Format input in Qwen chat template."""
    return (
        "<|im_start|>user\n"
        f"{report}\n"
        "<|im_end|>\n"
        "<|im_start|>assistant\n"
    )


In [9]:
THRESHOLD  = 0.5
MAX_LEN = 1024

In [10]:
import json

In [11]:
with open("code2idx_3lvl.json") as f:
    code2idx = json.load(f)
idx2code = {int(v): k for k, v in code2idx.items()}


In [12]:
idx2code

{0: '0/1/00',
 1: '0/1/20',
 2: '0/4/00',
 3: '0/8/00',
 4: '1/0/00',
 5: '1/0/09',
 6: '1/1/00',
 7: '1/1/02',
 8: '1/1/04',
 9: '1/1/06',
 10: '1/1/08',
 11: '1/1/30',
 12: '1/1/60',
 13: '1/2/00',
 14: '1/2/01',
 15: '1/2/02',
 16: '1/2/04',
 17: '1/2/06',
 18: '1/2/08',
 19: '1/2/10',
 20: '1/2/14',
 21: '1/2/16',
 22: '1/2/18',
 23: '1/2/20',
 24: '1/2/22',
 25: '1/2/23',
 26: '1/2/24',
 27: '1/2/25',
 28: '1/2/26',
 29: '1/2/27',
 30: '1/2/28',
 31: '1/2/30',
 32: '1/3/02',
 33: '1/3/04',
 34: '1/3/06',
 35: '1/3/08',
 36: '1/3/10',
 37: '1/3/12',
 38: '1/3/14',
 39: '1/3/16',
 40: '1/3/18',
 41: '1/3/20',
 42: '1/3/22',
 43: '1/3/24',
 44: '1/3/26',
 45: '1/4/02',
 46: '1/4/04',
 47: '1/4/06',
 48: '1/4/07',
 49: '1/5/00',
 50: '1/5/02',
 51: '1/5/04',
 52: '1/6/10',
 53: '2/0/00',
 54: '2/0/09',
 55: '2/1/00',
 56: '2/1/02',
 57: '2/1/04',
 58: '2/1/06',
 59: '2/1/08',
 60: '2/1/60',
 61: '2/2/00',
 62: '2/2/02',
 63: '2/3/02',
 64: '2/4/02',
 65: '2/4/04',
 66: '2/4/05',
 67: 

In [13]:
idx2code = {k: v.replace("/", "") for k, v in idx2code.items()}


In [14]:
@torch.inference_mode()
def predict_codes_with_probs(report: str,
                             threshold: float = THRESHOLD
                            ) -> str:
    # 1) 토큰화
    prompt = build_prompt(report)
    inputs = tokenizer(
        prompt,
        truncation=True,
        padding="max_length",
        max_length=MAX_LEN,
        return_tensors="pt",
    ).to(DEVICE)

    # 2) forward → logits → sigmoid → probs
    logits = model(**inputs).logits    # [1, num_labels]
    probs  = torch.sigmoid(logits)[0].cpu()
    k=2
    # 3) threshold 이상인 인덱스와 확률 추출
    idxs = (probs >= threshold).nonzero(as_tuple=True)[0].tolist()
    if not idxs:
        idxs = probs.topk(k).indices.tolist()

    # 4) (코드, 확률) 튜플 리스트로 변환
    code_prob_pairs = [(idx2code[i], float(probs[i])) for i in idxs]
    # 5) 확률 내림차순으로 정렬
    code_prob_pairs.sort(key=lambda x: x[1], reverse=True)

    # 6) 문자열 포맷: "코드1(0.87), 코드2(0.76), …"
    return ",".join(f"{code}({prob:.2f})" for code, prob in code_prob_pairs)


In [15]:
sample = """
CHEST CT PRE AND POSTCONTRAST : 1. Small amount hemopneumothorax, left. 
2. Acute fractures at left clavicle midshaft, posterior arc of left 3rd-6th rib. 
3. Localized bronchiolectasis at posterior basal segment of RLL. 
4. No significant enlarged lymph nodes in mediastinum. 
5. No coronary calcification. 
LT SHOULDER CT (3D) : 1. Comminuted fracture, left midclavicle.
2. Fractures, left 3rd~6th ribs posterior arc.
3. Left lung pneumothorax.
"""

# quick smoke-test
print("Predicted AIS codes →", predict_codes_with_probs(sample, threshold=0.3))

Predicted AIS codes → 4422(1.00),4502(1.00),7506(1.00)


In [16]:
import ast
import re
from pathlib import Path
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support


In [17]:
def to_prefix_set(code_list):
    """['1406942', '2510001'] → {'1406','2510'}"""
    return {str(c)[:4] for c in code_list}


def ensure_list(x):
    # 문자열이면 literal_eval, 아니면 그대로 반환
    if isinstance(x, str):
        try:
            return ast.literal_eval(x)
        except Exception:
            return []
    elif pd.isna(x):
        return []
    else:
        return x

In [18]:
gt_path     = "충북대 ct_reading_all_ais.xlsx"       # 7개의 시트 포함
#sheet_dict  = pd.read_excel(gt_path, sheet_name=None)   # dict[시트명] = DataFrame


In [19]:
# records = []
# for sheet_name, df in sheet_dict.items():
#     for _, row in df.iterrows():
#         pid   = row.iloc[0]                 # id
#         text  = row.iloc[1]                 # reading_text
#         codes = row.iloc[2]                 # "['1406942', ...]"
#         if isinstance(codes, str):
#             codes = ast.literal_eval(codes)
#         records.append({
#             "id":           pid,
#             "reading":      text,
#             "gt_set":       to_prefix_set(codes)
#         })

# gt_df = pd.DataFrame(records)

In [20]:
gt_ds = pd.read_excel(gt_path)

In [21]:
gt_df = gt_ds.copy()

In [22]:
gt_df

Unnamed: 0.1,Unnamed: 0,내원일시_환자번호,code,reading
0,0,20220102_10501227,"[1406212, 1406942, 1102021, 2512312, 2510001, ...",CT:Chest Routine (CE): [Conclusion]\n<Contrast...
1,1,20220103_10501467,"[8561634, 5450992]",CT:Chest Routine (CE): [Conclusion]\nEnhanced ...
2,2,20220104_7955334,"[4502033, 4220062, 7102021, 8102021, 4422053]",CT:3D-C-Spine: [Conclusion]\nC-spine CT 3D\n\n...
3,3,20220105_10501812,"[6506202, 6506202, 2102021, 2102021, 8106021, ...",CT:Brain : [Conclusion]\nno evidence of intrac...
4,4,20220105_456384,"[1406565, 1406143, 1406942]",CT:Angio 3D Brain(CE): [Conclusion]\n1. acute ...
...,...,...,...,...
1472,1472,20241228_7692905,"[1104021, 1106021]",CT:Brain : [Conclusion]\n1. no evidence of abn...
1473,1473,20241229_10598701,"[5410102, 5428122, 5420102]",CT:Abdomen & Pelvis (Dynamic) (CE): [Conclusio...
1474,1474,20241230_10598715,"[1406942, 1504022, 4502033, 7506612, 5214023, ...",CT:3D-T-L Spine: [Conclusion]\nT7~9 right tran...
1475,1475,20241230_10598761,"[2106021, 4422034, 7506512, 8102021]",CT:3D-Shoulder (Left): [Conclusion]\nFx. shaft...


In [23]:
print(gt_df['code'].apply(type).value_counts())


code
<class 'str'>    1477
Name: count, dtype: int64


In [24]:
# 2) 새로운 컬럼 추가
gt_df['code'] = gt_df['code'].apply(ensure_list)

In [25]:
print(gt_df['code'].apply(type).value_counts())


code
<class 'list'>    1477
Name: count, dtype: int64


In [26]:
# 2) 새로운 컬럼 추가
gt_df['code_prefix'] = gt_df['code'].apply(to_prefix_set)

In [27]:
gt_df

Unnamed: 0.1,Unnamed: 0,내원일시_환자번호,code,reading,code_prefix
0,0,20220102_10501227,"[1406212, 1406942, 1102021, 2512312, 2510001, ...",CT:Chest Routine (CE): [Conclusion]\n<Contrast...,"{2106, 1102, 2510, 5102, 1406, 2512, 2102}"
1,1,20220103_10501467,"[8561634, 5450992]",CT:Chest Routine (CE): [Conclusion]\nEnhanced ...,"{5450, 8561}"
2,2,20220104_7955334,"[4502033, 4220062, 7102021, 8102021, 4422053]",CT:3D-C-Spine: [Conclusion]\nC-spine CT 3D\n\n...,"{7102, 4220, 4502, 4422, 8102}"
3,3,20220105_10501812,"[6506202, 6506202, 2102021, 2102021, 8106021, ...",CT:Brain : [Conclusion]\nno evidence of intrac...,"{6506, 8106, 2102}"
4,4,20220105_456384,"[1406565, 1406143, 1406942]",CT:Angio 3D Brain(CE): [Conclusion]\n1. acute ...,{1406}
...,...,...,...,...,...
1472,1472,20241228_7692905,"[1104021, 1106021]",CT:Brain : [Conclusion]\n1. no evidence of abn...,"{1106, 1104}"
1473,1473,20241229_10598701,"[5410102, 5428122, 5420102]",CT:Abdomen & Pelvis (Dynamic) (CE): [Conclusio...,"{5428, 5410, 5420}"
1474,1474,20241230_10598715,"[1406942, 1504022, 4502033, 7506612, 5214023, ...",CT:3D-T-L Spine: [Conclusion]\nT7~9 right tran...,"{1504, 5214, 1406, 4502, 2512, 2508, 7506, 1502}"
1475,1475,20241230_10598761,"[2106021, 4422034, 7506512, 8102021]",CT:3D-Shoulder (Left): [Conclusion]\nFx. shaft...,"{2106, 7506, 8102, 4422}"


In [28]:
def predict_prefixes(text: str) -> set[str]:
    """
    reading 텍스트 1개 → {'1406', '2510', ...} 4-digit prefix 집합 반환
    """
    out_str = predict_codes_with_probs(text)           # "1406942(0.87), 2510001(0.78)"
    if out_str == "NONE":
        return set()

    # "코드(확률)" → 코드만 추출
    prefixes = {
        seg.split("(")[0]        # '1406942' → '1406'
        for seg in out_str.split(",")
    }
    return prefixes

In [29]:
# 1) 먼저 tqdm.pandas() 로 프로그레스 바를 pandas에 등록
from tqdm import tqdm
tqdm.pandas()
gt_df["pred_set"] = gt_df["reading"].apply(lambda txt: predict_prefixes(txt))


In [30]:
gt_df

Unnamed: 0.1,Unnamed: 0,내원일시_환자번호,code,reading,code_prefix,pred_set
0,0,20220102_10501227,"[1406212, 1406942, 1102021, 2512312, 2510001, ...",CT:Chest Routine (CE): [Conclusion]\n<Contrast...,"{2106, 1102, 2510, 5102, 1406, 2512, 2102}","{1406, 2510}"
1,1,20220103_10501467,"[8561634, 5450992]",CT:Chest Routine (CE): [Conclusion]\nEnhanced ...,"{5450, 8561}",{8561}
2,2,20220104_7955334,"[4502033, 4220062, 7102021, 8102021, 4422053]",CT:3D-C-Spine: [Conclusion]\nC-spine CT 3D\n\n...,"{7102, 4220, 4502, 4422, 8102}","{4502, 2510, 4414, 4422}"
3,3,20220105_10501812,"[6506202, 6506202, 2102021, 2102021, 8106021, ...",CT:Brain : [Conclusion]\nno evidence of intrac...,"{6506, 8106, 2102}","{6506, 8561}"
4,4,20220105_456384,"[1406565, 1406143, 1406942]",CT:Angio 3D Brain(CE): [Conclusion]\n1. acute ...,{1406},{1406}
...,...,...,...,...,...,...
1472,1472,20241228_7692905,"[1104021, 1106021]",CT:Brain : [Conclusion]\n1. no evidence of abn...,"{1106, 1104}",{6402}
1473,1473,20241229_10598701,"[5410102, 5428122, 5420102]",CT:Abdomen & Pelvis (Dynamic) (CE): [Conclusio...,"{5428, 5410, 5420}","{5420, 5418}"
1474,1474,20241230_10598715,"[1406942, 1504022, 4502033, 7506612, 5214023, ...",CT:3D-T-L Spine: [Conclusion]\nT7~9 right tran...,"{1504, 5214, 1406, 4502, 2512, 2508, 7506, 1502}","{1504, 1406, 2512, 2508, 7506, 1502}"
1475,1475,20241230_10598761,"[2106021, 4422034, 7506512, 8102021]",CT:3D-Shoulder (Left): [Conclusion]\nFx. shaft...,"{2106, 7506, 8102, 4422}","{7506, 4422}"


In [31]:
# ───────────────────────────────────────────────────────────
# 3. 지표 계산
# ───────────────────────────────────────────────────────────
gt_df["hit"] = gt_df.apply(lambda r: len(r["code_prefix"] & r["pred_set"]) > 0, axis=1)
gt_df["cor"] = gt_df.apply(lambda r: len(r["code_prefix"] & r["pred_set"])  == len(r["code_prefix"]), axis=1)
abs_acc = accuracy_score(gt_df["cor"], [True]*len(gt_df))
sample_acc   = accuracy_score(gt_df["hit"], [True]*len(gt_df))

# 멀티레이블 클래스 목록
labels = sorted({p for s in gt_df["code_prefix"] for p in s} |
                {p for s in gt_df["pred_set"] for p in s})

def to_binary(series, label_list):
    """set → binary row vector"""
    return pd.DataFrame(
        [[1 if l in s else 0 for l in label_list] for s in series],
        columns=label_list
    )

y_true_bin = to_binary(gt_df["code_prefix"],  labels)
y_pred_bin = to_binary(gt_df["pred_set"], labels)

prec, rec, f1, _ = precision_recall_fscore_support(
    y_true_bin.values, y_pred_bin.values,
    average='micro', zero_division=0
)

print(f"Sample-level Accuracy : {sample_acc:5.3f}")
print(f"Absolute Accuracy : {abs_acc:5.3f}")

print(f"Micro Precision       : {prec:5.3f}")
print(f"Micro Recall          : {rec:5.3f}")
print(f"Micro F1-score        : {f1:5.3f}")


Sample-level Accuracy : 0.914
Absolute Accuracy : 0.227
Micro Precision       : 0.745
Micro Recall          : 0.511
Micro F1-score        : 0.607


In [32]:
def to_prefix_set2(code_list):
    """['1406942', '2510001'] → {'1406','2510'}"""
    return {str(c)[:2] for c in code_list}

In [33]:
def sample_exact_match_accuracy(gt_sets, pred_sets):
    assert len(gt_sets) == len(pred_sets)
    return sum(gt == pred for gt, pred in zip(gt_sets, pred_sets)) / len(gt_sets)

In [34]:
from typing import List, Set, Union

def sample_level_accuracy_subset(
    gt_sets:   List[Set[Union[int,str]]],
    pred_sets: List[Set[Union[int,str]]],
) -> float:
    """
    gt_sets  : 각 샘플의 정답 코드 집합, ex) [{'1','2'}, {'3'}, ...]
    pred_sets: 각 샘플의 예측 코드 집합, ex) [{'1','2','3'}, {'4'}, ...]
    
    정답 처리 기준: pred ⊇ gt 인 경우만 맞음.
    """
    assert len(gt_sets) == len(pred_sets), "gt_sets 와 pred_sets 길이가 달라요!"
    total = len(gt_sets)
    correct = 0
    for gt, pred in zip(gt_sets, pred_sets):
        # gt 집합이 빈 경우(예: 코드가 없을 땐)에도
        # pred_sets에 빈 집합이 있으면 맞은 것으로 처리하려면 아래처럼:
        # if not gt and not pred: correct += 1
        # 이 예시에선 gt가 비어있으면 모두 맞은 것으로 봅니다:
        if gt.issubset(pred):
            correct += 1
    return correct / total

In [35]:
# ───────────────────────────────────────────────────────────
# 3. 지표 계산
# ───────────────────────────────────────────────────────────
gt_df["hit"] = gt_df.apply(lambda r: len(r["code_prefix"] & r["pred_set"]) > 0, axis=1)
gt_df["cor"] = gt_df.apply(lambda r: len(r["code_prefix"] & r["pred_set"])  == len(r["code_prefix"]), axis=1)
#abs_acc = accuracy_score(gt_df["cor"], [True]*len(gt_df))
sample_acc   = accuracy_score(gt_df["hit"], [True]*len(gt_df))

# 멀티레이블 클래스 목록
labels = sorted({p for s in gt_df["code_prefix"] for p in s} |
                {p for s in gt_df["pred_set"] for p in s})

def to_binary(series, label_list):
    """set → binary row vector"""
    return pd.DataFrame(
        [[1 if l in s else 0 for l in label_list] for s in series],
        columns=label_list
    )

y_true_bin = to_binary(gt_df["code_prefix"],  labels)
y_pred_bin = to_binary(gt_df["pred_set"], labels)
acc = sample_level_accuracy_subset(gt_df["code_prefix"], gt_df["pred_set"])
abs_acc = sample_exact_match_accuracy(gt_df["code_prefix"], gt_df["pred_set"])

prec, rec, f1, _ = precision_recall_fscore_support(
    y_true_bin.values, y_pred_bin.values,
    average='micro', zero_division=0
)
print('외부제외 결과')
print(f"Sample-level Accuracy : {sample_acc:5.3f}")
print(f"Absolute Accuracy     : {abs_acc:5.3f}")
print(f"Inclusion Accuracy    : {acc:5.3f}")

print(f"Micro Precision       : {prec:5.3f}")
print(f"Micro Recall          : {rec:5.3f}")
print(f"Micro F1-score        : {f1:5.3f}")


외부제외 결과
Sample-level Accuracy : 0.914
Absolute Accuracy     : 0.160
Inclusion Accuracy    : 0.227
Micro Precision       : 0.745
Micro Recall          : 0.511
Micro F1-score        : 0.607


In [36]:

# 방법1) 문자열 슬라이싱
gt_df['gt_2digit'] = gt_df['code_prefix'].apply(
    lambda codes: {int(str(code)[:2]) for code in codes}
)
gt_df['pred_2digit'] = gt_df['pred_set'].apply(
    lambda codes: {int(str(code)[:2]) for code in codes}
)


# 방법1) 문자열 슬라이싱
gt_df['gt_1digit'] = gt_df['code_prefix'].apply(
    lambda codes: {int(str(code)[:1]) for code in codes}
)
gt_df['pred_1digit'] = gt_df['pred_set'].apply(
    lambda codes: {int(str(code)[:1]) for code in codes}
)






In [37]:
# ───────────────────────────────────────────────────────────
# 3. 지표 계산
# ───────────────────────────────────────────────────────────
gt_df["hit2"] = gt_df.apply(lambda r: len(r["gt_2digit"] & r["pred_2digit"]) > 0, axis=1)
gt_df["cor2"] = gt_df.apply(lambda r: len(r["gt_2digit"] & r["pred_2digit"])  == len(r["gt_2digit"]), axis=1)
# abs_acc = accuracy_score(gt_df["cor2"], [True]*len(gt_df))
sample_acc   = accuracy_score(gt_df["hit2"], [True]*len(gt_df))

# 멀티레이블 클래스 목록
labels = sorted({p for s in gt_df["gt_2digit"] for p in s} |
                {p for s in gt_df["pred_2digit"] for p in s})

def to_binary(series, label_list):
    """set → binary row vector"""
    return pd.DataFrame(
        [[1 if l in s else 0 for l in label_list] for s in series],
        columns=label_list
    )

y_true_bin = to_binary(gt_df["gt_2digit"],  labels)
y_pred_bin = to_binary(gt_df["pred_2digit"], labels)

prec, rec, f1, _ = precision_recall_fscore_support(
    y_true_bin.values, y_pred_bin.values,
    average='micro', zero_division=0
)
acc = sample_level_accuracy_subset(gt_df["gt_2digit"], gt_df["pred_2digit"])
abs_acc = sample_exact_match_accuracy(gt_df["gt_2digit"], gt_df["pred_2digit"])

print(f"Sample-level Accuracy2 : {sample_acc:5.3f}")
print(f"Absolute Accuracy2     : {abs_acc:5.3f}")
print(f"Inclusion Accuracy2    : {acc:5.3f}")

print(f"Micro Precision2      : {prec:5.3f}")
print(f"Micro Recall2         : {rec:5.3f}")
print(f"Micro F1-score2       : {f1:5.3f}")


Sample-level Accuracy2 : 0.942
Absolute Accuracy2     : 0.234
Inclusion Accuracy2    : 0.301
Micro Precision2      : 0.844
Micro Recall2         : 0.572
Micro F1-score2       : 0.682


In [38]:
# ───────────────────────────────────────────────────────────
# 3. 지표 계산
# ───────────────────────────────────────────────────────────
gt_df["hit1"] = gt_df.apply(lambda r: len(r["gt_1digit"] & r["pred_1digit"]) > 0, axis=1)
gt_df["cor1"] = gt_df.apply(lambda r: len(r["gt_1digit"] & r["pred_1digit"])  == len(r["gt_1digit"]), axis=1)
# abs_acc = accuracy_score(gt_df["cor1"], [True]*len(gt_df))
sample_acc   = accuracy_score(gt_df["hit1"], [True]*len(gt_df))

# 멀티레이블 클래스 목록
labels = sorted({p for s in gt_df["gt_1digit"] for p in s} |
                {p for s in gt_df["pred_1digit"] for p in s})

def to_binary(series, label_list):
    """set → binary row vector"""
    return pd.DataFrame(
        [[1 if l in s else 0 for l in label_list] for s in series],
        columns=label_list
    )

y_true_bin = to_binary(gt_df["gt_1digit"],  labels)
y_pred_bin = to_binary(gt_df["pred_1digit"], labels)

acc = sample_level_accuracy_subset(gt_df["gt_1digit"], gt_df["pred_1digit"])
abs_acc = sample_exact_match_accuracy(gt_df["gt_1digit"], gt_df["pred_1digit"])
                                   
prec, rec, f1, _ = precision_recall_fscore_support(
    y_true_bin.values, y_pred_bin.values,
    average='micro', zero_division=0
)
print('외부포함 결과')
print(f"Sample-level Accuracy1 : {sample_acc:5.3f}")
print(f"Absolute Accuracy1     : {abs_acc:5.3f}")
print(f"Inclusion Accuracy1    : {acc:5.3f}")
print(f"Micro Precision1      : {prec:5.3f}")
print(f"Micro Recall1         : {rec:5.3f}")
print(f"Micro F1-score1       : {f1:5.3f}")


외부포함 결과
Sample-level Accuracy1 : 0.970
Absolute Accuracy1     : 0.404
Inclusion Accuracy1    : 0.476
Micro Precision1      : 0.898
Micro Recall1         : 0.674
Micro F1-score1       : 0.770


In [39]:
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, jaccard_score, hamming_loss


In [40]:
def evaluate_with_hamming_jaccard(gt_df, true_col='gt_1digit', pred_col='pred_1digit'):
    """
    Extends existing multi-label metrics with Jaccard and Hamming.
    gt_df: DataFrame containing sets in true_col and pred_col
    """
    # 1. Sample‐level hits and exact matches
    gt_sets = gt_df[true_col]
    pred_sets = gt_df[pred_col]
    
    gt_df["hit1"] = gt_df.apply(lambda r: len(r[true_col] & r[pred_col]) > 0, axis=1)
    gt_df["cor1"] = gt_df.apply(lambda r: r[true_col] == r[pred_col], axis=1)
    sample_acc = accuracy_score(gt_df["hit1"], [True]*len(gt_df))
    abs_acc    = accuracy_score(gt_df["cor1"], [True]*len(gt_df))
    incl_acc = accuracy_score(
        gt_df.apply(lambda r: r[true_col].issubset(r[pred_col]), axis=1),
        [True]*len(gt_df)
    )
    
    # 2. Binary indicator matrices
    labels = sorted({p for s in gt_sets for p in s} | {p for s in pred_sets for p in s})
    def to_binary(series):
        return pd.DataFrame([[1 if l in s else 0 for l in labels] for s in series], columns=labels)
    y_true_bin = to_binary(gt_sets)
    y_pred_bin = to_binary(pred_sets)
    
    # 3. Micro P/R/F1
    prec, rec, f1, _ = precision_recall_fscore_support(
        y_true_bin.values, y_pred_bin.values, average='micro', zero_division=0
    )
    
    # 4. Jaccard (sample average)
    jaccard = jaccard_score(y_true_bin.values, y_pred_bin.values, average='samples', zero_division=0)
    
    # 5. Hamming Loss/Score
    hloss = hamming_loss(y_true_bin.values, y_pred_bin.values)
    hscore = 1 - hloss
    
    # 6. Print results
    print("=== Extended Metrics ===")
    print(f"Sample-level Accuracy  : {sample_acc:.3f}")
    print(f"Absolute Exact-match   : {abs_acc:.3f}")
    print(f"Inclusion Accuracy     : {incl_acc:.3f}")
    print(f"Micro Precision        : {prec:.3f}")
    print(f"Micro Recall           : {rec:.3f}")
    print(f"Micro F1-score         : {f1:.3f}")
    print(f"Average Jaccard Score  : {jaccard:.3f}")
    print(f"Hamming Loss           : {hloss:.3f}")
    print(f"Hamming Score (1-Loss) : {hscore:.3f}")

In [41]:
evaluate_with_hamming_jaccard(gt_df, true_col='gt_1digit', pred_col='pred_1digit')

=== Extended Metrics ===
Sample-level Accuracy  : 0.970
Absolute Exact-match   : 0.404
Inclusion Accuracy     : 0.476
Micro Precision        : 0.898
Micro Recall           : 0.674
Micro F1-score         : 0.770
Average Jaccard Score  : 0.701
Hamming Loss           : 0.115
Hamming Score (1-Loss) : 0.885


In [42]:
evaluate_with_hamming_jaccard(gt_df, true_col='gt_2digit', pred_col='pred_2digit')

=== Extended Metrics ===
Sample-level Accuracy  : 0.942
Absolute Exact-match   : 0.234
Inclusion Accuracy     : 0.301
Micro Precision        : 0.844
Micro Recall           : 0.572
Micro F1-score         : 0.682
Average Jaccard Score  : 0.583
Hamming Loss           : 0.056
Hamming Score (1-Loss) : 0.944


In [43]:
evaluate_with_hamming_jaccard(gt_df, true_col='code_prefix', pred_col='pred_set')

=== Extended Metrics ===
Sample-level Accuracy  : 0.914
Absolute Exact-match   : 0.160
Inclusion Accuracy     : 0.227
Micro Precision        : 0.745
Micro Recall           : 0.511
Micro F1-score         : 0.607
Average Jaccard Score  : 0.496
Hamming Loss           : 0.017
Hamming Score (1-Loss) : 0.983


In [44]:
def evaluate_prefix_metrics(df, true_col, pred_col, prefix_lens=[1,2,4]):
    """
    Computes multi-label metrics for different prefix lengths.
    
    Parameters:
    - df: DataFrame containing true and predicted code lists.
    - true_col: column name for ground-truth code lists.
    - pred_col: column name for predicted code lists.
    - prefix_lens: list of prefix lengths to evaluate (e.g., [1,2,4]).
    
    Returns:
    - DataFrame where rows are prefix lengths and columns are metrics.
    """
    results = []
    for L in prefix_lens:
        # Build sets of prefixes
        gt_sets = df[true_col].apply(lambda lst: {str(c)[:L] for c in (lst or [])})
        pred_sets = df[pred_col].apply(lambda lst: {str(c)[:L] for c in (lst or [])})
        
        # Sample‐level hit: any intersection
        hits = [len(gt & pr) > 0 for gt, pr in zip(gt_sets, pred_sets)]
        sample_acc = accuracy_score(hits, [True]*len(df))
        
        # Absolute exact‐match
        exacts = [gt == pr for gt, pr in zip(gt_sets, pred_sets)]
        abs_acc = accuracy_score(exacts, [True]*len(df))
        
        # Inclusion: gt subset of pred
        inclusions = [gt.issubset(pr) for gt, pr in zip(gt_sets, pred_sets)]
        incl_acc = accuracy_score(inclusions, [True]*len(df))
        
        # Prepare binary indicator matrices
        labels = sorted(set().union(*gt_sets, *pred_sets))
        y_true = pd.DataFrame([[1 if lab in s else 0 for lab in labels] for s in gt_sets], columns=labels)
        y_pred = pd.DataFrame([[1 if lab in s else 0 for lab in labels] for s in pred_sets], columns=labels)
        
        # Micro precision/recall/f1
        prec, rec, f1, _ = precision_recall_fscore_support(
            y_true.values, y_pred.values, average='micro', zero_division=0
        )
        
        # Jaccard (samples)
        jaccard = jaccard_score(y_true.values, y_pred.values, average='samples', zero_division=0)
        
        # Hamming loss and score
        hloss = hamming_loss(y_true.values, y_pred.values)
        hscore = 1 - hloss
        
        # Collect
        results.append({
            'prefix_len': L,
            'sample_acc': sample_acc,
            'exact_acc': abs_acc,
            'inclusion_acc': incl_acc,
            'micro_prec': prec,
            'micro_rec': rec,
            'micro_f1': f1,
            'jaccard': jaccard,
            'hamming_loss': hloss,
            'hamming_score': hscore
        })
    
    return pd.DataFrame(results).set_index('prefix_len')


In [45]:
evaluate_prefix_metrics(gt_df, 'code_prefix', 'pred_set', prefix_lens=[1,2,4])

Unnamed: 0_level_0,sample_acc,exact_acc,inclusion_acc,micro_prec,micro_rec,micro_f1,jaccard,hamming_loss,hamming_score
prefix_len,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1
1,0.969533,0.404198,0.475965,0.898174,0.674045,0.770134,0.701404,0.114872,0.885128
2,0.942451,0.233582,0.300609,0.843697,0.57173,0.681585,0.583358,0.055538,0.944462
4,0.914015,0.16046,0.226811,0.745198,0.511349,0.606513,0.496157,0.016672,0.983328


In [46]:
gt_df

Unnamed: 0.1,Unnamed: 0,내원일시_환자번호,code,reading,code_prefix,pred_set,hit,cor,gt_2digit,pred_2digit,gt_1digit,pred_1digit,hit2,cor2,hit1,cor1
0,0,20220102_10501227,"[1406212, 1406942, 1102021, 2512312, 2510001, ...",CT:Chest Routine (CE): [Conclusion]\n<Contrast...,"{2106, 1102, 2510, 5102, 1406, 2512, 2102}","{1406, 2510}",True,False,"{11, 14, 51, 21, 25}","{25, 14}","{1, 2, 5}","{1, 2}",True,False,True,False
1,1,20220103_10501467,"[8561634, 5450992]",CT:Chest Routine (CE): [Conclusion]\nEnhanced ...,"{5450, 8561}",{8561},True,False,"{85, 54}",{85},"{8, 5}",{8},True,False,True,False
2,2,20220104_7955334,"[4502033, 4220062, 7102021, 8102021, 4422053]",CT:3D-C-Spine: [Conclusion]\nC-spine CT 3D\n\n...,"{7102, 4220, 4502, 4422, 8102}","{4502, 2510, 4414, 4422}",True,False,"{71, 42, 44, 45, 81}","{25, 44, 45}","{8, 4, 7}","{2, 4}",True,False,True,False
3,3,20220105_10501812,"[6506202, 6506202, 2102021, 2102021, 8106021, ...",CT:Brain : [Conclusion]\nno evidence of intrac...,"{6506, 8106, 2102}","{6506, 8561}",True,False,"{81, 65, 21}","{65, 85}","{8, 2, 6}","{8, 6}",True,False,True,False
4,4,20220105_456384,"[1406565, 1406143, 1406942]",CT:Angio 3D Brain(CE): [Conclusion]\n1. acute ...,{1406},{1406},True,True,{14},{14},{1},{1},True,True,True,True
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
1472,1472,20241228_7692905,"[1104021, 1106021]",CT:Brain : [Conclusion]\n1. no evidence of abn...,"{1106, 1104}",{6402},False,False,{11},{64},{1},{6},False,False,False,False
1473,1473,20241229_10598701,"[5410102, 5428122, 5420102]",CT:Abdomen & Pelvis (Dynamic) (CE): [Conclusio...,"{5428, 5410, 5420}","{5420, 5418}",True,False,{54},{54},{5},{5},True,True,True,False
1474,1474,20241230_10598715,"[1406942, 1504022, 4502033, 7506612, 5214023, ...",CT:3D-T-L Spine: [Conclusion]\nT7~9 right tran...,"{1504, 5214, 1406, 4502, 2512, 2508, 7506, 1502}","{1504, 1406, 2512, 2508, 7506, 1502}",True,False,"{75, 45, 14, 15, 52, 25}","{25, 75, 14, 15}","{1, 2, 4, 5, 7}","{1, 2, 7}",True,False,True,False
1475,1475,20241230_10598761,"[2106021, 4422034, 7506512, 8102021]",CT:3D-Shoulder (Left): [Conclusion]\nFx. shaft...,"{2106, 7506, 8102, 4422}","{7506, 4422}",True,False,"{81, 75, 44, 21}","{75, 44}","{8, 2, 4, 7}","{4, 7}",True,False,True,False


In [47]:

# 1) 이미 gt_df에 gt_1digit, pred_1digit 컬럼이 set 형태로 있다고 가정
labels = sorted({p for s in gt_df["gt_1digit"] for p in s} |
                {p for s in gt_df["pred_1digit"] for p in s})

# 2) set → 바이너리 매트릭스로 변환
def to_binary(series, label_list):
    return pd.DataFrame(
        [[1 if l in s else 0 for l in label_list] for s in series],
        columns=label_list
    )

y_true_bin = to_binary(gt_df["gt_1digit"], labels)
y_pred_bin = to_binary(gt_df["pred_1digit"], labels)

# 3) per-label precision, recall, f1, support
prec, rec, f1, support = precision_recall_fscore_support(
    y_true_bin.values, y_pred_bin.values,
    average=None, zero_division=0
)

per_label_metrics = pd.DataFrame({
    'precision': prec,
    'recall':    rec,
    'f1-score':  f1,
    'support':   support
}, index=labels)

# 4) per-label accuracy 추가
accuracy = (y_true_bin == y_pred_bin).mean(axis=0)
per_label_metrics['accuracy'] = accuracy.values

In [48]:
label_map = {
    '1': 'HEAD',
    '2': 'FACE',
    '3': 'NECK',
    '4': 'THORAX',
    '5': 'ABDOMEN',
    '6': 'SPINE',
    '7': 'UPPER EXTERMITY',
    '8': 'LOWER EXTERMITY',
    '9': 'EXTERNAL'
}


In [49]:
per_label_metrics.index = per_label_metrics.index.astype(str)
per_label_metrics_renamed = per_label_metrics.rename(index=label_map)

In [50]:
per_label_metrics_renamed

Unnamed: 0,precision,recall,f1-score,support,accuracy
HEAD,0.966495,0.663717,0.786988,565,0.862559
FACE,0.754941,0.519022,0.615137,368,0.838186
NECK,0.833333,0.3125,0.454545,32,0.983751
THORAX,0.876977,0.800963,0.837248,623,0.868653
ABDOMEN,0.930736,0.590659,0.722689,364,0.888287
SPINE,0.849687,0.875269,0.862288,465,0.911984
UPPER EXTERMITY,0.916914,0.506557,0.652587,610,0.777251
LOWER EXTERMITY,0.958261,0.726913,0.826707,758,0.843602
EXTERNAL,0.25,0.1,0.142857,10,0.991875


In [51]:
df = pd.read_csv("AIS_2005_depth_full_desc.csv", encoding="cp949")
df["code"] = df["code"].astype(str).str.zfill(7)  # zero-pad

In [52]:
# -----------------------------------------------------------
# 1. 4-자리 prefix → ‘Region-Type-Specific’ 한글 설명
# -----------------------------------------------------------
def describe_prefix(code4: str) -> str:
    """예) 4422 → 'Thorax-Internal Organs-Thoracic injury'"""
    r, t, spec = code4[0], code4[1], code4[2:]
    region = lvl_map[(1, ())].get(r, "UNKNOWN")
    type_  = lvl_map.get((2, (r,)), {}).get(t, "UNKNOWN")
    spec_  = lvl_map.get((3, (r, t)), {}).get(spec, "UNKNOWN")
    return f"{region} -> {type_} -> {spec_}"

In [53]:
# -----------------------------------------------------------
# 2. 4-자리 prefix → 가능한 7-자리 전체 코드 및 상세 설명
# -----------------------------------------------------------
def list_full_codes(prefix4: str, max_show: int = 10):
    """prefix4로 시작하는 7-digit AIS 코드 목록 (상위 max_show개)"""
    matches = df[df["code"].str.startswith(prefix4)]
    if matches.empty:
        return ["  ‣ (후보 코드 없음)"]
    else:
        lines = []
        for _, row in matches.head(max_show).iterrows():
            lines.append(f"  ‣ {row['code']}: {row['full_description']}")
        if len(matches) > max_show:
            lines.append(f"  ‣ … (외 {len(matches)-max_show}개 더 있음)")
        return lines

In [54]:
# -----------------------------------------------------------
# 3. 새 예측 함수
# -----------------------------------------------------------
@torch.inference_mode()
def predict_codes_verbose(report: str,
                          threshold: float = THRESHOLD,
                          k: int = 2) -> str:
    """
    • 4-자리 코드(확률)를 예측
    • 각 코드에 대해 Region-Type-Spec 설명 + 이어질 수 있는 7-자리 후보 코드 목록 출력
    """
    # ── 1) 토큰화 + 모델 forward ───────────────────────────
    prompt  = build_prompt(report)
    inputs  = tokenizer(prompt,
                        truncation=True,
                        padding="max_length",
                        max_length=MAX_LEN,
                        return_tensors="pt").to(DEVICE)

    logits  = model(**inputs).logits      # [1, num_labels]
    probs   = torch.sigmoid(logits)[0].cpu()

    # ── 2) threshold 적용 (없으면 상위 k) ──────────────────
    idxs = (probs >= threshold).nonzero(as_tuple=True)[0].tolist()
    if not idxs:
        idxs = probs.topk(k).indices.tolist()

    # ── 3) 예측 코드별 상세 출력 구축 ─────────────────────
    lines = []
    for i in idxs:
        code4 = idx2code[i]          # 4-자리 prefix
        p     = float(probs[i])
        desc  = describe_prefix(code4)
        lines.append(f"\nPredicted 4-digit  {code4}  (p={p:.2f})")
        lines.append(f" → {desc}")
        lines.extend(list_full_codes(code4))   # 7-자리 후보 코드 나열

    return "\n".join(lines)

In [58]:
report_txt = """
외부판독 BRAIN CT : Fracture, left parietal bone.

Traumatic SAH on the right.
외부판독 CHEST CT : * Limited evaluation due to motion artifact.

1. Fracture in the right clavicle.
2. r/o buckled fx in the Rt 5-6th rib.
3. Suspicious ill-defined ground glass opacities in the right middle lobe, possible mild lung contusion D/Dx) mild infectious/inflammatory process.
외부판독 ABDMINOPELVIC CT : No evidence of injury in abdomen.
외부판독 SPINE CT : No definite acute fracture of dislocation at the C-spine.
Small OPLL at C3-C6 level. 
"""
print(predict_codes_verbose(report_txt, threshold=0.3))


Predicted 4-digit  1406  (p=0.98)
 → HEAD -> INTERNAL ORGANS -> Cerebrum(includes basal ganglia, thalamus, putamen, globus pallidius)
  ‣ 1406993: Cerebrum NFS [includes basal ganglia, thalamus, putamen, globus pallidius]
  ‣ 1406023: Cerebrum contusion NFS [include perilesional edema for size]
  ‣ 1406043: Cerebrum contusion single NFS
  ‣ 1406052: Cerebrum contusion single tiny; <1cm diameter
  ‣ 1406063: Cerebrum contusion single small; superficial; ≤30cc or ≤15cc if ≤age 10; 1-4cm diameter or 1-2cm diameter if ≤age 10; midline shift ≤5mm
  ‣ 1406084: Cerebrum contusion single large; deep; 30-50cc or 15-30cc if ≤age 10; >4cm diameter or 2-4cm diameter if ≤age 10; midline shift >5mm
  ‣ 1406105: Cerebrum contusion single extensive; massive; total volume >50cc or >30cc if ≤age 10
  ‣ 1406113: Cerebrum contusion multiple NFS
  ‣ 1406123: Cerebrum contusion multiple, on same side but NFS
  ‣ 1406132: Cerebrum contusion multiple, on same side but tiny; each <1cm diameter
  ‣ … (외 63개 더 