In [6]:
import os, re, argparse, pandas as pd
from glob import glob

def stage_from_folder(name: str):
    name_l = name.lower()
    if 'stage 1' in name_l: return 'I'
    if 'stage 2' in name_l: return 'II'
    if 'stage 3' in name_l: return 'III'
    if 'stage 4' in name_l: return 'IV'
    raise ValueError(f"Cannot parse stage from folder: {name}")

def split_from_folder(name: str):
    name_l = name.lower()
    if 'train' in name_l: return 'train'
    if 'test'  in name_l: return 'test'
    return 'train'

def parse_slide_id(filename: str):
    base = os.path.basename(filename)
    # take substring before first '.' as slide_id
    if '.' in base:
        return base.split('.')[0]
    return os.path.splitext(base)[0]

def main(args):
    root = args.stage_root
    folders = [d for d in os.listdir(root) if os.path.isdir(os.path.join(root,d))]
    rows = []
    for fd in folders:
        stage = stage_from_folder(fd)
        split = split_from_folder(fd)
        folder_path = os.path.join(root, fd)
        imgs = []
        for ext in ('*.png','*.jpg','*.jpeg'):
            imgs += glob(os.path.join(folder_path, ext))
        for p in imgs:
            slide = parse_slide_id(p)
            rows.append({'patient_id': slide, 'path': os.path.abspath(p), 'wsi_id': slide, 'stage': stage, 'group': split})
    if not rows:
        print("No images found. Check your folder paths.")
        return

    df = pd.DataFrame(rows)
    # write labels: unique patient entries (stage/group: prefer 'train' if mixed; otherwise first)
    df_lab = df[['patient_id','stage','group']].drop_duplicates('patient_id')
    # If a patient appears in both train/test folders, mark group as 'train' (you may want to fix your split)
    dup = df_lab.duplicated('patient_id', keep=False)
    if dup.any():
        # keep the first occurrence; but warn
        print("Warning: some patient_ids appear in multiple splits. Keeping the first occurrence.")
        df_lab = df_lab.drop_duplicates('patient_id', keep='first')
    os.makedirs(args.out_dir, exist_ok=True)
    df_lab.to_csv(os.path.join(args.out_dir, 'labels.csv'), index=False)

    # write separate indexes
    df_train = df[df['group']=='train'][['patient_id','path','wsi_id']]
    df_test  = df[df['group']=='test'][['patient_id','path','wsi_id']]
    df_train.to_csv(os.path.join(args.out_dir, 'index_train.csv'), index=False)
    df_test.to_csv(os.path.join(args.out_dir, 'index_test.csv'), index=False)

    print("Wrote:")
    print(" -", os.path.join(args.out_dir, 'labels.csv'))
    print(" -", os.path.join(args.out_dir, 'index_train.csv'))
    print(" -", os.path.join(args.out_dir, 'index_test.csv'))

if __name__ == "__main__":
    class Args:
        stage_root = r"C:\Users\mxjli\PyCharmMiscProject\stage new"
        out_dir = r"C:\Users\mxjli\PyCharmMiscProject\stage new\processed"
    main(Args())

Wrote:
 - C:\Users\mxjli\PyCharmMiscProject\stage new\processed\labels.csv
 - C:\Users\mxjli\PyCharmMiscProject\stage new\processed\index_train.csv
 - C:\Users\mxjli\PyCharmMiscProject\stage new\processed\index_test.csv


In [7]:
import os, argparse, json, pandas as pd
from sklearn.model_selection import StratifiedGroupKFold

def main(args):
    df = pd.read_csv(args.label_csv)
    if args.task == 'A':
        y = (df['stage'].astype(str).str.upper() == 'IV').astype(int).values
        n_classes = 2
    else:
        mapping = {'I':0,'II':1,'III':2,'IV':3}
        y = df['stage'].astype(str).str.upper().map(mapping).values
        n_classes = 4
    groups = df['patient_id'].values
    sgkf = StratifiedGroupKFold(n_splits=args.n_splits, shuffle=True, random_state=42)
    os.makedirs(args.out_dir, exist_ok=True)
    for k, (tr, va) in enumerate(sgkf.split(X=df['patient_id'], y=y, groups=groups)):
        tr_ids = df.iloc[tr]['patient_id'].tolist()
        va_ids = df.iloc[va]['patient_id'].tolist()
        split = {'train': tr_ids, 'val': va_ids, 'n_classes': n_classes, 'task': args.task}
        out_dir = os.path.join(args.out_dir, f"fold_{k}")
        os.makedirs(out_dir, exist_ok=True)
        with open(os.path.join(out_dir, 'split.json'), 'w') as f:
            json.dump(split, f, indent=2)
    print("Saved splits to", args.out_dir)

