# Import Model

In [1]:
import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1)
        self.bn1   = nn.BatchNorm2d(out_channels)
        self.act1  = nn.SiLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)
        self.bn2   = nn.BatchNorm2d(out_channels)
        self.shortcut = None
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )
        self.act2 = nn.SiLU()

    def forward(self, x):
        identity = x
        out = self.act1(self.bn1(self.conv1(x)))
        out = self.act2(self.bn2(self.conv2(out)))
        if self.shortcut is not None:
            identity = self.shortcut(x)
        return out + identity



class DeepTileEncoder(nn.Module):
    """加深的 Tile 分支：全局信息，多尺度池化 + 三层 MLP"""
    def __init__(self, out_dim, in_channels=3, negative_slope=0.01):
        super().__init__()
        self.layer0 = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.SiLU(),
            nn.MaxPool2d(2)  # 78→39
        )
        self.layer1 = nn.Sequential(
            ResidualBlock(32, 64),
            ResidualBlock(64, 64),
            nn.MaxPool2d(2)  # 39→19
        )
        self.layer2 = nn.Sequential(
            ResidualBlock(64, 128),
            ResidualBlock(128, 128),
            nn.MaxPool2d(2)  # 19→9
        )
        self.layer3 = nn.Sequential(
            ResidualBlock(128, 256),
            ResidualBlock(256, 256)
        )  # 保持 9×9

        # 多尺度池化
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))  # [B,256,1,1]
        self.mid_pool    = nn.AdaptiveAvgPool2d((3, 3))  # [B,256,3,3]

        total_dim = 256*1*1 + 256*3*3
        # 三层 MLP：total_dim → 2*out_dim → out_dim → out_dim
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.1),
            nn.Linear(total_dim, out_dim*4),
            nn.LeakyReLU(negative_slope),
            nn.Dropout(0.1),
            nn.Linear(out_dim*4, out_dim*2),
            nn.LeakyReLU(negative_slope),
            nn.Dropout(0.1),
            nn.Linear(out_dim*2, out_dim),
            nn.LeakyReLU(negative_slope),
        )

    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        # x: [B,256,9,9]
        g = self.global_pool(x).contiguous().reshape(x.size(0), -1)  # [B,256]
        m = self.mid_pool(x).contiguous().reshape(x.size(0), -1)     # [B,256*3*3]

        return self.fc(torch.cat([g, m], dim=1))


class SubtileEncoder(nn.Module):
    """多尺度 Subtile 分支：局部信息 + 两层 MLP"""
    def __init__(self, out_dim, in_channels=3, negative_slope=0.01):
        super().__init__()
        self.layer0 = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.SiLU(),
            nn.MaxPool2d(2)  # 26→13
        )
        self.layer1 = nn.Sequential(
            ResidualBlock(32, 64),
            ResidualBlock(64, 64),
            nn.MaxPool2d(2)  # 13→6
        )
        self.layer2 = nn.Sequential(
            ResidualBlock(64, 128),
            ResidualBlock(128, 128)
        )  # 保持 6×6

        self.global_pool = nn.AdaptiveAvgPool2d((1,1))
        self.mid_pool    = nn.AdaptiveAvgPool2d((2,2))
        self.large_pool    = nn.AdaptiveAvgPool2d((3,3))

        total_dim = 128*1*1 + 128*2*2 + 128*3*3
        # 两层 MLP：total_dim → out_dim*2 → out_dim
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.1),
            nn.Linear(total_dim, out_dim*2),
            nn.LeakyReLU(negative_slope),
            nn.Dropout(0.1),
            nn.Linear(out_dim*2, out_dim),
            nn.LeakyReLU(negative_slope),
        )

    def forward(self, x):
        B, N, C, H, W = x.shape
        x = x.contiguous().reshape(B*N, C, H, W)
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        # g,m: [B*N, feat]
        g = self.global_pool(x).contiguous().reshape(B, N, -1)
        m = self.mid_pool(x).contiguous().reshape(B, N, -1)
        l = self.large_pool(x).contiguous().reshape(B, N, -1)

        # 合并 N 张 subtiles，再 FC
        feat = torch.cat([g, m, l], dim=2).mean(dim=1).contiguous()  # [B, total_dim]
        return self.fc(feat)
class CenterSubtileEncoder(nn.Module):
    """專門處理中心 subtile 的 Encoder"""
    def __init__(self, out_dim, in_channels=3, negative_slope= 0.01):
        super().__init__()
        self.layer0 = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.SiLU(),
            nn.MaxPool2d(2)  # 26→13
        )
        self.layer1 = nn.Sequential(
            ResidualBlock(32, 64),
            ResidualBlock(64, 64),
            nn.MaxPool2d(2)  # 13→6
        )
        self.layer2 = nn.Sequential(
            ResidualBlock(64, 128),
            ResidualBlock(128, 128)
        )  # 6×6

        # 多尺度池化
        self.global_pool = nn.AdaptiveAvgPool2d((1,1))
        self.mid_pool    = nn.AdaptiveAvgPool2d((2,2))
        self.large_pool    = nn.AdaptiveAvgPool2d((3,3))

        total_dim = 128*1*1 + 128*2*2 + 128*3*3
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.1),
            nn.Linear(total_dim, out_dim*2),
            nn.LeakyReLU(negative_slope),
            nn.Dropout(0.1),
            nn.Linear(out_dim*2, out_dim),
            nn.LeakyReLU(negative_slope),
        )

    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        g = self.global_pool(x).contiguous().reshape(x.size(0), -1)
        m = self.mid_pool(x).contiguous().reshape(x.size(0), -1)
        l = self.large_pool(x).contiguous().reshape(x.size(0), -1)

        return self.fc(torch.cat([g, m, l], dim=1)).contiguous()



class VisionMLP_MultiTask(nn.Module):
    """整體多任務模型：融合 tile + subtile + center，使用動態權重融合"""
    def __init__(self, tile_dim=128, subtile_dim=64, output_dim=35, negative_slope=0.01):
        super().__init__()
        self.encoder_tile    = DeepTileEncoder(tile_dim)
        self.encoder_subtile = SubtileEncoder(subtile_dim)
        self.encoder_center  = CenterSubtileEncoder(subtile_dim)

        # 輸出 decoder：輸入為 tile_dim (因為融合後只剩一個 vector)
        self.decoder = nn.Sequential(
            nn.Linear(tile_dim + subtile_dim + subtile_dim , 256),
            nn.LeakyReLU(negative_slope),
            nn.Dropout(0.1),
            nn.Linear(256, 128),
            nn.LeakyReLU(negative_slope),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.LeakyReLU(negative_slope),
            nn.Dropout(0.1),
            nn.Linear(64, output_dim),
        )

    def forward(self, tile, subtiles):
        tile = tile.contiguous()
        subtiles = subtiles.contiguous()
        center = subtiles[:, 4]

        f_tile = self.encoder_tile(tile)         # [B, tile_dim]
        f_sub  = self.encoder_subtile(subtiles)  # [B, subtile_dim]
        f_center = self.encoder_center(center)   # [B, subtile_dim]

        # 拼接三個分支做 gating
        features_cat = torch.cat([f_tile, f_sub, f_center], dim=1)  # [B, tile+sub+center]
        return self.decoder(features_cat)





# 用法示例
model = VisionMLP_MultiTask(tile_dim=128, subtile_dim=128, output_dim=35)


# —— 5) 确保只有 decoder 可训练 ——  
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total     = sum(p.numel() for p in model.parameters())
print(f"Trainable / total params = {trainable:,} / {total:,}")

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total     = sum(p.numel() for p in model.parameters())
print(f"Trainable / total params = {trainable:,} / {total:,}")
model


Trainable / total params = 6,679,843 / 6,679,843
Trainable / total params = 6,679,843 / 6,679,843


