In [5]:
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import scanpy as sc


In [2]:
def get_h5ad_data(file_name='zebrafish_scNODE0_2000genes_3227cells_12tps.h5ad'):
    dataset_dir = '/home/hanyuji/Workbench/Data/h5ad/'
    loaded_adata = sc.read_h5ad(dataset_dir + file_name)

    timepoints = sorted(loaded_adata.obs['tp'].unique())

    data_list = []

    for tp in timepoints:
        # 选择对应时间点的细胞
        subset = loaded_adata[loaded_adata.obs['tp'] == tp]
        # 获取X矩阵
        X_matrix = subset.X.toarray() if hasattr(subset.X, "toarray") else subset.X
        # 添加到数组中
        data_list.append(X_matrix)

    return data_list


In [3]:
# 自定义Dataset类
class scDataset(Dataset):
    def __init__(self, data_list):
        self.data_list = data_list

    def __len__(self):
        return sum(len(data) for data in self.data_list)

    def __getitem__(self, idx):
        # 查找idx所属的numpy数组及其对应的局部idx
        for label, data in enumerate(self.data_list):
            if idx < len(data):
                sample = data[idx]
                return sample, label
            idx -= len(data)

In [6]:
# 示例数据
data_list = get_h5ad_data()
data_list[0].shape

(311, 2000)

In [8]:

# 创建Dataset和DataLoader
dataset = scDataset(data_list)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)



In [22]:
# 遍历DataLoader
for i, (inputs, labels) in enumerate(dataloader):
    print(f"Batch {i+1}:")
    print(f"  Inputs: {inputs.shape}")
    print(f"  Labels: {labels.shape}")
    # 此处可以添加训练代码


Batch 1:
  Inputs: torch.Size([32, 2000])
  Labels: torch.Size([32])
Batch 2:
  Inputs: torch.Size([32, 2000])
  Labels: torch.Size([32])
Batch 3:
  Inputs: torch.Size([32, 2000])
  Labels: torch.Size([32])
Batch 4:
  Inputs: torch.Size([32, 2000])
  Labels: torch.Size([32])
Batch 5:
  Inputs: torch.Size([32, 2000])
  Labels: torch.Size([32])
Batch 6:
  Inputs: torch.Size([32, 2000])
  Labels: torch.Size([32])
Batch 7:
  Inputs: torch.Size([32, 2000])
  Labels: torch.Size([32])
Batch 8:
  Inputs: torch.Size([32, 2000])
  Labels: torch.Size([32])
Batch 9:
  Inputs: torch.Size([32, 2000])
  Labels: torch.Size([32])
Batch 10:
  Inputs: torch.Size([32, 2000])
  Labels: torch.Size([32])
Batch 11:
  Inputs: torch.Size([32, 2000])
  Labels: torch.Size([32])
Batch 12:
  Inputs: torch.Size([32, 2000])
  Labels: torch.Size([32])
Batch 13:
  Inputs: torch.Size([32, 2000])
  Labels: torch.Size([32])
Batch 14:
  Inputs: torch.Size([32, 2000])
  Labels: torch.Size([32])
Batch 15:
  Inputs: torch.Siz

In [21]:
labels

tensor([ 6,  8, 11, 11,  6,  5,  9,  5,  6,  5,  0,  7,  8,  8,  8,  1,  5,  9,
         5,  9, 11,  8, 11, 11,  9, 11, 10,  9, 11,  4,  4,  6])

In [25]:
from dataloader_VAE import get_dataloader

loader = get_dataloader(shuffle=False)

for i, (inputs, labels) in enumerate(loader):
    print(f"Batch {i+1}:")
    print(f"  Inputs: {inputs.shape}")
    print(f"  Labels: {labels}")
    # break

Batch 1:
  Inputs: torch.Size([64, 2000])
  Labels: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Batch 2:
  Inputs: torch.Size([64, 2000])
  Labels: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Batch 3:
  Inputs: torch.Size([64, 2000])
  Labels: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
Batch 4:
  Inputs: torch.Size([64, 2000])
  Labels: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,