if __name__ == "__main__":
    class Args:
        label_csv = r"C:\Users\mxjli\PyCharmMiscProject\stage new\processed\labels.csv"
        task = "B"  # ← 二分类任务
        n_splits = 5
        out_dir = r"C:\Users\mxjli\PyCharmMiscProject\stage new"

    args = Args()
    main(args)

Saved splits to C:\Users\mxjli\PyCharmMiscProject\stage new


In [10]:
train_path = r"C:\Users\mxjli\PyCharmMiscProject\stage new\processed\index_train.csv"
test_path  = r"C:\Users\mxjli\PyCharmMiscProject\stage new\processed\index_test.csv"
output_path = r"C:\Users\mxjli\PyCharmMiscProject\stage new\processed\index_all.csv"

train_df = pd.read_csv(train_path)
test_df = pd.read_csv(test_path)

all_df = pd.concat([train_df, test_df], ignore_index=True)

all_df.to_csv(output_path, index=False)

print("✅ integration completed，saved as index_all.csv")

✅ 合并完成，已保存为 index_all.csv


In [9]:
import pandas as pd
# 1. labels.csv
labels = pd.read_csv(r"C:\Users\mxjli\PyCharmMiscProject\stage new\processed\labels.csv")
stage2label = {"I": 0, "II": 1, "III": 2, "IV": 3}
labels["label"] = labels["stage"].map(stage2label)
labels = labels[["patient_id", "label"]]
labels.to_csv("labels_cleaned.csv", index=False)

# 2. index_train.csv
train = pd.read_csv(r"C:\Users\mxjli\PyCharmMiscProject\stage new\processed\index_train.csv")
train_unique = train[["wsi_id"]].drop_duplicates()
train_unique.columns = ["patient_id"]
train_unique.to_csv("index_train_cleaned.csv", index=False)

# 3. index_test.csv → index_val.csv
test = pd.read_csv(r"C:\Users\mxjli\PyCharmMiscProject\stage new\processed\index_test.csv")
val_unique = test[["wsi_id"]].drop_duplicates()
val_unique.columns = ["patient_id"]
val_unique.to_csv("index_val_cleaned.csv", index=False)

In [11]:
import os, argparse, h5py, numpy as np, torch, torchvision
from PIL import Image
from tqdm import tqdm
import pandas as pd
import json

def get_model(backbone='resnet50'):
    if backbone == 'resnet50':
        model = torchvision.models.resnet50(weights=torchvision.models.ResNet50_Weights.IMAGENET1K_V2)
        model.fc = torch.nn.Identity()
        feat_dim = 2048
    else:
        raise NotImplementedError
    model.eval()
    return model, feat_dim

def build_tfms(img_size=256):
    from torchvision import transforms as T
    tfms = T.Compose([
        T.Resize((img_size, img_size)),
        T.ToTensor(),
        T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225])
    ])
    return tfms

def walk_tiles_from_index(index_csv):
    df = pd.read_csv(index_csv)
    items = []
    for _, row in df.iterrows():
        pid = str(row['patient_id'])
        path = str(row['path'])
        wsi = str(row['wsi_id']) if 'wsi_id' in df.columns else '_flat'
        items.append((pid, wsi, path))
    return items