VisionMLP_MultiTask(
  (encoder_tile): DeepTileEncoder(
    (layer0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (layer1): Sequential(
      (0): ResidualBlock(
        (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): SiLU()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        

## Load Model

# Import training data

## Same in multiple .pt

In [2]:
import os
import torch
import random
import inspect
from python_scripts.import_data import load_all_tile_data

# 用法範例
#folder = "dataset/spot-rank/version-3/only_tile_sub/original_train"
folder = "dataset/spot-rank/filtered_directly_rank/masked/realign/Macenko_4_7_masked/filtered/train_data/"

grouped_data = load_all_tile_data( 
        folder_path=folder,
        model=model,
        fraction=1,
        shuffle=False
    )

    # grouped_data 現在只會有 model.forward() 需要的 key，
    # 像 ['tile','subtiles','neighbors','norm_coord','node_feat','adj_list','edge_feat','label','source_idx']
print("Loaded keys:", grouped_data.keys())
print("Samples:", len(next(iter(grouped_data.values()))))




  from .autonotebook import tqdm as notebook_tqdm
  d = torch.load(fpath, map_location='cpu')


Loaded keys: dict_keys(['source_idx', 'position', 'slide_idx', 'subtiles', 'tile', 'label'])
Samples: 8348


In [36]:
from python_scripts.import_data import convert_item, get_model_inputs

import torch
from torch.utils.data import Dataset
import inspect
import numpy as np

class importDataset(Dataset):
    def __init__(self, data_dict, model, image_keys=None, transform=None, print_sig=False):
        self.data = data_dict
        self.image_keys = set(image_keys) if image_keys is not None else set()
        self.transform = transform if transform is not None else lambda x: x
        self.forward_keys = list(get_model_inputs(model, print_sig=print_sig).parameters.keys())

        expected_length = None
        for key, value in self.data.items():
            if expected_length is None:
                expected_length = len(value)
            if len(value) != expected_length:
                raise ValueError(f"資料欄位 '{key}' 的長度 ({len(value)}) 與預期 ({expected_length}) 不一致。")

        for key in self.forward_keys:
            if key not in self.data:
                raise ValueError(f"data_dict 缺少模型 forward 所需欄位: '{key}'。目前可用的欄位: {list(self.data.keys())}")
        if "label" not in self.data:
            raise ValueError(f"data_dict 必須包含 'label' 欄位。可用的欄位: {list(self.data.keys())}")
        if "source_idx" not in self.data:
            raise ValueError("data_dict 必須包含 'source_idx' 欄位，用於 trace 原始順序對應。")
        if "position" not in self.data:
            raise ValueError("data_dict 必須包含 'position' 欄位，用於 trace 原始順序對應。")
    def __len__(self):
        return len(next(iter(self.data.values())))

    def __getitem__(self, idx):
        sample = {}
        for key in self.forward_keys:
            value = self.data[key][idx]
            value = self.transform(value)
            value = convert_item(value, is_image=(key in self.image_keys))
            if isinstance(value, torch.Tensor):
                value = value.float()
            sample[key] = value

        label = self.transform(self.data["label"][idx])
        label = convert_item(label, is_image=False)
        if isinstance(label, torch.Tensor):
            label = label.float()
        sample["label"] = label

        # 加入 source_idx
        source_idx = self.data["source_idx"][idx]
        sample["source_idx"] = torch.tensor(source_idx, dtype=torch.long)
        # 加入 position （假设 data_dict 中 'position' 是 (x, y) 或 [x, y]）
        pos = self.data["position"][idx]
        sample["position"] = torch.tensor(pos, dtype=torch.float)
        return sample
    def check_item(self, idx=0, num_lines=5):
        expected_keys = self.forward_keys + ['label', 'source_idx', 'position']
        sample = self[idx]
        print(f"🔍 Checking dataset sample: {idx}")
        for key in expected_keys:
            if key not in sample:
                print(f"❌ 資料中缺少 key: {key}")
                continue
            tensor = sample[key]
            if isinstance(tensor, torch.Tensor):
                try:
                    shape = tensor.shape
                except Exception:
                    shape = "N/A"
                dtype = tensor.dtype if hasattr(tensor, "dtype") else "N/A"
                output_str = f"📏 {key} shape: {shape} | dtype: {dtype}"
                if tensor.numel() > 0:
                    try:
                        tensor_float = tensor.float()
                        mn = tensor_float.min().item()
                        mx = tensor_float.max().item()
                        mean = tensor_float.mean().item()
                        std = tensor_float.std().item()
                        output_str += f" | min: {mn:.3f}, max: {mx:.3f}, mean: {mean:.3f}, std: {std:.3f}"
                    except Exception:
                        output_str += " | 無法計算統計數據"
                print(output_str)
                if key not in self.image_keys:
                    if tensor.ndim == 0:
                        print(f"--- {key} 資料為純量:", tensor)
                    elif tensor.ndim == 1:
                        print(f"--- {key} head (前 {num_lines} 個元素):")
                        print(tensor[:num_lines])
                    else:
                        print(f"--- {key} head (前 {num_lines} 列):")
                        print(tensor[:num_lines])
            else:
                # 如果 position 存的是 list/tuple/etc，也会走这里
                print(f"📏 {key} (非 tensor 資料):", tensor)
        print("✅ All checks passed!")


full_dataset = importDataset(grouped_data, model,
                             image_keys=['tile','subtiles'],
                             transform=lambda x: x)

full_dataset.check_item()

🔍 Checking dataset sample: 0
📏 tile shape: torch.Size([3, 78, 78]) | dtype: torch.float32 | min: 0.129, max: 1.000, mean: 0.653, std: 0.150
📏 subtiles shape: torch.Size([9, 3, 26, 26]) | dtype: torch.float32 | min: 0.129, max: 1.000, mean: 0.653, std: 0.150
📏 label shape: torch.Size([35]) | dtype: torch.float32 | min: 1.000, max: 35.000, mean: 18.000, std: 10.247
--- label head (前 5 個元素):
tensor([12., 24., 18.,  6., 30.])
📏 source_idx shape: torch.Size([]) | dtype: torch.int64 | min: 0.000, max: 0.000, mean: 0.000, std: nan
--- source_idx 資料為純量: tensor(0)
📏 position shape: torch.Size([2]) | dtype: torch.float32 | min: 0.171, max: 0.632, mean: 0.401, std: 0.326
--- position head (前 5 個元素):
tensor([0.6318, 0.1707])
✅ All checks passed!


  std = tensor_float.std().item()


In [22]:

from python_scripts.image_features import  *
from python_scripts.prediction_features import  *
import numpy as np
from torch.utils.data import DataLoader

# === Main Function with Names ===
def generate_meta_features(dataset, model_for_recon, device, ae_type, oof_preds = None, latents = None):
    """
    Generate meta-features and corresponding names.

    Returns
    -------
    features : np.ndarray, shape (n_samples, n_features)
    names    : list of str, length n_features
    """

    loader = DataLoader(dataset, batch_size=64, shuffle=False)

    # 1) 收集所有 (feats, names) 到同一个 outputs 列表
    outputs = []

    # # AE reconstruction loss
    feats, names = compute_ae_reconstruction_loss(model_for_recon, loader, device, ae_type)
    feats = feats[:, None]
    outputs.append((feats, names))

    # # AE embeddings
    feats, names = compute_ae_embeddings(loader, model_for_recon, device)
    outputs.append((feats, names))
    if latents is not None:
        # 原始 35 维 preds
        n_classes = latents.shape[1]
        raw_names = [f"trained-latents_{i}" for i in range(n_classes)]
        outputs.append((latents, raw_names))
    # # # Latent stats
    latent_feats = outputs[1][0]
    feats, names = compute_latent_stats(latent_feats)
    outputs.append((feats, names))

    # # # RGB stats

    feats, names = compute_center_subtile_rgb_stats(dataset)
    outputs.append((feats, names))
    feats, names = compute_subtiles_except_center_rgb_stats(dataset)
    outputs.append((feats, names))
    feats, names = compute_tile_rgb_stats(dataset)
    outputs.append((feats, names))
    feats, names = compute_subtile_contrast_stats(dataset)
    outputs.append((feats, names))

    # # # Texture & pattern features
    feats, names = compute_wavelet_stats(dataset)
    outputs.append((feats, names))
    feats, names = compute_sobel_stats(dataset)
    outputs.append((feats, names))

    # Color & distribution features
    feats, names = compute_hsv_stats(dataset)
    outputs.append((feats, names))

    # H&E stain features
    feats, names = compute_he_stats(dataset)
    outputs.append((feats, names))

    # Sliding-window std stats
    feats, names = compute_sliding_window_stats(dataset)
    outputs.append((feats, names))

    # # 9) OOF-based features (only if provided)
    if oof_preds is not None:
        # 原始 35 维 preds
        n_classes = oof_preds.shape[1]
        raw_names = [f"oof_pred_{i}" for i in range(n_classes)]
        outputs.append((oof_preds, raw_names))

        # # 相邻差异
        feats, names = compute_adjacent_diffs(oof_preds, stride=1)
        outputs.append((feats, names))

        # # top-2..top-6 统计
        feats, names = compute_lastn_stats_multi(oof_preds, max_n=35)
        outputs.append((feats, names))
        feats, names = compute_topn_stats_multi(oof_preds, max_n=6)
        outputs.append((feats, names))
        # 更多可选——只需取消注释即可
        feats, names = compute_adj_diff_histogram(oof_preds)
        outputs.append((feats, names))
        feats, names = compute_multi_stride_diffs(oof_preds)
        outputs.append((feats, names))
        feats, names = compute_median_mad(oof_preds)
        outputs.append((feats, names))
        feats, names = compute_skew_kurt(oof_preds)
        outputs.append((feats, names))
        feats, names = compute_percentile_iqr(oof_preds)
        outputs.append((feats, names))
        feats, names = compute_renyi_entropy(oof_preds, alpha=2)
        outputs.append((feats, names))
        feats, names = compute_mass_topk(oof_preds, k=5)
        outputs.append((feats, names))
        feats, names = compute_cdf_slope(oof_preds)
        outputs.append((feats, names))
        feats, names = compute_pca_components(oof_preds, n_components=10)
        outputs.append((feats, names))
        feats, names = compute_peak_stats(oof_preds)
        outputs.append((feats, names))
        feats, names = compute_segment_stats(oof_preds)
        outputs.append((feats, names))
        feats, names = compute_ar_coeffs(oof_preds)
        outputs.append((feats, names))
        feats, names = compute_autocorr_features(oof_preds)
        outputs.append((feats, names))
        feats, names = compute_second_order_diffs(oof_preds)
        outputs.append((feats, names))
        feats, names = compute_third_order_diffs(oof_preds)
        outputs.append((feats, names))
        feats, names = compute_relative_diffs(oof_preds)
        outputs.append((feats, names))
        feats, names = compute_diff_ratio_of_diffs(oof_preds)
        outputs.append((feats, names))


    # 2) unzip 成 feat_list 和 name_seq
    feat_list, name_seq = zip(*outputs)

    # 3) 逐块校验 feats 列数与 names 长度
    for feats, names_block in zip(feat_list, name_seq):
        ncols = feats.shape[1] if feats.ndim == 2 else 1
        if ncols != len(names_block):
            raise ValueError(
                f"Mismatch: got {ncols} columns but {len(names_block)} names "
                f"in block '{names_block[0].split('_')[0]}'"
            )
        print(
            f"{names_block[0].split('_')[0]:12s} -> cols: {ncols:4d}, names: {len(names_block):4d} OK"
        )

    # 4) 扁平化 names 并拼接 features
    name_list = [nm for block in name_seq for nm in block]
    features = np.concatenate(feat_list, axis=1)
    print(f"✅ Generated meta-features with shape: {features.shape}")

    return features, name_list


In [8]:
from collections import defaultdict
import numpy as np

def diagnose_meta_nonfinite(meta: np.ndarray, names: list[str]):
    """
    按名字前缀分组，统计每组：
      - 原始特征数（列数）
      - 总值数（列数 × 行数）
      - NaN 值数量
      - ±Inf 值数量
      - 非数值（non-finite）总数

    Parameters
    ----------
    meta  : np.ndarray, shape (n_samples, n_features)
    names : list of str, length n_features

    Returns
    -------
    stats : dict[prefix, dict]  
        每个 prefix 对应一个字典，
        包含 'n_feats','total_vals','n_nan','n_inf','n_nonfinite'。
    """
    groups = defaultdict(list)
    # 按前缀分组
    for idx, nm in enumerate(names):
        prefix = nm.split('_', 1)[0]
        groups[prefix].append(idx)

    stats = {}
    for prefix, idxs in groups.items():
        sub = meta[:, idxs]  # shape (n_samples, n_group_feats)
        n_feats = sub.shape[1]
        total_vals = sub.size
        n_nan = np.isnan(sub).sum()
        n_inf = np.isinf(sub).sum()
        n_nonfinite = (~np.isfinite(sub)).sum()

        stats[prefix] = {
            'n_feats':        n_feats,
            'total_vals':     total_vals,
            'n_nan':          int(n_nan),
            'n_inf':          int(n_inf),
            'n_nonfinite':    int(n_nonfinite),
        }
        print(
            f"Group '{prefix}': "
            f"features={n_feats}, "
            f"values={total_vals}, "
            f"non-finite={n_nonfinite} "
            f"(nan={n_nan}, inf={n_inf})"
        )
    return stats


In [None]:
import os
import numpy as np
import joblib
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.multioutput import MultiOutputRegressor
from sklearn.model_selection import train_test_split
import lightgbm as lgb
from scipy.stats import rankdata
from python_scripts.import_data import importDataset
from python_scripts.operate_model import predict
from lightgbm import early_stopping, log_evaluation
import h5py
import pandas as pd
from python_scripts.pretrain_model import PretrainedEncoderRegressor

import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
from torch.optim.lr_scheduler import ReduceLROnPlateau

# ---------------- Settings ----------------
trained_oof_model_folder = 'output_folder/rank-spot/realign/no_pretrain/3_encoder/filtered_directly_rank/k-fold/realign_all/stain_nor_with_4_7/Macenko_masked/'
n_folds    = len([d for d in os.listdir(trained_oof_model_folder) if d.startswith('fold')])
n_samples  = len(full_dataset)
C          = 35
BATCH_SIZE = 64
start_fold = 0

tile_dim = 128
center_dim = 128
neighbor_dim = 128
fusion_dim = tile_dim + center_dim + neighbor_dim
META_EPOCHS = 200
pretrained_ae_name = 'AE_Center_noaug'
pretrained_ae_path = f"AE_model/128/{pretrained_ae_name}/best.pt"
ae_type = 'center'

# Ground truth label (全 dataset)
y_true = np.vstack([ full_dataset[i]['label'].cpu().numpy() for i in range(n_samples) ])

# Build CV splitter (must match first stage splits)
logo = LeaveOneGroupOut()
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

class DeepResMLP(nn.Module):
    def __init__(self, in_dim=2924, hidden_dims=[1024,512,256,128], out_dim=35):
        super().__init__()
        layers = []
        prev = in_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.LeakyReLU(0.01))
            layers.append(nn.BatchNorm1d(h))
            layers.append(nn.Dropout(0.2))
            prev = h
        self.net = nn.Sequential(*layers)
        self.head = nn.Linear(prev, out_dim)

        # 残差桥：in_dim→first hidden
        if in_dim != hidden_dims[0]:
            self.res_proj = nn.Linear(in_dim, hidden_dims[0])
        else:
            self.res_proj = nn.Identity()

    def forward(self, x):
        # x: [B,2924], pred: [B,35]
        # 残差桥接
        res = self.res_proj(x)
        h   = self.net[0:4](x)       # 第一段 linear→act→bn→drop
        h  += res                    # add residual
        h   = self.net[4:](h)        # 剩余层
        delta = self.head(h)
        return delta

recon_model = PretrainedEncoderRegressor(
        ae_checkpoint=pretrained_ae_path,
        ae_type=ae_type,
        tile_dim=tile_dim,
        center_dim=center_dim,
        neighbor_dim=neighbor_dim,
        output_dim=C,
        mode='reconstruction'
    ).to(device)
slide_idx = np.array(grouped_data['slide_idx'])   # shape (N,)


  ae.load_state_dict(torch.load(ae_checkpoint, map_location="cpu"))


# Per model per meta model

In [None]:
META_EPOCHS = 300
start_fold = 0
repeats = 3
from python_scripts.aug         import augment_grouped_data, identity, subset_grouped_data
from scipy.stats import rankdata

def extract_feats_preds(loader, rank_preds=False):
    all_latents, all_preds, all_labels = [], [], []
    with torch.no_grad():
        for batch in loader:
            tiles    = batch['tile'].to(device)
            subtiles = batch['subtiles'].to(device)
            labels   = batch['label']                 # 假设 batch['label'] 在 cpu
            center   = subtiles[:, 4].contiguous()

            f_c  = net.encoder_center(center)
            f_n  = net.encoder_subtile(subtiles)
            f_t  = net.encoder_tile(tiles)
            fuse = torch.cat([f_c, f_n, f_t], dim=1)

            out  = net.decoder(fuse)

            all_latents.append(fuse.cpu())
            all_preds.append(out.cpu())
            all_labels.append(labels)

    # cat & to numpy
    latents = torch.cat(all_latents, dim=0).numpy()  # (N, D)
    preds   = torch.cat(all_preds,   dim=0).numpy()  # (N, 35)
    labs    = torch.cat(all_labels,  dim=0).numpy()  # (N, 35)

    if rank_preds:
        # 对每一行 row 做 rank，数值越大 rank 越大；最小的数 rank=1
        # 注意：rankdata 默认 smallest→1, largest→N
        # 如果你想让 largest→35, smallest→1，则：
        ranks = np.apply_along_axis(lambda row: rankdata(row, method='ordinal'), 1, preds)
        preds = ranks.astype(np.float32)

    return latents, preds, labs
# 在 loop 之前

fold_mse = {}
for fold_id, (tr_idx, va_idx) in enumerate(
    logo.split(X=np.zeros(n_samples), y=None, groups=slide_idx)):

    # if fold_id > start_fold:
    #     print(f"⏭️ Skipping fold {fold_id}")
    #     continue

    print(f"\n🚀 Starting fold {fold_id}...")
    ckpt_path = os.path.join(trained_oof_model_folder, f"fold{fold_id}", "best_model.pt")

    # === Load model and predict OOF ===
    net = VisionMLP_MultiTask(tile_dim=tile_dim, subtile_dim=center_dim, output_dim=C)
    net.load_state_dict(torch.load(ckpt_path, map_location=device))
    net = net.to(device).eval()


    
    local_idx = np.arange(len(va_idx))
    train_loc, val_loc = train_test_split(
        local_idx,
        test_size=0.2,
        random_state=42,
        shuffle=True
    )
    train_meta_idx = va_idx[train_loc]   # 真实的 global indices
    val_meta_idx   = va_idx[val_loc]

    # 2) 对 train_meta 做 augment
    train_base = subset_grouped_data(grouped_data, train_meta_idx)
    print("🌀 Starting augment for meta-train …")
    train_aug_ds = augment_grouped_data(
        grouped_data=train_base,
        image_keys=['tile','subtiles'],
        repeats=repeats            # 你要的增强次数
    )
    print("🌀 Starting import sugmentation data …")

    train_dataset = importDataset(train_aug_ds, net,
                                image_keys=['tile','subtiles'],
                                transform=lambda x: x)
    train_aug_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
    val_base = subset_grouped_data(grouped_data, val_meta_idx)
    val_dataset = importDataset(val_base, net,
                             image_keys=['tile','subtiles'],
                             transform=lambda x: x)
    val_meta_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

    print("🌀 Starting prepare OOF data from CNN model…")
    train_lat, train_pr, train_y = extract_feats_preds(train_aug_loader)
    
    val_lat,   val_pr,   val_y   = extract_feats_preds(val_meta_loader)

    extra_train_feats, train_feat_names = generate_meta_features(
        dataset=train_dataset,
        model_for_recon=recon_model,      # 你的 AE 模型
        device=device,
        ae_type=ae_type,
        oof_preds=train_pr,
        latents=train_lat
    )
    diagnose_meta_nonfinite(extra_train_feats, train_feat_names)
    extra_val_feats, val_feat_names = generate_meta_features(
        dataset=val_dataset,
        model_for_recon=recon_model,
        device=device,
        ae_type=ae_type,
        oof_preds=val_pr,
        latents=val_lat
    )
    diagnose_meta_nonfinite(extra_val_feats, val_feat_names)

    # 5) 把三部分特征拼在一起： [latents | preds | extra_feats]
    X_feat_train = np.concatenate([train_lat, extra_train_feats], axis=1)
    X_feat_val   = np.concatenate([val_lat, extra_val_feats  ], axis=1)
    # 这里 D_feat = X_feat_train.shape[1]
    D_feat = X_feat_train.shape[1]

    # 把 features, predictions, labels 三个张量放一起
    ds_train_meta = TensorDataset(
        torch.from_numpy(X_feat_train).float(),
        torch.from_numpy(train_pr).float(),
        torch.from_numpy(train_y).float()
    )
    ds_val_meta   = TensorDataset(
        torch.from_numpy(X_feat_val).float(),
        torch.from_numpy(val_pr).float(),
        torch.from_numpy(val_y).float()
    )

    loader_train_meta = DataLoader(ds_train_meta, batch_size=BATCH_SIZE, shuffle=True)
    loader_val_meta   = DataLoader(ds_val_meta,   batch_size=BATCH_SIZE, shuffle=False)

    # —— 6) 初始化 DeepResMLP —— 
    # 让它的输入维度等于 D_feat
    meta_model = DeepResMLP(in_dim=D_feat, hidden_dims=[1024,512,256,128], out_dim=C).to(device)

    criterion      = nn.MSELoss()
    optimizer_meta = torch.optim.Adam(meta_model.parameters(), lr=1e-3)
    scheduler_meta = ReduceLROnPlateau(optimizer_meta, mode='min', factor=0.5, patience=10, verbose=True)

    best_loss, es_cnt, es_patience = float('inf'), 0, 20
    best_path = os.path.join(trained_oof_model_folder, f"fold{fold_id}", "meta_model_best.pt")

    # —— 7) 训练循环 ——  
    for epoch in range(1, META_EPOCHS+1):
        # ——— 训练 ———
        meta_model.train()
        train_loss = 0.0
        for feats, pr, yb in loader_train_meta:
            feats, pr, yb = feats.to(device), pr.to(device), yb.to(device)
            out = meta_model(feats, pr)
            loss = criterion(out, yb)

            optimizer_meta.zero_grad()
            loss.backward()
            optimizer_meta.step()
            train_loss += loss.item() * feats.size(0)
        train_loss /= len(ds_train_meta)

        # ——— 验证 ———
        meta_model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for feats, pr, yb in loader_val_meta:
                feats, pr, yb = feats.to(device), pr.to(device), yb.to(device)
                out = meta_model(feats, pr)
                val_loss += criterion(out, yb).item() * feats.size(0)
        val_loss /= len(ds_val_meta)

        print(f"Fold{fold_id} Ep{epoch}/{META_EPOCHS} — "
            f"Train MSE: {train_loss:.6f}, Val MSE: {val_loss:.6f}, "
            f"LR: {optimizer_meta.param_groups[0]['lr']:.3e}")

        # 调度 & 早停 & 保存最佳
        scheduler_meta.step(val_loss)
        if val_loss < best_loss:
            best_loss, es_cnt = val_loss, 0
            torch.save(meta_model.state_dict(), best_path)
            print(f" ↳ New best (Val MSE={best_loss:.6f}), saved.")
        else:
            es_cnt += 1
            if es_cnt >= es_patience:
                print(f" ✋ Early stopping (no improvement in {es_patience} epochs).")
                break

    # —— 8) 加载最佳模型并评估（可选） ——  
    meta_model.load_state_dict(torch.load(best_path, map_location=device))
    meta_model.eval()
    with torch.no_grad():
        feats_all = torch.from_numpy(X_feat_val).float().to(device)
        pr_all    = torch.from_numpy(val_pr).float().to(device)
        refined_preds = meta_model(feats_all, pr_all).cpu().numpy()

    mse_val = ((refined_preds - val_y) ** 2).mean()
    print(f" Fold{fold_id} Best refined Val MSE: {mse_val:.6f}")
    print(f" ✅ Best meta-model saved to: {best_path}")
    fold_mse[fold_id] = mse_val
import pandas as pd
mse_df = pd.DataFrame.from_dict(fold_mse, orient='index', columns=['val_mse'])
mse_df.index.name = 'fold'
mse_df.to_csv(os.path.join(trained_oof_model_folder, "fold_meta_mse.csv"))


🚀 Starting fold 0...


  net.load_state_dict(torch.load(ckpt_path, map_location=device))


🌀 Starting augment for meta-train …
🌀 Starting import sugmentation data …
🌀 Starting prepare OOF data from CNN model…


Computing AE recon loss: 100%|██████████| 110/110 [00:11<00:00,  9.41it/s]


ae-recon-loss -> cols:    1, names:    1 OK
ae           -> cols:  384, names:  384 OK
trained-latents -> cols:  384, names:  384 OK
latent       -> cols:    4, names:    4 OK
subtile4     -> cols:   12, names:   12 OK
exsubtiles   -> cols:   12, names:   12 OK
tile         -> cols:   12, names:   12 OK
contrast     -> cols:    3, names:    3 OK
wavelet-tile -> cols:  280, names:  280 OK
sobel-tile   -> cols:   40, names:   40 OK
hsv-tile     -> cols:  120, names:  120 OK
he-tile      -> cols:   80, names:   80 OK
locstd       -> cols:  432, names:  432 OK
oof          -> cols:   35, names:   35 OK
adj          -> cols:   34, names:   34 OK
last         -> cols:  136, names:  136 OK
top          -> cols:   20, names:   20 OK
adj-his      -> cols:   10, names:   10 OK
diff         -> cols:  595, names:  595 OK
mad          -> cols:    1, names:    1 OK
skewness     -> cols:    2, names:    2 OK
p25          -> cols:    4, names:    4 OK
renyi        -> cols:    1, names:    1 OK
mass   

Computing AE recon loss: 100%|██████████| 7/7 [00:01<00:00,  5.72it/s]


ae-recon-loss -> cols:    1, names:    1 OK
ae           -> cols:  384, names:  384 OK
trained-latents -> cols:  384, names:  384 OK
latent       -> cols:    4, names:    4 OK
subtile4     -> cols:   12, names:   12 OK
exsubtiles   -> cols:   12, names:   12 OK
tile         -> cols:   12, names:   12 OK
contrast     -> cols:    3, names:    3 OK
wavelet-tile -> cols:  280, names:  280 OK
sobel-tile   -> cols:   40, names:   40 OK
hsv-tile     -> cols:  120, names:  120 OK
he-tile      -> cols:   80, names:   80 OK
locstd       -> cols:  432, names:  432 OK
oof          -> cols:   35, names:   35 OK
adj          -> cols:   34, names:   34 OK
last         -> cols:  136, names:  136 OK
top          -> cols:   20, names:   20 OK
adj-his      -> cols:   10, names:   10 OK
diff         -> cols:  595, names:  595 OK
mad          -> cols:    1, names:    1 OK
skewness     -> cols:    2, names:    2 OK
p25          -> cols:    4, names:    4 OK
renyi        -> cols:    1, names:    1 OK
mass   



Fold0 Ep1/300 — Train MSE: 405.319227, Val MSE: 376.475567, LR: 1.000e-03
 ↳ New best (Val MSE=376.475567), saved.
Fold0 Ep2/300 — Train MSE: 280.415148, Val MSE: 178.369030, LR: 1.000e-03
 ↳ New best (Val MSE=178.369030), saved.
Fold0 Ep3/300 — Train MSE: 105.109123, Val MSE: 61.640348, LR: 1.000e-03
 ↳ New best (Val MSE=61.640348), saved.
Fold0 Ep4/300 — Train MSE: 49.255349, Val MSE: 49.259682, LR: 1.000e-03
 ↳ New best (Val MSE=49.259682), saved.
Fold0 Ep5/300 — Train MSE: 43.230442, Val MSE: 46.381895, LR: 1.000e-03
 ↳ New best (Val MSE=46.381895), saved.
Fold0 Ep6/300 — Train MSE: 42.082382, Val MSE: 48.489742, LR: 1.000e-03
Fold0 Ep7/300 — Train MSE: 41.195338, Val MSE: 47.441790, LR: 1.000e-03
Fold0 Ep8/300 — Train MSE: 40.897712, Val MSE: 45.491223, LR: 1.000e-03
 ↳ New best (Val MSE=45.491223), saved.
Fold0 Ep9/300 — Train MSE: 40.880134, Val MSE: 51.667825, LR: 1.000e-03
Fold0 Ep10/300 — Train MSE: 40.398880, Val MSE: 46.396107, LR: 1.000e-03
Fold0 Ep11/300 — Train MSE: 40.1

  meta_model.load_state_dict(torch.load(best_path, map_location=device))
  net.load_state_dict(torch.load(ckpt_path, map_location=device))


🌀 Starting augment for meta-train …
🌀 Starting import sugmentation data …
🌀 Starting prepare OOF data from CNN model…


Computing AE recon loss: 100%|██████████| 114/114 [00:14<00:00,  7.68it/s]


ae-recon-loss -> cols:    1, names:    1 OK
ae           -> cols:  384, names:  384 OK
trained-latents -> cols:  384, names:  384 OK
latent       -> cols:    4, names:    4 OK
subtile4     -> cols:   12, names:   12 OK
exsubtiles   -> cols:   12, names:   12 OK
tile         -> cols:   12, names:   12 OK
contrast     -> cols:    3, names:    3 OK
wavelet-tile -> cols:  280, names:  280 OK
sobel-tile   -> cols:   40, names:   40 OK
hsv-tile     -> cols:  120, names:  120 OK
he-tile      -> cols:   80, names:   80 OK
locstd       -> cols:  432, names:  432 OK
oof          -> cols:   35, names:   35 OK
adj          -> cols:   34, names:   34 OK
last         -> cols:  136, names:  136 OK
top          -> cols:   20, names:   20 OK
adj-his      -> cols:   10, names:   10 OK
diff         -> cols:  595, names:  595 OK
mad          -> cols:    1, names:    1 OK
skewness     -> cols:    2, names:    2 OK
p25          -> cols:    4, names:    4 OK
renyi        -> cols:    1, names:    1 OK
mass   

Computing AE recon loss: 100%|██████████| 8/8 [00:01<00:00,  4.99it/s]


ae-recon-loss -> cols:    1, names:    1 OK
ae           -> cols:  384, names:  384 OK
trained-latents -> cols:  384, names:  384 OK
latent       -> cols:    4, names:    4 OK
subtile4     -> cols:   12, names:   12 OK
exsubtiles   -> cols:   12, names:   12 OK
tile         -> cols:   12, names:   12 OK
contrast     -> cols:    3, names:    3 OK
wavelet-tile -> cols:  280, names:  280 OK
sobel-tile   -> cols:   40, names:   40 OK
hsv-tile     -> cols:  120, names:  120 OK
he-tile      -> cols:   80, names:   80 OK
locstd       -> cols:  432, names:  432 OK
oof          -> cols:   35, names:   35 OK
adj          -> cols:   34, names:   34 OK
last         -> cols:  136, names:  136 OK
top          -> cols:   20, names:   20 OK
adj-his      -> cols:   10, names:   10 OK
diff         -> cols:  595, names:  595 OK
mad          -> cols:    1, names:    1 OK
skewness     -> cols:    2, names:    2 OK
p25          -> cols:    4, names:    4 OK
renyi        -> cols:    1, names:    1 OK
mass   



Fold1 Ep1/300 — Train MSE: 411.042008, Val MSE: 377.315096, LR: 1.000e-03
 ↳ New best (Val MSE=377.315096), saved.
Fold1 Ep2/300 — Train MSE: 270.428721, Val MSE: 162.301475, LR: 1.000e-03
 ↳ New best (Val MSE=162.301475), saved.
Fold1 Ep3/300 — Train MSE: 90.509911, Val MSE: 47.384867, LR: 1.000e-03
 ↳ New best (Val MSE=47.384867), saved.
Fold1 Ep4/300 — Train MSE: 36.676492, Val MSE: 28.724335, LR: 1.000e-03
 ↳ New best (Val MSE=28.724335), saved.
Fold1 Ep5/300 — Train MSE: 29.649644, Val MSE: 27.953990, LR: 1.000e-03
 ↳ New best (Val MSE=27.953990), saved.
Fold1 Ep6/300 — Train MSE: 28.837453, Val MSE: 28.143556, LR: 1.000e-03
Fold1 Ep7/300 — Train MSE: 28.451967, Val MSE: 28.259422, LR: 1.000e-03
Fold1 Ep8/300 — Train MSE: 27.967031, Val MSE: 28.058649, LR: 1.000e-03
Fold1 Ep9/300 — Train MSE: 27.670166, Val MSE: 28.164841, LR: 1.000e-03
Fold1 Ep10/300 — Train MSE: 27.576628, Val MSE: 28.310839, LR: 1.000e-03
Fold1 Ep11/300 — Train MSE: 27.387788, Val MSE: 28.470198, LR: 1.000e-03


  meta_model.load_state_dict(torch.load(best_path, map_location=device))
  net.load_state_dict(torch.load(ckpt_path, map_location=device))


🌀 Starting augment for meta-train …
🌀 Starting import sugmentation data …
🌀 Starting prepare OOF data from CNN model…


Computing AE recon loss: 100%|██████████| 35/35 [00:05<00:00,  6.46it/s]


ae-recon-loss -> cols:    1, names:    1 OK
ae           -> cols:  384, names:  384 OK
trained-latents -> cols:  384, names:  384 OK
latent       -> cols:    4, names:    4 OK
subtile4     -> cols:   12, names:   12 OK
exsubtiles   -> cols:   12, names:   12 OK
tile         -> cols:   12, names:   12 OK
contrast     -> cols:    3, names:    3 OK
wavelet-tile -> cols:  280, names:  280 OK
sobel-tile   -> cols:   40, names:   40 OK
hsv-tile     -> cols:  120, names:  120 OK
he-tile      -> cols:   80, names:   80 OK
locstd       -> cols:  432, names:  432 OK
oof          -> cols:   35, names:   35 OK
adj          -> cols:   34, names:   34 OK
last         -> cols:  136, names:  136 OK
top          -> cols:   20, names:   20 OK
adj-his      -> cols:   10, names:   10 OK
diff         -> cols:  595, names:  595 OK
mad          -> cols:    1, names:    1 OK
skewness     -> cols:    2, names:    2 OK
p25          -> cols:    4, names:    4 OK
renyi        -> cols:    1, names:    1 OK
mass   

Computing AE recon loss: 100%|██████████| 3/3 [00:00<00:00,  3.12it/s]


ae-recon-loss -> cols:    1, names:    1 OK
ae           -> cols:  384, names:  384 OK
trained-latents -> cols:  384, names:  384 OK
latent       -> cols:    4, names:    4 OK
subtile4     -> cols:   12, names:   12 OK
exsubtiles   -> cols:   12, names:   12 OK
tile         -> cols:   12, names:   12 OK
contrast     -> cols:    3, names:    3 OK
wavelet-tile -> cols:  280, names:  280 OK
sobel-tile   -> cols:   40, names:   40 OK
hsv-tile     -> cols:  120, names:  120 OK
he-tile      -> cols:   80, names:   80 OK
locstd       -> cols:  432, names:  432 OK
oof          -> cols:   35, names:   35 OK
adj          -> cols:   34, names:   34 OK
last         -> cols:  136, names:  136 OK
top          -> cols:   20, names:   20 OK
adj-his      -> cols:   10, names:   10 OK
diff         -> cols:  595, names:  595 OK
mad          -> cols:    1, names:    1 OK
skewness     -> cols:    2, names:    2 OK
p25          -> cols:    4, names:    4 OK
renyi        -> cols:    1, names:    1 OK
mass   



Fold2 Ep1/300 — Train MSE: 424.875498, Val MSE: 427.288564, LR: 1.000e-03
 ↳ New best (Val MSE=427.288564), saved.
Fold2 Ep2/300 — Train MSE: 418.162374, Val MSE: 413.792458, LR: 1.000e-03
 ↳ New best (Val MSE=413.792458), saved.
Fold2 Ep3/300 — Train MSE: 400.164258, Val MSE: 381.396480, LR: 1.000e-03
 ↳ New best (Val MSE=381.396480), saved.
Fold2 Ep4/300 — Train MSE: 361.391283, Val MSE: 327.860761, LR: 1.000e-03
 ↳ New best (Val MSE=327.860761), saved.
Fold2 Ep5/300 — Train MSE: 303.020588, Val MSE: 264.346616, LR: 1.000e-03
 ↳ New best (Val MSE=264.346616), saved.
Fold2 Ep6/300 — Train MSE: 234.182586, Val MSE: 194.364886, LR: 1.000e-03
 ↳ New best (Val MSE=194.364886), saved.
Fold2 Ep7/300 — Train MSE: 168.792973, Val MSE: 132.308159, LR: 1.000e-03
 ↳ New best (Val MSE=132.308159), saved.
Fold2 Ep8/300 — Train MSE: 117.709889, Val MSE: 94.294266, LR: 1.000e-03
 ↳ New best (Val MSE=94.294266), saved.
Fold2 Ep9/300 — Train MSE: 83.275832, Val MSE: 68.708540, LR: 1.000e-03
 ↳ New bes

  meta_model.load_state_dict(torch.load(best_path, map_location=device))
  net.load_state_dict(torch.load(ckpt_path, map_location=device))


🌀 Starting augment for meta-train …
🌀 Starting import sugmentation data …
🌀 Starting prepare OOF data from CNN model…


Computing AE recon loss: 100%|██████████| 60/60 [00:09<00:00,  6.45it/s]


ae-recon-loss -> cols:    1, names:    1 OK
ae           -> cols:  384, names:  384 OK
trained-latents -> cols:  384, names:  384 OK
latent       -> cols:    4, names:    4 OK
subtile4     -> cols:   12, names:   12 OK
exsubtiles   -> cols:   12, names:   12 OK
tile         -> cols:   12, names:   12 OK
contrast     -> cols:    3, names:    3 OK
wavelet-tile -> cols:  280, names:  280 OK
sobel-tile   -> cols:   40, names:   40 OK
hsv-tile     -> cols:  120, names:  120 OK
he-tile      -> cols:   80, names:   80 OK
locstd       -> cols:  432, names:  432 OK
oof          -> cols:   35, names:   35 OK
adj          -> cols:   34, names:   34 OK
last         -> cols:  136, names:  136 OK
top          -> cols:   20, names:   20 OK
adj-his      -> cols:   10, names:   10 OK
diff         -> cols:  595, names:  595 OK
mad          -> cols:    1, names:    1 OK
skewness     -> cols:    2, names:    2 OK
p25          -> cols:    4, names:    4 OK
renyi        -> cols:    1, names:    1 OK
mass   

Computing AE recon loss: 100%|██████████| 4/4 [00:01<00:00,  3.16it/s]


ae-recon-loss -> cols:    1, names:    1 OK
ae           -> cols:  384, names:  384 OK
trained-latents -> cols:  384, names:  384 OK
latent       -> cols:    4, names:    4 OK
subtile4     -> cols:   12, names:   12 OK
exsubtiles   -> cols:   12, names:   12 OK
tile         -> cols:   12, names:   12 OK
contrast     -> cols:    3, names:    3 OK
wavelet-tile -> cols:  280, names:  280 OK
sobel-tile   -> cols:   40, names:   40 OK
hsv-tile     -> cols:  120, names:  120 OK
he-tile      -> cols:   80, names:   80 OK
locstd       -> cols:  432, names:  432 OK
oof          -> cols:   35, names:   35 OK
adj          -> cols:   34, names:   34 OK
last         -> cols:  136, names:  136 OK
top          -> cols:   20, names:   20 OK
adj-his      -> cols:   10, names:   10 OK
diff         -> cols:  595, names:  595 OK
mad          -> cols:    1, names:    1 OK
skewness     -> cols:    2, names:    2 OK
p25          -> cols:    4, names:    4 OK
renyi        -> cols:    1, names:    1 OK
mass   



Fold3 Ep1/300 — Train MSE: 418.914582, Val MSE: 407.213725, LR: 1.000e-03
 ↳ New best (Val MSE=407.213725), saved.
Fold3 Ep2/300 — Train MSE: 388.166473, Val MSE: 364.770517, LR: 1.000e-03
 ↳ New best (Val MSE=364.770517), saved.
Fold3 Ep3/300 — Train MSE: 311.646583, Val MSE: 262.454061, LR: 1.000e-03
 ↳ New best (Val MSE=262.454061), saved.
Fold3 Ep4/300 — Train MSE: 200.667073, Val MSE: 152.127018, LR: 1.000e-03
 ↳ New best (Val MSE=152.127018), saved.
Fold3 Ep5/300 — Train MSE: 112.136992, Val MSE: 88.034260, LR: 1.000e-03
 ↳ New best (Val MSE=88.034260), saved.
Fold3 Ep6/300 — Train MSE: 70.441844, Val MSE: 66.564086, LR: 1.000e-03
 ↳ New best (Val MSE=66.564086), saved.
Fold3 Ep7/300 — Train MSE: 58.056061, Val MSE: 67.834291, LR: 1.000e-03
Fold3 Ep8/300 — Train MSE: 54.285114, Val MSE: 70.151232, LR: 1.000e-03
Fold3 Ep9/300 — Train MSE: 53.548904, Val MSE: 67.810923, LR: 1.000e-03
Fold3 Ep10/300 — Train MSE: 52.242772, Val MSE: 66.898035, LR: 1.000e-03
Fold3 Ep11/300 — Train MSE

  meta_model.load_state_dict(torch.load(best_path, map_location=device))


 Fold3 Best refined Val MSE: 51.745636
 ✅ Best meta-model saved to: output_folder/rank-spot/realign/no_pretrain/3_encoder/filtered_directly_rank/k-fold/realign_all/stain_nor_with_4_7/Macenko_masked/fold3/meta_model_best.pt

🚀 Starting fold 4...


  net.load_state_dict(torch.load(ckpt_path, map_location=device))


🌀 Starting augment for meta-train …
🌀 Starting import sugmentation data …
🌀 Starting prepare OOF data from CNN model…


Computing AE recon loss: 100%|██████████| 84/84 [00:14<00:00,  5.67it/s]


ae-recon-loss -> cols:    1, names:    1 OK
ae           -> cols:  384, names:  384 OK
trained-latents -> cols:  384, names:  384 OK
latent       -> cols:    4, names:    4 OK
subtile4     -> cols:   12, names:   12 OK
exsubtiles   -> cols:   12, names:   12 OK
tile         -> cols:   12, names:   12 OK
contrast     -> cols:    3, names:    3 OK
wavelet-tile -> cols:  280, names:  280 OK
sobel-tile   -> cols:   40, names:   40 OK
hsv-tile     -> cols:  120, names:  120 OK
he-tile      -> cols:   80, names:   80 OK
locstd       -> cols:  432, names:  432 OK
oof          -> cols:   35, names:   35 OK
adj          -> cols:   34, names:   34 OK
last         -> cols:  136, names:  136 OK
top          -> cols:   20, names:   20 OK
adj-his      -> cols:   10, names:   10 OK
diff         -> cols:  595, names:  595 OK
mad          -> cols:    1, names:    1 OK
skewness     -> cols:    2, names:    2 OK
p25          -> cols:    4, names:    4 OK
renyi        -> cols:    1, names:    1 OK
mass   

Computing AE recon loss: 100%|██████████| 6/6 [00:01<00:00,  3.79it/s]


ae-recon-loss -> cols:    1, names:    1 OK
ae           -> cols:  384, names:  384 OK
trained-latents -> cols:  384, names:  384 OK
latent       -> cols:    4, names:    4 OK
subtile4     -> cols:   12, names:   12 OK
exsubtiles   -> cols:   12, names:   12 OK
tile         -> cols:   12, names:   12 OK
contrast     -> cols:    3, names:    3 OK
wavelet-tile -> cols:  280, names:  280 OK
sobel-tile   -> cols:   40, names:   40 OK
hsv-tile     -> cols:  120, names:  120 OK
he-tile      -> cols:   80, names:   80 OK
locstd       -> cols:  432, names:  432 OK
oof          -> cols:   35, names:   35 OK
adj          -> cols:   34, names:   34 OK
last         -> cols:  136, names:  136 OK
top          -> cols:   20, names:   20 OK
adj-his      -> cols:   10, names:   10 OK
diff         -> cols:  595, names:  595 OK
mad          -> cols:    1, names:    1 OK
skewness     -> cols:    2, names:    2 OK
p25          -> cols:    4, names:    4 OK
renyi        -> cols:    1, names:    1 OK
mass   



Fold4 Ep1/300 — Train MSE: 409.448090, Val MSE: 395.539264, LR: 1.000e-03
 ↳ New best (Val MSE=395.539264), saved.
Fold4 Ep2/300 — Train MSE: 345.559379, Val MSE: 277.021697, LR: 1.000e-03
 ↳ New best (Val MSE=277.021697), saved.
Fold4 Ep3/300 — Train MSE: 199.972403, Val MSE: 129.048939, LR: 1.000e-03
 ↳ New best (Val MSE=129.048939), saved.
Fold4 Ep4/300 — Train MSE: 88.795603, Val MSE: 56.900571, LR: 1.000e-03
 ↳ New best (Val MSE=56.900571), saved.
Fold4 Ep5/300 — Train MSE: 55.122899, Val MSE: 45.398594, LR: 1.000e-03
 ↳ New best (Val MSE=45.398594), saved.
Fold4 Ep6/300 — Train MSE: 49.904167, Val MSE: 44.283881, LR: 1.000e-03
 ↳ New best (Val MSE=44.283881), saved.
Fold4 Ep7/300 — Train MSE: 49.051342, Val MSE: 45.661766, LR: 1.000e-03
Fold4 Ep8/300 — Train MSE: 48.755924, Val MSE: 48.140233, LR: 1.000e-03
Fold4 Ep9/300 — Train MSE: 47.674203, Val MSE: 45.336952, LR: 1.000e-03
Fold4 Ep10/300 — Train MSE: 47.661655, Val MSE: 50.403617, LR: 1.000e-03
Fold4 Ep11/300 — Train MSE: 47

  meta_model.load_state_dict(torch.load(best_path, map_location=device))
  net.load_state_dict(torch.load(ckpt_path, map_location=device))


🌀 Starting augment for meta-train …
🌀 Starting import sugmentation data …
🌀 Starting prepare OOF data from CNN model…


Computing AE recon loss: 100%|██████████| 17/17 [00:02<00:00,  6.11it/s]


ae-recon-loss -> cols:    1, names:    1 OK
ae           -> cols:  384, names:  384 OK
trained-latents -> cols:  384, names:  384 OK
latent       -> cols:    4, names:    4 OK
subtile4     -> cols:   12, names:   12 OK
exsubtiles   -> cols:   12, names:   12 OK
tile         -> cols:   12, names:   12 OK
contrast     -> cols:    3, names:    3 OK
wavelet-tile -> cols:  280, names:  280 OK
sobel-tile   -> cols:   40, names:   40 OK
hsv-tile     -> cols:  120, names:  120 OK
he-tile      -> cols:   80, names:   80 OK
locstd       -> cols:  432, names:  432 OK
oof          -> cols:   35, names:   35 OK
adj          -> cols:   34, names:   34 OK
last         -> cols:  136, names:  136 OK
top          -> cols:   20, names:   20 OK
adj-his      -> cols:   10, names:   10 OK
diff         -> cols:  595, names:  595 OK
mad          -> cols:    1, names:    1 OK
skewness     -> cols:    2, names:    2 OK
p25          -> cols:    4, names:    4 OK
renyi        -> cols:    1, names:    1 OK
mass   

Computing AE recon loss: 100%|██████████| 2/2 [00:00<00:00,  2.34it/s]


ae-recon-loss -> cols:    1, names:    1 OK
ae           -> cols:  384, names:  384 OK
trained-latents -> cols:  384, names:  384 OK
latent       -> cols:    4, names:    4 OK
subtile4     -> cols:   12, names:   12 OK
exsubtiles   -> cols:   12, names:   12 OK
tile         -> cols:   12, names:   12 OK
contrast     -> cols:    3, names:    3 OK
wavelet-tile -> cols:  280, names:  280 OK
sobel-tile   -> cols:   40, names:   40 OK
hsv-tile     -> cols:  120, names:  120 OK
he-tile      -> cols:   80, names:   80 OK
locstd       -> cols:  432, names:  432 OK
oof          -> cols:   35, names:   35 OK
adj          -> cols:   34, names:   34 OK
last         -> cols:  136, names:  136 OK
top          -> cols:   20, names:   20 OK
adj-his      -> cols:   10, names:   10 OK
diff         -> cols:  595, names:  595 OK
mad          -> cols:    1, names:    1 OK
skewness     -> cols:    2, names:    2 OK
p25          -> cols:    4, names:    4 OK
renyi        -> cols:    1, names:    1 OK
mass   



Fold5 Ep1/300 — Train MSE: 424.029936, Val MSE: 420.461090, LR: 1.000e-03
 ↳ New best (Val MSE=420.461090), saved.
Fold5 Ep2/300 — Train MSE: 419.462858, Val MSE: 419.982905, LR: 1.000e-03
 ↳ New best (Val MSE=419.982905), saved.
Fold5 Ep3/300 — Train MSE: 415.673991, Val MSE: 416.259880, LR: 1.000e-03
 ↳ New best (Val MSE=416.259880), saved.
Fold5 Ep4/300 — Train MSE: 410.761667, Val MSE: 405.759045, LR: 1.000e-03
 ↳ New best (Val MSE=405.759045), saved.
Fold5 Ep5/300 — Train MSE: 403.037614, Val MSE: 398.261005, LR: 1.000e-03
 ↳ New best (Val MSE=398.261005), saved.
Fold5 Ep6/300 — Train MSE: 391.263614, Val MSE: 388.079122, LR: 1.000e-03
 ↳ New best (Val MSE=388.079122), saved.
Fold5 Ep7/300 — Train MSE: 373.915523, Val MSE: 364.999288, LR: 1.000e-03
 ↳ New best (Val MSE=364.999288), saved.
Fold5 Ep8/300 — Train MSE: 351.082631, Val MSE: 332.308534, LR: 1.000e-03
 ↳ New best (Val MSE=332.308534), saved.
Fold5 Ep9/300 — Train MSE: 322.790691, Val MSE: 316.812773, LR: 1.000e-03
 ↳ New

  meta_model.load_state_dict(torch.load(best_path, map_location=device))


In [25]:
import torch
from python_scripts.import_data import load_node_feature_data


image_keys = [ 'tile', 'subtiles']

model = VisionMLP_MultiTask(tile_dim=tile_dim, subtile_dim=center_dim, output_dim=C)

# 用法示例
from python_scripts.import_data import importDataset
# 假设你的 model 已经定义好并实例化为 `model`
test_dataset = load_node_feature_data("dataset/spot-rank/filtered_directly_rank/masked/test/Macenko_4_7/test_dataset.pt", model)
test_dataset = importDataset(
        data_dict=test_dataset,
        model=model,
        image_keys=image_keys,
        transform=lambda x: x,  # identity transform
        print_sig=True
    )



  raw = torch.load(pt_path, map_location="cpu")


⚠️ 從 '<class 'list'>' 推斷樣本數量: 2088
Model forward signature: (tile, subtiles)


In [32]:
# --- 0) 先读入各 fold 的 val_mse，并计算权重 ---
mse_df = pd.read_csv(
    os.path.join(trained_oof_model_folder, "fold_meta_mse.csv"),
    index_col="fold"
).sort_index()
mses = mse_df["val_mse"].values          # shape (n_folds,)
weights = 1.0 / mses                      # 越低的 mse 权重大
weights = weights / weights.sum()         # 归一化和为 1
print("Ensemble weights per fold:", weights)

# --- 3) Prepare test meta-features & stacking predictions ---
n_test = len(test_dataset)
all_final = []

for fold_id in range(n_folds):
    # if fold_id > start_fold:
    #     print(f"⏭️ Skipping fold {fold_id}")
    #     continue
    print(f"🍀 Fold {fold_id} predicting ...")
    
    # 1) Load base model (VisionMLP_MultiTask) and get test_preds, test_latents
    ckpt_path   = os.path.join(trained_oof_model_folder, f"fold{fold_id}", "best_model.pt")
    net         = VisionMLP_MultiTask(tile_dim=tile_dim, subtile_dim=center_dim, output_dim=C)
    net.load_state_dict(torch.load(ckpt_path, map_location=device))
    net = net.to(device).eval()

    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

    test_latents, test_preds, _ = extract_feats_preds(test_loader)
    extra_test_feats, test_feat_names = generate_meta_features(
        dataset=test_dataset,
        model_for_recon=recon_model,
        device=device,
        ae_type=ae_type,
        oof_preds=test_preds,
        latents=test_latents
    )
    diagnose_meta_nonfinite(extra_test_feats, test_feat_names)

    # 5) 把三部分特征拼在一起： [latents | preds | extra_feats]
    X_feat_test = np.concatenate([test_latents, extra_test_feats], axis=1)
    # 这里 D_feat = X_feat_train.shape[1]
    D_feat = X_feat_test.shape[1]

    # 把 features, predictions, labels 三个张量放一起
    ds_test_meta = TensorDataset(
        torch.from_numpy(X_feat_test).float(),
        torch.from_numpy(test_preds).float())


    loader_test_meta = DataLoader(ds_test_meta, batch_size=BATCH_SIZE, shuffle=False)

    # 2) Load the trained meta-learner for this fold
    D = test_latents.shape[1]
    meta_model = DeepResMLP(in_dim=D_feat, hidden_dims=[1024,512,256,128], out_dim=C).to(device)

    meta_path = os.path.join(
        trained_oof_model_folder,
        f"fold{fold_id}",
        "meta_model_best.pt"
    )
    meta_model.load_state_dict(torch.load(meta_path, map_location=device))
    meta_model.eval()

    # 3) 用 meta-learner 做 refined prediction
    with torch.no_grad():
        feats_tensor = torch.from_numpy(X_feat_test).float().to(device)
        preds_tensor = torch.from_numpy(test_preds).float().to(device)
        refined_pred = meta_model(feats_tensor, preds_tensor).cpu().numpy()

    all_final.append(refined_pred)

# --- 4) 加权 ensemble ---
# all_refined: list of (n_test, C) arrays
# weights: shape (n_folds,)
final_weighted = np.zeros_like(all_final[0])
for w, pred in zip(weights, all_final):
    final_weighted += w * pred

# --- 5) Save submissions ---
with h5py.File("./dataset/elucidata_ai_challenge_data.h5","r") as f:
    test_spot_ids = pd.DataFrame(np.array(f["spots/Test"]["S_7"]))

# unweighted average（保留旧版对比）
final_simple = np.mean(all_final, axis=0)

# 写两个文件
for name, arr in [("stacked", final_simple), ("weighted", final_weighted)]:
    sub = pd.DataFrame(arr, columns=[f"C{i+1}" for i in range(C)])
    sub.insert(0, 'ID', test_spot_ids.index)
    path = os.path.join(trained_oof_model_folder, f"submission_{name}.csv")
    sub.to_csv(path, index=False)
    print(f"✅ Saved {name} ensemble submission → {path}")

Ensemble weights per fold: [0.14668326 0.22717708 0.13848365 0.12272543 0.15362587 0.21130471]
🍀 Fold 0 predicting ...


  net.load_state_dict(torch.load(ckpt_path, map_location=device))
Computing AE recon loss: 100%|██████████| 33/33 [00:05<00:00,  6.41it/s]


ae-recon-loss -> cols:    1, names:    1 OK
ae           -> cols:  384, names:  384 OK
trained-latents -> cols:  384, names:  384 OK
latent       -> cols:    4, names:    4 OK
subtile4     -> cols:   12, names:   12 OK
exsubtiles   -> cols:   12, names:   12 OK
tile         -> cols:   12, names:   12 OK
contrast     -> cols:    3, names:    3 OK
wavelet-tile -> cols:  280, names:  280 OK
sobel-tile   -> cols:   40, names:   40 OK
hsv-tile     -> cols:  120, names:  120 OK
he-tile      -> cols:   80, names:   80 OK
locstd       -> cols:  432, names:  432 OK
oof          -> cols:   35, names:   35 OK
adj          -> cols:   34, names:   34 OK
last         -> cols:  136, names:  136 OK
top          -> cols:   20, names:   20 OK
adj-his      -> cols:   10, names:   10 OK
diff         -> cols:  595, names:  595 OK
mad          -> cols:    1, names:    1 OK
skewness     -> cols:    2, names:    2 OK
p25          -> cols:    4, names:    4 OK
renyi        -> cols:    1, names:    1 OK
mass   

  meta_model.load_state_dict(torch.load(meta_path, map_location=device))


🍀 Fold 1 predicting ...


  net.load_state_dict(torch.load(ckpt_path, map_location=device))
Computing AE recon loss: 100%|██████████| 33/33 [00:05<00:00,  6.33it/s]


ae-recon-loss -> cols:    1, names:    1 OK
ae           -> cols:  384, names:  384 OK
trained-latents -> cols:  384, names:  384 OK
latent       -> cols:    4, names:    4 OK
subtile4     -> cols:   12, names:   12 OK
exsubtiles   -> cols:   12, names:   12 OK
tile         -> cols:   12, names:   12 OK
contrast     -> cols:    3, names:    3 OK
wavelet-tile -> cols:  280, names:  280 OK
sobel-tile   -> cols:   40, names:   40 OK
hsv-tile     -> cols:  120, names:  120 OK
he-tile      -> cols:   80, names:   80 OK
locstd       -> cols:  432, names:  432 OK
oof          -> cols:   35, names:   35 OK
adj          -> cols:   34, names:   34 OK
last         -> cols:  136, names:  136 OK
top          -> cols:   20, names:   20 OK
adj-his      -> cols:   10, names:   10 OK
diff         -> cols:  595, names:  595 OK
mad          -> cols:    1, names:    1 OK
skewness     -> cols:    2, names:    2 OK
p25          -> cols:    4, names:    4 OK
renyi        -> cols:    1, names:    1 OK
mass   

  meta_model.load_state_dict(torch.load(meta_path, map_location=device))
  net.load_state_dict(torch.load(ckpt_path, map_location=device))


🍀 Fold 2 predicting ...


Computing AE recon loss: 100%|██████████| 33/33 [00:06<00:00,  4.86it/s]


ae-recon-loss -> cols:    1, names:    1 OK
ae           -> cols:  384, names:  384 OK
trained-latents -> cols:  384, names:  384 OK
latent       -> cols:    4, names:    4 OK
subtile4     -> cols:   12, names:   12 OK
exsubtiles   -> cols:   12, names:   12 OK
tile         -> cols:   12, names:   12 OK
contrast     -> cols:    3, names:    3 OK
wavelet-tile -> cols:  280, names:  280 OK
sobel-tile   -> cols:   40, names:   40 OK
hsv-tile     -> cols:  120, names:  120 OK
he-tile      -> cols:   80, names:   80 OK
locstd       -> cols:  432, names:  432 OK
oof          -> cols:   35, names:   35 OK
adj          -> cols:   34, names:   34 OK
last         -> cols:  136, names:  136 OK
top          -> cols:   20, names:   20 OK
adj-his      -> cols:   10, names:   10 OK
diff         -> cols:  595, names:  595 OK
mad          -> cols:    1, names:    1 OK
skewness     -> cols:    2, names:    2 OK
p25          -> cols:    4, names:    4 OK
renyi        -> cols:    1, names:    1 OK
mass   

  meta_model.load_state_dict(torch.load(meta_path, map_location=device))
  net.load_state_dict(torch.load(ckpt_path, map_location=device))


🍀 Fold 3 predicting ...


Computing AE recon loss: 100%|██████████| 33/33 [00:05<00:00,  6.24it/s]


ae-recon-loss -> cols:    1, names:    1 OK
ae           -> cols:  384, names:  384 OK
trained-latents -> cols:  384, names:  384 OK
latent       -> cols:    4, names:    4 OK
subtile4     -> cols:   12, names:   12 OK
exsubtiles   -> cols:   12, names:   12 OK
tile         -> cols:   12, names:   12 OK
contrast     -> cols:    3, names:    3 OK
wavelet-tile -> cols:  280, names:  280 OK
sobel-tile   -> cols:   40, names:   40 OK
hsv-tile     -> cols:  120, names:  120 OK
he-tile      -> cols:   80, names:   80 OK
locstd       -> cols:  432, names:  432 OK
oof          -> cols:   35, names:   35 OK
adj          -> cols:   34, names:   34 OK
last         -> cols:  136, names:  136 OK
top          -> cols:   20, names:   20 OK
adj-his      -> cols:   10, names:   10 OK
diff         -> cols:  595, names:  595 OK
mad          -> cols:    1, names:    1 OK
skewness     -> cols:    2, names:    2 OK
p25          -> cols:    4, names:    4 OK
renyi        -> cols:    1, names:    1 OK
mass   

  meta_model.load_state_dict(torch.load(meta_path, map_location=device))


🍀 Fold 4 predicting ...


  net.load_state_dict(torch.load(ckpt_path, map_location=device))
Computing AE recon loss: 100%|██████████| 33/33 [00:05<00:00,  6.06it/s]


ae-recon-loss -> cols:    1, names:    1 OK
ae           -> cols:  384, names:  384 OK
trained-latents -> cols:  384, names:  384 OK
latent       -> cols:    4, names:    4 OK
subtile4     -> cols:   12, names:   12 OK
exsubtiles   -> cols:   12, names:   12 OK
tile         -> cols:   12, names:   12 OK
contrast     -> cols:    3, names:    3 OK
wavelet-tile -> cols:  280, names:  280 OK
sobel-tile   -> cols:   40, names:   40 OK
hsv-tile     -> cols:  120, names:  120 OK
he-tile      -> cols:   80, names:   80 OK
locstd       -> cols:  432, names:  432 OK
oof          -> cols:   35, names:   35 OK
adj          -> cols:   34, names:   34 OK
last         -> cols:  136, names:  136 OK
top          -> cols:   20, names:   20 OK
adj-his      -> cols:   10, names:   10 OK
diff         -> cols:  595, names:  595 OK
mad          -> cols:    1, names:    1 OK
skewness     -> cols:    2, names:    2 OK
p25          -> cols:    4, names:    4 OK
renyi        -> cols:    1, names:    1 OK
mass   

  meta_model.load_state_dict(torch.load(meta_path, map_location=device))
  net.load_state_dict(torch.load(ckpt_path, map_location=device))


🍀 Fold 5 predicting ...


Computing AE recon loss: 100%|██████████| 33/33 [00:02<00:00, 11.57it/s]


ae-recon-loss -> cols:    1, names:    1 OK
ae           -> cols:  384, names:  384 OK
trained-latents -> cols:  384, names:  384 OK
latent       -> cols:    4, names:    4 OK
subtile4     -> cols:   12, names:   12 OK
exsubtiles   -> cols:   12, names:   12 OK
tile         -> cols:   12, names:   12 OK
contrast     -> cols:    3, names:    3 OK
wavelet-tile -> cols:  280, names:  280 OK
sobel-tile   -> cols:   40, names:   40 OK
hsv-tile     -> cols:  120, names:  120 OK
he-tile      -> cols:   80, names:   80 OK
locstd       -> cols:  432, names:  432 OK
oof          -> cols:   35, names:   35 OK
adj          -> cols:   34, names:   34 OK
last         -> cols:  136, names:  136 OK
top          -> cols:   20, names:   20 OK
adj-his      -> cols:   10, names:   10 OK
diff         -> cols:  595, names:  595 OK
mad          -> cols:    1, names:    1 OK
skewness     -> cols:    2, names:    2 OK
p25          -> cols:    4, names:    4 OK
renyi        -> cols:    1, names:    1 OK
mass   

  meta_model.load_state_dict(torch.load(meta_path, map_location=device))


✅ Saved stacked ensemble submission → output_folder/rank-spot/realign/no_pretrain/3_encoder/filtered_directly_rank/k-fold/realign_all/stain_nor_with_4_7/Macenko_masked/submission_stacked.csv
✅ Saved weighted ensemble submission → output_folder/rank-spot/realign/no_pretrain/3_encoder/filtered_directly_rank/k-fold/realign_all/stain_nor_with_4_7/Macenko_masked/submission_weighted.csv


In [None]:
import glob
import torch
import numpy as np
import pandas as pd
import os
import h5py
from torch.utils.data import DataLoader

# 讀 test spot index
with h5py.File("./dataset/elucidata_ai_challenge_data.h5","r") as f:
    test_spots     = f["spots/Test"]
    test_spot_table= pd.DataFrame(np.array(test_spots['S_7']))

fold_ckpts = sorted(glob.glob(os.path.join(trained_oof_model_folder, "fold*", "best_model.pt")))
models = []
for ckpt in fold_ckpts:
    net = VisionMLP_MultiTask(tile_dim=tile_dim, subtile_dim=center_dim, output_dim=C)
    net = net.to(device)
    net.load_state_dict(torch.load(ckpt, map_location="cpu"))
    net.to(device).eval()
    models.append(net)

all_fold_preds = []
for fold_id, net in enumerate(models):
    # 推論
    print(f"🍀 Fold {fold_id} predicting ...")
    with torch.no_grad():
        preds = predict(net, test_loader, device)  # (N_test,35) numpy array

    # 1) 存每一折的原始預測
    df_fold = pd.DataFrame(preds, columns=[f"C{i+1}" for i in range(preds.shape[1])])
    df_fold.insert(0, "ID", test_spot_table.index)
    path_fold = os.path.join(trained_oof_model_folder, f"submission_fold{fold_id}.csv")
    df_fold.to_csv(path_fold, index=False)
    print(f"✅ Saved fold {fold_id} predictions to {path_fold}")

    all_fold_preds.append(preds)

# 2) 做 rank‐average ensemble
all_fold_preds = np.stack(all_fold_preds, axis=0)       # (K, N_test, 35)
ranks          = all_fold_preds.argsort(axis=2).argsort(axis=2).astype(float)
mean_rank      = ranks.mean(axis=0)                    # (N_test,35)

# 3) 存 final ensemble
df_ens = pd.DataFrame(mean_rank, columns=[f"C{i+1}" for i in range(mean_rank.shape[1])])
df_ens.insert(0, "ID", test_spot_table.index)
path_ens = os.path.join(trained_oof_model_folder, "submission_rank_ensemble.csv")
df_ens.to_csv(path_ens, index=False)
print(f"✅ Saved rank‐ensemble submission to {path_ens}")


array([[21.119944, 24.191807, 24.289158, ..., 12.108937,  9.812806,
        23.269012],
       [20.796225, 24.415634, 24.210833, ..., 12.384914,  9.888517,
        24.08245 ],
       [24.244806, 28.043358, 25.876398, ...,  9.660413,  9.817519,
        19.834955],
       ...,
       [21.064796, 30.208473, 23.31816 , ..., 11.452721,  8.848305,
        20.618204],
       [20.672852, 24.048859, 23.066084, ..., 12.829765, 10.746237,
        25.129263],
       [18.379812, 24.315762, 22.625013, ..., 12.282604, 10.867987,
        25.213802]], dtype=float32)

# model stacking with one meta model

In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from python_scripts.import_data import importDataset  # 假设这个能给你 full_dataset
from python_scripts.aug import subset_grouped_data   # 用来切出 grouped_data

# Settings
trained_oof_model_folder = 'output_folder/rank-spot/realign/no_pretrain/3_encoder/filtered_directly_rank/k-fold/realign_all/stain_nor_with_4_7/Macenko_masked'
n_folds    = 6
device     = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64
C          = 35  # 类别数 / 每个预测向量长度
META_EPOCHS = 250
tile_dim = 128
center_dim = 128
neighbor_dim = 128
# 1) Load full_dataset & y_true
n_samples = len(full_dataset)
y_true = np.vstack([full_dataset[i]['label'].cpu().numpy() for i in range(n_samples)])


In [76]:
# 2) 对每个 fold 收集 OOF 预测
oof_preds = np.zeros((n_samples, n_folds, C), dtype=np.float32)
full_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=False)

for fold_id, (_, va_idx) in enumerate(
    logo.split(X=np.zeros(n_samples), y=None, groups=slide_idx)
):
    # a) load base model for this fold
    print(f"🍀 Fold {fold_id} predicting ...")
    ckpt = os.path.join(trained_oof_model_folder, f"fold{fold_id}", "best_model.pt")
    net  = VisionMLP_MultiTask(tile_dim, center_dim, output_dim=C).to(device)
    net.load_state_dict(torch.load(ckpt, map_location=device))
    net.eval()
    
    # b) 用这个 model 只对它的 **验证集** 做预测（OOF）
    val_ds = Subset(full_dataset, va_idx)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)
    preds_list = []
    with torch.no_grad():
        for batch in full_loader:
            tiles, subtiles = batch['tile'].to(device), batch['subtiles'].to(device)
            center = subtiles[:,4]
            f_c = net.encoder_center(center)
            f_n = net.encoder_subtile(subtiles)
            f_t = net.encoder_tile(tiles)
            fuse = torch.cat([f_c,f_n,f_t], dim=1)
            out = net.decoder(fuse)
            preds_list.append(out.cpu().numpy())
    preds_fold = np.concatenate(preds_list, axis=0)  # (len(va_idx), C)

    # c) 填回到 oof_preds[:, fold_id, :]
    oof_preds[:, fold_id, :] = preds_fold

print("🍀 Preparing data for meta model ...")

# 3) reshape → stacking 特征矩阵 X_stack
X_stack = oof_preds.reshape(n_samples, n_folds*C)    # (N, n_folds*C)
y_stack = y_true                                   # (N, C)

# 4) 划分 meta‐train / meta‐val
X_train, X_val, y_train, y_val = train_test_split(
    X_stack, y_stack, test_size=0.2, random_state=42, shuffle=True
)

# 5) DataLoader
ds_train = TensorDataset(torch.from_numpy(X_train).float(),
                         torch.from_numpy(y_train).float())
ds_val   = TensorDataset(torch.from_numpy(X_val).float(),
                         torch.from_numpy(y_val).float())
loader_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True)
loader_val   = DataLoader(ds_val,   batch_size=BATCH_SIZE, shuffle=False)

# 6) 定义/初始化 stacking MLP
in_dim = n_folds * C
class StackingMLP(nn.Module):
    def __init__(self, in_dim, hidden_dims=[512,256], out_dim=35, dropout=0.2):
        super().__init__()
        layers = []
        prev = in_dim
        for h in hidden_dims:
            layers.append(nn.Linear(prev, h))
            layers.append(nn.LeakyReLU(0.01))
            layers.append(nn.BatchNorm1d(h))
            layers.append(nn.Dropout(dropout))
            prev = h
        # 最后一层直接输出 out_dim
        layers.append(nn.Linear(prev, out_dim))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        """
        x: Tensor of shape [B, in_dim], in_dim = n_folds * C
        returns: Tensor of shape [B, out_dim]
        """
        return self.model(x)

mlp = StackingMLP(in_dim=n_folds*C, hidden_dims=[512,256], out_dim=C).to(device)

optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, min_lr=1e-6, verbose=True)
criterion = nn.MSELoss()

