In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.metrics import precision_recall_curve, f1_score
import numpy as np
import logging
from torch.utils.data import TensorDataset, DataLoader

In [2]:
class MLP(nn.Module):
    def __init__(
            self,
            input_size=512,
            num_classes=18,
            activation='relu',
            hidden_sizes=[1024, 2048, 1024, 256, 128],
            dropout=0.1
        ):
        super().__init__()
        
        # Pick activation
        if activation == "relu":
            activation_cls = nn.ReLU
        elif activation == "leaky_relu":
            activation_cls = nn.LeakyReLU
        elif activation == "gelu":
            activation_cls = nn.GELU
        else:
            raise ValueError(f"Unsupported activation: {activation}")

        layers = []
        in_dim = input_size
        for h in hidden_sizes:
            layers.append(nn.Linear(in_dim, h))
            layers.append(nn.BatchNorm1d(h))  # helps stabilize
            layers.append(activation_cls())
            layers.append(nn.Dropout(dropout))
            in_dim = h

        # Final classification layer
        layers.append(nn.Linear(in_dim, num_classes))

        self.layers = nn.Sequential(*layers)

    def forward(self, x):
        return self.layers(x)

In [3]:
MODEL_CKPT = '/home/free4ky/projects/chest-diseases/model_multilabel20_pr034_rec0777.pth'

In [4]:
def _load_concat_dataset(embed_paths, label_paths):
    """
    Load multiple .pt files and concatenate along dim 0
    """
    X_list = [torch.load(p) for p in embed_paths]
    y_list = [torch.load(p) for p in label_paths]
    
    X = torch.cat(X_list, dim=0)
    y = torch.cat(y_list, dim=0)
    
    return TensorDataset(X, y)


In [6]:

logger = logging.getLogger(__name__)

# -------------------------
# 1. Load model
# -------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
# input_size: 512
# batch_size: 2048
# activation: leaky_relu
# dropout: 0.2
# hidden_sizes:
# - 512
# - 256
# - 128
model = MLP(
    input_size=512,
    activation='leaky_relu',
    dropout=0.2,
    num_classes=20,  # e.g., 20
    hidden_sizes=[512,256,128]
)

state_dict = torch.load(MODEL_CKPT, map_location=device)
model.load_state_dict(state_dict)
model.to(device).eval()
print("Model loaded successfully")

# -------------------------
# 2. Load dataset
# -------------------------
val_ds = _load_concat_dataset(
['/home/free4ky/projects/chest-diseases/data/preprocessed_mosmed/test_data.pt',
 '/home/free4ky/projects/chest-diseases/data/preprocessed_val_20/validation_data.pt'
],
[
'/home/free4ky/projects/chest-diseases/data/preprocessed_mosmed/test_labels.pt',
'/home/free4ky/projects/chest-diseases/data/preprocessed_val_20/validation_labels.pt'
]
)
val_dl = DataLoader(val_ds, batch_size=1, shuffle=False)

# -------------------------
# 3. Inference
# -------------------------
print("Computing probabilities on validation set...")
all_probs = []
all_labels = []

with torch.no_grad():
    for emb, y in tqdm(val_dl):
        emb = torch.nn.functional.normalize(emb, dim=-1)
        logits = model(emb.to(device))
        probs = torch.sigmoid(logits)  # multilabel, sigmoid per class
        all_probs.append(probs.cpu())
        all_labels.append(y)

all_probs = torch.cat(all_probs, dim=0).numpy()  # shape [num_samples, num_classes]
all_labels = torch.cat(all_labels, dim=0).numpy()

# -------------------------
# 4. Compute best threshold per class
# -------------------------
best_thresholds = []

for i in range(all_labels.shape[1]):
    y_true = all_labels[:, i]
    y_prob = all_probs[:, i]

    precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
    f1_scores = 2 * precision * recall / (precision + recall + 1e-8)
    best_idx = np.argmax(f1_scores)
    best_threshold = thresholds[best_idx] if best_idx < len(thresholds) else 0.5
    best_thresholds.append(best_threshold)
    print(f"Class {i}: Best threshold={best_threshold:.4f}, F1={f1_scores[best_idx]:.4f}")

