In [1]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
import openslide
from skimage import color, filters, morphology, measure, exposure
from skimage.color import hed_from_rgb

import random
import torch
import torch.nn as nn
import torchvision.models as models
import torchvision.transforms as T
from PIL import Image
from collections import Counter

from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.metrics import classification_report, confusion_matrix
from torchvision.models import ResNet18_Weights

# ===== ✅ 设置随机种子（全局可复现）=====
SEED = 0
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# ===== 根目录 =====
ROOT = r"D:\total"
REQUIRE_MPP = True

# ===== CNN 特征提取器 =====
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
resnet = models.resnet18(weights=ResNet18_Weights.DEFAULT)
resnet.fc = nn.Identity()
resnet = resnet.to(device).eval()

transform_cnn = T.Compose([
    T.Resize((224, 224)),
    T.ToTensor(),
    T.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
])

def extract_cnn_feature(img):
    if isinstance(img, np.ndarray):
        img = Image.fromarray(img)
    elif isinstance(img, torch.Tensor):
        img = T.ToPILImage()(img)
    if getattr(img, "mode", None) != "RGB":
        img = img.convert("RGB")
    x = transform_cnn(img).unsqueeze(0).to(device)
    with torch.no_grad():
        feat = resnet(x).cpu().numpy().flatten()
    return feat  # [512]

# ---------- Top-K 病灶裁剪 + 池化 ----------
def extract_topk_crops(pil_img, nuc_mask, k=3, pad_ratio=0.12, out_size=224):
    lab = measure.label(nuc_mask)
    props = sorted(measure.regionprops(lab), key=lambda p: p.area, reverse=True)
    W, H = pil_img.width, pil_img.height
    crops = []
    for p in props[:k]:
        ymin, xmin, ymax, xmax = p.bbox
        h, w = ymax - ymin, xmax - xmin
        if h <= 0 or w <= 0: continue
        px = int(pad_ratio * w); py = int(pad_ratio * h)
        xmin = max(xmin - px, 0); xmax = min(xmax + px, W)
        ymin = max(ymin - py, 0); ymax = min(ymax + py, H)
        crops.append(pil_img.crop((xmin, ymin, xmax, ymax)).resize((out_size, out_size)))
    if not crops:
        crops = [pil_img.resize((out_size, out_size))]
    return crops

def cnn_features_pooled(crops):
    feats = [extract_cnn_feature(im) for im in crops]  # [k,512]
    F = np.stack(feats, 0)
    return np.concatenate([F.mean(0), F.max(0)], 0)  # [1024]
# ------------------------------------------

def read_thumbnail_and_geometry(svs_path, min_dim=3000):
    slide = openslide.OpenSlide(svs_path)
    W0, H0 = slide.dimensions
    scale = max(W0, H0) / float(min_dim)
    new_w, new_h = int(W0/scale), int(H0/scale)
    img = slide.get_thumbnail((new_w, new_h))
    mpp_x = slide.properties.get("openslide.mpp-x", None)
    mpp_y = slide.properties.get("openslide.mpp-y", None)
    slide.close()
    mpp_x = float(mpp_x) if mpp_x not in (None, "") else None
    mpp_y = float(mpp_y) if mpp_y not in (None, "") else None
    return np.asarray(img), scale, mpp_x, mpp_y

def tissue_mask(rgb):
    hsv = color.rgb2hsv(rgb); v = hsv[..., 2]
    thr = filters.threshold_otsu(v)
    mask = v < thr * 0.98
    mask = morphology.remove_small_holes(mask, 256)
    mask = morphology.remove_small_objects(mask, 256)
    return mask

def he_nuclei_mask(rgb, tissue):
    rgbf = np.clip(rgb/255.0, 0, 1)
    hed = color.separate_stains(rgbf, hed_from_rgb)
    H = hed[..., 0]
    H = exposure.rescale_intensity(
        H, in_range=(np.percentile(H[tissue], 2), np.percentile(H[tissue], 98)))
    thr = filters.threshold_otsu(H[tissue])
    nuc = np.zeros_like(H, dtype=bool)
    nuc[tissue] = H[tissue] > thr
    nuc = morphology.remove_small_objects(nuc.astype(bool), 64)
    nuc = morphology.binary_opening(nuc, morphology.disk(2))
    return nuc, H