best_val = float('inf')
es_cnt, es_patience = 0, 20
best_path = os.path.join(trained_oof_model_folder, "stacking_meta_best.pt")

print("🍀 Start training meta model ...")

# 7) 训练循环
for epoch in range(1, META_EPOCHS+1):
    mlp.train()
    train_loss = 0.0
    for xb, yb in loader_train:
        xb, yb = xb.to(device), yb.to(device)
        out = mlp(xb)               # 前向
        loss = criterion(out, yb)   # MSE
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * xb.size(0)
    train_loss /= len(ds_train)

    # 验证
    mlp.eval()
    val_loss = 0.0
    with torch.no_grad():
        for xb, yb in loader_val:
            xb, yb = xb.to(device), yb.to(device)
            val_loss += criterion(mlp(xb), yb).item() * xb.size(0)
    val_loss /= len(ds_val)


    current_lr = optimizer.param_groups[0]['lr']
    print(f"Epoch {epoch} — Train MSE {train_loss:.6f}, Val MSE {val_loss:.6f}, lr {current_lr:.3e}")
    scheduler.step(val_loss)

    if val_loss < best_val:
        best_val = val_loss; es_cnt = 0
        torch.save(mlp.state_dict(), best_path)
        print(" ↳ New best saved.")
    else:
        es_cnt += 1
        if es_cnt >= es_patience:
            print(" 🛑 Early stopping.")
            break

