# Import Model

In [1]:
from python_scripts.pretrain_model import PretrainedEncoderRegressor
import torch.nn as nn

name = 'AE_Center_noaug'

checkpoint_path = f"AE_model/128/{name}/best.pt"

# 1) 实例化（会自动加载并冻结 encoder）
model = PretrainedEncoderRegressor(
    ae_checkpoint=checkpoint_path,
    ae_type="center",
    tile_dim=128,
    center_dim=128,
    neighbor_dim=128,
    output_dim=35
)

# 2) monkey‐patch 一个新的 head
model.decoder  = nn.Sequential(
    nn.Linear(128+128+128, 256),
    nn.LeakyReLU(0.01),
    nn.Dropout(0.1),
    nn.Linear(256, 128),
    nn.LeakyReLU(0.01),
    nn.Dropout(0.1),
    nn.Linear(128, 64),
    nn.LeakyReLU(0.01),
    nn.Dropout(0.1),
    nn.Linear(64, 35)
    
)

import torch
import torch.nn as nn

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, 3, stride=stride, padding=1)
        self.bn1   = nn.BatchNorm2d(out_channels)
        self.act1  = nn.SiLU()
        self.conv2 = nn.Conv2d(out_channels, out_channels, 3, stride=1, padding=1)
        self.bn2   = nn.BatchNorm2d(out_channels)
        self.shortcut = None
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 1, stride=stride),
                nn.BatchNorm2d(out_channels)
            )
        self.act2 = nn.SiLU()

    def forward(self, x):
        identity = x
        out = self.act1(self.bn1(self.conv1(x)))
        out = self.act2(self.bn2(self.conv2(out)))
        if self.shortcut is not None:
            identity = self.shortcut(x)
        return out + identity



class DeepTileEncoder(nn.Module):
    """加深的 Tile 分支：全局信息，多尺度池化 + 三层 MLP"""
    def __init__(self, out_dim, in_channels=3, negative_slope=0.01):
        super().__init__()
        self.layer0 = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.SiLU(),
            nn.MaxPool2d(2)  # 78→39
        )
        self.layer1 = nn.Sequential(
            ResidualBlock(32, 64),
            ResidualBlock(64, 64),
            nn.MaxPool2d(2)  # 39→19
        )
        self.layer2 = nn.Sequential(
            ResidualBlock(64, 128),
            ResidualBlock(128, 128),
            nn.MaxPool2d(2)  # 19→9
        )
        self.layer3 = nn.Sequential(
            ResidualBlock(128, 256),
            ResidualBlock(256, 256)
        )  # 保持 9×9

        # 多尺度池化
        self.global_pool = nn.AdaptiveAvgPool2d((1, 1))  # [B,256,1,1]
        self.mid_pool    = nn.AdaptiveAvgPool2d((3, 3))  # [B,256,3,3]

        total_dim = 256*1*1 + 256*3*3
        # 三层 MLP：total_dim → 2*out_dim → out_dim → out_dim
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.1),
            nn.Linear(total_dim, out_dim*4),
            nn.LeakyReLU(negative_slope),
            nn.Dropout(0.1),
            nn.Linear(out_dim*4, out_dim*2),
            nn.LeakyReLU(negative_slope),
            nn.Dropout(0.1),
            nn.Linear(out_dim*2, out_dim),
            nn.LeakyReLU(negative_slope),
        )

    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        # x: [B,256,9,9]
        g = self.global_pool(x).contiguous().reshape(x.size(0), -1)  # [B,256]
        m = self.mid_pool(x).contiguous().reshape(x.size(0), -1)     # [B,256*3*3]

        return self.fc(torch.cat([g, m], dim=1))


class SubtileEncoder(nn.Module):
    """多尺度 Subtile 分支：局部信息 + 两层 MLP"""
    def __init__(self, out_dim, in_channels=3, negative_slope=0.01):
        super().__init__()
        self.layer0 = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.SiLU(),
            nn.MaxPool2d(2)  # 26→13
        )
        self.layer1 = nn.Sequential(
            ResidualBlock(32, 64),
            ResidualBlock(64, 64),
            nn.MaxPool2d(2)  # 13→6
        )
        self.layer2 = nn.Sequential(
            ResidualBlock(64, 128),
            ResidualBlock(128, 128)
        )  # 保持 6×6

        self.global_pool = nn.AdaptiveAvgPool2d((1,1))
        self.mid_pool    = nn.AdaptiveAvgPool2d((2,2))
        self.large_pool    = nn.AdaptiveAvgPool2d((3,3))

        total_dim = 128*1*1 + 128*2*2 + 128*3*3
        # 两层 MLP：total_dim → out_dim*2 → out_dim
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.1),
            nn.Linear(total_dim, out_dim*2),
            nn.LeakyReLU(negative_slope),
            nn.Dropout(0.1),
            nn.Linear(out_dim*2, out_dim),
            nn.LeakyReLU(negative_slope),
        )

    def forward(self, x):
        B, N, C, H, W = x.shape
        x = x.contiguous().reshape(B*N, C, H, W)
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        # g,m: [B*N, feat]
        g = self.global_pool(x).contiguous().reshape(B, N, -1)
        m = self.mid_pool(x).contiguous().reshape(B, N, -1)
        l = self.large_pool(x).contiguous().reshape(B, N, -1)

        # 合并 N 张 subtiles，再 FC
        feat = torch.cat([g, m, l], dim=2).mean(dim=1).contiguous()  # [B, total_dim]
        return self.fc(feat)
class CenterSubtileEncoder(nn.Module):
    """專門處理中心 subtile 的 Encoder"""
    def __init__(self, out_dim, in_channels=3, negative_slope= 0.01):
        super().__init__()
        self.layer0 = nn.Sequential(
            nn.Conv2d(in_channels, 32, 3, padding=1),
            nn.BatchNorm2d(32),
            nn.SiLU(),
            nn.MaxPool2d(2)  # 26→13
        )
        self.layer1 = nn.Sequential(
            ResidualBlock(32, 64),
            ResidualBlock(64, 64),
            nn.MaxPool2d(2)  # 13→6
        )
        self.layer2 = nn.Sequential(
            ResidualBlock(64, 128),
            ResidualBlock(128, 128)
        )  # 6×6

        # 多尺度池化
        self.global_pool = nn.AdaptiveAvgPool2d((1,1))
        self.mid_pool    = nn.AdaptiveAvgPool2d((2,2))
        self.large_pool    = nn.AdaptiveAvgPool2d((3,3))

        total_dim = 128*1*1 + 128*2*2 + 128*3*3
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Dropout(0.1),
            nn.Linear(total_dim, out_dim*2),
            nn.LeakyReLU(negative_slope),
            nn.Dropout(0.1),
            nn.Linear(out_dim*2, out_dim),
            nn.LeakyReLU(negative_slope),
        )

    def forward(self, x):
        x = self.layer0(x)
        x = self.layer1(x)
        x = self.layer2(x)
        g = self.global_pool(x).contiguous().reshape(x.size(0), -1)
        m = self.mid_pool(x).contiguous().reshape(x.size(0), -1)
        l = self.large_pool(x).contiguous().reshape(x.size(0), -1)

        return self.fc(torch.cat([g, m, l], dim=1)).contiguous()



class VisionMLP_MultiTask(nn.Module):
    """整體多任務模型：融合 tile + subtile + center，使用動態權重融合"""
    def __init__(self, tile_dim=128, subtile_dim=64, output_dim=35, negative_slope=0.01):
        super().__init__()
        self.encoder_tile    = DeepTileEncoder(tile_dim)
        self.encoder_subtile = SubtileEncoder(subtile_dim)
        self.encoder_center  = CenterSubtileEncoder(subtile_dim)

        # 融合層：輸入三個分支的 concat，輸出三個 gate
        self.gate_fc = nn.Sequential(
            nn.Linear(tile_dim + subtile_dim + subtile_dim, 64),
            nn.LeakyReLU(negative_slope),
            nn.Linear(64, 3),  # 對 tile, subtile, center 分支輸出 gate
            nn.Softmax(dim=1)  # 轉成權重
        )

        # 輸出 decoder：輸入為 tile_dim (因為融合後只剩一個 vector)
        self.decoder = nn.Sequential(
            nn.Linear(tile_dim, 256),
            nn.LeakyReLU(negative_slope),
            nn.Dropout(0.1),
            nn.Linear(256, 128),
            nn.LeakyReLU(negative_slope),
            nn.Dropout(0.1),
            nn.Linear(128, 64),
            nn.LeakyReLU(negative_slope),
            nn.Dropout(0.1),
            nn.Linear(64, output_dim),
        )

    def forward(self, tile, subtiles):
        tile = tile.contiguous()
        subtiles = subtiles.contiguous()
        center = subtiles[:, 4]

        f_tile = self.encoder_tile(tile)         # [B, tile_dim]
        f_sub  = self.encoder_subtile(subtiles)  # [B, subtile_dim]
        f_center = self.encoder_center(center)   # [B, subtile_dim]

        # 拼接三個分支做 gating
        features_cat = torch.cat([f_tile, f_sub, f_center], dim=1)  # [B, tile+sub+center]
        gates = self.gate_fc(features_cat)  # [B, 3]

        # 對三個分支做 weighted sum
        f_fused = (
            gates[:, 0:1] * f_tile + 
            gates[:, 1:2] * f_sub + 
            gates[:, 2:3] * f_center
        )  # [B, tile_dim]（注意：需保證 f_tile == f_sub == f_center 的維度）

        return self.decoder(f_fused)
    
# class VisionMLP_MultiTask(nn.Module):
#     """整體多任務模型：融合 tile + subtile + center + position 特徵"""
#     def __init__(self, tile_dim=128, subtile_dim=128, output_dim=35, negative_slope=0.01):
#         super().__init__()
#         self.encoder_tile    = DeepTileEncoder(tile_dim)
#         self.encoder_subtile = SubtileEncoder(subtile_dim)
#         self.encoder_center  = CenterSubtileEncoder(subtile_dim)

#         self.feature_dim = tile_dim + subtile_dim + subtile_dim # +2 for position(x,y)

#         self.decoder = nn.Sequential(
#             nn.Linear(self.feature_dim, 256),
#             nn.LeakyReLU(negative_slope),
#             nn.Dropout(0.1),
#             nn.Linear(256, 128),
#             nn.LeakyReLU(negative_slope),
#             nn.Dropout(0.1),
#             nn.Linear(128, 64),
#             nn.LeakyReLU(negative_slope),
#             nn.Dropout(0.1),
#             nn.Linear(64, output_dim),
#         )

#     def forward(self, tile, subtiles):
#         tile = tile.contiguous()
#         subtiles = subtiles.contiguous()
#         center = subtiles[:, 4]

#         f_tile = self.encoder_tile(tile)         # [B, tile_dim]
#         f_sub  = self.encoder_subtile(subtiles)  # [B, subtile_dim]
#         f_center = self.encoder_center(center)   # [B, subtile_dim]

#         # 拼接特徵向量與座標
#         features_cat = torch.cat([f_tile, f_sub, f_center], dim=1)  # [B, tile+sub+center+2]

#         return self.decoder(features_cat)


# 用法示例
model = VisionMLP_MultiTask(tile_dim=128, subtile_dim=128, output_dim=35)


# —— 5) 确保只有 decoder 可训练 ——  
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
total     = sum(p.numel() for p in model.parameters())
print(f"Trainable / total params = {trainable:,} / {total:,}")


device   = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

model = model.to(device)
model.load_state_dict(torch.load('output_folder/rank-spot/realign/whole_worflow/s_m_l/filtered_directly_rank/k-fold_mix/realign_all/Macenko_masked/results/model_epoch022.pt', map_location="cpu"))
model.to(device).eval()

  ae.load_state_dict(torch.load(ae_checkpoint, map_location="cpu"))
  model.load_state_dict(torch.load('output_folder/rank-spot/realign/whole_worflow/s_m_l/filtered_directly_rank/k-fold_mix/realign_all/Macenko_masked/results/model_epoch022.pt', map_location="cpu"))


Trainable / total params = 6,639,142 / 6,639,142