best_thresholds = np.array(best_thresholds)
print(f"Optimal thresholds per class: {best_thresholds}")


Model loaded successfully
Computing probabilities on validation set...


  0%|          | 0/3133 [00:00<?, ?it/s]

100%|██████████| 3133/3133 [00:00<00:00, 4584.39it/s]

Class 0: Best threshold=0.7061, F1=0.3781
Class 1: Best threshold=0.5387, F1=0.6974
Class 2: Best threshold=0.8773, F1=0.5885
Class 3: Best threshold=0.8802, F1=0.3786
Class 4: Best threshold=0.6020, F1=0.6442
Class 5: Best threshold=0.6131, F1=0.3719
Class 6: Best threshold=0.5477, F1=0.5210
Class 7: Best threshold=0.4760, F1=0.4824
Class 8: Best threshold=0.5168, F1=0.4488
Class 9: Best threshold=0.3881, F1=0.6346
Class 10: Best threshold=0.4427, F1=0.6290
Class 11: Best threshold=0.4749, F1=0.4716
Class 12: Best threshold=0.6319, F1=0.6395
Class 13: Best threshold=0.8423, F1=0.3762
Class 14: Best threshold=0.7366, F1=0.3939
Class 15: Best threshold=0.6886, F1=0.5043
Class 16: Best threshold=0.6208, F1=0.3194
Class 17: Best threshold=0.7119, F1=0.3768
Class 18: Best threshold=0.9429, F1=0.6076
Class 19: Best threshold=0.9712, F1=0.6111
Optimal thresholds per class: [0.70609677 0.53872406 0.87727875 0.88017046 0.60204166 0.613144
 0.5476915  0.47602674 0.5168072  0.38809362 0.44271022




In [8]:
best_thresholds

array([0.70609677, 0.53872406, 0.87727875, 0.88017046, 0.60204166,
       0.613144  , 0.5476915 , 0.47602674, 0.5168072 , 0.38809362,
       0.44271022, 0.47488382, 0.6318726 , 0.8423309 , 0.7365511 ,
       0.6886128 , 0.6208342 , 0.711936  , 0.94288176, 0.9711801 ],
      dtype=float32)

In [7]:
from sklearn.metrics import (
    precision_recall_curve,
    auc,
    f1_score,
    precision_score,
    recall_score,
    accuracy_score,
    roc_auc_score
)
# -------------------------
# 5. Binarize predictions using best thresholds
# -------------------------
binary_preds = (all_probs > best_thresholds).astype(int)

# -------------------------
# 6. Compute metrics
# -------------------------
# Macro F1
f1_macro = f1_score(all_labels, binary_preds, average="macro")
# Micro F1
f1_micro = f1_score(all_labels, binary_preds, average="micro")
# Per-class F1
f1_per_class = f1_score(all_labels, binary_preds, average=None)

# Precision & Recall (macro)
precision_macro = precision_score(all_labels, binary_preds, average="macro")
recall_macro = recall_score(all_labels, binary_preds, average="macro")

# Accuracy (sample-level)
accuracy = accuracy_score(all_labels, binary_preds)

# AUROC
try:
    auroc_macro = roc_auc_score(all_labels, all_probs, average="macro")
    auroc_per_class = roc_auc_score(all_labels, all_probs, average=None)
except ValueError:
    auroc_macro = None
    auroc_per_class = None

# -------------------------
# AU-PR per class and macro
# -------------------------
aupr_per_class = []

for i in range(all_labels.shape[1]):
    y_true = all_labels[:, i]
    y_prob = all_probs[:, i]
    precision, recall, _ = precision_recall_curve(y_true, y_prob)
    # recall decreases, so reverse both arrays to make x-axis increasing
    aupr_per_class.append(auc(recall[::-1], precision[::-1]))

aupr_per_class = np.array(aupr_per_class)

# Macro AU-PR: flatten all classes
precision_flat, recall_flat, _ = precision_recall_curve(
    all_labels.flatten(), all_probs.flatten()
)
aupr_macro = auc(recall_flat[::-1], precision_flat[::-1])

# -------------------------
# 7. Print metrics
# -------------------------
print(f"Macro F1: {f1_macro:.4f}, Micro F1: {f1_micro:.4f}")
print(f"Precision (macro): {precision_macro:.4f}, Recall (macro): {recall_macro:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"AUROC macro: {auroc_macro}")
print(f"AU-PR macro: {aupr_macro}")
print(f"Per-class F1: {f1_per_class}")
print(f"Per-class AUROC: {auroc_per_class}")
print(f"Per-class AU-PR: {aupr_per_class}")

Macro F1: 0.5008, Micro F1: 0.5252
Precision (macro): 0.4087, Recall (macro): 0.6692
Accuracy: 0.0753
AUROC macro: 0.8040248288977032
AU-PR macro: 0.38417851807324566
Per-class F1: [0.37653737 0.69681104 0.58679707 0.37614679 0.64343164 0.37055418
 0.5202934  0.48150256 0.44804866 0.63427242 0.62849258 0.47098976
 0.63796909 0.37396122 0.39240506 0.50326323 0.31789282 0.37485172
 0.58974359 0.5915493 ]
Per-class AUROC: [0.75454648 0.87355278 0.91557199 0.8004268  0.86182764 0.74423348
 0.72725877 0.77809975 0.6984549  0.67501074 0.73360559 0.6638663
 0.92844327 0.82221056 0.77672761 0.78473298 0.73475929 0.82065572
 0.99334527 0.99316668]
Per-class AU-PR: [0.25831469 0.67865456 0.53710951 0.23486803 0.61553397 0.28869146
 0.45929579 0.47251787 0.39937928 0.60204746 0.61789258 0.39263981
 0.6584722  0.27616092 0.28926137 0.45531531 0.22114663 0.26570183
 0.58131068 0.61506464]


In [14]:
binary_preds

array([[0, 1, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 1, 0],
       [0, 0, 0, ..., 0, 1, 1],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 1, 1, ..., 1, 0, 0],
       [0, 1, 0, ..., 0, 0, 0]], shape=(3133, 20))

In [15]:
all_probs

array([[0.21652523, 0.8563636 , 0.02323474, ..., 0.02188679, 0.08152154,
        0.13146359],
       [0.01217982, 0.01380807, 0.02129888, ..., 0.01234266, 0.9964037 ,
        0.51302296],
       [0.01879999, 0.01722941, 0.0245573 , ..., 0.01340362, 0.98717   ,
        0.98037535],
       ...,
       [0.01326523, 0.0242513 , 0.00580554, ..., 0.00947404, 0.03806407,
        0.0230825 ],
       [0.1676907 , 0.54416275, 0.8802362 , ..., 0.68881893, 0.03004128,
        0.01985495],
       [0.0692402 , 0.46907988, 0.49736205, ..., 0.5212696 , 0.02730496,
        0.00875787]], shape=(3133, 20), dtype=float32)

In [38]:
import numpy as np
from sklearn.metrics import roc_curve

def compute_youden_thresholds(y_true: np.ndarray, y_probs: np.ndarray):
    """
    Compute per-class thresholds using Youden's J statistic.
    
    Args:
        y_true (np.ndarray): shape [num_samples, num_classes], binary labels
        y_probs (np.ndarray): shape [num_samples, num_classes], predicted probabilities

    Returns:
        np.ndarray: per-class thresholds
    """
    num_classes = y_true.shape[1]
    thresholds = np.zeros(num_classes, dtype=float)

    for i in range(num_classes):
        y_c = y_true[:, i]
        p_c = y_probs[:, i]

        if np.unique(y_c).size < 2:
            # Class not present in validation, fallback
            thresholds[i] = 0.5
            continue

        fpr, tpr, thr = roc_curve(y_c, p_c)
        j_scores = tpr - fpr
        best_idx = np.argmax(j_scores)
        thresholds[i] = thr[best_idx]

    return thresholds


In [39]:
best_thresholds = compute_youden_thresholds(all_labels, all_probs)
binary_preds = (all_probs > best_thresholds).astype(int)

In [40]:
from sklearn.metrics import (
    precision_recall_curve,
    auc,
    f1_score,
    precision_score,
    recall_score,
    accuracy_score,
    roc_auc_score
)
# -------------------------
# 5. Binarize predictions using best thresholds
# -------------------------
binary_preds = (all_probs > best_thresholds).astype(int)

# -------------------------
# 6. Compute metrics
# -------------------------
# Macro F1
f1_macro = f1_score(all_labels, binary_preds, average="macro")
# Micro F1
f1_micro = f1_score(all_labels, binary_preds, average="micro")
# Per-class F1
f1_per_class = f1_score(all_labels, binary_preds, average=None)

# Precision & Recall (macro)
precision_macro = precision_score(all_labels, binary_preds, average="macro")
recall_macro = recall_score(all_labels, binary_preds, average="macro")

# Accuracy (sample-level)
accuracy = accuracy_score(all_labels, binary_preds)

# AUROC
try:
    auroc_macro = roc_auc_score(all_labels, all_probs, average="macro")
    auroc_per_class = roc_auc_score(all_labels, all_probs, average=None)
except ValueError:
    auroc_macro = None
    auroc_per_class = None

# -------------------------
# AU-PR per class and macro
# -------------------------
aupr_per_class = []

for i in range(all_labels.shape[1]):
    y_true = all_labels[:, i]
    y_prob = all_probs[:, i]
    precision, recall, _ = precision_recall_curve(y_true, y_prob)
    # recall decreases, so reverse both arrays to make x-axis increasing
    aupr_per_class.append(auc(recall[::-1], precision[::-1]))

aupr_per_class = np.array(aupr_per_class)

# Macro AU-PR: flatten all classes
precision_flat, recall_flat, _ = precision_recall_curve(
    all_labels.flatten(), all_probs.flatten()
)
aupr_macro = auc(recall_flat[::-1], precision_flat[::-1])

# -------------------------
# 7. Print metrics
# -------------------------
print(f"Macro F1: {f1_macro:.4f}, Micro F1: {f1_micro:.4f}")
print(f"Precision (macro): {precision_macro:.4f}, Recall (macro): {recall_macro:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"AUROC macro: {auroc_macro}")
print(f"AU-PR macro: {aupr_macro}")
print(f"Per-class F1: {f1_per_class}")
print(f"Per-class AUROC: {auroc_per_class}")
print(f"Per-class AU-PR: {aupr_per_class}")

Macro F1: 0.4560, Micro F1: 0.4612
Precision (macro): 0.3473, Recall (macro): 0.7485
Accuracy: 0.0361
AUROC macro: 0.7806986463679311
AU-PR macro: 0.41625180442996307
Per-class F1: [0.35174954 0.6819222  0.51056015 0.32933479 0.63254862 0.352473
 0.49455984 0.45404814 0.42950326 0.59057239 0.57673267 0.44187009
 0.59513075 0.29705882 0.29417773 0.48764629 0.24307837 0.29002154
 0.60504202 0.46153846]
Per-class AUROC: [0.75064464 0.86808149 0.90605194 0.77552201 0.84996053 0.73821267
 0.69849185 0.73867219 0.66306608 0.65426846 0.70846969 0.6184885
 0.92538625 0.7885101  0.72387978 0.7529817  0.65438221 0.81139259
 0.99248338 0.99502686]
Per-class AU-PR: [0.26142292 0.66527927 0.49897836 0.28077083 0.57010477 0.2916531
 0.4108997  0.42890124 0.32771492 0.57235455 0.57916542 0.35877698
 0.68927743 0.267614   0.235799   0.39422172 0.17461135 0.2542383
 0.68605984 0.69032716]
