In [1]:
# MODULE 0: Imports & Global Config

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from pathlib import Path
from glob import glob
import json
import warnings

warnings.filterwarnings("ignore")
np.random.seed(42)

DATA_PATH = Path("../data/")


In [9]:
# MODULE 1: Load & Clean Dataset

def load_nih_dataset(data_path=DATA_PATH, image_base_path="E:/archive"):
    df = pd.read_csv(data_path / "raw/Data_Entry_2017.csv")

    df.columns = [
        'Image Index', 'Finding Labels', 'Follow_Up_#', 'Patient_ID',
        'Patient_Age', 'Patient_Gender', 'View_Position',
        'Original_Image_Width', 'Original_Image_Height',
        'Pixel_X', 'Pixel_Y', 'Unused'
    ]

    df = df[['Image Index', 'Finding Labels']]

    # Search for PNG files in the specific image directory
    image_paths = {
        Path(p).name: p
        for p in glob(str(Path(image_base_path) / "**/*.png"), recursive=True)
    }
    

    df["full_path"] = df["Image Index"].map(image_paths)
    
    # Check how many images were found
    found_count = df["full_path"].notna().sum()
    print(f"Loaded {len(df):,} samples")
    print(f"Found {found_count:,} images ({found_count/len(df)*100:.1f}%)")
    
    # Show samples with missing paths
    if found_count < len(df):
        missing_count = len(df) - found_count
        print(f"Missing {missing_count:,} image paths")
        print("First few missing image names:", df[df["full_path"].isna()]["Image Index"].head().tolist())
    
    return df


df = load_nih_dataset()
df.head()

Loaded 112,120 samples
Found 112,120 images (100.0%)


Unnamed: 0,Image Index,Finding Labels,full_path
0,00000001_000.png,Cardiomegaly,E:\archive\images_001\images\00000001_000.png
1,00000001_001.png,Cardiomegaly|Emphysema,E:\archive\images_001\images\00000001_001.png
2,00000001_002.png,Cardiomegaly|Effusion,E:\archive\images_001\images\00000001_002.png
3,00000002_000.png,No Finding,E:\archive\images_001\images\00000002_000.png
4,00000003_000.png,Hernia,E:\archive\images_001\images\00000003_000.png


In [3]:
# MODULE 2: Label Analysis

def analyze_labels(df):
    all_labels = sorted({
        label.strip()
        for row in df["Finding Labels"].astype(str)
        for label in row.split("|")
    })

    label_counts = {
        label: df["Finding Labels"].str.contains(fr"\b{label}\b").sum()
        for label in all_labels
    }

    label_dist = (
        pd.DataFrame.from_dict(label_counts, orient="index", columns=["count"])
        .assign(percentage=lambda x: x["count"] / len(df) * 100)
        .sort_values("count", ascending=False)
        .reset_index(names="label")
    )

    print(f"Found {len(all_labels)} unique disease labels")
    return label_dist, all_labels


label_dist, ALL_LABELS = analyze_labels(df)
label_dist.head()


Found 15 unique disease labels


Unnamed: 0,label,count,percentage
0,No Finding,60361,53.836068
1,Infiltration,19894,17.743489
2,Effusion,13317,11.877453
3,Atelectasis,11559,10.30949
4,Nodule,6331,5.646629


In [4]:
# MODULE 3: Visualization

def plot_label_distribution(label_dist):
    top = label_dist.head(15)[::-1]

    plt.figure(figsize=(12, 7))
    plt.barh(top["label"], top["count"])
    plt.title("Top 15 Disease Labels")
    plt.tight_layout()
    plt.savefig("../results/plots/nih_label_distribution_top15.png", dpi=150)
    plt.close()


def analyze_multilabel(df):
    df["num_labels"] = df["Finding Labels"].astype(str).str.split("|").str.len()

    stats = df["num_labels"].value_counts().sort_index()

    plt.figure(figsize=(8, 5))
    stats.plot(kind="bar")
    plt.title("Labels per Image Distribution")
    plt.tight_layout()
    plt.savefig("../results/plots/nih_labels_per_image_distribution.png", dpi=150)
    plt.close()

    return df


