In [4]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms

class DeepSpotTileDataset(Dataset):
    def __init__(self, tiles, subtiles, neighbor_tiles, labels, meta_info=None, transform=None):
        """
        tiles: List of center tiles (np.array, H×W×3)
        subtiles: List of list (9 subtiles per spot), each subtile is np.array(H', W', 3)
        neighbor_tiles: List of list (8 neighbor tiles per spot), each tile is np.array(H, W, 3)
        labels: List of np.array (35,)
        meta_info: Optional list of (slide_id, x, y) for each sample
        transform: torchvision transform to apply to all tiles
        """
        self.tiles = tiles
        self.subtiles = subtiles
        self.neighbor_tiles = neighbor_tiles
        self.labels = labels
        self.meta_info = meta_info if meta_info is not None else [None] * len(tiles)

        self.transform = transform or transforms.Compose([
            transforms.ToTensor(),  # (H, W, C) → (C, H, W)
            transforms.Normalize(mean=[0.5]*3, std=[0.5]*3)
        ])

    def __len__(self):
        return len(self.tiles)

    def __getitem__(self, idx):
        center_tile = self.transform(self.tiles[idx])  # → (3, H, W)

        subtiles = [self.transform(t) for t in self.subtiles[idx]]  # → list of (3, h, w)
        subtiles = torch.stack(subtiles)  # → (9, 3, h, w)

        neighbor_tiles = [self.transform(t) for t in self.neighbor_tiles[idx]]  # → list of (3, H, W)
        neighbor_tiles = torch.stack(neighbor_tiles)  # → (8, 3, H, W)

        label = torch.tensor(self.labels[idx], dtype=torch.float32)  # → (35,)
        meta = self.meta_info[idx]

        return {
            'center_tile': center_tile,              # (3, H, W)
            'subtiles': subtiles,                    # (9, 3, h, w)
            'neighbor_tiles': neighbor_tiles,        # (8, 3, H, W)
            'label': label,                          # (35,)
            'meta': meta
        }


In [5]:
# 載入資料
data = torch.load("train_dataset.pt")
train_dataset = SubTileDataset(
    tiles=data['tiles'],
    labels=data['labels'],
    subtiles=data['subtiles'],
    neighbor_tiles=data['neighbor_tiles'],
    meta_info=data['meta_info']
)
# 載入資料
data = torch.load("test_dataset.pt")
test_dataset = SubTileDataset(data['tiles'], data['subtiles'], data['neighbor_tiles'], data['labels'], data['meta_info'])

  data = torch.load("train_dataset.pt")


TypeError: __init__() got an unexpected keyword argument 'subtiles'

In [3]:
train_dataset = SubTileDataset(
    tiles=data['tiles'],
    labels=data['labels'],
    subtiles=data['subtiles'],
    neighbor_tiles=data['neighbor_tiles'],
    meta_info=data['meta_info']
)

TypeError: __init__() got an unexpected keyword argument 'subtiles'