# 8) 从最好模型 load 回来
mlp.load_state_dict(torch.load(best_path, map_location=device))
mlp.eval()


🍀 Fold 0 predicting ...


  net.load_state_dict(torch.load(ckpt, map_location=device))


🍀 Fold 1 predicting ...
🍀 Fold 2 predicting ...
🍀 Fold 3 predicting ...
🍀 Fold 4 predicting ...
🍀 Fold 5 predicting ...
🍀 Preparing data for meta model ...
🍀 Start training meta model ...




Epoch 1 — Train MSE 398.833550, Val MSE 343.151417, lr 1.000e-03
 ↳ New best saved.
Epoch 2 — Train MSE 220.268595, Val MSE 109.576775, lr 1.000e-03
 ↳ New best saved.
Epoch 3 — Train MSE 70.895166, Val MSE 52.277888, lr 1.000e-03
 ↳ New best saved.
Epoch 4 — Train MSE 53.457557, Val MSE 47.230638, lr 1.000e-03
 ↳ New best saved.
Epoch 5 — Train MSE 51.738313, Val MSE 45.111218, lr 1.000e-03
 ↳ New best saved.
Epoch 6 — Train MSE 50.567188, Val MSE 45.199467, lr 1.000e-03
Epoch 7 — Train MSE 50.151351, Val MSE 46.167478, lr 1.000e-03
Epoch 8 — Train MSE 49.153631, Val MSE 43.752192, lr 1.000e-03
 ↳ New best saved.
