In [15]:
import os
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
import cv2
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import albumentations as A
from albumentations.pytorch import ToTensorV2
import timm
from sklearn.metrics import f1_score, roc_auc_score

In [2]:
# =====================
# CONFIG
# =====================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

PATCH_SIZE = 256
BATCH_SIZE = 64

BEST_MODEL_PATH = Path("/home/khdp-user/workspace/dataset/Models/m0m1_run_cls/best_model.pt")
SPLIT_CSV_PATH = Path("/home/khdp-user/workspace/dataset/Models/m0m1_run_cls/dataset.csv")

In [3]:
class PatchClsDataset(Dataset):
    def __init__(self, df, transform=None):
        self.df = df.reset_index(drop=True)
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        img = cv2.imread(row["path"])
        if img is None:
            raise RuntimeError(row["path"])
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        if self.transform:
            img = self.transform(image=img)["image"]

        y = torch.tensor(row["y"]).long()
        return img, y

def get_eval_transform():
    return A.Compose([
        A.Resize(PATCH_SIZE, PATCH_SIZE),
        A.Normalize(),
        ToTensorV2(),
    ])
def build_model():
    model = timm.create_model(
        "resnet50",
        pretrained=False,
        num_classes=2
    )
    return model

In [4]:
model = build_model()
model.load_state_dict(torch.load(BEST_MODEL_PATH, map_location=DEVICE))
model.to(DEVICE)
model.eval()
print("[OK] Model loaded:", BEST_MODEL_PATH)

[OK] Model loaded: /home/khdp-user/workspace/dataset/Models/m0m1_run_cls/best_model.pt


In [5]:
df = pd.read_csv(SPLIT_CSV_PATH)
df_eval = df[df.split == "test"].reset_index(drop=True)
eval_tf = get_eval_transform()
dl_eval = DataLoader(
    PatchClsDataset(df_eval, eval_tf),
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=4,
    pin_memory=True
)
print("Eval patches:", len(df_eval))
print(df_eval["y"].value_counts())

Eval patches: 149
y
0    98
1    51
Name: count, dtype: int64


In [16]:
@torch.no_grad()
def infer_with_probs(model, loader, df_eval):
    model.eval()

    all_probs = []
    all_preds = []

    idx = 0
    for x, _ in tqdm(loader, desc="Infer probs"):
        x = x.to(DEVICE)

        logits = model(x)
        probs = torch.softmax(logits, dim=1)[:, 1]  # m1 prob

        bs = x.size(0)
        all_probs.append(probs.cpu().numpy())
        all_preds.append(torch.argmax(logits, dim=1).cpu().numpy())
        idx += bs
    out = df_eval.copy()
    out["prob_m1"] = np.concatenate(all_probs)
    out["pred"] = np.concatenate(all_preds)
    return out
    
def compute_metrics_from_df(df, label_col="label"):
    """
    df: infer_with_probs 결과 DataFrame
    label_col: GT label 컬럼명
    """
    y_true = df[label_col].values
    y_pred = df["pred"].values
    y_prob = df["prob_m1"].values

    # Macro F1
    macro_f1 = f1_score(y_true, y_pred, average="macro")

    # ROC AUC (binary)
    auc = roc_auc_score(y_true, y_prob)

    return {
        "macro_f1": macro_f1,
        "auc": auc,
    }

In [20]:
df_pred = infer_with_probs(model, dl_eval, df_eval)
compute_metrics_from_df(df_pred,'y')

Infer probs: 100%|██████████| 3/3 [00:00<00:00,  5.76it/s]


{'macro_f1': 0.5389515835456862, 'auc': 0.6464585834333734}