def extract_features(index_csv, out_root, img_size=256, batch_size=256, backbone='resnet50'):
    os.makedirs(out_root, exist_ok=True)
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model, feat_dim = get_model(backbone)
    model = model.to(device)
    tfms = build_tfms(img_size)

    items = walk_tiles_from_index(index_csv)
    by_patient = {}
    for pid, wsi, path in items:
        by_patient.setdefault(pid, []).append((wsi, path))

    for pid, lst in tqdm(by_patient.items(), desc='Extracting features'):
        feats, metas = [], []
        for i in range(0, len(lst), batch_size):
            batch = lst[i:i+batch_size]
            imgs = []
            for wsi, path in batch:
                try:
                    img = Image.open(path).convert('RGB')
                except Exception:
                    continue
                imgs.append(tfms(img))
                metas.append((os.path.basename(path), wsi))
            if not imgs:
                continue
            with torch.no_grad():
                x = torch.stack(imgs, dim=0).to(device)
                f = model(x).detach().cpu().numpy()
            feats.append(f)
        if not feats:
            continue
        feats = np.concatenate(feats, axis=0)
        out_path = os.path.join(out_root, f"{pid}.h5")
        with h5py.File(out_path, 'w') as h5:
            h5.create_dataset('feat', data=feats, compression='gzip')
            meta_arr = np.array(metas, dtype=h5py.string_dtype(encoding='utf-8'))
            h5.create_dataset('meta', data=meta_arr)
    print("Done. Saved to", out_root)

if __name__ == "__main__":
    split_json = r"C:\Users\mxjli\PyCharmMiscProject\stage new\fold_0\split.json"
    index_all_csv = r"C:\Users\mxjli\PyCharmMiscProject\stage new\processed\index_all.csv"
    out_root = r"C:\Users\mxjli\PyCharmMiscProject\stage new\features"

    split = json.load(open(split_json))
    ids_needed = set(split['train'] + split['val'])
    index_all = pd.read_csv(index_all_csv)
    index_needed = index_all[index_all['patient_id'].isin(ids_needed)]

    tmp_index_csv = r"C:\Users\mxjli\PyCharmMiscProject\stage new\processed\index_needed.csv"
    index_needed.to_csv(tmp_index_csv, index=False)

    extract_features(tmp_index_csv, out_root)


Extracting features: 100%|██████████| 71/71 [1:04:03<00:00, 54.14s/it]

Done. Saved to C:\Users\mxjli\PyCharmMiscProject\stage new\features





In [13]:
import os
import h5py
import torch
import torch.nn as nn
import torch.nn.functional as F
import pandas as pd
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import f1_score, average_precision_score
import numpy as np
from tqdm import tqdm

# ==================== parameters setting ====================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
feature_dir = r"C:\\Users\\mxjli\\PyCharmMiscProject\\stage new\\features"
label_file = r"C:\\Users\\mxjli\\PyCharmMiscProject\\stage new\\processed\\labels_cleaned.csv"
train_index_file = r"C:\\Users\\mxjli\\PyCharmMiscProject\\stage new\\processed\\index_train_cleaned.csv"
val_index_file = r"C:\\Users\\mxjli\\PyCharmMiscProject\\stage new\\processed\\index_val_cleaned.csv"
num_classes = 4
batch_size = 1
max_epochs = 50
early_stop_patience = 10
lr = 1e-4

# ==================== Dataset ====================
class MILDataset(Dataset):
    def __init__(self, index_csv, label_csv, feature_dir):
        self.index_df = pd.read_csv(index_csv)
        self.label_df = pd.read_csv(label_csv)
        self.feature_dir = feature_dir

        self.label_map = dict(zip(self.label_df["patient_id"], self.label_df["label"]))
        self.patient_ids = self.index_df["patient_id"].tolist()

    def __len__(self):
        return len(self.patient_ids)

    def __getitem__(self, idx):
        pid = self.patient_ids[idx]
        label = self.label_map[pid]

        feature_path = os.path.join(self.feature_dir, f"{pid}.h5")
        with h5py.File(feature_path, 'r') as f:
            # 自动识别 key（取第一个 dataset）
            feature_key = list(f.keys())[0]
            features = f[feature_key][:]  # shape: [num_patches, feature_dim]

        return torch.tensor(features, dtype=torch.float32), torch.tensor(label, dtype=torch.long)

# ==================== Gated Attention MIL ====================
class GatedAttentionMIL(nn.Module):
    def __init__(self, input_dim=2048, hidden_dim=256, num_classes=4):
        super().__init__()
        self.embedding = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU()
        )
        self.attention_V = nn.Linear(hidden_dim, 128)
        self.attention_U = nn.Linear(hidden_dim, 128)
        self.attention_weights = nn.Linear(128, 1)

        self.classifier = nn.Linear(hidden_dim, num_classes)

    def forward(self, x):
        H = self.embedding(x)  # [N, H]
        A_V = torch.tanh(self.attention_V(H))
        A_U = torch.sigmoid(self.attention_U(H))
        A = self.attention_weights(A_V * A_U)  # [N, 1]
        A = torch.softmax(A, dim=0)  # MIL attention

        M = torch.sum(A * H, dim=0)  # weighted sum
        out = self.classifier(M)
        return out, A