Epoch 9 — Train MSE 48.613082, Val MSE 44.310437, lr 1.000e-03
Epoch 10 — Train MSE 48.550224, Val MSE 43.390919, lr 1.000e-03
 ↳ New best saved.
Epoch 11 — Train MSE 48.351976, Val MSE 44.635134, lr 1.000e-03
Epoch 12 — Train MSE 47.697107, Val MSE 42.686886, lr 1.000e-03
 ↳ New best saved.
Epoch 13 — Train MSE 47.436593, Val MSE 42.415576, lr 1.000e-03
 ↳ New best saved.
Ep

  mlp.load_state_dict(torch.load(best_path, map_location=device))


StackingMLP(
  (model): Sequential(
    (0): Linear(in_features=210, out_features=512, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): BatchNorm1d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (3): Dropout(p=0.2, inplace=False)
    (4): Linear(in_features=512, out_features=256, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
    (6): BatchNorm1d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (7): Dropout(p=0.2, inplace=False)
    (8): Linear(in_features=256, out_features=35, bias=True)
  )
)

In [77]:

# --- 测试阶段：stacking 特征 + 一次 mlp.forward() ---
# 对测试集每个 fold 做预测，拼 oof_preds_test (n_test, n_folds, C)，reshape → (n_test, n_folds*C)
oof_test = np.zeros((len(test_dataset), n_folds, C), dtype=np.float32)
for fold_id in range(n_folds):
    print(f"🍀 Fold {fold_id} predicting ...")
    # load each base model & predict on full test_dataset
    ckpt = os.path.join(trained_oof_model_folder, f"fold{fold_id}", "best_model.pt")
    net  = VisionMLP_MultiTask(tile_dim, center_dim, output_dim=C).to(device)
    net.load_state_dict(torch.load(ckpt, map_location=device))
    net.eval()

    loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    preds_list = []
    with torch.no_grad():
        for batch in loader:
            tiles, subtiles = batch['tile'].to(device), batch['subtiles'].to(device)
            center = subtiles[:,4]
            fuse = torch.cat([
                net.encoder_center(center),
                net.encoder_subtile(subtiles),
                net.encoder_tile(tiles)
            ], dim=1)
            preds_list.append(net.decoder(fuse).cpu().numpy())
    oof_test[:, fold_id, :] = np.concatenate(preds_list, axis=0)

print("🍀 Meta model predicting ...")
X_test_stack = oof_test.reshape(len(test_dataset), n_folds*C)
with torch.no_grad():
    X_test_t = torch.from_numpy(X_test_stack).float().to(device)
    final_test_preds = mlp(X_test_t).cpu().numpy()

# 9) 存 submission
with h5py.File("./dataset/elucidata_ai_challenge_data.h5","r") as f:
    ids = pd.DataFrame(np.array(f["spots/Test"]["S_7"])).index
sub = pd.DataFrame(final_test_preds, columns=[f"C{i+1}" for i in range(C)])
sub.insert(0, 'ID', ids)
sub.to_csv(os.path.join(trained_oof_model_folder, "submission_stacked_single_mlp.csv"), index=False)
print("✅ Done single‐MLP stacking submission.")

🍀 Fold 0 predicting ...


  net.load_state_dict(torch.load(ckpt, map_location=device))


🍀 Fold 1 predicting ...
🍀 Fold 2 predicting ...
🍀 Fold 3 predicting ...
🍀 Fold 4 predicting ...
🍀 Fold 5 predicting ...
🍀 Meta model predicting ...
✅ Done single‐MLP stacking submission.


# Training with Lightgbm

In [78]:
import lightgbm as lgb
from sklearn.multioutput import MultiOutputRegressor
import joblib
from sklearn.metrics import mean_squared_error

# 1) 定义 LightGBM 基学习器
params = dict(
    objective='regression',
    metric='rmse',
    learning_rate=0.007522970004049377,
    n_estimators=12000,
    max_depth=11,
    num_leaves=194,
    colsample_bytree=0.7619407413363416,
    subsample=0.8,
    subsample_freq=1,
    min_data_in_leaf=20,
    reg_alpha=0.7480401395491829,
    reg_lambda=0.2589860348178542,
    verbosity=-1,
    random_state=42
)

# 2) MultiOutput 包装
print("🍀 Start training meta model ...")
# 3) 训练，并给内部 estimator 传 early stopping+eval_set
# 2) 用一个 list 保存每个目标的模型
gbm_models = []

print("🍀 Start training one LGBM per target ...")
for k in range(C):
    print(f" ▶ Training target #{k+1}/{C}")
    m = lgb.LGBMRegressor(**params)
    # 传入对应的一维 label
    m.fit(
        X_train, y_train[:,k],
        eval_set=[(X_val, y_val[:,k])],
        callbacks=[
                early_stopping(stopping_rounds=200),
                log_evaluation(period=100)
            ]
    )
    gbm_models.append(m)

# 3) 在验证集上合成多输出预测并算 MSE
y_val_pred = np.column_stack([m.predict(X_val) for m in gbm_models])
mse_val = mean_squared_error(y_val, y_val_pred)
print(f"📊 LightGBM meta‐model Val MSE: {mse_val:.6f}")

# 4) 保存整组模型
gbm_path = os.path.join(trained_oof_model_folder, "stacking_gbm_meta.pkl")
joblib.dump(gbm_models, gbm_path)
print(f"✅ Saved {C} LightGBM models → {gbm_path}")

# … 测试时同样加载这组模型 …
gbm_models = joblib.load(gbm_path)
final_test_preds = np.column_stack([m.predict(X_test_stack) for m in gbm_models])

# 5) 存 submission
with h5py.File("./dataset/elucidata_ai_challenge_data.h5","r") as f:
    test_spot_ids = pd.DataFrame(np.array(f["spots/Test"]["S_7"]))

sub = pd.DataFrame(final_test_preds, columns=[f"C{i+1}" for i in range(C)])
sub.insert(0, 'ID', test_spot_ids.index)
sub_path = os.path.join(trained_oof_model_folder, "submission_stacking_gbm.csv")
sub.to_csv(sub_path, index=False)
print(f"✅ Saved GBM‐stacked submission → {sub_path}")

🍀 Start training meta model ...
🍀 Start training one LGBM per target ...
 ▶ Training target #1/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 5.80957
[200]	valid_0's rmse: 5.19367
[300]	valid_0's rmse: 5.03413
[400]	valid_0's rmse: 4.99558
[500]	valid_0's rmse: 4.98487
[600]	valid_0's rmse: 4.98404
[700]	valid_0's rmse: 4.98417
Early stopping, best iteration is:
[566]	valid_0's rmse: 4.98222
 ▶ Training target #2/35




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 2.58937
[200]	valid_0's rmse: 2.26426
[300]	valid_0's rmse: 2.16704
[400]	valid_0's rmse: 2.13877
[500]	valid_0's rmse: 2.1283
[600]	valid_0's rmse: 2.12534
[700]	valid_0's rmse: 2.12257
[800]	valid_0's rmse: 2.12198
[900]	valid_0's rmse: 2.12155
[1000]	valid_0's rmse: 2.12206
Early stopping, best iteration is:
[892]	valid_0's rmse: 2.12121
 ▶ Training target #3/35




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 4.09634
[200]	valid_0's rmse: 3.52543
[300]	valid_0's rmse: 3.37769
[400]	valid_0's rmse: 3.34427
[500]	valid_0's rmse: 3.33766
[600]	valid_0's rmse: 3.33485
[700]	valid_0's rmse: 3.33407
[800]	valid_0's rmse: 3.33513
Early stopping, best iteration is:
[686]	valid_0's rmse: 3.3334
 ▶ Training target #4/35




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 8.20165
[200]	valid_0's rmse: 7.01577
[300]	valid_0's rmse: 6.70831
[400]	valid_0's rmse: 6.63517
[500]	valid_0's rmse: 6.62244
[600]	valid_0's rmse: 6.62169
[700]	valid_0's rmse: 6.61905
[800]	valid_0's rmse: 6.62346
Early stopping, best iteration is:
[695]	valid_0's rmse: 6.61806
 ▶ Training target #5/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 7.7173
[200]	valid_0's rmse: 6.66838
[300]	valid_0's rmse: 6.35032
[400]	valid_0's rmse: 6.24286
[500]	valid_0's rmse: 6.2111
[600]	valid_0's rmse: 6.19792
[700]	valid_0's rmse: 6.1862
[800]	valid_0's rmse: 6.18073
[900]	valid_0's rmse: 6.18034
[1000]	valid_0's rmse: 6.18071
[1100]	valid_0's rmse: 6.18019
Early stopping, best iteration is:
[914]	valid_0's rmse: 6.1791
 ▶ Training target #6/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 7.88352
[200]	valid_0's rmse: 6.79117
[300]	valid_0's rmse: 6.47416
[400]	valid_0's rmse: 6.37851
[500]	valid_0's rmse: 6.34054
[600]	valid_0's rmse: 6.32187
[700]	valid_0's rmse: 6.31667
[800]	valid_0's rmse: 6.3116
[900]	valid_0's rmse: 6.31055
[1000]	valid_0's rmse: 6.30979
[1100]	valid_0's rmse: 6.30663
[1200]	valid_0's rmse: 6.30798
[1300]	valid_0's rmse: 6.30814
Early stopping, best iteration is:
[1134]	valid_0's rmse: 6.30539
 ▶ Training target #7/35




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 3.67557
[200]	valid_0's rmse: 3.44097
[300]	valid_0's rmse: 3.38332
[400]	valid_0's rmse: 3.37214
[500]	valid_0's rmse: 3.37246
[600]	valid_0's rmse: 3.37353
Early stopping, best iteration is:
[409]	valid_0's rmse: 3.37135
 ▶ Training target #8/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 5.65305
[200]	valid_0's rmse: 5.56243
[300]	valid_0's rmse: 5.54871
[400]	valid_0's rmse: 5.5537
[500]	valid_0's rmse: 5.56196
Early stopping, best iteration is:
[315]	valid_0's rmse: 5.54837
 ▶ Training target #9/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 8.64605
[200]	valid_0's rmse: 7.4803
[300]	valid_0's rmse: 7.15783
[400]	valid_0's rmse: 7.06843
[500]	valid_0's rmse: 7.04112
[600]	valid_0's rmse: 7.02802
[700]	valid_0's rmse: 7.01882
[800]	valid_0's rmse: 7.01577
[900]	valid_0's rmse: 7.01531
[1000]	valid_0's rmse: 7.01511
[1100]	valid_0's rmse: 7.01387
[1200]	valid_0's rmse: 7.01559
Early stopping, best iteration is:
[1075]	valid_0's rmse: 7.01232
 ▶ Training target #10/35




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 7.17653
[200]	valid_0's rmse: 6.20806
[300]	valid_0's rmse: 5.981
[400]	valid_0's rmse: 5.92642
[500]	valid_0's rmse: 5.92103
[600]	valid_0's rmse: 5.91982
[700]	valid_0's rmse: 5.91797
[800]	valid_0's rmse: 5.91826
Early stopping, best iteration is:
[658]	valid_0's rmse: 5.91482
 ▶ Training target #11/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 7.78296
[200]	valid_0's rmse: 7.04437
[300]	valid_0's rmse: 6.81941
[400]	valid_0's rmse: 6.75013
[500]	valid_0's rmse: 6.72555
[600]	valid_0's rmse: 6.71404
[700]	valid_0's rmse: 6.70726
[800]	valid_0's rmse: 6.70565
[900]	valid_0's rmse: 6.70861
[1000]	valid_0's rmse: 6.70967
Early stopping, best iteration is:
[807]	valid_0's rmse: 6.70446
 ▶ Training target #12/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 4.2633
[200]	valid_0's rmse: 3.62364
[300]	valid_0's rmse: 3.44562
[400]	valid_0's rmse: 3.39954
[500]	valid_0's rmse: 3.38621
[600]	valid_0's rmse: 3.37988
[700]	valid_0's rmse: 3.37881
[800]	valid_0's rmse: 3.37708
[900]	valid_0's rmse: 3.37586
[1000]	valid_0's rmse: 3.37608
[1100]	valid_0's rmse: 3.37812
Early stopping, best iteration is:
[986]	valid_0's rmse: 3.3753
 ▶ Training target #13/35




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 6.12512
[200]	valid_0's rmse: 6.06792
[300]	valid_0's rmse: 6.06935
[400]	valid_0's rmse: 6.07591
Early stopping, best iteration is:
[254]	valid_0's rmse: 6.0633
 ▶ Training target #14/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 5.46693
[200]	valid_0's rmse: 5.21039
[300]	valid_0's rmse: 5.16203
[400]	valid_0's rmse: 5.15915
[500]	valid_0's rmse: 5.16455
Early stopping, best iteration is:
[336]	valid_0's rmse: 5.15706
 ▶ Training target #15/35




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 6.33803
[200]	valid_0's rmse: 5.43604
[300]	valid_0's rmse: 5.1774
[400]	valid_0's rmse: 5.10034
[500]	valid_0's rmse: 5.07629
[600]	valid_0's rmse: 5.06675
[700]	valid_0's rmse: 5.06123
[800]	valid_0's rmse: 5.06196
[900]	valid_0's rmse: 5.06235
Early stopping, best iteration is:
[701]	valid_0's rmse: 5.06108
 ▶ Training target #16/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 6.24249
[200]	valid_0's rmse: 5.66736
[300]	valid_0's rmse: 5.51309
[400]	valid_0's rmse: 5.47241
[500]	valid_0's rmse: 5.46156
[600]	valid_0's rmse: 5.46198
Early stopping, best iteration is:
[489]	valid_0's rmse: 5.46022
 ▶ Training target #17/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 7.23501
[200]	valid_0's rmse: 6.33322
[300]	valid_0's rmse: 6.09611
[400]	valid_0's rmse: 6.02032
[500]	valid_0's rmse: 5.9925
[600]	valid_0's rmse: 5.98214
[700]	valid_0's rmse: 5.97779
[800]	valid_0's rmse: 5.97324
[900]	valid_0's rmse: 5.97057
[1000]	valid_0's rmse: 5.97129
[1100]	valid_0's rmse: 5.97274
Early stopping, best iteration is:
[926]	valid_0's rmse: 5.9698
 ▶ Training target #18/35




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 3.92344
[200]	valid_0's rmse: 3.33891
[300]	valid_0's rmse: 3.1936
[400]	valid_0's rmse: 3.16039
[500]	valid_0's rmse: 3.15388
[600]	valid_0's rmse: 3.15111
[700]	valid_0's rmse: 3.15202
Early stopping, best iteration is:
[596]	valid_0's rmse: 3.15094
 ▶ Training target #19/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 7.34135
[200]	valid_0's rmse: 7.01777
[300]	valid_0's rmse: 6.94866
[400]	valid_0's rmse: 6.93829
[500]	valid_0's rmse: 6.94053
[600]	valid_0's rmse: 6.94413
Early stopping, best iteration is:
[403]	valid_0's rmse: 6.93669
 ▶ Training target #20/35




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 5.03803
[200]	valid_0's rmse: 4.83778
[300]	valid_0's rmse: 4.78316
[400]	valid_0's rmse: 4.76957
[500]	valid_0's rmse: 4.76722
[600]	valid_0's rmse: 4.76841
Early stopping, best iteration is:
[445]	valid_0's rmse: 4.76608
 ▶ Training target #21/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 6.6216
[200]	valid_0's rmse: 6.10619
[300]	valid_0's rmse: 5.97404
[400]	valid_0's rmse: 5.93998
[500]	valid_0's rmse: 5.93119
[600]	valid_0's rmse: 5.92652
[700]	valid_0's rmse: 5.92588
[800]	valid_0's rmse: 5.9245
[900]	valid_0's rmse: 5.92876
Early stopping, best iteration is:
[798]	valid_0's rmse: 5.92376
 ▶ Training target #22/35




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 7.75956
[200]	valid_0's rmse: 7.35264
[300]	valid_0's rmse: 7.25175
[400]	valid_0's rmse: 7.2228
[500]	valid_0's rmse: 7.21706
[600]	valid_0's rmse: 7.21345
[700]	valid_0's rmse: 7.2181
[800]	valid_0's rmse: 7.21388
Early stopping, best iteration is:
[630]	valid_0's rmse: 7.21257
 ▶ Training target #23/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 2.82375
[200]	valid_0's rmse: 2.60096
[300]	valid_0's rmse: 2.54835
[400]	valid_0's rmse: 2.53885
[500]	valid_0's rmse: 2.53962
[600]	valid_0's rmse: 2.54082
Early stopping, best iteration is:
[405]	valid_0's rmse: 2.53852
 ▶ Training target #24/35




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 6.37896
[200]	valid_0's rmse: 5.74931
[300]	valid_0's rmse: 5.59342
[400]	valid_0's rmse: 5.56089
[500]	valid_0's rmse: 5.55583
[600]	valid_0's rmse: 5.5565
[700]	valid_0's rmse: 5.5596
Early stopping, best iteration is:
[559]	valid_0's rmse: 5.55501
 ▶ Training target #25/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 7.58222
[200]	valid_0's rmse: 6.87418
[300]	valid_0's rmse: 6.68426
[400]	valid_0's rmse: 6.6332
[500]	valid_0's rmse: 6.61512
[600]	valid_0's rmse: 6.6118
[700]	valid_0's rmse: 6.60934
[800]	valid_0's rmse: 6.60876
[900]	valid_0's rmse: 6.60814
[1000]	valid_0's rmse: 6.61081
[1100]	valid_0's rmse: 6.61534
Early stopping, best iteration is:
[918]	valid_0's rmse: 6.60638
 ▶ Training target #26/35




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 6.91
[200]	valid_0's rmse: 6.57775
[300]	valid_0's rmse: 6.47922
[400]	valid_0's rmse: 6.45402
[500]	valid_0's rmse: 6.45096
[600]	valid_0's rmse: 6.44968
[700]	valid_0's rmse: 6.44988
[800]	valid_0's rmse: 6.45701
Early stopping, best iteration is:
[641]	valid_0's rmse: 6.44766
 ▶ Training target #27/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 8.12002
[200]	valid_0's rmse: 6.82343
[300]	valid_0's rmse: 6.44353
[400]	valid_0's rmse: 6.32508
[500]	valid_0's rmse: 6.27824
[600]	valid_0's rmse: 6.26088
[700]	valid_0's rmse: 6.25802
[800]	valid_0's rmse: 6.25569
[900]	valid_0's rmse: 6.25576
[1000]	valid_0's rmse: 6.2574
[1100]	valid_0's rmse: 6.25979
Early stopping, best iteration is:
[909]	valid_0's rmse: 6.25448
 ▶ Training target #28/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 5.95725
[200]	valid_0's rmse: 5.64077
[300]	valid_0's rmse: 5.58146
[400]	valid_0's rmse: 5.56663
[500]	valid_0's rmse: 5.56868
[600]	valid_0's rmse: 5.56623
[700]	valid_0's rmse: 5.57063
[800]	valid_0's rmse: 5.57202
Early stopping, best iteration is:
[600]	valid_0's rmse: 5.56623
 ▶ Training target #29/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 6.54976
[200]	valid_0's rmse: 6.05634
[300]	valid_0's rmse: 5.94747
[400]	valid_0's rmse: 5.92804
[500]	valid_0's rmse: 5.92549
[600]	valid_0's rmse: 5.93178
Early stopping, best iteration is:
[476]	valid_0's rmse: 5.92275
 ▶ Training target #30/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 7.29486
[200]	valid_0's rmse: 6.18596
[300]	valid_0's rmse: 5.89076
[400]	valid_0's rmse: 5.81109
[500]	valid_0's rmse: 5.78786
[600]	valid_0's rmse: 5.77903
[700]	valid_0's rmse: 5.78003
[800]	valid_0's rmse: 5.78269
Early stopping, best iteration is:
[628]	valid_0's rmse: 5.77654
 ▶ Training target #31/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 2.88191
[200]	valid_0's rmse: 2.48797
[300]	valid_0's rmse: 2.38446
[400]	valid_0's rmse: 2.35892
[500]	valid_0's rmse: 2.35251
[600]	valid_0's rmse: 2.35204
[700]	valid_0's rmse: 2.3534
Early stopping, best iteration is:
[544]	valid_0's rmse: 2.35129
 ▶ Training target #32/35




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 2.35362
[200]	valid_0's rmse: 2.2714
[300]	valid_0's rmse: 2.25079
[400]	valid_0's rmse: 2.24801
[500]	valid_0's rmse: 2.24734
[600]	valid_0's rmse: 2.24763
[700]	valid_0's rmse: 2.24902
Early stopping, best iteration is:
[545]	valid_0's rmse: 2.24685
 ▶ Training target #33/35




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 7.31082
[200]	valid_0's rmse: 7.18214
[300]	valid_0's rmse: 7.16169
[400]	valid_0's rmse: 7.16806
[500]	valid_0's rmse: 7.17284
Early stopping, best iteration is:
[308]	valid_0's rmse: 7.16078
 ▶ Training target #34/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 5.72499
[200]	valid_0's rmse: 5.46612
[300]	valid_0's rmse: 5.41951
[400]	valid_0's rmse: 5.42492
[500]	valid_0's rmse: 5.44309
Early stopping, best iteration is:
[317]	valid_0's rmse: 5.4175
 ▶ Training target #35/35
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 2.90248
[200]	valid_0's rmse: 2.79129
[300]	valid_0's rmse: 2.7663
[400]	valid_0's rmse: 2.76172
[500]	valid_0's rmse: 2.76276
Early stopping, best iteration is:
[381]	valid_0's rmse: 2.76108




📊 LightGBM meta‐model Val MSE: 29.343309
✅ Saved 35 LightGBM models → output_folder/rank-spot/realign/no_pretrain/3_encoder/filtered_directly_rank/k-fold/realign_all/Macenko_masked/stacking_gbm_meta.pkl




✅ Saved GBM‐stacked submission → output_folder/rank-spot/realign/no_pretrain/3_encoder/filtered_directly_rank/k-fold/realign_all/Macenko_masked/submission_stacking_gbm.csv


In [None]:
# Train woth 2 k-fold models

In [79]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, TensorDataset
from sklearn.model_selection import train_test_split
from torch.optim.lr_scheduler import ReduceLROnPlateau
from python_scripts.import_data import importDataset  # 假设这个能给你 full_dataset
from python_scripts.aug import subset_grouped_data   # 用来切出 grouped_data

# Settings
# 你的两套管道模型所在文件夹
pipe_folders = [
    'output_folder/rank-spot/realign/no_pretrain/3_encoder/filtered_directly_rank/k-fold/realign_all/Macenko_masked' , 
    'output_folder/rank-spot/realign/no_pretrain/3_encoder/filtered_directly_rank/k-fold/realign_all/stain_nor_with_4_7/Macenko_masked']
n_pipes = len(pipe_folders)
n_folds    = 6
device     = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 64
C          = 35  # 类别数 / 每个预测向量长度
META_EPOCHS = 250
tile_dim = 128
center_dim = 128
neighbor_dim = 128
# 1) Load full_dataset & y_true
n_samples = len(full_dataset)
y_true = np.vstack([full_dataset[i]['label'].cpu().numpy() for i in range(n_samples)])
oof_preds = np.zeros((n_pipes, n_folds, n_samples, C), dtype=np.float32)
full_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=False, pin_memory=False)