plot_label_distribution(label_dist)
df = analyze_multilabel(df)


In [5]:
# MODULE 4: Multilabel Encoding

def encode_labels(df, labels):
    for label in labels:
        df[label] = df["Finding Labels"].str.contains(
            fr"\b{label}\b", regex=True
        ).astype(int)
    return df


df = encode_labels(df, ALL_LABELS)
df.head()


Unnamed: 0,Image Index,Finding Labels,full_path,num_labels,Atelectasis,Cardiomegaly,Consolidation,Edema,Effusion,Emphysema,Fibrosis,Hernia,Infiltration,Mass,No Finding,Nodule,Pleural_Thickening,Pneumonia,Pneumothorax
0,00000001_000.png,Cardiomegaly,E:\archive\images_001\images\00000001_000.png,1,0,1,0,0,0,0,0,0,0,0,0,0,0,0,0
1,00000001_001.png,Cardiomegaly|Emphysema,E:\archive\images_001\images\00000001_001.png,2,0,1,0,0,0,1,0,0,0,0,0,0,0,0,0
2,00000001_002.png,Cardiomegaly|Effusion,E:\archive\images_001\images\00000001_002.png,2,0,1,0,0,1,0,0,0,0,0,0,0,0,0,0
3,00000002_000.png,No Finding,E:\archive\images_001\images\00000002_000.png,1,0,0,0,0,0,0,0,0,0,0,1,0,0,0,0
4,00000003_000.png,Hernia,E:\archive\images_001\images\00000003_000.png,1,0,0,0,0,0,0,0,1,0,0,0,0,0,0,0


In [6]:
# MODULE 5: Label Mapping

def save_label_mapping(labels):
    mapping = {
        "labels": labels,
        "label_to_index": {l: i for i, l in enumerate(labels)}
    }

    with open(DATA_PATH / "processed/metadata/nih_label_mapping.json", "w") as f:
        json.dump(mapping, f, indent=2)

    pd.DataFrame(mapping["label_to_index"].items(),
                 columns=["label", "index"]).to_csv(
        DATA_PATH / "processed/metadata/nih_label_mapping.csv", index=False
    )

    print("Label mapping saved")


save_label_mapping(ALL_LABELS)


Label mapping saved


In [7]:
# MODULE 6: Class Weights

def compute_class_weights(df, labels):
    positive_counts = df[labels].sum()
    freq = positive_counts / len(df)
    median_freq = np.median(freq[freq > 0])

    class_weights = {l: median_freq / f for l, f in freq.items()}

    with open(DATA_PATH / "processed/metadata/nih_class_weights.json", "w") as f:
        json.dump(class_weights, f, indent=2)

    pd.DataFrame(class_weights.items(),
                 columns=["label", "weight"]).to_csv(
        DATA_PATH / "processed/metadata/nih_class_weights.csv", index=False
    )

    print("Class weights computed")
    return class_weights


class_weights = compute_class_weights(df, ALL_LABELS)


Class weights computed


In [8]:
# MODULE 7: Dataset Splitting

def split_data(df, data_path=DATA_PATH):
    train_val_ids = set(open(DATA_PATH / "raw/train_val_list.txt").read().split())
    test_ids = set(open(DATA_PATH / "raw/test_list.txt").read().split())

    train_val_df = df[df["Image Index"].isin(train_val_ids)].sample(
        frac=1, random_state=42
    )
    test_df = df[df["Image Index"].isin(test_ids)]

    split_idx = int(0.8 * len(train_val_df))
    train_df = train_val_df.iloc[:split_idx]
    val_df = train_val_df.iloc[split_idx:]

    train_df.to_csv(DATA_PATH / "processed/metadata/nih_train_data.csv", index=False)
    val_df.to_csv(DATA_PATH / "processed/metadata/nih_val_data.csv", index=False)
    test_df.to_csv(DATA_PATH / "processed/metadata/nih_test_data.csv", index=False)

    print("Train / Val / Test splits saved")
    return train_df, val_df, test_df


train_df, val_df, test_df = split_data(df)


Train / Val / Test splits saved
