In [3]:
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

In [None]:
class SetNetClassifier(nn.Module):
    def __init__(self, 
                 input_dim=45,          # 每个细胞的特征维度
                 hidden_dim=128,        # 细胞特征编码后的维度
                 classifier_hidden=128, # MLP中间层
                 output_dim=3):         # 类别数（分类）
        super().__init__()

        # Set Encoder: 对每个细胞特征进行变换
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )

        # 分类器 MLP（输入维度是 2×hidden_dim，因为拼接了 mean 和 sum）
        self.classifier = nn.Sequential(
            nn.Linear(2 * hidden_dim, classifier_hidden),
            nn.ReLU(),
            nn.Linear(classifier_hidden, output_dim)
            # add softmax
        )

    def forward(self, batch_cells):
        """
        batch_cells: list of tensors, each of shape [N_i, 45]
        return: logits tensor of shape [B, 3]
        """
        logits_list = []

        for cells in batch_cells:  # 每一组细胞
            z = self.encoder(cells)               # [N, hidden_dim]
            z_mean = z.mean(dim=0)                # [hidden_dim]
            z_sum = z.sum(dim=0)                  # [hidden_dim]
            group_embedding = torch.cat([z_mean, z_sum], dim=0)  # [2 * hidden_dim]
            logits = self.classifier(group_embedding)            # [3]
            logits_list.append(logits)

        return torch.stack(logits_list)  # [B, 3]
    
    
    

In [5]:
class CellGroupDataset(Dataset):
    def __init__(self, npy_path, label_list):
        self.raw_data = np.load(npy_path, allow_pickle=True)  # shape: [40]
        self.labels = label_list  # shape: [40]，例如 [0, 2, 1, ...]
        
        # 每组为 list of np.array → 将每个细胞拼接为 torch tensor
        self.data = []
        for group in self.raw_data:
            group_tensor = []
            for cell in group:
                cell_feat = np.concatenate([cell[0].flatten(), cell[1].flatten()])  # shape (45,)
                cell_feat_tensor = torch.tensor(cell_feat, dtype=torch.float32)
                group_tensor.append(cell_feat_tensor)
            self.data.append(group_tensor)  # list of tensors per group

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

# collate_fn 处理变长 batch
def collate_fn(batch):
    data, labels = zip(*batch)
    return list(data), torch.tensor(labels, dtype=torch.long)

In [1]:
# 你自己的标签（例如 0 → class 0, 0.5 → class 1, 1 → class 2）
float_labels = [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0.5, 0.5, 
                0.5, 0.5, 0.5, 0.5, 0.5, 0, 0, 0, 0, 0, 0, 0, 0.5, 
                0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5, 0.5]  # 长度 = 40

In [6]:

label_map = {0.0: 0, 0.5: 1, 1.0: 2}
class_labels = [label_map[val] for val in float_labels]



In [None]:
# 加载 Dataset
dataset = CellGroupDataset("./dataset_first.npy", class_labels)

# 负责打包
loader = DataLoader(dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [11]:
model = SetNetClassifier(input_dim=45, hidden_dim=128, output_dim=3)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
criterion = torch.nn.CrossEntropyLoss()

# 可选 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(10):
    model.train()
    for batch_cells, batch_labels in loader:
        batch_cells = [torch.stack([cell.to(device) for cell in group]) for group in batch_cells]
        batch_labels = batch_labels.to(device)

        logits = model(batch_cells)  # [B, 3]
        loss = criterion(logits, batch_labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}, Loss = {loss.item():.4f}")
    print(f"Logits sample: {logits[0].detach().cpu().numpy()}")
    print(f"Labels sample: {batch_labels[0].item()}")

Epoch 1, Loss = 115952.2812
Logits sample: [ -54836.598  129788.17  -438236.97 ]
Labels sample: 2
Epoch 2, Loss = 1121383.0000
Logits sample: [-76073.734  30626.912 -92088.1  ]
Labels sample: 0
Epoch 3, Loss = 187645.0156
Logits sample: [-262130.19   -18247.143 -115872.586]
Labels sample: 1
Epoch 4, Loss = 51118.2578
Logits sample: [-3231345.8   -942575.8   -240301.78]
Labels sample: 2
Epoch 5, Loss = 93108.2422
Logits sample: [-2814.9766   814.7249 -1550.7041]
Labels sample: 1
Epoch 6, Loss = 68588.9375
Logits sample: [-591837.44   -60465.223  -81041.54 ]
Labels sample: 1
Epoch 7, Loss = 110411.7109
Logits sample: [-113985.44    -17859.018    -2997.1104]
Labels sample: 0
Epoch 8, Loss = 25992.3809
Logits sample: [-115119.664   -15443.127     2423.8503]
Labels sample: 0
Epoch 9, Loss = 201737.8125
Logits sample: [-60174.33    -5489.643   -7294.5796]
Labels sample: 0
Epoch 10, Loss = 182416.4062
Logits sample: [-349654.88   -95280.45    28354.459]
Labels sample: 0