In [80]:
for p, folder in enumerate(pipe_folders):
    for fold_id in range(n_folds):
        print(f"Pipe {p} Fold {fold_id} predicting...")
        # a) load model
        ckpt = os.path.join(folder, f"fold{fold_id}", "best_model.pt")
        net  = VisionMLP_MultiTask(tile_dim, center_dim, output_dim=C).to(device)
        net.load_state_dict(torch.load(ckpt, map_location=device))
        net.eval()

        # b) 全样本 forward
        preds_full = []
        with torch.no_grad():
            for batch in full_loader:
                tiles, subtiles = batch['tile'].to(device), batch['subtiles'].to(device)
                center = subtiles[:, 4]
                fuse = torch.cat([
                    net.encoder_center(center),
                    net.encoder_subtile(subtiles),
                    net.encoder_tile(tiles)
                ], dim=1)
                preds_full.append(net.decoder(fuse).cpu().numpy())

        preds_full = np.concatenate(preds_full, axis=0)  # (n_samples, C)
        oof_preds[p, fold_id] = preds_full

# 4) 构造 stacking 特征矩阵 (n_samples, n_pipes*n_folds*C)
#    转轴之后 reshape
stack = oof_preds.transpose(2,0,1,3)   # (n_samples, pipe, fold, C)
X_stack = stack.reshape(n_samples, n_pipes * n_folds * C)