def wsi_features(svs_path):
    rgb, scale, mpp_x, mpp_y = read_thumbnail_and_geometry(svs_path)
    has_mpp = (mpp_x is not None) and (mpp_y is not None) and (mpp_x > 0) and (mpp_y > 0)
    if REQUIRE_MPP and not has_mpp:
        return None
    mask = tissue_mask(rgb)
    if mask.sum() < 5000:
        return None
    nuc, H = he_nuclei_mask(rgb, mask)

    # —— 手工统计
    tissue_px_thumb = int(mask.sum())
    tumor_px_thumb  = int(nuc.sum())
    lab = measure.label(nuc)
    props = measure.regionprops(lab)
    areas  = np.array([p.area for p in props], dtype=np.float32) if props else np.array([])
    perims = np.array([p.perimeter for p in props], dtype=np.float32) if props else np.array([])
    cc_count = len(props)
    largest_cc_px = int(areas.max()) if cc_count else 0

    tumor_frac = tumor_px_thumb / max(tissue_px_thumb, 1)
    largest_cc_frac = largest_cc_px / max(tissue_px_thumb, 1)
    small_thresh = 0.001 * tissue_px_thumb
    cc_small = int((areas < small_thresh).sum()) if cc_count else 0
    cc_small_frac = cc_small / max(cc_count, 1)
    frag_ratio = float(perims.sum() / (areas.sum() + 1e-6)) if cc_count else 0.0

    feats = dict(
        tumor_frac=tumor_frac,
        largest_cc_frac=largest_cc_frac,
        cc_count=cc_count,
        cc_small_frac=cc_small_frac,
        frag_ratio=frag_ratio,
    )

    # —— CNN：Top-K 裁剪 + 池化（K=3）
    pil_img = Image.fromarray(rgb)
    crops = extract_topk_crops(pil_img, nuc_mask=nuc, k=3, pad_ratio=0.12)
    cnn_feat = cnn_features_pooled(crops)  # [1024]

    return feats, cnn_feat

def collect_rows(root):
    rows = []
    for stage in [1,2,3,4]:
        for split in ["train","val","test"]:   # ★ 加入 val
            d = Path(root) / f"stage {stage} {split}"
            if not d.exists(): 
                continue
            label = stage - 1  # 四分类：stage1→0, stage2→1, stage3→2, stage4→3
            for p in list(d.glob("*.svs")) + list(d.glob("*.tif")):
                res = wsi_features(str(p))
                if res is None: 
                    continue
                feats, cnn_feat = res
                row = dict(path=str(p), label=label, split=split)
                row.update(feats)
                for i, v in enumerate(cnn_feat):
                    row[f"cnn_{i}"] = v
                rows.append(row)
    return pd.DataFrame(rows)

# ===== 主程序 =====
if __name__ == "__main__":
    df = collect_rows(ROOT)
    print("Feature rows:", df.shape)

    # 按 split 打印类别分布
    for sp in ["train","val","test"]:
        if sp in df["split"].unique():
            print(f"\nClass distribution in {sp}:")
            print(df[df.split==sp]["label"].value_counts().sort_index())

    df.to_csv("wsi_stage_features_topk_4class_split.csv", index=False)

    train_df = df[df.split=="train"].copy()
    val_df   = df[df.split=="val"].copy()
    test_df  = df[df.split=="test"].copy()

    feature_cols = [c for c in df.columns if c.startswith("cnn_")] + [
        "tumor_frac","largest_cc_frac","cc_count","cc_small_frac","frag_ratio"
    ]

    Xtr, ytr = train_df[feature_cols].values, train_df["label"].values
    Xval, yval = val_df[feature_cols].values, val_df["label"].values
    Xte, yte = test_df[feature_cols].values, test_df["label"].values

    # ===== 类不平衡权重（只根据训练集计算） =====
    counts = train_df["label"].value_counts()
    n_classes = 4
    class_weight_map = {c: len(train_df) / (n_classes * counts[c]) for c in counts.index}

    custom_multipliers = {3: 2.5}  # stage4 提升权重
    for k, mult in custom_multipliers.items():
        if k in class_weight_map:
            class_weight_map[k] *= mult

    weights = train_df["label"].map(class_weight_map).values

    # ===== 模型训练（用训练集 + 验证集监控） =====
    clf = HistGradientBoostingClassifier(
        max_depth=6, learning_rate=0.06, max_iter=600,
        l2_regularization=1.0, min_samples_leaf=20, random_state=SEED
    )
    clf.fit(Xtr, ytr, sample_weight=weights)

    # 验证集效果
    yval_pr = clf.predict(Xval)
    print("\nValidation performance:")
    print(confusion_matrix(yval, yval_pr, labels=[0,1,2,3]))
    print(classification_report(yval, yval_pr, labels=[0,1,2,3],
          target_names=["stage1","stage2","stage3","stage4"], digits=4))
    
    # ===== 保存验证集预测结果 =====
    val_df["y_true"] = yval
    val_df["y_pred"] = yval_pr
    val_df.to_csv("val_preds.csv", index=False)
    print("\n✅ val predictions will be saved in val_preds.csv")

Feature rows: (190, 1032)

Class distribution in train:
label
0    46
1    41
2    52
3     9
Name: count, dtype: int64

Class distribution in val:
label
0    6
1    6
2    6
3    1
Name: count, dtype: int64

Class distribution in test:
label
0    12
1     6
2     4
3     1
Name: count, dtype: int64

Validation performance:
[[3 1 2 0]
 [0 6 0 0]
 [2 1 3 0]
 [1 0 0 0]]
              precision    recall  f1-score   support

      stage1     0.5000    0.5000    0.5000         6
      stage2     0.7500    1.0000    0.8571         6
      stage3     0.6000    0.5000    0.5455         6
      stage4     0.0000    0.0000    0.0000         1

    accuracy                         0.6316        19
   macro avg     0.4625    0.5000    0.4756        19
weighted avg     0.5842    0.6316    0.6008        19


✅ val predictions will be saved in val_preds.csv


  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


In [2]:
import joblib
import torch

joblib.dump(clf, "new_stage_classifier_gbdt.pkl")

['new_stage_classifier_gbdt.pkl']