# ==================== Train ====================
def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    losses, all_preds, all_labels = [], [], []
    for feats, labels in loader:
        feats, labels = feats[0].to(device), labels.to(device)
        outputs, _ = model(feats)
        loss = criterion(outputs.unsqueeze(0), labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
        all_preds.append(outputs.argmax().item())
        all_labels.append(labels.item())

    f1 = f1_score(all_labels, all_preds, average='macro')
    return np.mean(losses), f1

# ==================== Validate ====================
def evaluate(model, loader, criterion):
    model.eval()
    losses, all_preds, all_labels, all_probs = [], [], [], []
    with torch.no_grad():
        for feats, labels in loader:
            feats, labels = feats[0].to(device), labels.to(device)
            outputs, _ = model(feats)
            loss = criterion(outputs.unsqueeze(0), labels)
            losses.append(loss.item())

            all_preds.append(outputs.argmax().item())
            all_labels.append(labels.item())
            all_probs.append(outputs.softmax(dim=-1).cpu().numpy())

    f1 = f1_score(all_labels, all_preds, average='macro')
    return np.mean(losses), f1, np.array(all_labels), np.array(all_probs)

# ==================== Run ====================
train_dataset = MILDataset(train_index_file, label_file, feature_dir)
val_dataset = MILDataset(val_index_file, label_file, feature_dir)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

model = GatedAttentionMIL(input_dim=2048, hidden_dim=256, num_classes=num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()

best_f1, patience = 0, 0
for epoch in range(1, max_epochs + 1):
    train_loss, train_f1 = train_one_epoch(model, train_loader, optimizer, criterion)
    val_loss, val_f1, val_labels, val_probs = evaluate(model, val_loader, criterion)

    print(f"Epoch {epoch}: Train Loss={train_loss:.4f}, F1={train_f1:.4f} | Val Loss={val_loss:.4f}, F1={val_f1:.4f}")

    if val_f1 > best_f1:
        best_f1 = val_f1
        patience = 0
        torch.save(model.state_dict(), "best_model.pth")

        # Save validation set prediction probabilities
        df_out = pd.DataFrame(val_probs, columns=[f"prob_{i}" for i in range(num_classes)])
        df_out.insert(0, "label", val_labels)
        df_out.to_csv("val_preds.csv", index=False)
    else:
        patience += 1
        if patience >= early_stop_patience:
            print("Early stopping triggered.")
            break

print(f"Best validation F1: {best_f1:.4f}")

Epoch 1: Train Loss=1.3913, F1=0.1366 | Val Loss=1.4167, F1=0.1875
Epoch 2: Train Loss=1.3365, F1=0.3749 | Val Loss=1.4200, F1=0.2667
Epoch 3: Train Loss=1.3064, F1=0.3271 | Val Loss=1.4235, F1=0.1111
Epoch 4: Train Loss=1.2684, F1=0.3820 | Val Loss=1.4340, F1=0.3631
Epoch 5: Train Loss=1.2226, F1=0.4121 | Val Loss=1.4389, F1=0.2833
Epoch 6: Train Loss=1.1883, F1=0.3802 | Val Loss=1.4441, F1=0.3631
Epoch 7: Train Loss=1.1639, F1=0.4108 | Val Loss=1.4616, F1=0.2429
Epoch 8: Train Loss=1.1151, F1=0.4425 | Val Loss=1.4535, F1=0.3631
Epoch 9: Train Loss=1.0990, F1=0.5193 | Val Loss=1.4698, F1=0.1875
Epoch 10: Train Loss=1.0656, F1=0.5037 | Val Loss=1.4589, F1=0.4583
Epoch 11: Train Loss=1.0285, F1=0.5822 | Val Loss=1.4905, F1=0.2833
Epoch 12: Train Loss=1.0192, F1=0.5469 | Val Loss=1.5076, F1=0.4583
Epoch 13: Train Loss=0.9768, F1=0.5750 | Val Loss=1.5070, F1=0.2667
Epoch 14: Train Loss=0.9209, F1=0.5861 | Val Loss=1.5659, F1=0.3000
Epoch 15: Train Loss=0.8710, F1=0.6457 | Val Loss=1.5741,