# 5) 划分 train/val
X_tr, X_va, y_tr, y_va = train_test_split(
    X_stack, y_true, test_size=0.2, random_state=42, shuffle=True
)

# 6) DataLoader
ds_tr = TensorDataset(torch.from_numpy(X_tr).float(), torch.from_numpy(y_tr).float())
ds_va = TensorDataset(torch.from_numpy(X_va).float(), torch.from_numpy(y_va).float())
loader_tr = DataLoader(ds_tr, batch_size=BATCH_SIZE, shuffle=True, pin_memory=False)
loader_va = DataLoader(ds_va, batch_size=BATCH_SIZE, shuffle=False, pin_memory=False)
# 7) 定义 Stacking MLP
class StackingMLP(nn.Module):
    def __init__(self, in_dim, hidden_dims=[1024,512,256], out_dim=35, dropout=0.2):
        super().__init__()
        layers = []
        prev = in_dim
        for h in hidden_dims:
            layers += [nn.Linear(prev,h),
                       nn.LeakyReLU(0.01),
                       nn.BatchNorm1d(h),
                       nn.Dropout(dropout)]
            prev = h
        layers.append(nn.Linear(prev, out_dim))
        self.model = nn.Sequential(*layers)
    def forward(self, x):
        return self.model(x)

in_dim = n_pipes * n_folds * C
mlp = StackingMLP(in_dim=in_dim, hidden_dims=[1024,512,256], out_dim=C).to(device)

optimizer = torch.optim.Adam(mlp.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10, verbose=True)
criterion = nn.MSELoss()

best_val, es_cnt, es_patience = float('inf'), 0, 20
best_path = os.path.join(pipe_folders[1], "stacking_2meta_best.pt")

# 8) 训练循环
for epoch in range(1, META_EPOCHS+1):
    mlp.train()
    tr_loss = 0.0
    for xb, yb in loader_tr:
        xb, yb = xb.to(device), yb.to(device)
        out = mlp(xb)
        loss = criterion(out, yb)
        optimizer.zero_grad(); loss.backward(); optimizer.step()
        tr_loss += loss.item() * xb.size(0)
    tr_loss /= len(ds_tr)

    mlp.eval()
    va_loss = 0.0
    with torch.no_grad():
        for xb, yb in loader_va:
            xb, yb = xb.to(device), yb.to(device)
            va_loss += criterion(mlp(xb), yb).item() * xb.size(0)
    va_loss /= len(ds_va)

    print(f"Epoch {epoch} → Train MSE {tr_loss:.4f}, Val MSE {va_loss:.4f}, LR {optimizer.param_groups[0]['lr']:.2e}")
    scheduler.step(va_loss)

    if va_loss < best_val:
        best_val, es_cnt = va_loss, 0
        torch.save(mlp.state_dict(), best_path)
        print(" ↳ New best saved.")
    else:
        es_cnt += 1
        if es_cnt >= es_patience:
            print(" 🛑 Early stopping.")
            break

print("✅ Cross-pipe stacking done. Best at", best_path)

Pipe 0 Fold 0 predicting...


  net.load_state_dict(torch.load(ckpt, map_location=device))


Pipe 0 Fold 1 predicting...
Pipe 0 Fold 2 predicting...
Pipe 0 Fold 3 predicting...
Pipe 0 Fold 4 predicting...
Pipe 0 Fold 5 predicting...
Pipe 1 Fold 0 predicting...
Pipe 1 Fold 1 predicting...
Pipe 1 Fold 2 predicting...
Pipe 1 Fold 3 predicting...
Pipe 1 Fold 4 predicting...
Pipe 1 Fold 5 predicting...




Epoch 1 → Train MSE 391.4221, Val MSE 320.4870, LR 1.00e-03
 ↳ New best saved.
Epoch 2 → Train MSE 213.2606, Val MSE 94.9896, LR 1.00e-03
 ↳ New best saved.
Epoch 3 → Train MSE 62.3294, Val MSE 46.4068, LR 1.00e-03
 ↳ New best saved.
Epoch 4 → Train MSE 47.7791, Val MSE 41.9149, LR 1.00e-03
 ↳ New best saved.
Epoch 5 → Train MSE 45.7437, Val MSE 42.7662, LR 1.00e-03
Epoch 6 → Train MSE 45.9071, Val MSE 40.3713, LR 1.00e-03
 ↳ New best saved.
Epoch 7 → Train MSE 44.9099, Val MSE 41.4321, LR 1.00e-03
Epoch 8 → Train MSE 44.4296, Val MSE 39.5969, LR 1.00e-03
 ↳ New best saved.
Epoch 9 → Train MSE 44.3049, Val MSE 42.6474, LR 1.00e-03
Epoch 10 → Train MSE 44.4582, Val MSE 41.5099, LR 1.00e-03
Epoch 11 → Train MSE 44.2516, Val MSE 38.7766, LR 1.00e-03
 ↳ New best saved.
Epoch 12 → Train MSE 43.8401, Val MSE 38.6763, LR 1.00e-03
 ↳ New best saved.
Epoch 13 → Train MSE 43.1779, Val MSE 38.1927, LR 1.00e-03
 ↳ New best saved.
Epoch 14 → Train MSE 43.6875, Val MSE 40.0167, LR 1.00e-03
Epoch 15 

In [None]:
import os
import numpy as np
import torch
from torch.utils.data import DataLoader
import h5py
import pandas as pd

# --- 1) 初始化 ---
n_pipes = len(pipe_folders)    # 2
n_folds = 6
C       = 35
n_test  = len(test_dataset)

# 1) 确定 device
device = torch.device("mps" if torch.backends.mps.is_available() else
                      "cuda" if torch.cuda.is_available() else "cpu")

# 2) 复现模型结构
in_dim     = n_pipes * n_folds * C        # 例如 2*6*35 = 420
hidden_dims= [1024, 512, 256]             # 你训练时用的
out_dim    = C

mlp = StackingMLP(in_dim=in_dim,
                  hidden_dims=hidden_dims,
                  out_dim=out_dim).to(device)

# 3) 加载训练好的权重
best_path = os.path.join(pipe_folders[1], "stacking_2meta_best.pt")
state = torch.load(best_path, map_location=device)
mlp.load_state_dict(state)

# 4) 切换到 eval 模式
mlp.eval()

print(f"✅ Loaded StackingMLP from {best_path} onto {device}")

# oof_test[p, f, i, c] := 第 p 管道，第 f 折 模型 对 test 样本 i 的 第 c 类预测
oof_test = np.zeros((n_pipes, n_folds, n_test, C), dtype=np.float32)

full_test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
)

# --- 2) 对每条管道的每折模型做全样本预测 ---
for p, folder in enumerate(pipe_folders):
    for fold_id in range(n_folds):
        print(f"Pipe {p} Fold {fold_id} inference…")
        # a) 加载对应模型
        ckpt = os.path.join(folder, f"fold{fold_id}", "best_model.pt")
        net  = VisionMLP_MultiTask(tile_dim, center_dim, output_dim=C).to(device)
        net.load_state_dict(torch.load(ckpt, map_location=device))
        net.eval()

        # b) 全 test 样本上跑一次 forward
        preds_list = []
        with torch.no_grad():
            for batch in full_test_loader:
                tiles, subtiles = batch['tile'].to(device), batch['subtiles'].to(device)
                center = subtiles[:, 4]
                fuse = torch.cat([
                    net.encoder_center(center),
                    net.encoder_subtile(subtiles),
                    net.encoder_tile(tiles)
                ], dim=1)
                out = net.decoder(fuse)
                preds_list.append(out.cpu().numpy())

        preds_full = np.concatenate(preds_list, axis=0)  # (n_test, C)
        oof_test[p, fold_id] = preds_full

# --- 3) 构造 stacking 特征矩阵 ---
# 转轴到 (n_test, pipe, fold, C)，再 reshape → (n_test, n_pipes*n_folds*C)
X_test_stack = oof_test.transpose(2, 0, 1, 3).reshape(n_test, n_pipes * n_folds * C)

# --- 4) 用训练好的 mlp 做一次性预测 ---
mlp.eval()
with torch.no_grad():
    X_t = torch.from_numpy(X_test_stack).float().to(device)
    final_preds = mlp(X_t).cpu().numpy()  # (n_test, C)

# --- 5) 存 submission ---
with h5py.File("./dataset/elucidata_ai_challenge_data.h5", "r") as f:
    ids = pd.DataFrame(np.array(f["spots/Test"]["S_7"])).index

sub = pd.DataFrame(final_preds, columns=[f"C{i+1}" for i in range(C)])
sub.insert(0, "ID", ids)
out_path = os.path.join(pipe_folders[1], "submission_stack_multi_pipe.csv")
sub.to_csv(out_path, index=False)
print("✅ Saved submission →", out_path)


  state = torch.load(best_path, map_location=device)
  net.load_state_dict(torch.load(ckpt, map_location=device))


✅ Loaded StackingMLP from output_folder/rank-spot/realign/no_pretrain/3_encoder/filtered_directly_rank/k-fold/realign_all/stain_nor_with_4_7/Macenko_masked/stacking_2meta_best.pt onto mps
Pipe 0 Fold 0 inference…
Pipe 0 Fold 1 inference…
Pipe 0 Fold 2 inference…
Pipe 0 Fold 3 inference…
Pipe 0 Fold 4 inference…
Pipe 0 Fold 5 inference…
Pipe 1 Fold 0 inference…
Pipe 1 Fold 1 inference…
Pipe 1 Fold 2 inference…
Pipe 1 Fold 3 inference…
Pipe 1 Fold 4 inference…
Pipe 1 Fold 5 inference…
✅ Saved submission → output_folder/rank-spot/realign/no_pretrain/3_encoder/filtered_directly_rank/k-fold/realign_all/Macenko_masked/submission_stack_multi_pipe.csv


In [None]:
import os
import numpy as np
import joblib
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.multioutput import MultiOutputRegressor
from sklearn.model_selection import train_test_split
import lightgbm as lgb
from scipy.stats import rankdata
from python_scripts.import_data import importDataset
from python_scripts.operate_model import predict
from lightgbm import early_stopping, log_evaluation
import h5py
import pandas as pd
# ---------------- Settings ----------------
save_root  = save_folder  # your save_folder path
n_folds    = len([d for d in os.listdir(save_root) if d.startswith('fold')])
n_samples  = len(full_dataset)
C          = 35  # num cell types
start_fold = 0
BATCH_SIZE = 64
# If optimizing Spearman, convert labels to ranks

# --- 1) Prepare OOF meta-features ---
# Initialize matrix for OOF predictions
n_samples = len(full_dataset)
oof_preds = np.zeros((n_samples, C), dtype=np.float32)
# True labels (raw or rank)
# importDataset returns a dict-like sample, so label is under key 'label'
y_true = np.vstack([ full_dataset[i]['label'].cpu().numpy() for i in range(n_samples) ])
y_meta = y_true

# Build CV splitter (must match first stage splits)
logo = LeaveOneGroupOut()
image_latents = np.zeros((n_samples, 128), dtype=np.float32)