VisionMLP_MultiTask(
  (encoder_tile): DeepTileEncoder(
    (layer0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU()
      (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (layer1): Sequential(
      (0): ResidualBlock(
        (conv1): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (act1): SiLU()
        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (shortcut): Sequential(
          (0): Conv2d(32, 64, kernel_size=(1, 1), stride=(1, 1))
          (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        )
        

## Load Model

# Import training data

## Same in multiple .pt

In [2]:
import os
import torch
import random
import inspect
from python_scripts.import_data import load_all_tile_data

# 用法範例
#folder = "dataset/spot-rank/version-3/only_tile_sub/original_train"
folder = "dataset/spot-rank/filtered_directly_rank/masked/realign/Macenko_masked/filtered/train_data/"

grouped_data = load_all_tile_data( 
        folder_path=folder,
        model=model,
        fraction=1,
        shuffle=False
    )

    # grouped_data 現在只會有 model.forward() 需要的 key，
    # 像 ['tile','subtiles','neighbors','norm_coord','node_feat','adj_list','edge_feat','label','source_idx']
print("Loaded keys:", grouped_data.keys())
print("Samples:", len(next(iter(grouped_data.values()))))




  from .autonotebook import tqdm as notebook_tqdm
  d = torch.load(fpath, map_location='cpu')


Loaded keys: dict_keys(['subtiles', 'label', 'source_idx', 'slide_idx', 'tile', 'position'])
Samples: 8348


In [5]:
from python_scripts.import_data import convert_item, get_model_inputs

import torch
from torch.utils.data import Dataset
import inspect
import numpy as np

class importDataset(Dataset):
    def __init__(self, data_dict, model, image_keys=None, transform=None, print_sig=False):
        self.data = data_dict
        self.image_keys = set(image_keys) if image_keys is not None else set()
        self.transform = transform if transform is not None else lambda x: x
        self.forward_keys = list(get_model_inputs(model, print_sig=print_sig).parameters.keys())

        expected_length = None
        for key, value in self.data.items():
            if expected_length is None:
                expected_length = len(value)
            if len(value) != expected_length:
                raise ValueError(f"資料欄位 '{key}' 的長度 ({len(value)}) 與預期 ({expected_length}) 不一致。")

        for key in self.forward_keys:
            if key not in self.data:
                raise ValueError(f"data_dict 缺少模型 forward 所需欄位: '{key}'。目前可用的欄位: {list(self.data.keys())}")
        if "label" not in self.data:
            raise ValueError(f"data_dict 必須包含 'label' 欄位。可用的欄位: {list(self.data.keys())}")
        if "source_idx" not in self.data:
            raise ValueError("data_dict 必須包含 'source_idx' 欄位，用於 trace 原始順序對應。")
        if "position" not in self.data:
            raise ValueError("data_dict 必須包含 'position' 欄位，用於 trace 原始順序對應。")
    def __len__(self):
        return len(next(iter(self.data.values())))

    def __getitem__(self, idx):
        sample = {}
        for key in self.forward_keys:
            value = self.data[key][idx]
            value = self.transform(value)
            value = convert_item(value, is_image=(key in self.image_keys))
            if isinstance(value, torch.Tensor):
                value = value.float()
            sample[key] = value

        label = self.transform(self.data["label"][idx])
        label = convert_item(label, is_image=False)
        if isinstance(label, torch.Tensor):
            label = label.float()
        sample["label"] = label

        # 加入 source_idx
        source_idx = self.data["source_idx"][idx]
        sample["source_idx"] = torch.tensor(source_idx, dtype=torch.long)
        # 加入 position （假设 data_dict 中 'position' 是 (x, y) 或 [x, y]）
        pos = self.data["position"][idx]
        sample["position"] = torch.tensor(pos, dtype=torch.float)
        return sample
    def check_item(self, idx=0, num_lines=5):
        expected_keys = self.forward_keys + ['label', 'source_idx', 'position']
        sample = self[idx]
        print(f"🔍 Checking dataset sample: {idx}")
        for key in expected_keys:
            if key not in sample:
                print(f"❌ 資料中缺少 key: {key}")
                continue
            tensor = sample[key]
            if isinstance(tensor, torch.Tensor):
                try:
                    shape = tensor.shape
                except Exception:
                    shape = "N/A"
                dtype = tensor.dtype if hasattr(tensor, "dtype") else "N/A"
                output_str = f"📏 {key} shape: {shape} | dtype: {dtype}"
                if tensor.numel() > 0:
                    try:
                        tensor_float = tensor.float()
                        mn = tensor_float.min().item()
                        mx = tensor_float.max().item()
                        mean = tensor_float.mean().item()
                        std = tensor_float.std().item()
                        output_str += f" | min: {mn:.3f}, max: {mx:.3f}, mean: {mean:.3f}, std: {std:.3f}"
                    except Exception:
                        output_str += " | 無法計算統計數據"
                print(output_str)
                if key not in self.image_keys:
                    if tensor.ndim == 0:
                        print(f"--- {key} 資料為純量:", tensor)
                    elif tensor.ndim == 1:
                        print(f"--- {key} head (前 {num_lines} 個元素):")
                        print(tensor[:num_lines])
                    else:
                        print(f"--- {key} head (前 {num_lines} 列):")
                        print(tensor[:num_lines])
            else:
                # 如果 position 存的是 list/tuple/etc，也会走这里
                print(f"📏 {key} (非 tensor 資料):", tensor)
        print("✅ All checks passed!")


full_dataset = importDataset(grouped_data, model,
                             image_keys=['tile','subtiles'],
                             transform=lambda x: x)

full_dataset.check_item()

🔍 Checking dataset sample: 0
📏 tile shape: torch.Size([3, 78, 78]) | dtype: torch.float32 | min: 0.157, max: 1.000, mean: 0.680, std: 0.142
📏 subtiles shape: torch.Size([9, 3, 26, 26]) | dtype: torch.float32 | min: 0.157, max: 1.000, mean: 0.680, std: 0.142
📏 label shape: torch.Size([35]) | dtype: torch.float32 | min: 1.000, max: 35.000, mean: 18.000, std: 10.247
--- label head (前 5 個元素):
tensor([12., 24., 18.,  6., 30.])
📏 source_idx shape: torch.Size([]) | dtype: torch.int64 | min: 0.000, max: 0.000, mean: 0.000, std: nan
--- source_idx 資料為純量: tensor(0)
📏 position shape: torch.Size([2]) | dtype: torch.float32 | min: 0.171, max: 0.632, mean: 0.401, std: 0.326
--- position head (前 5 個元素):
tensor([0.6318, 0.1707])
✅ All checks passed!


  std = tensor_float.std().item()


In [3]:
import torch
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.cluster import KMeans
import pandas as pd
import joblib
from tqdm import tqdm
import torch
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
import joblib
from tqdm import tqdm

import torch
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.cluster import KMeans
from sklearn.metrics import silhouette_score
from sklearn.decomposition import PCA
from skimage.feature import local_binary_pattern
from skimage.feature.texture import graycomatrix, graycoprops
from skimage.filters import sobel
import pywt
import joblib
from tqdm import tqdm

# === 工具函數 ===
def compute_ae_reconstruction_loss(ae_model, dataloader, device, ae_type):
    ae_model.eval()
    losses = []
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Computing AE recon loss"):
            tile = batch['tile'].to(device)
            subtiles = batch['subtiles'].to(device)
            recon = ae_model(tile, subtiles)

            if ae_type == 'center':
                target = subtiles[:, 4]
            else:
                target = subtiles

            loss = F.mse_loss(recon, target, reduction='none')
            loss = loss.view(loss.shape[0], -1).mean(dim=1)
            losses.append(loss.cpu().numpy())
    return np.concatenate(losses)

def compute_latent_stats(latents):
    return np.concatenate([
        latents.mean(axis=1, keepdims=True),
        latents.std(axis=1, keepdims=True),
        latents.min(axis=1, keepdims=True),
        latents.max(axis=1, keepdims=True),
    ], axis=1)


# === Loading Label Cluster Map ===
def load_label_cluster_map(filepath="dataset/label_cluster_map.pkl", use_clusters="both"):
    cluster_map = joblib.load(filepath)
    # keys are 'label_cluster_map_4', 'label_cluster_map_20'
    if use_clusters == "4":
        return cluster_map['label_cluster_map_4']
    elif use_clusters == "20":
        return cluster_map['label_cluster_map_20']
    elif use_clusters == "both":
        return {
            '4': cluster_map['label_cluster_map_4'],
            '20': cluster_map['label_cluster_map_20']
        }
    else:
        raise ValueError("Invalid cluster selection. Choose '4', '20', or 'both'.")

# === Helper Functions ===
def choose_best_n_clusters(X, min_k=35, max_k=38, random_state=42):
    best_k, best_score = min_k, -np.inf
    for k in range(min_k, max_k + 1):
        labels = KMeans(n_clusters=k, random_state=random_state).fit_predict(X)
        score = silhouette_score(X, labels)
        print(f"Silhouette score for k={k}: {score:.4f}")
        if score > best_score:
            best_score, best_k = score, k
    print(f"Selected best k={best_k} (score={best_score:.4f})")
    return best_k


def compute_cluster_summary_stats(matrix, cluster_ids):
    """
    matrix: (n_samples, n_dims)
    cluster_ids: length n_samples, 每個 sample 所屬的群編號
    回傳 shape=(n_samples, 4)：
      [mean, std, min, max] 是該 sample 所屬群裡，所有 matrix 值的全局統計
    """
    stats = np.zeros((matrix.shape[0], 4), dtype=float)
    for c in np.unique(cluster_ids):
        mask = (cluster_ids == c)
        values = matrix[mask]            # shape=(n_c, n_dims)
        flat = values.flatten()         # 把維度攤平成一維
        summary = [
            flat.mean(),
            flat.std(),
            flat.min(),
            flat.max(),
        ]
        stats[mask] = summary           # 同一群裡的所有 sample 都用同一組 summary
    return stats


def compute_feature_cluster_stats(matrix, feature_cluster_ids):
    """
    Given matrix shape (n_samples, n_features) and feature_cluster_ids length n_features,
    compute per-sample [mean, std, min, max] within each feature-cluster.
    Returns array shape (n_samples, n_clusters*4)
    """
    unique_clusters = np.unique(feature_cluster_ids)
    stats_per_cluster = []
    for c in unique_clusters:
        mask = (feature_cluster_ids == c)
        cluster_vals = matrix[:, mask]  # shape (n_samples, n_feats_in_cluster)
        mean = cluster_vals.mean(axis=1, keepdims=True)
        std  = cluster_vals.std(axis=1, keepdims=True)
        mn   = cluster_vals.min(axis=1, keepdims=True)
        mx   = cluster_vals.max(axis=1, keepdims=True)
        stats_per_cluster.append(np.hstack([mean, std, mn, mx]))
    return np.hstack(stats_per_cluster)


def compute_rgb_stats(dataset):
    stats = []
    for i in range(len(dataset)):
        sub = dataset[i]['subtiles'][4].numpy()
        ch_stats = []
        for ch in range(sub.shape[0]):
            vals = sub[ch]
            ch_stats += [vals.mean(), vals.std(), vals.min(), vals.max()]
        stats.append(ch_stats)
    return np.array(stats)


def compute_rgb_cluster_stats(dataset, cluster_ids):
    unique = np.unique(cluster_ids)
    stats_dict = {}
    for c in unique:
        idxs = np.where(cluster_ids == c)[0]
        acc = []
        for i in idxs:
            acc.append(dataset[i]['subtiles'][4].numpy())
        arr = np.stack(acc)
        ch_stats = []
        for ch in range(arr.shape[1]):
            flat = arr[:, ch].flatten()
            ch_stats += [flat.mean(), flat.std(), flat.min(), flat.max()]
        stats_dict[c] = np.array(ch_stats)
    return np.vstack([stats_dict[c] for c in cluster_ids])
# === RGB Statistics Extensions ===
# center subtile (index 4) already in compute_rgb_stats

def compute_all_subtiles_rgb_stats(dataset):
    """
    每個 sample 的所有 subtiles (0-8) 中，每個 channel 的 mean/std/min/max，concat 形成 (n, 9*4*C)
    """
    stats = []
    for i in range(len(dataset)):
        subs = dataset[i]['subtiles'].numpy()  # shape (9, C, H, W)
        sample_stats = []
        for idx in range(subs.shape[0]):
            for ch in range(subs.shape[1]):
                vals = subs[idx, ch]
                sample_stats += [vals.mean(), vals.std(), vals.min(), vals.max()]
        stats.append(sample_stats)
    return np.array(stats)

def compute_subtiles_except_center_rgb_stats(dataset):
    """
    每個 sample 的 subtiles 除了 center (index 4) 外，其餘 8 塊合併後，每 channel 的 mean/std/min/max，shape (n, C*4)
    """
    stats = []
    for i in range(len(dataset)):
        subs = dataset[i]['subtiles'].numpy()  # (9, C, H, W)
        exclude = np.concatenate([subs[:4], subs[5:]], axis=0)  # (8, C, H, W)
        sample_stats = []
        for ch in range(exclude.shape[1]):
            vals = exclude[:, ch].flatten()
            sample_stats += [vals.mean(), vals.std(), vals.min(), vals.max()]
        stats.append(sample_stats)
    return np.array(stats)

def compute_tile_rgb_stats(dataset):
    """
    每個 sample 的 tile (整張圖)，每 channel 的 mean/std/min/max，shape (n, C*4)
    """
    stats = []
    for i in range(len(dataset)):
        tile = dataset[i]['tile'].numpy()  # shape (C, H, W)
        sample_stats = []
        for ch in range(tile.shape[0]):
            vals = tile[ch]
            sample_stats += [vals.mean(), vals.std(), vals.min(), vals.max()]
        stats.append(sample_stats)
    return np.array(stats)

# === Texture & Pattern Features ===


def compute_wavelet_stats(dataset, wavelet='db1', level=2):
    feats = []
    for i in range(len(dataset)):
        patch = dataset[i]['subtiles'][4].numpy()[0]
        coeffs = pywt.wavedec2(patch, wavelet=wavelet, level=level)
        sample = []
        for arr in coeffs:
            if isinstance(arr, tuple):
                for sub in arr:
                    sample += [sub.mean(), sub.std()]
            else:
                sample += [arr.mean(), arr.std()]
        feats.append(sample)
    return np.array(feats)


def compute_sobel_stats(dataset):
    feats = []
    for i in range(len(dataset)):
        gray = dataset[i]['tile'].numpy().mean(axis=0)
        edge = sobel(gray)
        feats.append([edge.mean(), edge.std(), edge.min(), edge.max()])
    return np.array(feats)

# === Texture & Pattern Features ===

def compute_hsv_stats(dataset):
    feats = []
    from skimage.color import rgb2hsv
    for i in range(len(dataset)):
        sub = dataset[i]['subtiles'][4].numpy()  # (C, H, W)
        img = sub[:3].transpose(1,2,0)         # (H, W, 3)
        hsv = rgb2hsv(img)
        sample = []
        for ch in range(3):
            vals = hsv[:,:,ch]
            sample += [vals.mean(), vals.std(), vals.min(), vals.max()]
        feats.append(sample)
    return np.array(feats)


def compute_color_moments(dataset):
    feats = []
    from scipy.stats import skew, kurtosis
    for i in range(len(dataset)):
        sub = dataset[i]['subtiles'][4].numpy()  # (C, H, W)
        img = sub.transpose(1,2,0)               # (H, W, C)
        sample = []
        for ch in range(img.shape[2]):
            vals = img[:,:,ch].ravel()
            sample += [vals.mean(), vals.std(), skew(vals), kurtosis(vals)]
        feats.append(sample)
    return np.array(feats)


# === Subtile Contrast Features ===

def compute_subtile_contrast_stats(dataset, eps=1e-7):
    """
    Compute contrast between center subtile and surrounding subtiles:
    For each sample and each channel, calculate:
      diff = center_mean - surround_mean
      ratio = center_mean / (surround_mean + eps)
    Returns array (n_samples, C*2)
    """
    stats = []
    for i in range(len(dataset)):
        subs = dataset[i]['subtiles'].numpy()  # (9, C, H, W)
        center = subs[4]  # (C, H, W)
        surround = np.concatenate([subs[:4], subs[5:]], axis=0)  # (8, C, H, W)
        center_mean = center.reshape(center.shape[0], -1).mean(axis=1)
        surround_mean = surround.reshape(surround.shape[0], surround.shape[1], -1).mean(axis=2).mean(axis=0)
        diff = center_mean - surround_mean
        ratio = center_mean / (surround_mean + eps)
        stats.append(np.hstack([diff, ratio]))
    return np.array(stats)

# === H&E Color Deconvolution Features ===
def compute_he_stats(dataset):
    """
    Compute H&E stain intensity stats from center subtile RGB (3 channels):
    Returns array (n_samples, 8): [H_mean, H_std, H_min, H_max, E_mean, E_std, E_min, E_max]
    """
    from skimage.color import separate_stains, hed_from_rgb
    stats = []
    for i in range(len(dataset)):
        sub = dataset[i]['subtiles'][4].numpy()  # (C, H, W)
        # assume first 3 channels are RGB
        rgb = sub[:3].transpose(1,2,0)
        hed = separate_stains(rgb, hed_from_rgb)
        h = hed[:,:,0]
        e = hed[:,:,1]
        sample = [
            h.mean(), h.std(), h.min(), h.max(),
            e.mean(), e.std(), e.min(), e.max()
        ]
        stats.append(sample)
    return np.array(stats)


# === Sliding Window Std Features ===

def compute_sliding_std_stats(dataset, window_size=3, eps=1e-6):
    """
    Compute local standard deviation within a sliding window of given size on center subtile.
    Returns array shape (n_samples, C*2) with mean and max of local std for each channel.
    """
    from scipy.ndimage import uniform_filter

    stats = []
    for i in range(len(dataset)):
        sub = dataset[i]['subtiles'][4].numpy()  # (C, H, W)
        sample = []
        for ch in range(sub.shape[0]):
            arr = sub[ch]
            # compute local mean and mean of squares
            mean = uniform_filter(arr, size=window_size)
            mean_sq = uniform_filter(arr * arr, size=window_size)
            local_std = np.sqrt(np.maximum(mean_sq - mean * mean, 0))
            sample += [local_std.mean(), local_std.max()]
        stats.append(sample)
    return np.array(stats)


# === Distribution-based Features ===

def compute_entropy(oof_preds, eps=1e-12):
    """
    Compute Shannon entropy for each sample's OOF prediction distribution.
    oof_preds: (n_samples, C)
    returns: array shape (n_samples, 1)
    """
    # normalize to sum 1
    probs = oof_preds / (oof_preds.sum(axis=1, keepdims=True) + eps)
    ent = -np.sum(probs * np.log(probs + eps), axis=1, keepdims=True)
    return ent


def compute_top2_diff(oof_preds):
    """
    Compute difference between the top-1 and top-2 predicted values per sample.
    returns: array shape (n_samples, 1)
    """
    # sort descending
    sorted_preds = -np.sort(-oof_preds, axis=1)
    diff = sorted_preds[:, 0] - sorted_preds[:, 1]
    return diff.reshape(-1, 1)

# === Pairwise Differences for All Cell Types ===
from itertools import combinations

def compute_pairwise_diff(oof_preds):
    """
    Compute raw differences for every pair of cell-type predictions.
    For each sample, returns array of length C*(C-1)/2 in order of (i<j).
    """
    n_samples, C = oof_preds.shape
    # generate list of index pairs i<j
    idx_pairs = list(combinations(range(C), 2))
    # stack differences for each pair
    diffs = np.stack([oof_preds[:, i] - oof_preds[:, j] for i, j in idx_pairs], axis=1)
    return diffs

def compute_dispersion(oof_preds):
    """
    Compute dispersion metric (Gini impurity) or std across cell-type predictions.
    returns: array shape (n_samples, 1)
    """
    # Gini impurity: 1 - sum(p_i^2)
    probs = oof_preds / (oof_preds.sum(axis=1, keepdims=True) + 1e-12)
    gini = 1 - np.sum(probs**2, axis=1)
    return gini.reshape(-1, 1)

def compute_ae_embeddings(loader, recon_model, device):
    """
    Extract fused encoder embeddings from the PretrainedEncoderRegressor.
    Returns numpy array shape (n_samples, fusion_dim).
    """
    recon_model.eval()
    embeddings = []
    with torch.no_grad():
        for batch in loader:
            tiles = batch['tile'].to(device)
            subtiles = batch['subtiles'].to(device)
            subtiles = subtiles.contiguous()
            tiles     = tiles.contiguous()
            # forward up to encoder fusion
            f_c = recon_model.enc_center(subtiles[:, 4])
            f_n = recon_model.enc_neigh(subtiles)
            f_t = recon_model.enc_tile(tiles)
            fused = torch.cat([f_c, f_n, f_t], dim=1)
            embeddings.append(fused.cpu().numpy())
    return np.vstack(embeddings)

# === Main Function ===
def generate_meta_features(
    dataset,
    oof_preds,
    image_latents,
    model_for_recon,
    device,
    ae_type,
    label_cluster_map_path="dataset/label_cluster_map.pkl",
    use_clusters="both"
):
    loader = DataLoader(dataset, batch_size=64, shuffle=False)
    recon_loss = compute_ae_reconstruction_loss(model_for_recon, loader, device, ae_type)

    # 1-4: OOF preds, image latents, recon loss, latent stats
    image_latent = compute_ae_embeddings(loader, model_for_recon, device)
    latent_stats = compute_latent_stats(image_latent)

    # # 6: Ground Truth Cell Expression Cluster Stats
    label_map = load_label_cluster_map(label_cluster_map_path, use_clusters)
    if use_clusters == "both":
        stats4  = compute_feature_cluster_stats(oof_preds, label_map['4'])   # shape (n, 4*4)
        stats20 = compute_feature_cluster_stats(oof_preds, label_map['20'])  # shape (n, 20*4)
        gt_cluster_stats = np.hstack([stats4, stats20])  # (n, 16+80)
    else:
        gt_cluster_stats = compute_feature_cluster_stats(oof_preds, label_map)

    # # 8: OOF pred clusters stats
    # # 把 cell 當作樣本來作群聚
    # # best_k_cells = choose_best_n_clusters(oof_preds.T, min_k=2, max_k=6)
    # cell_cluster_ids = KMeans(n_clusters=4, random_state=42)\
    #                     .fit_predict(oof_preds.T)
    # pred_cluster_stats = compute_feature_cluster_stats(oof_preds, cell_cluster_ids)

    # # 9-10: Latent clusters & summary
    # best_k_latent = choose_best_n_clusters(image_latents, min_k=2, max_k=50)
    # latent_ids = KMeans(n_clusters=best_k_latent, random_state=42).fit_predict(image_latents)
    # latent_summary = compute_cluster_summary_stats(image_latents, latent_ids)  # shape = (2197, 4)

    # # 11-12: AE loss clusters & summary
    # loss_vals = recon_loss.reshape(-1,1)
    # best_k_loss = choose_best_n_clusters(loss_vals, min_k=2, max_k=50)
    # loss_ids = KMeans(n_clusters=best_k_loss, random_state=42).fit_predict(loss_vals)
    # loss_summary_stats = compute_cluster_summary_stats(loss_vals, loss_ids)


    # # 13: RGB stats
    rgb_stats = compute_rgb_stats(dataset)
    # 新增: 所有 subtiles RGB, 除 center 以外的 subtiles RGB, 整張 tile RGB
    rgb_all_subs = compute_all_subtiles_rgb_stats(dataset)
    rgb_except_center = compute_subtiles_except_center_rgb_stats(dataset)
    rgb_tile = compute_tile_rgb_stats(dataset)
    
    # # 14-15: Cluster-level RGB stats
    # latent_rgb_stats = compute_rgb_cluster_stats(dataset, latent_ids)
    # loss_rgb_stats   = compute_rgb_cluster_stats(dataset, loss_ids)
    # concatenate all features
    
        # 質地 & 紋理特徵
    wavelet_feats     = compute_wavelet_stats(dataset)
    sobel_feats       = compute_sobel_stats(dataset)
    
    
    # 新增: 顏色空間與色彩分佈特徵
    hsv_feats         = compute_hsv_stats(dataset)
    color_moments_feats = compute_color_moments(dataset)
    
    # 新增: Subtile 間對比特徵
    contrast_feats    = compute_subtile_contrast_stats(dataset)
    # 新增: H&E 染色成分強度特徵
    he_feats          = compute_he_stats(dataset)
    
    sliding_std_stats = compute_sliding_std_stats(dataset)
    
    # ent = compute_entropy(oof_preds)
    top2 = compute_top2_diff(oof_preds)
    # dis = compute_dispersion(oof_preds)
    
    features = np.concatenate([
        oof_preds,
        image_latents,
        image_latent,
        recon_loss[:,None],
        latent_stats,
        gt_cluster_stats,
        # pred_cluster_stats,
        # latent_ids.reshape(-1,1),
        # latent_summary,
        # loss_ids.reshape(-1,1),
        # loss_summary_stats,
        rgb_stats,
        # latent_rgb_stats,
        # loss_rgb_stats
        rgb_all_subs,
        rgb_except_center,
        rgb_tile,

        wavelet_feats,
        sobel_feats,
        
        hsv_feats,
        color_moments_feats,
        
        contrast_feats,
        he_feats,
        sliding_std_stats,
        
        # ent,
        top2,
        # dis
    ], axis=1)
    print(f"✅ Generated meta-features with shape: {features.shape}")
    return features


In [7]:
import os
import numpy as np
import joblib
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.multioutput import MultiOutputRegressor
from sklearn.model_selection import train_test_split
import lightgbm as lgb
from scipy.stats import rankdata
from python_scripts.import_data import importDataset
from python_scripts.operate_model import predict
from lightgbm import early_stopping, log_evaluation
import h5py
import pandas as pd
from python_scripts.pretrain_model import PretrainedEncoderRegressor

# --------------- Settings ---------------
trained_model_path = 'output_folder/rank-spot/realign/whole_worflow/s_m_l/filtered_directly_rank/k-fold_mix/realign_all/Macenko_masked/results/model_epoch022.pt'
n_samples  = len(full_dataset)
C          = 35
BATCH_SIZE = 64

tile_dim    = 128
center_dim  = 128
neighbor_dim= 128
version = 'version2'

pretrained_ae_name  = 'AE_Center_noaug'
pretrained_ae_path  = f"AE_model/128/{pretrained_ae_name}/best.pt"
ae_type            = 'center'
device             = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

# 1) Ground truth labels
y_true = np.vstack([ full_dataset[i]['label'].cpu().numpy() for i in range(n_samples) ])

# 2) Load the single model and get preds + latents on full_dataset
net = VisionMLP_MultiTask(tile_dim=tile_dim, subtile_dim=center_dim, output_dim=C)
net.load_state_dict(torch.load(trained_model_path, map_location=device))
net = net.to(device).eval()

full_loader = DataLoader(full_dataset, batch_size=BATCH_SIZE, shuffle=False)
all_preds, all_latents = [], []
with torch.no_grad():
    for batch in full_loader:
        tiles, subtiles = batch['tile'].to(device), batch['subtiles'].to(device)
        center = subtiles[:, 4].contiguous()
        f_c = net.encoder_center(center)
        f_n = net.encoder_subtile(subtiles)
        f_t = net.encoder_tile(tiles)
        features_cat = torch.cat([f_c, f_n, f_t], dim=1)
        gates = net.gate_fc(features_cat)        # (B, 3), softmax over features
        f_fused = (
            gates[:, 0:1] * f_t +
            gates[:, 1:2] * f_n +
            gates[:, 2:3] * f_c
        )  # (B, tile_dim) — 注意各 encoder 輸出維度要一致
        out  = net.decoder(f_fused)
        all_preds.append(out.cpu().numpy())
        all_latents.append(features_cat.cpu().numpy())

oof_preds     = np.vstack(all_preds)     # shape (n_samples, 35)
image_latents = np.vstack(all_latents)   # shape (n_samples, fusion_dim)

# 3) AE reconstruction model (unchanged)
recon_model = PretrainedEncoderRegressor(
    ae_checkpoint=pretrained_ae_path,
    ae_type=ae_type,
    tile_dim=tile_dim,
    center_dim=center_dim,
    neighbor_dim=neighbor_dim,
    output_dim=C,
    mode='reconstruction'
).to(device)

# 4) Generate meta features for the full dataset
meta = generate_meta_features(
    dataset         = full_dataset,
    oof_preds       = oof_preds,
    image_latents   = image_latents,
    model_for_recon = recon_model,
    device          = device,
    ae_type         = ae_type,
    use_clusters    = "both"
)

# 5) Train a single Meta‐Model on all meta features

lgb_base = lgb.LGBMRegressor(
    objective='regression',         # 等价于 'l2'
    metric='rmse',
    learning_rate=0.01,             # 稍微放宽到 0.01
    n_estimators=20000,             # 上限提高，配合 early stopping
    max_depth=8,                    # 深度可以再增加一些
    num_leaves=127,                 # 2^7-1，与 max_depth=8 匹配
    feature_fraction=0.8,           # 每棵树采 80% 特征
    bagging_fraction=0.8,           # 每棵树采 80% 样本
    bagging_freq=1,                 # 开启行抽样
    min_data_in_leaf=30,            # 叶子上最少 20 样本
    reg_alpha=1.0,                  # L1 正则
    reg_lambda=1.0,                 # L2 正则
    verbosity=-1
)


# 5a) Optional: split a small val‐set for early stopping
X_tr, X_val, y_tr, y_val = train_test_split(
    meta, y_true, test_size=0.2, random_state=42
)

meta_model = MultiOutputRegressor(lgb_base)
meta_model.estimators_ = []
for i in range(C):
    print(f"Training meta‐model target {i} …")
    m = lgb.LGBMRegressor(**lgb_base.get_params())
    m.fit(
        X_tr, y_tr[:, i],
        eval_set=[(X_val, y_val[:, i])],
        callbacks=[early_stopping(stopping_rounds=200), log_evaluation(period=100)]
    )
    meta_model.estimators_.append(m)

save_folder = f"only_lightgbm/{version}"  # 修改為你想要的資料夾名稱
if not os.path.exists(save_folder):   
    os.makedirs(save_folder)
# 6) Save the single meta‐model
joblib.dump(meta_model, f'{save_folder}meta_model_single.pkl')
print("✅ Saved entire‐dataset meta‐model → meta_model_single.pkl")


  net.load_state_dict(torch.load(trained_model_path, map_location=device))
  ae.load_state_dict(torch.load(ae_checkpoint, map_location="cpu"))
Computing AE recon loss: 100%|██████████| 131/131 [00:11<00:00, 11.33it/s]


✅ Generated meta-features with shape: (8348, 1111)
Training meta‐model target 0 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 5.62928
[200]	valid_0's rmse: 5.1935
[300]	valid_0's rmse: 5.10812
[400]	valid_0's rmse: 5.08589
[500]	valid_0's rmse: 5.07978
[600]	valid_0's rmse: 5.07577
[700]	valid_0's rmse: 5.07662
[800]	valid_0's rmse: 5.07461
[900]	valid_0's rmse: 5.07512
Early stopping, best iteration is:
[760]	valid_0's rmse: 5.07203
Training meta‐model target 1 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 2.50659
[200]	valid_0's rmse: 2.28498
[300]	valid_0's rmse: 2.24173
[400]	valid_0's rmse: 2.23013
[500]	valid_0's rmse: 2.22742
[600]	valid_0's rmse: 2.22476
[700]	valid_0's rmse: 2.22518
[800]	valid_0's rmse: 2.22652
Early stopping, best iteration is:
[630]	valid_0's rmse: 2.22392
Training meta‐model target 2 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 3.90722
[200]	valid_0's rmse: 3.50991
[300]	valid_0's rmse: 3.43278
[400]	valid_0's rmse: 3.41424
[500]	valid_0's rmse: 3.40599
[600]	valid_0's rmse: 3.404
[700]	valid_0's rmse: 3.4007
[800]	valid_0's rmse: 3.40066
[900]	valid_0's rmse: 3.40222
[1000]	valid_0's rmse: 3.40177
Early stopping, best iteration is:
[821]	valid_0's rmse: 3.40003
Training meta‐model target 3 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 7.87697
[200]	valid_0's rmse: 7.07054
[300]	valid_0's rmse: 6.93512
[400]	valid_0's rmse: 6.91232
[500]	valid_0's rmse: 6.9071
[600]	valid_0's rmse: 6.89867
[700]	valid_0's rmse: 6.89215
[800]	valid_0's rmse: 6.89046
[900]	valid_0's rmse: 6.88435
[1000]	valid_0's rmse: 6.88218
[1100]	valid_0's rmse: 6.88041
Early stopping, best iteration is:
[965]	valid_0's rmse: 6.87996
Training meta‐model target 4 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 7.56609
[200]	valid_0's rmse: 6.90237
[300]	valid_0's rmse: 6.77326
[400]	valid_0's rmse: 6.73091
[500]	valid_0's rmse: 6.70357
[600]	valid_0's rmse: 6.69291
[700]	valid_0's rmse: 6.68477
[800]	valid_0's rmse: 6.67288
[900]	valid_0's rmse: 6.66559
[1000]	valid_0's rmse: 6.66112
[1100]	valid_0's rmse: 6.65508
[1200]	valid_0's rmse: 6.64884
[1300]	valid_0's rmse: 6.64436
[1400]	valid_0's rmse: 6.64121
[1500]	valid_0's rmse: 6.63647
[1600]	valid_0's rmse: 6.63581
[1700]	valid_0's rmse: 6.63668
Early stopping, best iteration is:
[1595]	valid_0's rmse: 6.6348
Training meta‐model target 5 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 7.51962
[200]	valid_0's rmse: 6.75658
[300]	valid_0's rmse: 6.62833
[400]	valid_0's rmse: 6.60548
[500]	valid_0's rmse: 6.58967
[600]	valid_0's rmse: 6.58548
[700]	valid_0's rmse: 6.57677
[800]	valid_0's rmse: 6.57028
[900]	valid_0's rmse: 6.56499
[1000]	valid_0's rmse: 6.5585
[1100]	valid_0's rmse: 6.55821
[1200]	valid_0's rmse: 6.55353
[1300]	valid_0's rmse: 6.55061
[1400]	valid_0's rmse: 6.54989
[1500]	valid_0's rmse: 6.54647
[1600]	valid_0's rmse: 6.54648
[1700]	valid_0's rmse: 6.5464
[1800]	valid_0's rmse: 6.54553
[1900]	valid_0's rmse: 6.54694
Early stopping, best iteration is:
[1786]	valid_0's rmse: 6.54492
Training meta‐model target 6 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 3.57117
[200]	valid_0's rmse: 3.3929
[300]	valid_0's rmse: 3.36007
[400]	valid_0's rmse: 3.35145
[500]	valid_0's rmse: 3.34754
[600]	valid_0's rmse: 3.34562
[700]	valid_0's rmse: 3.34475
[800]	valid_0's rmse: 3.34474
[900]	valid_0's rmse: 3.34623
Early stopping, best iteration is:
[757]	valid_0's rmse: 3.34358
Training meta‐model target 7 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 5.64936
[200]	valid_0's rmse: 5.60818
[300]	valid_0's rmse: 5.61569
[400]	valid_0's rmse: 5.62079
Early stopping, best iteration is:
[209]	valid_0's rmse: 5.60484
Training meta‐model target 8 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 8.20126
[200]	valid_0's rmse: 7.35111
[300]	valid_0's rmse: 7.1854
[400]	valid_0's rmse: 7.14574
[500]	valid_0's rmse: 7.12458
[600]	valid_0's rmse: 7.10594
[700]	valid_0's rmse: 7.09464
[800]	valid_0's rmse: 7.08607
[900]	valid_0's rmse: 7.07839
[1000]	valid_0's rmse: 7.07182
[1100]	valid_0's rmse: 7.06809
[1200]	valid_0's rmse: 7.06331
[1300]	valid_0's rmse: 7.05882
[1400]	valid_0's rmse: 7.058
[1500]	valid_0's rmse: 7.05293
[1600]	valid_0's rmse: 7.05104
[1700]	valid_0's rmse: 7.04837
[1800]	valid_0's rmse: 7.04528
[1900]	valid_0's rmse: 7.04279
[2000]	valid_0's rmse: 7.04063
[2100]	valid_0's rmse: 7.03853
[2200]	valid_0's rmse: 7.03899
[2300]	valid_0's rmse: 7.03683
[2400]	valid_0's rmse: 7.03666
[2500]	valid_0's rmse: 7.03682
Early stopping, best iteration is:
[2362]	valid_0's rmse: 7.03559
Training meta‐model target 9 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 6.56198
[200]	valid_0's rmse: 5.78207
[300]	valid_0's rmse: 5.63304
[400]	valid_0's rmse: 5.59688
[500]	valid_0's rmse: 5.57503
[600]	valid_0's rmse: 5.55679
[700]	valid_0's rmse: 5.54878
[800]	valid_0's rmse: 5.54156
[900]	valid_0's rmse: 5.53117
[1000]	valid_0's rmse: 5.52251
[1100]	valid_0's rmse: 5.52049
[1200]	valid_0's rmse: 5.51668
[1300]	valid_0's rmse: 5.51751
[1400]	valid_0's rmse: 5.51491
[1500]	valid_0's rmse: 5.51416
[1600]	valid_0's rmse: 5.51417
Early stopping, best iteration is:
[1447]	valid_0's rmse: 5.51286
Training meta‐model target 10 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 7.58335
[200]	valid_0's rmse: 7.02192
[300]	valid_0's rmse: 6.88899
[400]	valid_0's rmse: 6.8441
[500]	valid_0's rmse: 6.82416
[600]	valid_0's rmse: 6.81238
[700]	valid_0's rmse: 6.8078
[800]	valid_0's rmse: 6.79994
[900]	valid_0's rmse: 6.79479
[1000]	valid_0's rmse: 6.78994
[1100]	valid_0's rmse: 6.78281
[1200]	valid_0's rmse: 6.77952
[1300]	valid_0's rmse: 6.7778
[1400]	valid_0's rmse: 6.7774
[1500]	valid_0's rmse: 6.77514
[1600]	valid_0's rmse: 6.7741
Early stopping, best iteration is:
[1440]	valid_0's rmse: 6.77347
Training meta‐model target 11 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 3.96525
[200]	valid_0's rmse: 3.4672
[300]	valid_0's rmse: 3.3631
[400]	valid_0's rmse: 3.33325
[500]	valid_0's rmse: 3.32589
[600]	valid_0's rmse: 3.32125
[700]	valid_0's rmse: 3.31957
[800]	valid_0's rmse: 3.31909
Early stopping, best iteration is:
[656]	valid_0's rmse: 3.31784
Training meta‐model target 12 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 6.11302
[200]	valid_0's rmse: 6.09307
[300]	valid_0's rmse: 6.0933
[400]	valid_0's rmse: 6.10403
Early stopping, best iteration is:
[254]	valid_0's rmse: 6.08878
Training meta‐model target 13 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 5.34751
[200]	valid_0's rmse: 5.13353
[300]	valid_0's rmse: 5.10263
[400]	valid_0's rmse: 5.09773
[500]	valid_0's rmse: 5.10263
[600]	valid_0's rmse: 5.10973
Early stopping, best iteration is:
[425]	valid_0's rmse: 5.09702
Training meta‐model target 14 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 5.9574
[200]	valid_0's rmse: 5.29078
[300]	valid_0's rmse: 5.16759
[400]	valid_0's rmse: 5.13375
[500]	valid_0's rmse: 5.12102
[600]	valid_0's rmse: 5.11819
[700]	valid_0's rmse: 5.11568
[800]	valid_0's rmse: 5.11891
Early stopping, best iteration is:
[680]	valid_0's rmse: 5.11445
Training meta‐model target 15 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 6.11763
[200]	valid_0's rmse: 5.74066
[300]	valid_0's rmse: 5.6546
[400]	valid_0's rmse: 5.62692
[500]	valid_0's rmse: 5.61239
[600]	valid_0's rmse: 5.60035
[700]	valid_0's rmse: 5.59657
[800]	valid_0's rmse: 5.59559
[900]	valid_0's rmse: 5.59167
[1000]	valid_0's rmse: 5.58773
[1100]	valid_0's rmse: 5.58292
[1200]	valid_0's rmse: 5.58243
[1300]	valid_0's rmse: 5.57973
[1400]	valid_0's rmse: 5.57664
[1500]	valid_0's rmse: 5.57663
[1600]	valid_0's rmse: 5.57734
[1700]	valid_0's rmse: 5.57703
Early stopping, best iteration is:
[1503]	valid_0's rmse: 5.57633
Training meta‐model target 16 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 6.82283
[200]	valid_0's rmse: 6.09639
[300]	valid_0's rmse: 5.92526
[400]	valid_0's rmse: 5.88245
[500]	valid_0's rmse: 5.85867
[600]	valid_0's rmse: 5.8489
[700]	valid_0's rmse: 5.84444
[800]	valid_0's rmse: 5.84011
[900]	valid_0's rmse: 5.83652
[1000]	valid_0's rmse: 5.83316
[1100]	valid_0's rmse: 5.83245
[1200]	valid_0's rmse: 5.82957
[1300]	valid_0's rmse: 5.82576
[1400]	valid_0's rmse: 5.82429
[1500]	valid_0's rmse: 5.82441
[1600]	valid_0's rmse: 5.82219
[1700]	valid_0's rmse: 5.82158
[1800]	valid_0's rmse: 5.81956
[1900]	valid_0's rmse: 5.81823
[2000]	valid_0's rmse: 5.819
[2100]	valid_0's rmse: 5.81768
[2200]	valid_0's rmse: 5.81677
[2300]	valid_0's rmse: 5.81795
Early stopping, best iteration is:
[2188]	valid_0's rmse: 5.81633
Training meta‐model target 17 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 3.7024
[200]	valid_0's rmse: 3.27567
[300]	valid_0's rmse: 3.19361
[400]	valid_0's rmse: 3.17335
[500]	valid_0's rmse: 3.16505
[600]	valid_0's rmse: 3.16274
[700]	valid_0's rmse: 3.15935
[800]	valid_0's rmse: 3.15496
[900]	valid_0's rmse: 3.15477
[1000]	valid_0's rmse: 3.15346
[1100]	valid_0's rmse: 3.15278
[1200]	valid_0's rmse: 3.15243
[1300]	valid_0's rmse: 3.15271
[1400]	valid_0's rmse: 3.1514
[1500]	valid_0's rmse: 3.1513
[1600]	valid_0's rmse: 3.15112
[1700]	valid_0's rmse: 3.15106
[1800]	valid_0's rmse: 3.15075
Early stopping, best iteration is:
[1621]	valid_0's rmse: 3.15041
Training meta‐model target 18 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 7.16375
[200]	valid_0's rmse: 6.92752
[300]	valid_0's rmse: 6.89158
[400]	valid_0's rmse: 6.88595
[500]	valid_0's rmse: 6.88398
[600]	valid_0's rmse: 6.88651
Early stopping, best iteration is:
[464]	valid_0's rmse: 6.88202
Training meta‐model target 19 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 4.90648
[200]	valid_0's rmse: 4.72389
[300]	valid_0's rmse: 4.68299
[400]	valid_0's rmse: 4.67237
[500]	valid_0's rmse: 4.66743
[600]	valid_0's rmse: 4.66336
[700]	valid_0's rmse: 4.66148
[800]	valid_0's rmse: 4.65852
[900]	valid_0's rmse: 4.65698
[1000]	valid_0's rmse: 4.65393
[1100]	valid_0's rmse: 4.65492
Early stopping, best iteration is:
[981]	valid_0's rmse: 4.65265
Training meta‐model target 20 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 6.40041
[200]	valid_0's rmse: 6.04834
[300]	valid_0's rmse: 5.98503
[400]	valid_0's rmse: 5.97331
[500]	valid_0's rmse: 5.97575
[600]	valid_0's rmse: 5.97561
Early stopping, best iteration is:
[405]	valid_0's rmse: 5.97264
Training meta‐model target 21 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 7.6319
[200]	valid_0's rmse: 7.37032
[300]	valid_0's rmse: 7.33343
[400]	valid_0's rmse: 7.32765
[500]	valid_0's rmse: 7.32355
[600]	valid_0's rmse: 7.31929
[700]	valid_0's rmse: 7.3156
[800]	valid_0's rmse: 7.31706
[900]	valid_0's rmse: 7.32337
Early stopping, best iteration is:
[773]	valid_0's rmse: 7.31421
Training meta‐model target 22 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 2.76044
[200]	valid_0's rmse: 2.60316
[300]	valid_0's rmse: 2.56989
[400]	valid_0's rmse: 2.56086
[500]	valid_0's rmse: 2.56004
[600]	valid_0's rmse: 2.55988
[700]	valid_0's rmse: 2.56028
[800]	valid_0's rmse: 2.55849
[900]	valid_0's rmse: 2.55613
[1000]	valid_0's rmse: 2.55346
[1100]	valid_0's rmse: 2.55396
[1200]	valid_0's rmse: 2.55404
Early stopping, best iteration is:
[1016]	valid_0's rmse: 2.55295
Training meta‐model target 23 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 6.10941
[200]	valid_0's rmse: 5.64126
[300]	valid_0's rmse: 5.54626
[400]	valid_0's rmse: 5.51938
[500]	valid_0's rmse: 5.51005
[600]	valid_0's rmse: 5.50438
[700]	valid_0's rmse: 5.50581
[800]	valid_0's rmse: 5.50496
Early stopping, best iteration is:
[638]	valid_0's rmse: 5.50296
Training meta‐model target 24 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 7.25374
[200]	valid_0's rmse: 6.68999
[300]	valid_0's rmse: 6.57134
[400]	valid_0's rmse: 6.52698
[500]	valid_0's rmse: 6.51334
[600]	valid_0's rmse: 6.50556
[700]	valid_0's rmse: 6.49598
[800]	valid_0's rmse: 6.493
[900]	valid_0's rmse: 6.49211
[1000]	valid_0's rmse: 6.48864
[1100]	valid_0's rmse: 6.49235
[1200]	valid_0's rmse: 6.4929
Early stopping, best iteration is:
[1006]	valid_0's rmse: 6.48762
Training meta‐model target 25 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 6.81848
[200]	valid_0's rmse: 6.58717
[300]	valid_0's rmse: 6.53573
[400]	valid_0's rmse: 6.52326
[500]	valid_0's rmse: 6.52199
[600]	valid_0's rmse: 6.52508
Early stopping, best iteration is:
[447]	valid_0's rmse: 6.52065
Training meta‐model target 26 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 7.58908
[200]	valid_0's rmse: 6.64477
[300]	valid_0's rmse: 6.45336
[400]	valid_0's rmse: 6.3978
[500]	valid_0's rmse: 6.37467
[600]	valid_0's rmse: 6.36404
[700]	valid_0's rmse: 6.35738
[800]	valid_0's rmse: 6.3543
[900]	valid_0's rmse: 6.3491
[1000]	valid_0's rmse: 6.34666
[1100]	valid_0's rmse: 6.34176
[1200]	valid_0's rmse: 6.34079
[1300]	valid_0's rmse: 6.34091
Early stopping, best iteration is:
[1165]	valid_0's rmse: 6.33994
Training meta‐model target 27 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 5.87072
[200]	valid_0's rmse: 5.63419
[300]	valid_0's rmse: 5.58598
[400]	valid_0's rmse: 5.57095
[500]	valid_0's rmse: 5.57058
[600]	valid_0's rmse: 5.57103
Early stopping, best iteration is:
[466]	valid_0's rmse: 5.56749
Training meta‐model target 28 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 6.37755
[200]	valid_0's rmse: 6.02791
[300]	valid_0's rmse: 5.96983
[400]	valid_0's rmse: 5.95933
[500]	valid_0's rmse: 5.95411
[600]	valid_0's rmse: 5.95253
[700]	valid_0's rmse: 5.95096
[800]	valid_0's rmse: 5.95387
Early stopping, best iteration is:
[679]	valid_0's rmse: 5.94813
Training meta‐model target 29 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 6.8666
[200]	valid_0's rmse: 6.13368
[300]	valid_0's rmse: 6.03068
[400]	valid_0's rmse: 6.01904
[500]	valid_0's rmse: 6.0102
[600]	valid_0's rmse: 6.00933
[700]	valid_0's rmse: 6.01089
[800]	valid_0's rmse: 6.00678
[900]	valid_0's rmse: 6.00733
[1000]	valid_0's rmse: 6.0173
Early stopping, best iteration is:
[847]	valid_0's rmse: 6.00548
Training meta‐model target 30 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 2.72523
[200]	valid_0's rmse: 2.43192
[300]	valid_0's rmse: 2.37855
[400]	valid_0's rmse: 2.36568
[500]	valid_0's rmse: 2.36227
[600]	valid_0's rmse: 2.35998
[700]	valid_0's rmse: 2.36082
Early stopping, best iteration is:
[595]	valid_0's rmse: 2.35965
Training meta‐model target 31 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 2.35656
[200]	valid_0's rmse: 2.29944
[300]	valid_0's rmse: 2.28869
[400]	valid_0's rmse: 2.2836
[500]	valid_0's rmse: 2.28257
[600]	valid_0's rmse: 2.28174
[700]	valid_0's rmse: 2.28201
Early stopping, best iteration is:
[576]	valid_0's rmse: 2.28121
Training meta‐model target 32 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 7.34933
[200]	valid_0's rmse: 7.27005
[300]	valid_0's rmse: 7.26479
[400]	valid_0's rmse: 7.26584
[500]	valid_0's rmse: 7.26725
Early stopping, best iteration is:
[324]	valid_0's rmse: 7.26016
Training meta‐model target 33 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 5.5962
[200]	valid_0's rmse: 5.38504
[300]	valid_0's rmse: 5.34936
[400]	valid_0's rmse: 5.34352
[500]	valid_0's rmse: 5.34348
[600]	valid_0's rmse: 5.34841
Early stopping, best iteration is:
[440]	valid_0's rmse: 5.34229
Training meta‐model target 34 …




Training until validation scores don't improve for 200 rounds
[100]	valid_0's rmse: 2.88309
[200]	valid_0's rmse: 2.7987
[300]	valid_0's rmse: 2.77899
[400]	valid_0's rmse: 2.77554
[500]	valid_0's rmse: 2.77384
[600]	valid_0's rmse: 2.77268
[700]	valid_0's rmse: 2.77378
[800]	valid_0's rmse: 2.77352
[900]	valid_0's rmse: 2.77444
Early stopping, best iteration is:
[777]	valid_0's rmse: 2.77262
✅ Saved entire‐dataset meta‐model → meta_model_single.pkl


In [None]:
# 6) Save the single meta‐model
joblib.dump(meta_model, 'meta_model_single.pkl')
print("✅ Saved entire‐dataset meta‐model → meta_model_single.pkl")


✅ Saved entire‐dataset meta‐model → meta_model_single.pkl


In [8]:
import torch
from python_scripts.import_data import load_node_feature_data


image_keys = [ 'tile', 'subtiles']

model = VisionMLP_MultiTask(tile_dim=tile_dim, subtile_dim=center_dim, output_dim=C)

# 用法示例
from python_scripts.import_data import importDataset
# 假设你的 model 已经定义好并实例化为 `model`
test_dataset = load_node_feature_data("dataset/spot-rank/filtered_directly_rank/masked/test/Macenko/test_dataset.pt", model)
test_dataset = importDataset(
        data_dict=test_dataset,
        model=model,
        image_keys=image_keys,
        transform=lambda x: x,  # identity transform
        print_sig=True
    )



  raw = torch.load(pt_path, map_location="cpu")


⚠️ 從 '<class 'list'>' 推斷樣本數量: 2088
Model forward signature: (tile, subtiles)


In [None]:
import joblib

# 1) 对 test_dataset 用你的第 1 阶段 model 一次性算出 oof_preds 和 latents
net = VisionMLP_MultiTask(tile_dim=tile_dim, subtile_dim=center_dim, output_dim=C)
net.load_state_dict(torch.load(trained_model_path, map_location=device))
net = net.to(device).eval()

test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
test_preds, test_latents = [], []
with torch.no_grad():
    for batch in test_loader:
        tiles, subtiles = batch['tile'].to(device), batch['subtiles'].to(device)
        center = subtiles[:, 4].contiguous()
        f_c = net.encoder_center(center)
        f_n = net.encoder_subtile(subtiles)
        f_t = net.encoder_tile(tiles)
        features_cat = torch.cat([f_c, f_n, f_t], dim=1)
        gates = net.gate_fc(features_cat)        # (B, 3), softmax over features
        f_fused = (
            gates[:, 0:1] * f_t +
            gates[:, 1:2] * f_n +
            gates[:, 2:3] * f_c
        )  # (B, tile_dim) — 注意各 encoder 輸出維度要一致
        out  = net.decoder(f_fused)
        test_preds.append(out.cpu())
        test_latents.append(f_fused.cpu())
        
    test_preds = torch.cat(test_preds, dim=0).numpy()
    test_latents = torch.cat(test_latents, dim=0).numpy()
 
# 2) AE Recon Model 不变，算测试的 recon_loss & 必要特征
recon_model = PretrainedEncoderRegressor(
    ae_checkpoint=pretrained_ae_path,
    ae_type=ae_type,
    tile_dim=tile_dim,
    center_dim=center_dim,
    neighbor_dim=neighbor_dim,
    output_dim=C,
    mode='reconstruction'
).to(device)

meta_test = generate_meta_features(
    dataset         = test_dataset,
    oof_preds       = test_preds,
    image_latents   = test_latents,
    model_for_recon = recon_model,
    device          = device,
    ae_type         = ae_type,
    use_clusters    = "both"
)

# 3) 直接载入并用 single‐fold 的 meta_model 预测
meta_model = joblib.load(f'{save_folder}meta_model_single.pkl')
final_preds = meta_model.predict(meta_test)

# 4) 写出 submission
import h5py, pandas as pd
with h5py.File("./dataset/elucidata_ai_challenge_data.h5","r") as f:
    test_spot_ids = pd.DataFrame(np.array(f["spots/Test"]["S_7"]))
sub = pd.DataFrame(final_preds, columns=[f"C{i+1}" for i in range(C)])
sub.insert(0, 'ID', test_spot_ids.index)
sub.to_csv(f'{save_folder}submission_stacked.csv', index=False)
print("✅ Saved submission_stacked.csv")


  net.load_state_dict(torch.load(trained_model_path, map_location=device))


In [29]:
import os
import numpy as np
import joblib
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.multioutput import MultiOutputRegressor
from sklearn.model_selection import train_test_split
import lightgbm as lgb
from scipy.stats import rankdata
from python_scripts.import_data import importDataset
from python_scripts.operate_model import predict
from lightgbm import early_stopping, log_evaluation
import h5py
import pandas as pd
from python_scripts.pretrain_model import PretrainedEncoderRegressor
# ---------------- Settings ----------------
trained_oof_model_folder = 'output_folder/rank-spot/realign/no_pretrain/3_encoder/filtered_directly_rank/k-fold/realign_all/Macenko_masked/'
n_folds    = len([d for d in os.listdir(trained_oof_model_folder) if d.startswith('fold')])
n_samples  = len(full_dataset)
C          = 35
BATCH_SIZE = 64
start_fold = 0

tile_dim = 128
center_dim = 128
neighbor_dim = 128
fusion_dim = tile_dim + center_dim + neighbor_dim

pretrained_ae_name = 'AE_Center_noaug'
pretrained_ae_path = f"AE_model/128/{pretrained_ae_name}/best.pt"
ae_type = 'center'

# Ground truth label (全 dataset)
y_true = np.vstack([ full_dataset[i]['label'].cpu().numpy() for i in range(n_samples) ])

# Build CV splitter (must match first stage splits)
logo = LeaveOneGroupOut()
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

lgb_base = lgb.LGBMRegressor(
    objective='l2',
    metric='rmse',
    learning_rate=0.007522970004049377,
    n_estimators=12000,
    max_depth=11,
    num_leaves=20,
    colsample_bytree=0.7619407413363416,
    subsample=0.8,
    subsample_freq=1,
    min_data_in_leaf=20,
    reg_alpha=0.7480401395491829,
    reg_lambda=0.2589860348178542,
    verbosity=-1
    )

slide_idx = np.array(grouped_data['slide_idx'])   # shape (N,)


for fold_id, (tr_idx, va_idx) in enumerate(
    logo.split(X=np.zeros(n_samples), y=None, groups=slide_idx)):

    if fold_id > start_fold:
        print(f"⏭️ Skipping fold {fold_id}")
        continue

    print(f"\n🚀 Starting fold {fold_id}...")
    ckpt_path = os.path.join(trained_oof_model_folder, f"fold{fold_id}", "best_model.pt")

    # === Load model and predict OOF ===
    net = VisionMLP_MultiTask(tile_dim=tile_dim, subtile_dim=center_dim, output_dim=C)
    net.load_state_dict(torch.load(ckpt_path, map_location=device))
    net = net.to(device).eval()

    val_ds = Subset(full_dataset, va_idx)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

    preds, latents = [], []
    with torch.no_grad():
        for batch in val_loader:
            tiles = batch['tile'].to(device)
            subtiles = batch['subtiles'].to(device)
            center = subtiles[:, 4].contiguous()

            f_c = net.encoder_center(center)
            f_n = net.encoder_subtile(subtiles)
            f_t = net.encoder_tile(tiles)
            fuse = torch.cat([f_c, f_n, f_t], dim=1).contiguous()
            output = net.decoder(fuse)

            preds.append(output.cpu())
            latents.append(fuse.cpu())

    preds = torch.cat(preds, dim=0).numpy()
    latents = torch.cat(latents, dim=0).numpy()

    # === AE model reconstruction loss ===
    recon_model = PretrainedEncoderRegressor(
        ae_checkpoint=pretrained_ae_path,
        ae_type=ae_type,
        tile_dim=tile_dim,
        center_dim=center_dim,
        neighbor_dim=neighbor_dim,
        output_dim=C,
        mode='reconstruction'
    ).to(device)

    meta = generate_meta_features(
        dataset = val_ds,
        oof_preds = preds,
        image_latents = latents,
        model_for_recon = recon_model,
        device = device,
        ae_type = ae_type,
        use_clusters="4"
    )
    
    y_val = y_true[va_idx]   # shape = (len(va_idx), 35)

    # 2) 再對這個 fold 的 meta 做 train/val 切分
    X_train_tab, X_val_tab, y_train_tab, y_val_tab = train_test_split(
        meta, y_val, test_size=0.2, random_state=42
    )

    # 3) train MultiOutputRegressor with early stopping
    meta_model = MultiOutputRegressor(lgb_base)
    meta_model.estimators_ = []

    for i in range(y_train_tab.shape[1]):
        print(f"[fold {fold_id}] training target {i} on meta features …")
        model = lgb.LGBMRegressor(**lgb_base.get_params())
        model.fit(
            X_train_tab, y_train_tab[:, i],
            eval_set=[(X_val_tab, y_val_tab[:, i])],
            callbacks=[
                early_stopping(stopping_rounds=200),
                log_evaluation(period=100)
            ]
        )
        meta_model.estimators_.append(model)

    # 4) 存下這個 fold 的 meta model
    save_path = os.path.join(trained_oof_model_folder, f"meta_model_fold{fold_id}.pkl")
    joblib.dump(meta_model, save_path)
    print(f"✅ Saved fold {fold_id} meta‐model → {save_path}")



🚀 Starting fold 0...


  net.load_state_dict(torch.load(ckpt_path, map_location=device))
  ae.load_state_dict(torch.load(ae_checkpoint, map_location="cpu"))
Computing AE recon loss: 100%|██████████| 35/35 [00:02<00:00, 11.82it/s]


Silhouette score for k=2: 0.5500
Silhouette score for k=3: 0.5892
Silhouette score for k=4: 0.5201
Silhouette score for k=5: 0.5311
Silhouette score for k=6: 0.5067
Selected best k=3 (score=0.5892)
✅ Generated meta-features with shape: (2197, 227)
[fold 0] training target 0 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 6.4679
[200]	valid_0's rmse: 6.27986
[300]	valid_0's rmse: 6.21137
[400]	valid_0's rmse: 6.17162
[500]	valid_0's rmse: 6.15372
[600]	valid_0's rmse: 6.15222
[700]	valid_0's rmse: 6.16033
[800]	valid_0's rmse: 6.16606
Early stopping, best iteration is:
[611]	valid_0's rmse: 6.14999
[fold 0] training target 1 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 2.72637
[200]	valid_0's rmse: 2.38419
[300]	valid_0's rmse: 2.27991
[400]	valid_0's rmse: 2.25302
[500]	valid_0's rmse: 2.24256
[600]	valid_0's rmse: 2.23851
[700]	valid_0's rmse: 2.23599
[800]	valid_0's rmse: 2.23449
[900]	valid_0's rmse: 2.23727
Early stopping, best iteration is:
[791]	valid_0's rmse: 2.23409
[fold 0] training target 2 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 4.4783
[200]	valid_0's rmse: 4.33658
[300]	valid_0's rmse: 4.27057
[400]	valid_0's rmse: 4.24216
[500]	valid_0's rmse: 4.23617
[600]	valid_0's rmse: 4.22769
[700]	valid_0's rmse: 4.22185
[800]	valid_0's rmse: 4.22169
[900]	valid_0's rmse: 4.21721
[1000]	valid_0's rmse: 4.21561
Early stopping, best iteration is:
[861]	valid_0's rmse: 4.21528
[fold 0] training target 3 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 6.86958
[200]	valid_0's rmse: 6.78088
[300]	valid_0's rmse: 6.75143
[400]	valid_0's rmse: 6.74663
[500]	valid_0's rmse: 6.76621
Early stopping, best iteration is:
[349]	valid_0's rmse: 6.7417
[fold 0] training target 4 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 7.25344
[200]	valid_0's rmse: 6.22301
[300]	valid_0's rmse: 5.82344
[400]	valid_0's rmse: 5.66066
[500]	valid_0's rmse: 5.56592
[600]	valid_0's rmse: 5.52801
[700]	valid_0's rmse: 5.50798
[800]	valid_0's rmse: 5.49388
[900]	valid_0's rmse: 5.49443
[1000]	valid_0's rmse: 5.48707
[1100]	valid_0's rmse: 5.48587
[1200]	valid_0's rmse: 5.48644
[1300]	valid_0's rmse: 5.48098
[1400]	valid_0's rmse: 5.48116
[1500]	valid_0's rmse: 5.47934
[1600]	valid_0's rmse: 5.47653
[1700]	valid_0's rmse: 5.47305
[1800]	valid_0's rmse: 5.47248
[1900]	valid_0's rmse: 5.47189
[2000]	valid_0's rmse: 5.47113
[2100]	valid_0's rmse: 5.4709
[2200]	valid_0's rmse: 5.47206
Early stopping, best iteration is:
[2012]	valid_0's rmse: 5.47008
[fold 0] training target 5 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 8.1606
[200]	valid_0's rmse: 6.98825
[300]	valid_0's rmse: 6.55376
[400]	valid_0's rmse: 6.38535
[500]	valid_0's rmse: 6.29584
[600]	valid_0's rmse: 6.23991
[700]	valid_0's rmse: 6.20386
[800]	valid_0's rmse: 6.17691
[900]	valid_0's rmse: 6.15441
[1000]	valid_0's rmse: 6.1323
[1100]	valid_0's rmse: 6.11909
[1200]	valid_0's rmse: 6.11195
[1300]	valid_0's rmse: 6.10149
[1400]	valid_0's rmse: 6.09578
[1500]	valid_0's rmse: 6.0899
[1600]	valid_0's rmse: 6.08739
[1700]	valid_0's rmse: 6.08184
[1800]	valid_0's rmse: 6.07722
[1900]	valid_0's rmse: 6.07481
[2000]	valid_0's rmse: 6.07287
[2100]	valid_0's rmse: 6.07336
[2200]	valid_0's rmse: 6.06998
[2300]	valid_0's rmse: 6.06779
[2400]	valid_0's rmse: 6.06676
[2500]	valid_0's rmse: 6.06568
[2600]	valid_0's rmse: 6.06579
[2700]	valid_0's rmse: 6.06543
Early stopping, best iteration is:
[2548]	valid_0's rmse: 6.06465
[fold 0] training target 6 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 4.44826
[200]	valid_0's rmse: 4.20094
[300]	valid_0's rmse: 4.13247
[400]	valid_0's rmse: 4.10873
[500]	valid_0's rmse: 4.09578
[600]	valid_0's rmse: 4.08713
[700]	valid_0's rmse: 4.08574
[800]	valid_0's rmse: 4.08396
Early stopping, best iteration is:
[647]	valid_0's rmse: 4.08213
[fold 0] training target 7 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 5.38641
[200]	valid_0's rmse: 5.33827
[300]	valid_0's rmse: 5.32495
[400]	valid_0's rmse: 5.32559
[500]	valid_0's rmse: 5.34017
Early stopping, best iteration is:
[369]	valid_0's rmse: 5.32197
[fold 0] training target 8 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 7.68661
[200]	valid_0's rmse: 6.73552
[300]	valid_0's rmse: 6.4189
[400]	valid_0's rmse: 6.2864
[500]	valid_0's rmse: 6.21924
[600]	valid_0's rmse: 6.19431
[700]	valid_0's rmse: 6.19219
[800]	valid_0's rmse: 6.18215
[900]	valid_0's rmse: 6.17764
[1000]	valid_0's rmse: 6.17513
[1100]	valid_0's rmse: 6.17667
[1200]	valid_0's rmse: 6.17426
[1300]	valid_0's rmse: 6.16828
[1400]	valid_0's rmse: 6.16791
[1500]	valid_0's rmse: 6.1657
[1600]	valid_0's rmse: 6.16385
[1700]	valid_0's rmse: 6.1618
[1800]	valid_0's rmse: 6.16111
[1900]	valid_0's rmse: 6.16288
[2000]	valid_0's rmse: 6.16266
Early stopping, best iteration is:
[1808]	valid_0's rmse: 6.16098
[fold 0] training target 9 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 9.10774
[200]	valid_0's rmse: 8.39064
[300]	valid_0's rmse: 8.15342
[400]	valid_0's rmse: 8.03477
[500]	valid_0's rmse: 7.97941
[600]	valid_0's rmse: 7.95566
[700]	valid_0's rmse: 7.96163
[800]	valid_0's rmse: 7.95926
Early stopping, best iteration is:
[609]	valid_0's rmse: 7.95275
[fold 0] training target 10 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 8.50651
[200]	valid_0's rmse: 8.13523
[300]	valid_0's rmse: 7.99524
[400]	valid_0's rmse: 7.91231
[500]	valid_0's rmse: 7.87991
[600]	valid_0's rmse: 7.87377
[700]	valid_0's rmse: 7.86957
[800]	valid_0's rmse: 7.86932
Early stopping, best iteration is:
[644]	valid_0's rmse: 7.85756
[fold 0] training target 11 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 5.03787
[200]	valid_0's rmse: 4.45928
[300]	valid_0's rmse: 4.25658
[400]	valid_0's rmse: 4.18942
[500]	valid_0's rmse: 4.17441
[600]	valid_0's rmse: 4.16913
[700]	valid_0's rmse: 4.16923
[800]	valid_0's rmse: 4.1712
Early stopping, best iteration is:
[644]	valid_0's rmse: 4.16591
[fold 0] training target 12 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 5.89781
[200]	valid_0's rmse: 5.86192
[300]	valid_0's rmse: 5.86296
Early stopping, best iteration is:
[185]	valid_0's rmse: 5.85755
[fold 0] training target 13 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 5.79806
[200]	valid_0's rmse: 5.73127
[300]	valid_0's rmse: 5.71865
[400]	valid_0's rmse: 5.72507
Early stopping, best iteration is:
[285]	valid_0's rmse: 5.71689
[fold 0] training target 14 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 5.53081
[200]	valid_0's rmse: 5.11301
[300]	valid_0's rmse: 4.93668
[400]	valid_0's rmse: 4.86542
[500]	valid_0's rmse: 4.82702
[600]	valid_0's rmse: 4.80333
[700]	valid_0's rmse: 4.78666
[800]	valid_0's rmse: 4.76851
[900]	valid_0's rmse: 4.75815
[1000]	valid_0's rmse: 4.74286
[1100]	valid_0's rmse: 4.73694
[1200]	valid_0's rmse: 4.73419
[1300]	valid_0's rmse: 4.72631
[1400]	valid_0's rmse: 4.72228
[1500]	valid_0's rmse: 4.71604
[1600]	valid_0's rmse: 4.71245
[1700]	valid_0's rmse: 4.71046
[1800]	valid_0's rmse: 4.71024
[1900]	valid_0's rmse: 4.70858
[2000]	valid_0's rmse: 4.70667
[2100]	valid_0's rmse: 4.70596
[2200]	valid_0's rmse: 4.70571
[2300]	valid_0's rmse: 4.70377
[2400]	valid_0's rmse: 4.70386
[2500]	valid_0's rmse: 4.70327
[2600]	valid_0's rmse: 4.7019
[2700]	valid_0's rmse: 4.70195
Early stopping, best iteration is:
[2568]	valid_0's rmse: 4.70123
[fold 0] training target 15 on meta features …
Training until validation scores don't improve for 200 round



[100]	valid_0's rmse: 2.60967
[200]	valid_0's rmse: 2.48952
[300]	valid_0's rmse: 2.43757
[400]	valid_0's rmse: 2.41192
[500]	valid_0's rmse: 2.40363
[600]	valid_0's rmse: 2.40248
[700]	valid_0's rmse: 2.40159
Early stopping, best iteration is:
[547]	valid_0's rmse: 2.40068
[fold 0] training target 16 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 7.69544
[200]	valid_0's rmse: 6.97326
[300]	valid_0's rmse: 6.71563
[400]	valid_0's rmse: 6.62898
[500]	valid_0's rmse: 6.59072
[600]	valid_0's rmse: 6.57381
[700]	valid_0's rmse: 6.55242
[800]	valid_0's rmse: 6.53982
[900]	valid_0's rmse: 6.53103
[1000]	valid_0's rmse: 6.52339
[1100]	valid_0's rmse: 6.51429
[1200]	valid_0's rmse: 6.50468
[1300]	valid_0's rmse: 6.49661
[1400]	valid_0's rmse: 6.48962
[1500]	valid_0's rmse: 6.48764
[1600]	valid_0's rmse: 6.48545
[1700]	valid_0's rmse: 6.48265
[1800]	valid_0's rmse: 6.48525
[1900]	valid_0's rmse: 6.48441
Early stopping, best iteration is:
[1704]	valid_0's rmse: 6.4826
[fold 0] training target 17 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 3.79048
[200]	valid_0's rmse: 3.77163
[300]	valid_0's rmse: 3.76181
[400]	valid_0's rmse: 3.75824
[500]	valid_0's rmse: 3.75725
[600]	valid_0's rmse: 3.76351
[700]	valid_0's rmse: 3.76593
Early stopping, best iteration is:
[504]	valid_0's rmse: 3.75645
[fold 0] training target 18 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 8.00557
[200]	valid_0's rmse: 7.70243
[300]	valid_0's rmse: 7.62185
[400]	valid_0's rmse: 7.58112
[500]	valid_0's rmse: 7.56092
[600]	valid_0's rmse: 7.56403
[700]	valid_0's rmse: 7.5734
Early stopping, best iteration is:
[508]	valid_0's rmse: 7.55849
[fold 0] training target 19 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 5.4269
[200]	valid_0's rmse: 5.11919
[300]	valid_0's rmse: 5.01449
[400]	valid_0's rmse: 4.96557
[500]	valid_0's rmse: 4.94925
[600]	valid_0's rmse: 4.9414
[700]	valid_0's rmse: 4.93399
[800]	valid_0's rmse: 4.93587
Early stopping, best iteration is:
[676]	valid_0's rmse: 4.93226
[fold 0] training target 20 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 7.32446
[200]	valid_0's rmse: 6.73073
[300]	valid_0's rmse: 6.53688
[400]	valid_0's rmse: 6.4626
[500]	valid_0's rmse: 6.42963
[600]	valid_0's rmse: 6.41676
[700]	valid_0's rmse: 6.41349
[800]	valid_0's rmse: 6.41847
[900]	valid_0's rmse: 6.41254
[1000]	valid_0's rmse: 6.42003
Early stopping, best iteration is:
[867]	valid_0's rmse: 6.409
[fold 0] training target 21 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 8.14351
[200]	valid_0's rmse: 7.70604
[300]	valid_0's rmse: 7.556
[400]	valid_0's rmse: 7.49216
[500]	valid_0's rmse: 7.47353
[600]	valid_0's rmse: 7.47709
[700]	valid_0's rmse: 7.47483
[800]	valid_0's rmse: 7.47358
[900]	valid_0's rmse: 7.46984
[1000]	valid_0's rmse: 7.46346
[1100]	valid_0's rmse: 7.47819
Early stopping, best iteration is:
[989]	valid_0's rmse: 7.46194
[fold 0] training target 22 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 3.3518
[200]	valid_0's rmse: 3.08804
[300]	valid_0's rmse: 3.02736
[400]	valid_0's rmse: 3.01887
[500]	valid_0's rmse: 3.01986
[600]	valid_0's rmse: 3.02052
Early stopping, best iteration is:
[429]	valid_0's rmse: 3.01607
[fold 0] training target 23 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 6.73935
[200]	valid_0's rmse: 6.63321
[300]	valid_0's rmse: 6.59639
[400]	valid_0's rmse: 6.60402
Early stopping, best iteration is:
[296]	valid_0's rmse: 6.59478
[fold 0] training target 24 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 7.77596
[200]	valid_0's rmse: 7.62494
[300]	valid_0's rmse: 7.56398
[400]	valid_0's rmse: 7.53075
[500]	valid_0's rmse: 7.52936
[600]	valid_0's rmse: 7.53013
[700]	valid_0's rmse: 7.53851
Early stopping, best iteration is:
[576]	valid_0's rmse: 7.52369
[fold 0] training target 25 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 7.12509
[200]	valid_0's rmse: 6.91167
[300]	valid_0's rmse: 6.83581
[400]	valid_0's rmse: 6.79851
[500]	valid_0's rmse: 6.77443
[600]	valid_0's rmse: 6.75406
[700]	valid_0's rmse: 6.7496
[800]	valid_0's rmse: 6.75137
Early stopping, best iteration is:
[669]	valid_0's rmse: 6.74552
[fold 0] training target 26 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 8.02291
[200]	valid_0's rmse: 7.01319
[300]	valid_0's rmse: 6.66819
[400]	valid_0's rmse: 6.53019
[500]	valid_0's rmse: 6.47541
[600]	valid_0's rmse: 6.43666
[700]	valid_0's rmse: 6.42391
[800]	valid_0's rmse: 6.41324
[900]	valid_0's rmse: 6.40975
[1000]	valid_0's rmse: 6.39891
[1100]	valid_0's rmse: 6.38682
[1200]	valid_0's rmse: 6.37889
[1300]	valid_0's rmse: 6.37288
[1400]	valid_0's rmse: 6.36601
[1500]	valid_0's rmse: 6.35924
[1600]	valid_0's rmse: 6.35168
[1700]	valid_0's rmse: 6.34631
[1800]	valid_0's rmse: 6.34257
[1900]	valid_0's rmse: 6.34382
[2000]	valid_0's rmse: 6.34075
[2100]	valid_0's rmse: 6.33723
[2200]	valid_0's rmse: 6.33399
[2300]	valid_0's rmse: 6.33448
[2400]	valid_0's rmse: 6.3328
[2500]	valid_0's rmse: 6.33286
[2600]	valid_0's rmse: 6.32975
[2700]	valid_0's rmse: 6.3293
[2800]	valid_0's rmse: 6.32963
[2900]	valid_0's rmse: 6.32925
[3000]	valid_0's rmse: 6.32687
[3100]	valid_0's rmse: 6.32735
[3200]	valid_0's rmse: 6.32613
[3300]	valid_0's rm



[100]	valid_0's rmse: 6.7432
[200]	valid_0's rmse: 6.15762
[300]	valid_0's rmse: 6.00819
[400]	valid_0's rmse: 5.95888
[500]	valid_0's rmse: 5.94522
[600]	valid_0's rmse: 5.94734
Early stopping, best iteration is:
[478]	valid_0's rmse: 5.94123
[fold 0] training target 28 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 7.31572
[200]	valid_0's rmse: 7.0511
[300]	valid_0's rmse: 6.93549
[400]	valid_0's rmse: 6.89763
[500]	valid_0's rmse: 6.88269
[600]	valid_0's rmse: 6.8758
[700]	valid_0's rmse: 6.8694
[800]	valid_0's rmse: 6.86029
[900]	valid_0's rmse: 6.86543
[1000]	valid_0's rmse: 6.86373
Early stopping, best iteration is:
[821]	valid_0's rmse: 6.8557
[fold 0] training target 29 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 4.93819
[200]	valid_0's rmse: 4.87235
[300]	valid_0's rmse: 4.8532
[400]	valid_0's rmse: 4.85126
[500]	valid_0's rmse: 4.84888
[600]	valid_0's rmse: 4.85716
Early stopping, best iteration is:
[446]	valid_0's rmse: 4.84632
[fold 0] training target 30 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 3.7234
[200]	valid_0's rmse: 3.29966
[300]	valid_0's rmse: 3.18448
[400]	valid_0's rmse: 3.14423
[500]	valid_0's rmse: 3.12999
[600]	valid_0's rmse: 3.12275
[700]	valid_0's rmse: 3.11733
[800]	valid_0's rmse: 3.11938
Early stopping, best iteration is:
[696]	valid_0's rmse: 3.11702
[fold 0] training target 31 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 2.69433
[200]	valid_0's rmse: 2.65308
[300]	valid_0's rmse: 2.65267
[400]	valid_0's rmse: 2.65751
Early stopping, best iteration is:
[248]	valid_0's rmse: 2.64932
[fold 0] training target 32 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 7.72682
[200]	valid_0's rmse: 7.67412
[300]	valid_0's rmse: 7.64429
[400]	valid_0's rmse: 7.64086
[500]	valid_0's rmse: 7.64688
[600]	valid_0's rmse: 7.64892
[700]	valid_0's rmse: 7.6519
Early stopping, best iteration is:
[550]	valid_0's rmse: 7.63966
[fold 0] training target 33 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 5.69947
[200]	valid_0's rmse: 5.64388
[300]	valid_0's rmse: 5.63622
[400]	valid_0's rmse: 5.63685
[500]	valid_0's rmse: 5.64229
Early stopping, best iteration is:
[358]	valid_0's rmse: 5.6298
[fold 0] training target 34 on meta features …
Training until validation scores don't improve for 200 rounds




[100]	valid_0's rmse: 3.34226
[200]	valid_0's rmse: 3.22946
[300]	valid_0's rmse: 3.18547
[400]	valid_0's rmse: 3.16129
[500]	valid_0's rmse: 3.15555
[600]	valid_0's rmse: 3.14992
[700]	valid_0's rmse: 3.14998
[800]	valid_0's rmse: 3.14933
[900]	valid_0's rmse: 3.14878
Early stopping, best iteration is:
[741]	valid_0's rmse: 3.14737
✅ Saved fold 0 meta‐model → output_folder/rank-spot/realign/no_pretrain/3_encoder/filtered_directly_rank/k-fold/realign_all/Macenko_masked/meta_model_fold0.pkl
⏭️ Skipping fold 1
⏭️ Skipping fold 2
⏭️ Skipping fold 3
⏭️ Skipping fold 4
⏭️ Skipping fold 5


In [8]:
import torch
from python_scripts.import_data import load_node_feature_data


image_keys = [ 'tile', 'subtiles']

model = VisionMLP_MultiTask(tile_dim=tile_dim, subtile_dim=center_dim, output_dim=C)

# 用法示例
from python_scripts.import_data import importDataset
# 假设你的 model 已经定义好并实例化为 `model`
test_dataset = load_node_feature_data("dataset/spot-rank/filtered_directly_rank/masked/test/Macenko/test_dataset.pt", model)
test_dataset = importDataset(
        data_dict=test_dataset,
        model=model,
        image_keys=image_keys,
        transform=lambda x: x,  # identity transform
        print_sig=True
    )



  raw = torch.load(pt_path, map_location="cpu")


⚠️ 從 '<class 'list'>' 推斷樣本數量: 2088
Model forward signature: (tile, subtiles)


In [30]:
# --- 3) Prepare test meta-features ---
n_test = len(test_dataset)


for fold_id in range(n_folds):
    if fold_id > start_fold:
        print(f"⏭️ Skipping fold {fold_id}")
        continue
    ckpt_path = os.path.join(trained_oof_model_folder, f"fold{fold_id}", "best_model.pt")
    net = VisionMLP_MultiTask(tile_dim=tile_dim, subtile_dim=center_dim, output_dim=C)
    net = net.to(device)
    net.load_state_dict(torch.load(ckpt_path, map_location=device))
    net.eval()

    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    test_preds = []
    test_latents = []

    with torch.no_grad():
        for batch in test_loader:
            tiles = batch['tile'].to(device)
            subtiles = batch['subtiles'].to(device)
            center = subtiles[:, 4].contiguous()

            f_c = net.encoder_center(center)
            f_n = net.encoder_subtile(subtiles)
            f_t = net.encoder_tile(tiles)
            fuse = torch.cat([f_c, f_n, f_t], dim=1).contiguous()
            output = net.decoder(fuse)

            test_preds.append(output.cpu())
            test_latents.append(fuse.cpu())


    test_preds = torch.cat(test_preds, dim=0).numpy()
    test_latents = torch.cat(test_latents, dim=0).numpy()
# === AE model reconstruction loss ===
    recon_model = PretrainedEncoderRegressor(
        ae_checkpoint=pretrained_ae_path,
        ae_type=ae_type,
        tile_dim=tile_dim,
        center_dim=center_dim,
        neighbor_dim=neighbor_dim,
        output_dim=C,
        mode='reconstruction'
    ).to(device)

    meta = generate_meta_features(
        dataset = test_dataset,
        oof_preds = test_preds,
        image_latents = test_latents,
        model_for_recon = recon_model,
        device = device,
        ae_type = ae_type,
        use_clusters="both"
    )
    # 1) 直接載入整個 MultiOutputRegressor
    meta_model_path = os.path.join(trained_oof_model_folder, f"meta_model_fold{fold_id}.pkl")
    meta_model = joblib.load(meta_model_path)

    # 2) 用剛剛算出的 meta features 做預測
    final_preds = meta_model.predict(meta)


# --- Save submission ---
import h5py
import pandas as pd
with h5py.File("./dataset/elucidata_ai_challenge_data.h5","r") as f:
    test_spot_ids = pd.DataFrame(np.array(f["spots/Test"]["S_7"]))
sub = pd.DataFrame(final_preds, columns=[f"C{i+1}" for i in range(C)])
sub.insert(0, 'ID', test_spot_ids.index)
sub.to_csv(os.path.join(trained_oof_model_folder, 'submission_stacked.csv'), index=False)
print(f"✅ Saved stacked submission in {trained_oof_model_folder}")


  net.load_state_dict(torch.load(ckpt_path, map_location=device))
  ae.load_state_dict(torch.load(ae_checkpoint, map_location="cpu"))
Computing AE recon loss: 100%|██████████| 33/33 [00:02<00:00, 12.64it/s]


Silhouette score for k=2: 0.5554
Silhouette score for k=3: 0.6321
Silhouette score for k=4: 0.5934
Silhouette score for k=5: 0.5564
Silhouette score for k=6: 0.5140
Selected best k=3 (score=0.6321)
✅ Generated meta-features with shape: (2088, 307)




ValueError: X has 307 features, but LGBMRegressor is expecting 227 features as input.

In [None]:

# # Base model
# lgb_base = lgb.LGBMRegressor(
#     objective='l2',
#     metric='rmse',
#     n_estimators=12000,
#     max_depth=15,
#     learning_rate=0.008,
#     num_leaves=32,
#     colsample_bytree=0.25
# )

lgb_base = lgb.LGBMRegressor(
    objective='l2',
    metric='rmse',
    learning_rate=0.007522970004049377,
    n_estimators=12000,
    max_depth=11,
    num_leaves=194,
    colsample_bytree=0.7619407413363416,
    subsample=0.8,
    subsample_freq=1,
    min_data_in_leaf=20,
    reg_alpha=0.7480401395491829,
    reg_lambda=0.2589860348178542,
    verbosity=-1
)
# 將每個 target 分別 early stopping
meta_model = MultiOutputRegressor(lgb_base)

print("Training LightGBM on OOF meta-features with early stopping...")
meta_model.estimators_ = []

for i in range(y_train.shape[1]):
    print(f"Training target {i}...")
    model  = lgb.LGBMRegressor(
        objective='l2',
        metric='rmse',
        learning_rate=0.007522970004049377,
        n_estimators=12000,
        max_depth=11,
        num_leaves=194,
        colsample_bytree=0.7619407413363416,
        subsample=0.8,
        subsample_freq=1,
        min_data_in_leaf=20,
        reg_alpha=0.7480401395491829,
        reg_lambda=0.2589860348178542,
        verbosity=-1
    )

    model.fit(
        X_train,
        y_train[:, i],
        eval_set=[(X_val, y_val[:, i])],
        callbacks=[
            early_stopping(stopping_rounds=200),
            log_evaluation(period=100)
        ]
    )

    meta_model.estimators_.append(model)

# 保存模型
joblib.dump(meta_model, os.path.join(save_root, 'meta_model.pkl'))


# --- 3) Prepare test meta-features ---
n_test = len(test_dataset)
test_preds = []
test_latents = []

for fold_id in range(n_folds):
    ckpt_path = os.path.join(save_root, f"fold{fold_id}", "best_model.pt")
    net = PretrainedEncoderRegressor(
        ae_checkpoint=checkpoint_path,
        ae_type="all",
        center_dim=64, neighbor_dim=64, hidden_dim=128,
        tile_size=26, output_dim=35,
        freeze_encoder = True
    )

    net.decoder = nn.Sequential(
        nn.Linear(64+64, 256),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(256, 128),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(128, 64),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(64, 35)
    )

    net = net.to(device)
    net.load_state_dict(torch.load(ckpt_path, map_location=device))
    net.eval()

    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    preds = []
    latents = []

    with torch.no_grad():
        for batch in test_loader:
            tiles = batch['tile'].to(device)
            subtiles = batch['subtiles'].to(device)

            center = subtiles[:, 4].contiguous()
            f_c = net.enc_center(center)
            f_n = net.enc_neigh(subtiles)
            fuse = torch.cat([f_c, f_n], dim=1)

            out = net.decoder(fuse)

            preds.append(out.cpu())
            latents.append(fuse.cpu())  # image embedding (128D)

    test_preds.append(torch.cat(preds, dim=0).numpy())      # shape: (n_test, 35)
    test_latents.append(torch.cat(latents, dim=0).numpy())  # shape: (n_test, 128)

# === Stack + Average ===
test_preds = np.mean(np.stack(test_preds, axis=0), axis=0)      # (n_test, 35)
test_latents = np.mean(np.stack(test_latents, axis=0), axis=0)  # (n_test, 128)

with h5py.File("dataset/elucidata_ai_challenge_data.h5", "r") as f:
    test_spots = f["spots/Test"]
    spot_array = np.array(test_spots['S_7'])
    df = pd.DataFrame(spot_array)

xy = df[["x", "y"]].to_numpy()  # shape: (n_test, 2)

# 合併為最終 test meta features
test_meta = np.concatenate([test_preds, xy, test_latents], axis=1)  # shape: (n_test, 35+2+128)



final_preds = meta_model.predict(test_meta)

# --- Save submission ---
import h5py
import pandas as pd
with h5py.File("./dataset/elucidata_ai_challenge_data.h5","r") as f:
    test_spot_ids = pd.DataFrame(np.array(f["spots/Test"]["S_7"]))
sub = pd.DataFrame(final_preds, columns=[f"C{i+1}" for i in range(C)])
sub.insert(0, 'ID', test_spot_ids.index)
sub.to_csv(os.path.join(save_root, 'submission_stacked.csv'), index=False)
print(f"✅ Saved stacked submission in {save_root}")


In [None]:
import os
import numpy as np
import joblib
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.multioutput import MultiOutputRegressor
from sklearn.model_selection import train_test_split
import lightgbm as lgb
from scipy.stats import rankdata
from python_scripts.import_data import importDataset
from python_scripts.operate_model import predict
from lightgbm import early_stopping, log_evaluation
import h5py
import pandas as pd
# ---------------- Settings ----------------
save_root  = save_folder  # your save_folder path
n_folds    = len([d for d in os.listdir(save_root) if d.startswith('fold')])
n_samples  = len(full_dataset)
C          = 35  # num cell types
start_fold = 0
BATCH_SIZE = 64
# If optimizing Spearman, convert labels to ranks

# --- 1) Prepare OOF meta-features ---
# Initialize matrix for OOF predictions
n_samples = len(full_dataset)
oof_preds = np.zeros((n_samples, C), dtype=np.float32)
# True labels (raw or rank)
# importDataset returns a dict-like sample, so label is under key 'label'
y_true = np.vstack([ full_dataset[i]['label'].cpu().numpy() for i in range(n_samples) ])
y_meta = y_true

# Build CV splitter (must match first stage splits)
logo = LeaveOneGroupOut()
image_latents = np.zeros((n_samples, 128), dtype=np.float32)

# Loop over folds, load best model, predict on validation indices
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
for fold_id, (tr_idx, va_idx) in enumerate(
        logo.split(X=np.zeros(n_samples), y=None, groups=slide_idx)):
    # Load model
    # if fold_id > start_fold:
    #     print(f"⏭️ Skipping fold {fold_id}")
    #     continue
    ckpt_path = os.path.join(save_root, f"fold{fold_id}", "best_model.pt")
    print(f"Loading model from {ckpt_path}...")
    net = PretrainedEncoderRegressor(
        ae_checkpoint=checkpoint_path,
        ae_type="all",
        center_dim=64, neighbor_dim=64, hidden_dim=128,
        tile_size=26, output_dim=35,
        freeze_encoder = True
    )

    # 2) monkey‐patch 一个新的 head
    net.decoder  = nn.Sequential(
        nn.Linear(64+64, 256),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(256, 128),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(128, 64),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(64, 35)
        
    )
    net = net.to(device)    # Alternatively, if your model requires specific args, replace with:
    # net = VisionMLP_MultiTask(tile_dim=64, subtile_dim=64, output_dim=35).to(device)
    net.load_state_dict(torch.load(ckpt_path, map_location=device))
    net.to(device).eval()
    
    # Predict on validation set
    val_ds = Subset(full_dataset, va_idx)
    val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False)

    preds = []
    latents = []

    with torch.no_grad():
        for batch in val_loader:
            tiles    = batch['tile'].to(device)
            subtiles = batch['subtiles'].to(device)

            center = subtiles[:, 4].contiguous()
            f_c = net.enc_center(center)
            f_n = net.enc_neigh(subtiles)
            fuse = torch.cat([f_c, f_n], dim=1)

            output = net.decoder(fuse)

            preds.append(output.cpu())
            latents.append(fuse.cpu())  # ⬅️ 收集 latent vector

    preds = torch.cat(preds, dim=0).numpy()    # (n_val, 35)
    latents = torch.cat(latents, dim=0).numpy()  # (n_val, 128)

    oof_preds[va_idx] = preds
    image_latents[va_idx] = latents

    print(f"Fold {fold_id}: OOF preds shape {preds.shape}, Latent shape: {latents.shape}")


    
with h5py.File("dataset/realign/filtered_dataset.h5", "r") as f:
    train_spots = f["spots/Train"]
    
    train_spot_tables = {}
    
    for slide_name in train_spots.keys():
        spot_array = np.array(train_spots[slide_name])
        df = pd.DataFrame(spot_array)
        df["slide_name"] = slide_name
        train_spot_tables[slide_name] = df
        print(f"✅ 已讀取 slide: {slide_name}")

# -----------------------------------------------------
# Step 2: 合併所有 slide 的資料
# -----------------------------------------------------
all_train_spots_df = pd.concat(train_spot_tables.values(), ignore_index=True)
# 提取 x, y
xy = all_train_spots_df[["x", "y"]].to_numpy()  # shape: (8348, 2)

# 合併成新的 meta feature
meta_features = np.concatenate([oof_preds, xy, image_latents], axis=1)
# --- 2) Train LightGBM meta-model ---
# Choose objective: regression on rank (for Spearman) or raw (for MSE)
# 將 meta features 拆成訓練集與 early stopping 用的驗證集
X_train, X_val, y_train, y_val = train_test_split(meta_features, y_meta, test_size=0.2, random_state=42)
print("Meta feature shape:", X_train.shape)
print("Feature std (min/max):", np.min(np.std(X_train, axis=0)), np.max(np.std(X_train, axis=0)))


# # Base model
# lgb_base = lgb.LGBMRegressor(
#     objective='l2',
#     metric='rmse',
#     n_estimators=12000,
#     max_depth=15,
#     learning_rate=0.008,
#     num_leaves=32,
#     colsample_bytree=0.25
# )
import optuna
from sklearn.multioutput import MultiOutputRegressor
from sklearn.metrics import mean_squared_error

# Define Optuna objective function
def objective(trial):
    params = {
        'objective': 'regression',
        'metric': 'rmse',
        'verbosity': -1,
        'boosting_type': 'gbdt',
        'device': 'gpu',                # ✅ GPU 支援
        'gpu_platform_id': 0,
        'gpu_device_id': 0,
        'learning_rate': trial.suggest_float("learning_rate", 0.005, 0.1),
        'max_depth': trial.suggest_int("max_depth", 4, 15),
        'num_leaves': trial.suggest_int("num_leaves", 32, 256),
        'min_data_in_leaf': trial.suggest_int("min_data_in_leaf", 20, 100),
        'colsample_bytree': trial.suggest_float("colsample_bytree", 0.6, 1.0),
        'reg_alpha': trial.suggest_float("reg_alpha", 0, 1),
        'reg_lambda': trial.suggest_float("reg_lambda", 0, 1),
        'n_estimators': 12000
    }

    model = lgb.LGBMRegressor(**params)
    multi_model = MultiOutputRegressor(model)
    multi_model.fit(X_train, y_train)

    y_pred = multi_model.predict(X_val)
    rmse = np.mean([
        np.sqrt(mean_squared_error(y_val[:, i], y_pred[:, i]))
        for i in range(y_val.shape[1])
    ])


    return rmse

# Run optimization
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=30)

# Use best params to train final models
best_params = study.best_trial.params
best_params['objective'] = 'l2'
best_params['metric'] = 'rmse'
best_params['verbosity'] = -1

# Train final models with best parameters
meta_model = MultiOutputRegressor(lgb.LGBMRegressor(**best_params))
meta_model.estimators_ = []

print("Training LightGBM on OOF meta-features with best Optuna params...")
for i in range(y_train.shape[1]):
    print(f"Training target {i}...")
    model = lgb.LGBMRegressor(**best_params)

    model.fit(
        X_train,
        y_train[:, i],
        eval_set=[(X_val, y_val[:, i])],
        callbacks=[
            early_stopping(stopping_rounds=200),
            log_evaluation(period=100)
        ]
    )

    meta_model.estimators_.append(model)

# Save model
joblib.dump(meta_model, os.path.join(save_root, 'meta_model.pkl'))
# 保存模型


# --- 3) Prepare test meta-features ---
n_test = len(test_dataset)
test_preds = []
test_latents = []

for fold_id in range(n_folds):
    ckpt_path = os.path.join(save_root, f"fold{fold_id}", "best_model.pt")
    net = PretrainedEncoderRegressor(
        ae_checkpoint=checkpoint_path,
        ae_type="all",
        center_dim=64, neighbor_dim=64, hidden_dim=128,
        tile_size=26, output_dim=35,
        freeze_encoder = True
    )

    net.decoder = nn.Sequential(
        nn.Linear(64+64, 256),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(256, 128),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(128, 64),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(64, 35)
    )

    net = net.to(device)
    net.load_state_dict(torch.load(ckpt_path, map_location=device))
    net.eval()

    test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    preds = []
    latents = []

    with torch.no_grad():
        for batch in test_loader:
            tiles = batch['tile'].to(device)
            subtiles = batch['subtiles'].to(device)

            center = subtiles[:, 4].contiguous()
            f_c = net.enc_center(center)
            f_n = net.enc_neigh(subtiles)
            fuse = torch.cat([f_c, f_n], dim=1)

            out = net.decoder(fuse)

            preds.append(out.cpu())
            latents.append(fuse.cpu())  # image embedding (128D)

    test_preds.append(torch.cat(preds, dim=0).numpy())      # shape: (n_test, 35)
    test_latents.append(torch.cat(latents, dim=0).numpy())  # shape: (n_test, 128)

# === Stack + Average ===
test_preds = np.mean(np.stack(test_preds, axis=0), axis=0)      # (n_test, 35)
test_latents = np.mean(np.stack(test_latents, axis=0), axis=0)  # (n_test, 128)

with h5py.File("dataset/elucidata_ai_challenge_data.h5", "r") as f:
    test_spots = f["spots/Test"]
    spot_array = np.array(test_spots['S_7'])
    df = pd.DataFrame(spot_array)

xy = df[["x", "y"]].to_numpy()  # shape: (n_test, 2)

# 合併為最終 test meta features
test_meta = np.concatenate([test_preds, xy, test_latents], axis=1)  # shape: (n_test, 35+2+128)



final_preds = meta_model.predict(test_meta)

# --- Save submission ---
import h5py
import pandas as pd
with h5py.File("./dataset/elucidata_ai_challenge_data.h5","r") as f:
    test_spot_ids = pd.DataFrame(np.array(f["spots/Test"]["S_7"]))
sub = pd.DataFrame(final_preds, columns=[f"C{i+1}" for i in range(C)])
sub.insert(0, 'ID', test_spot_ids.index)
sub.to_csv(os.path.join(save_root, 'submission_stacked.csv'), index=False)
print("✅ Saved stacked submission.")


In [None]:
import os
import numpy as np
import joblib
import torch
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import LeaveOneGroupOut
from sklearn.multioutput import MultiOutputRegressor
import lightgbm as lgb
from scipy.stats import rankdata
from python_scripts.import_data import importDataset
from python_scripts.operate_model import predict

# --- 配置: 只用哪些 fold 的结果来训练/预测 meta-model ---
meta_folds = [0]  # 例如只用 fold0, fold2, fold4

# 1) 准备 full_dataset, slide_idx, test_dataset 等
full_dataset = importDataset(
    grouped_data, model,
    image_keys=['tile','subtiles'],
    transform=lambda x: x
)
n_samples = len(full_dataset)
C = 35  # 类别数

# 2) 预留 oof_preds 和 fold_ids
oof_preds    = np.zeros((n_samples, C), dtype=np.float32)
oof_fold_ids = np.full(n_samples, -1, dtype=int)

# 真标签
y_true = np.vstack([ full_dataset[i]['label'].cpu().numpy() for i in range(n_samples) ])
y_meta = y_true.copy()  # 不做 rank 时直接用 raw

# 3) 生成 OOF 预测并记录 fold id
logo = LeaveOneGroupOut()
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

for fold_id, (tr_idx, va_idx) in enumerate(
        logo.split(X=np.zeros(n_samples), y=None, groups=slide_idx)):

    # 如果当前 fold 不在我们想要的 meta_folds 列表里，就跳过
    if fold_id not in meta_folds:
        print(f"⏭️ Skipping OOF for fold {fold_id}")
        continue

    print(f"\n>>> Generating OOF for fold {fold_id}")
    ckpt_path = os.path.join(save_root, f"fold{fold_id}", "best_model.pt")
    net = PretrainedEncoderRegressor(
        ae_checkpoint=checkpoint_path,
        ae_type="all",
        center_dim=64, neighbor_dim=64, hidden_dim=128,
        tile_size=26, output_dim=35,
        freeze_encoder = True
    )

    # 2) monkey‐patch 一个新的 head
    net.decoder  = nn.Sequential(
        nn.Linear(64+64, 128),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(128, 64),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(64, 35)
        
    )
    net = net.to(device)
    net.load_state_dict(torch.load(ckpt_path, map_location=device))
    net.eval()

    val_loader = DataLoader(Subset(full_dataset, va_idx), batch_size=BATCH_SIZE, shuffle=False)
    preds = predict(net, val_loader, device)  # (n_val, C)

    oof_preds[va_idx]    = preds
    oof_fold_ids[va_idx] = fold_id

    print(f"  → Fold {fold_id} OOF preds shape: {preds.shape}")
# 4) 只选取 meta_folds 的行来训练 meta-model
mask = np.isin(oof_fold_ids, meta_folds)
X_meta = oof_preds[mask]
y_meta_sub = y_meta[mask]

print(f"\nTraining meta-model on folds {meta_folds}:")
print(f"  使用样本数：{X_meta.shape[0]} / {n_samples}")

lgb_base = lgb.LGBMRegressor(
    objective='regression',
    learning_rate=0.001,
    n_estimators=1000,
    num_leaves=31,
    subsample=0.7,
    colsample_bytree=0.7,
    n_jobs=-1,
    force_col_wise=True
)
meta_model = MultiOutputRegressor(lgb_base)
meta_model.fit(X_meta, y_meta_sub)
joblib.dump(meta_model, os.path.join(save_root, 'meta_model.pkl'))

# 5) 准备 test_meta，只平均 meta_folds 中的预测
n_folds = len([d for d in os.listdir(save_root) if d.startswith('fold')])
n_test  = len(test_dataset)
test_meta = np.zeros((n_test, C), dtype=np.float32)

for fold_id in range(n_folds):
    if fold_id not in meta_folds:
        continue
    ckpt_path = os.path.join(save_root, f"fold{fold_id}", "best_model.pt")
    net = PretrainedEncoderRegressor(
        ae_checkpoint=checkpoint_path,
        ae_type="all",
        center_dim=64, neighbor_dim=64, hidden_dim=128,
        tile_size=26, output_dim=35,
        freeze_encoder = True
    )

    # 2) monkey‐patch 一个新的 head
    net.decoder  = nn.Sequential(
        nn.Linear(64+64, 128),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(128, 64),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(64, 35)
        
    )
    net = net.to(device)
    net.load_state_dict(torch.load(ckpt_path, map_location=device))
    net.eval()

    loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)
    preds = predict(net, loader, device)
    test_meta += preds

# 平均时除以参与的 folds 数目
test_meta /= len(meta_folds)

# 6) 用 meta-model 做最终预测
final_preds = meta_model.predict(test_meta)

# --- Save submission ---
import h5py
import pandas as pd

with h5py.File("./dataset/elucidata_ai_challenge_data.h5","r") as f:
    test_spot_ids = pd.DataFrame(np.array(f["spots/Test"]["S_7"]))

sub = pd.DataFrame(final_preds, columns=[f"C{i+1}" for i in range(C)])
sub.insert(0, 'ID', test_spot_ids.index)
sub.to_csv(os.path.join(save_root, 'submission_stacked.csv'), index=False)
print("✅ Saved stacked submission.")


# Predict

In [None]:
import torch
import inspect
from python_scripts.operate_model import get_model_inputs
from python_scripts.import_data import load_node_feature_data


image_keys = [ 'tile', 'subtiles']


# 用法示例
from python_scripts.import_data import importDataset
# 假设你的 model 已经定义好并实例化为 `model`
test_dataset = load_node_feature_data("dataset/spot-rank/filtered_directly_rank/masked/test/Macenko/test_dataset.pt", model)
test_dataset = importDataset(
        data_dict=test_dataset,
        model=model,
        image_keys=image_keys,
        transform=lambda x: x,  # identity transform
        print_sig=True
    )



In [None]:

test_dataset.check_item(1000, 10)


In [None]:
from torch.utils.data import DataLoader
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [None]:
import glob
import torch
import numpy as np
import pandas as pd
import os
import h5py
from torch.utils.data import DataLoader

# 讀 test spot index
with h5py.File("./dataset/elucidata_ai_challenge_data.h5","r") as f:
    test_spots     = f["spots/Test"]
    test_spot_table= pd.DataFrame(np.array(test_spots['S_7']))

fold_ckpts = sorted(glob.glob(os.path.join(save_folder, "fold*", "best_model.pt")))
models = []
for ckpt in fold_ckpts:
    net = PretrainedEncoderRegressor(
        ae_checkpoint=checkpoint_path,
        ae_type="all",
        center_dim=64, neighbor_dim=64, hidden_dim=128,
        tile_size=26, output_dim=35,
        freeze_encoder = False
    )

    # 2) monkey‐patch 一个新的 head
    net.decoder  = nn.Sequential(
        nn.Linear(64+64, 256),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(256, 128),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(128, 64),
        nn.SiLU(),
        nn.Dropout(0.1),
        nn.Linear(64, 35)
        
    )
    net = net.to(device)
    net.load_state_dict(torch.load(ckpt, map_location="cpu"))
    net.to(device).eval()
    models.append(net)

all_fold_preds = []
for fold_id, net in enumerate(models):
    # 推論
    with torch.no_grad():
        preds = predict(net, test_loader, device)  # (N_test,35) numpy array

    # 1) 存每一折的原始預測
    df_fold = pd.DataFrame(preds, columns=[f"C{i+1}" for i in range(preds.shape[1])])
    df_fold.insert(0, "ID", test_spot_table.index)
    path_fold = os.path.join(save_folder, f"submission_fold{fold_id}.csv")
    df_fold.to_csv(path_fold, index=False)
    print(f"✅ Saved fold {fold_id} predictions to {path_fold}")

    all_fold_preds.append(preds)

# 2) 做 rank‐average ensemble
all_fold_preds = np.stack(all_fold_preds, axis=0)       # (K, N_test, 35)
ranks          = all_fold_preds.argsort(axis=2).argsort(axis=2).astype(float)
mean_rank      = ranks.mean(axis=0)                    # (N_test,35)

# 3) 存 final ensemble
df_ens = pd.DataFrame(mean_rank, columns=[f"C{i+1}" for i in range(mean_rank.shape[1])])
df_ens.insert(0, "ID", test_spot_table.index)
path_ens = os.path.join(save_folder, "submission_rank_ensemble.csv")
df_ens.to_csv(path_ens, index=False)
print(f"✅ Saved rank‐ensemble submission to {path_ens}")
