In [15]:
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, Subset
import random
import shap
import matplotlib.pyplot as plt

In [11]:
class CellGroupDataset(Dataset):
    def __init__(self, data, labels):
        self.data = []
        for group in data:
            group_tensor = []
            for cell in group:
                feature = np.concatenate([cell[0].flatten(), cell[1].flatten()])  # (45,)
                group_tensor.append(torch.tensor(feature, dtype=torch.float32))
            self.data.append(torch.stack(group_tensor))  # shape: [11, 45]
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

In [12]:

# 加载数据
raw_data = np.load("dataset_first.npy", allow_pickle=True)
float_labels = [1, 0, 0.5, 0, 0.5, 0.5]

label_map = {0.0: 0, 0.5: 1, 1.0: 2}
labels = [label_map[val] for val in float_labels]
dataset = CellGroupDataset(raw_data, labels)

In [13]:
class SetNetClassifier(nn.Module):
    def __init__(self, 
                 input_dim=45 * 8,          # 每个细胞的特征维度
                 hidden_dim=128,        # 细胞特征编码后的维度
                 classifier_hidden=128, # MLP中间层
                 output_dim=3):         # 类别数（分类）
        super().__init__()

        # Set Encoder: 对每个细胞特征进行变换
        # self.encoder = nn.Sequential(
        #     nn.Linear(input_dim, hidden_dim),
        #     nn.ReLU(),
        #     nn.Linear(hidden_dim, hidden_dim)
        # )

        # 分类器 MLP
        self.classifier = nn.Sequential(
           
            nn.Linear(input_dim, classifier_hidden),
            nn.ReLU(),
            nn.Linear(classifier_hidden, output_dim)
            # add softmax
        )
    
    def forward(self, batch_cells):
        """
        batch_cells: list of tensors, each of shape [8, 45]
        return: logits tensor of shape [B, 3]
        """
        logits_list = []

        for cells in batch_cells:  # 每一组细胞
            # z = self.encoder(cells)               # [N, hidden_dim]

            flat = cells.flatten() 
            logits = self.classifier(flat)            # [3]
            logits_list.append(logits)

        return torch.stack(logits_list)  # [B, 3]        

In [14]:
model = SetNetClassifier()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()

num_epochs = 10
all_indices = list(range(len(dataset)))

for epoch in range(num_epochs):
    # 随机划分 2 个测试样本，剩下用于训练
    test_indices = random.sample(all_indices, 2)
    train_indices = list(set(all_indices) - set(test_indices))

    train_loader = DataLoader(Subset(dataset, train_indices), batch_size=2, shuffle=True)
    test_loader = DataLoader(Subset(dataset, test_indices), batch_size=2)

    # === 训练 ===
    model.train()
    for batch_x, batch_y in train_loader:
        logits = model(batch_x)
        loss = loss_fn(logits, batch_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # === 测试 ===
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_x, batch_y in test_loader:
            logits = model(batch_x)
            pred = torch.argmax(logits, dim=1)
            correct += (pred == batch_y).sum().item()
            total += batch_y.size(0)
    
    acc = correct / total
    print(f"Epoch {epoch+1} | Loss: {loss.item():.4f} | Test Accuracy: {acc*100:.2f}%")
    
    

Epoch 1 | Loss: 8742.1650 | Test Accuracy: 50.00%
Epoch 2 | Loss: 420914.0625 | Test Accuracy: 50.00%
Epoch 3 | Loss: 125624.9375 | Test Accuracy: 50.00%
Epoch 4 | Loss: 0.0000 | Test Accuracy: 50.00%
Epoch 5 | Loss: 0.0000 | Test Accuracy: 50.00%
Epoch 6 | Loss: 46718.1094 | Test Accuracy: 0.00%
Epoch 7 | Loss: 1743.3209 | Test Accuracy: 50.00%
Epoch 8 | Loss: 1216.3971 | Test Accuracy: 100.00%
Epoch 9 | Loss: 637.4739 | Test Accuracy: 50.00%
Epoch 10 | Loss: 0.0000 | Test Accuracy: 50.00%


In [16]:
torch.save(model.state_dict(), 'model.pt')