# Loop over folds, load best model, predict on validation indices
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
for fold_id, (tr_idx, va_idx) in enumerate(
        logo.split(X=np.zeros(n_samples), y=None, groups=slide_idx)):
    # Load model
    # if fold_id > start_fold:
    #     print(f"⏭️ Skipping fold {fold_id}")
    #     continue
    ckpt_path = os.path.join(save_root, f"fold{fold_id}", "best_model.pt")
    print(f"Loading model from {ckpt_path}...")
    net = PretrainedEncoderRegressor(
        ae_checkpoint=checkpoint_path,
        ae_type="all",
        center_dim=64, neighbor_dim=64, hidden_dim=128,
        tile_size=26, output_dim=35,
        freeze_encoder = True
    )

    # 2) monkey‐patch 一个新的 head
    net.decoder  = nn.Sequential(
        nn.Linear(64+64, 256),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(256, 128),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(128, 64),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(64, 35)
        
    )
    net = net.to(device)    # Alternatively, if your model requires specific args, replace with:
    # net = VisionMLP_MultiTask(tile_dim=64, subtile_dim=64, output_dim=35).to(device)
    net.load_state_dict(torch.load(ckpt_path, map_location=device))
    net.to(device).eval()
    
    # Predict on validation set
    val_ds = Subset(full_dataset, va_idx)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

    preds = []
    latents = []

    with torch.no_grad():
        for batch in val_loader:
            tiles    = batch['tile'].to(device)
            subtiles = batch['subtiles'].to(device)

            center = subtiles[:, 4].contiguous()
            f_c = net.enc_center(center)
            f_n = net.enc_neigh(subtiles)
            fuse = torch.cat([f_c, f_n], dim=1)

            output = net.decoder(fuse)

            preds.append(output.cpu())
            latents.append(fuse.cpu())  # ⬅️ 收集 latent vector

    preds = torch.cat(preds, dim=0).numpy()    # (n_val, 35)
    latents = torch.cat(latents, dim=0).numpy()  # (n_val, 128)

    oof_preds[va_idx] = preds
    image_latents[va_idx] = latents

    print(f"Fold {fold_id}: OOF preds shape {preds.shape}, Latent shape: {latents.shape}")


    
with h5py.File("dataset/realign/filtered_dataset.h5", "r") as f:
    train_spots = f["spots/Train"]
    
    train_spot_tables = {}
    
    for slide_name in train_spots.keys():
        spot_array = np.array(train_spots[slide_name])
        df = pd.DataFrame(spot_array)
        df["slide_name"] = slide_name
        train_spot_tables[slide_name] = df
        print(f"✅ 已讀取 slide: {slide_name}")

# -----------------------------------------------------
# Step 2: 合併所有 slide 的資料
# -----------------------------------------------------
all_train_spots_df = pd.concat(train_spot_tables.values(), ignore_index=True)
# 提取 x, y
xy = all_train_spots_df[["x", "y"]].to_numpy()  # shape: (8348, 2)

# 合併成新的 meta feature
meta_features = np.concatenate([oof_preds, xy, image_latents], axis=1)
# --- 2) Train LightGBM meta-model ---
# Choose objective: regression on rank (for Spearman) or raw (for MSE)
# 將 meta features 拆成訓練集與 early stopping 用的驗證集
X_train, X_val, y_train, y_val = train_test_split(meta_features, y_meta, test_size=0.2, random_state=42)
print("Meta feature shape:", X_train.shape)
print("Feature std (min/max):", np.min(np.std(X_train, axis=0)), np.max(np.std(X_train, axis=0)))


# # Base model
# lgb_base = lgb.LGBMRegressor(
#     objective='l2',
#     metric='rmse',
#     n_estimators=12000,
#     max_depth=15,
#     learning_rate=0.008,
#     num_leaves=32,
#     colsample_bytree=0.25
# )
import optuna
from sklearn.multioutput import MultiOutputRegressor
from sklearn.metrics import mean_squared_error

# Define Optuna objective function
def objective(trial):
    params = {
        'objective': 'regression',
        'metric': 'rmse',
        'verbosity': -1,
        'boosting_type': 'gbdt',
        'device': 'gpu',                # ✅ GPU 支援
        'gpu_platform_id': 0,
        'gpu_device_id': 0,
        'learning_rate': trial.suggest_float("learning_rate", 0.005, 0.1),
        'max_depth': trial.suggest_int("max_depth", 4, 15),
        'num_leaves': trial.suggest_int("num_leaves", 32, 256),
        'min_data_in_leaf': trial.suggest_int("min_data_in_leaf", 20, 100),
        'colsample_bytree': trial.suggest_float("colsample_bytree", 0.6, 1.0),
        'reg_alpha': trial.suggest_float("reg_alpha", 0, 1),
        'reg_lambda': trial.suggest_float("reg_lambda", 0, 1),
        'n_estimators': 12000
    }

    model = lgb.LGBMRegressor(**params)
    multi_model = MultiOutputRegressor(model)
    multi_model.fit(X_train, y_train)

    y_pred = multi_model.predict(X_val)
    rmse = np.mean([
        np.sqrt(mean_squared_error(y_val[:, i], y_pred[:, i]))
        for i in range(y_val.shape[1])
    ])


    return rmse

# Run optimization
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=30)

# Use best params to train final models
best_params = study.best_trial.params
best_params['objective'] = 'l2'
best_params['metric'] = 'rmse'
best_params['verbosity'] = -1

# Train final models with best parameters
meta_model = MultiOutputRegressor(lgb.LGBMRegressor(**best_params))
meta_model.estimators_ = []

print("Training LightGBM on OOF meta-features with best Optuna params...")
for i in range(y_train.shape[1]):
    print(f"Training target {i}...")
    model = lgb.LGBMRegressor(**best_params)

    model.fit(
        X_train,
        y_train[:, i],
        eval_set=[(X_val, y_val[:, i])],
        callbacks=[
            early_stopping(stopping_rounds=200),
            log_evaluation(period=100)
        ]
    )

    meta_model.estimators_.append(model)

# Save model
joblib.dump(meta_model, os.path.join(save_root, 'meta_model.pkl'))
# 保存模型


# --- 3) Prepare test meta-features ---
n_test = len(test_dataset)
test_preds = []
test_latents = []

for fold_id in range(n_folds):
    ckpt_path = os.path.join(save_root, f"fold{fold_id}", "best_model.pt")
    net = PretrainedEncoderRegressor(
        ae_checkpoint=checkpoint_path,
        ae_type="all",
        center_dim=64, neighbor_dim=64, hidden_dim=128,
        tile_size=26, output_dim=35,
        freeze_encoder = True
    )

    net.decoder = nn.Sequential(
        nn.Linear(64+64, 256),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(256, 128),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(128, 64),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(64, 35)
    )

    net = net.to(device)
    net.load_state_dict(torch.load(ckpt_path, map_location=device))
    net.eval()

    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    preds = []
    latents = []

    with torch.no_grad():
        for batch in test_loader:
            tiles = batch['tile'].to(device)
            subtiles = batch['subtiles'].to(device)

            center = subtiles[:, 4].contiguous()
            f_c = net.enc_center(center)
            f_n = net.enc_neigh(subtiles)
            fuse = torch.cat([f_c, f_n], dim=1)

            out = net.decoder(fuse)

            preds.append(out.cpu())
            latents.append(fuse.cpu())  # image embedding (128D)

    test_preds.append(torch.cat(preds, dim=0).numpy())      # shape: (n_test, 35)
    test_latents.append(torch.cat(latents, dim=0).numpy())  # shape: (n_test, 128)

# === Stack + Average ===
test_preds = np.mean(np.stack(test_preds, axis=0), axis=0)      # (n_test, 35)
test_latents = np.mean(np.stack(test_latents, axis=0), axis=0)  # (n_test, 128)

with h5py.File("dataset/elucidata_ai_challenge_data.h5", "r") as f:
    test_spots = f["spots/Test"]
    spot_array = np.array(test_spots['S_7'])
    df = pd.DataFrame(spot_array)

xy = df[["x", "y"]].to_numpy()  # shape: (n_test, 2)

# 合併為最終 test meta features
test_meta = np.concatenate([test_preds, xy, test_latents], axis=1)  # shape: (n_test, 35+2+128)



final_preds = meta_model.predict(test_meta)

# --- Save submission ---
import h5py
import pandas as pd
with h5py.File("./dataset/elucidata_ai_challenge_data.h5","r") as f:
    test_spot_ids = pd.DataFrame(np.array(f["spots/Test"]["S_7"]))
sub = pd.DataFrame(final_preds, columns=[f"C{i+1}" for i in range(C)])
sub.insert(0, 'ID', test_spot_ids.index)
sub.to_csv(os.path.join(save_root, 'submission_stacked.csv'), index=False)
print("✅ Saved stacked submission.")


In [None]:
import os
import numpy as np
import joblib
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.multioutput import MultiOutputRegressor
import lightgbm as lgb
from scipy.stats import rankdata
from python_scripts.import_data import importDataset
from python_scripts.operate_model import predict

# --- 配置: 只用哪些 fold 的结果来训练/预测 meta-model ---
meta_folds = [0]  # 例如只用 fold0, fold2, fold4

# 1) 准备 full_dataset, slide_idx, test_dataset 等
full_dataset = importDataset(
    grouped_data, model,
    image_keys=['tile','subtiles'],
    transform=lambda x: x
)
n_samples = len(full_dataset)
C = 35  # 类别数

# 2) 预留 oof_preds 和 fold_ids
oof_preds    = np.zeros((n_samples, C), dtype=np.float32)
oof_fold_ids = np.full(n_samples, -1, dtype=int)

# 真标签
y_true = np.vstack([ full_dataset[i]['label'].cpu().numpy() for i in range(n_samples) ])
y_meta = y_true.copy()  # 不做 rank 时直接用 raw

# 3) 生成 OOF 预测并记录 fold id
logo = LeaveOneGroupOut()
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

for fold_id, (tr_idx, va_idx) in enumerate(
        logo.split(X=np.zeros(n_samples), y=None, groups=slide_idx)):

    # 如果当前 fold 不在我们想要的 meta_folds 列表里，就跳过
    if fold_id not in meta_folds:
        print(f"⏭️ Skipping OOF for fold {fold_id}")
        continue

    print(f"\n>>> Generating OOF for fold {fold_id}")
    ckpt_path = os.path.join(save_root, f"fold{fold_id}", "best_model.pt")
    net = PretrainedEncoderRegressor(
        ae_checkpoint=checkpoint_path,
        ae_type="all",
        center_dim=64, neighbor_dim=64, hidden_dim=128,
        tile_size=26, output_dim=35,
        freeze_encoder = True
    )

    # 2) monkey‐patch 一个新的 head
    net.decoder  = nn.Sequential(
        nn.Linear(64+64, 128),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(128, 64),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(64, 35)
        
    )
    net = net.to(device)
    net.load_state_dict(torch.load(ckpt_path, map_location=device))
    net.eval()

    val_loader = DataLoader(Subset(full_dataset, va_idx), batch_size=BATCH_SIZE, shuffle=False)
    preds = predict(net, val_loader, device)  # (n_val, C)

    oof_preds[va_idx]    = preds
    oof_fold_ids[va_idx] = fold_id

    print(f"  → Fold {fold_id} OOF preds shape: {preds.shape}")
# 4) 只选取 meta_folds 的行来训练 meta-model
mask = np.isin(oof_fold_ids, meta_folds)
X_meta = oof_preds[mask]
y_meta_sub = y_meta[mask]

print(f"\nTraining meta-model on folds {meta_folds}:")
print(f"  使用样本数：{X_meta.shape[0]} / {n_samples}")

lgb_base = lgb.LGBMRegressor(
    objective='regression',
    learning_rate=0.001,
    n_estimators=1000,
    num_leaves=31,
    subsample=0.7,
    colsample_bytree=0.7,
    n_jobs=-1,
    force_col_wise=True
)
meta_model = MultiOutputRegressor(lgb_base)
meta_model.fit(X_meta, y_meta_sub)
joblib.dump(meta_model, os.path.join(save_root, 'meta_model.pkl'))

# 5) 准备 test_meta，只平均 meta_folds 中的预测
n_folds = len([d for d in os.listdir(save_root) if d.startswith('fold')])
n_test  = len(test_dataset)
test_meta = np.zeros((n_test, C), dtype=np.float32)

for fold_id in range(n_folds):
    if fold_id not in meta_folds:
        continue
    ckpt_path = os.path.join(save_root, f"fold{fold_id}", "best_model.pt")
    net = PretrainedEncoderRegressor(
        ae_checkpoint=checkpoint_path,
        ae_type="all",
        center_dim=64, neighbor_dim=64, hidden_dim=128,
        tile_size=26, output_dim=35,
        freeze_encoder = True
    )

    # 2) monkey‐patch 一个新的 head
    net.decoder  = nn.Sequential(
        nn.Linear(64+64, 128),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(128, 64),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(64, 35)
        
    )
    net = net.to(device)
    net.load_state_dict(torch.load(ckpt_path, map_location=device))
    net.eval()

    loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    preds = predict(net, loader, device)
    test_meta += preds

# 平均时除以参与的 folds 数目
test_meta /= len(meta_folds)

# 6) 用 meta-model 做最终预测
final_preds = meta_model.predict(test_meta)

# --- Save submission ---
import h5py
import pandas as pd

with h5py.File("./dataset/elucidata_ai_challenge_data.h5","r") as f:
    test_spot_ids = pd.DataFrame(np.array(f["spots/Test"]["S_7"]))

sub = pd.DataFrame(final_preds, columns=[f"C{i+1}" for i in range(C)])
sub.insert(0, 'ID', test_spot_ids.index)
sub.to_csv(os.path.join(save_root, 'submission_stacked.csv'), index=False)
print("✅ Saved stacked submission.")


# Predict

In [21]:
import glob
import torch
import numpy as np
import pandas as pd
import os
import h5py
from torch.utils.data import DataLoader

# 讀 test spot index
with h5py.File("./dataset/elucidata_ai_challenge_data.h5","r") as f:
    test_spots     = f["spots/Test"]
    test_spot_table= pd.DataFrame(np.array(test_spots['S_7']))

fold_ckpts = sorted(glob.glob(os.path.join(trained_oof_model_folder, "fold*", "best_model.pt")))
models = []
for ckpt in fold_ckpts:
    net = VisionMLP_MultiTask(tile_dim=tile_dim, subtile_dim=center_dim, output_dim=C)
    net = net.to(device)
    net.load_state_dict(torch.load(ckpt, map_location="cpu"))
    net.to(device).eval()
    models.append(net)

all_fold_preds = []
for fold_id, net in enumerate(models):
    # 推論
    with torch.no_grad():
        preds = predict(net, test_loader, device)  # (N_test,35) numpy array

    # 1) 存每一折的原始預測
    df_fold = pd.DataFrame(preds, columns=[f"C{i+1}" for i in range(preds.shape[1])])
    df_fold.insert(0, "ID", test_spot_table.index)
    path_fold = os.path.join(trained_oof_model_folder, f"submission_fold{fold_id}.csv")
    df_fold.to_csv(path_fold, index=False)
    print(f"✅ Saved fold {fold_id} predictions to {path_fold}")

    all_fold_preds.append(preds)

# 2) 做 rank‐average ensemble
all_fold_preds = np.stack(all_fold_preds, axis=0)       # (K, N_test, 35)
ranks          = all_fold_preds.argsort(axis=2).argsort(axis=2).astype(float)
mean_rank      = ranks.mean(axis=0)                    # (N_test,35)

# 3) 存 final ensemble
df_ens = pd.DataFrame(mean_rank, columns=[f"C{i+1}" for i in range(mean_rank.shape[1])])
df_ens.insert(0, "ID", test_spot_table.index)
path_ens = os.path.join(trained_oof_model_folder, "submission_rank_ensemble.csv")
df_ens.to_csv(path_ens, index=False)
print(f"✅ Saved rank‐ensemble submission to {path_ens}")


  net.load_state_dict(torch.load(ckpt, map_location="cpu"))


✅ Saved fold 0 predictions to output_folder/rank-spot/realign/no_pretrain/3_encoder/filtered_directly_rank/k-fold/realign_all/Macenko_masked/submission_fold0.csv
✅ Saved fold 1 predictions to output_folder/rank-spot/realign/no_pretrain/3_encoder/filtered_directly_rank/k-fold/realign_all/Macenko_masked/submission_fold1.csv
✅ Saved fold 2 predictions to output_folder/rank-spot/realign/no_pretrain/3_encoder/filtered_directly_rank/k-fold/realign_all/Macenko_masked/submission_fold2.csv
✅ Saved fold 3 predictions to output_folder/rank-spot/realign/no_pretrain/3_encoder/filtered_directly_rank/k-fold/realign_all/Macenko_masked/submission_fold3.csv
✅ Saved fold 4 predictions to output_folder/rank-spot/realign/no_pretrain/3_encoder/filtered_directly_rank/k-fold/realign_all/Macenko_masked/submission_fold4.csv
✅ Saved fold 5 predictions to output_folder/rank-spot/realign/no_pretrain/3_encoder/filtered_directly_rank/k-fold/realign_all/Macenko_masked/submission_fold5.csv
✅ Saved rank‐ensemble submis