In [1]:
# general package
import numpy as np
# deep learning package
import torch
import torchvision.models as models
import torchvision.transforms as T
import torch.nn as nn
import torch.nn.functional as F
from torchvision.datasets import STL10
from torch.utils.data import DataLoader
import torchvision.transforms as T

In [2]:
# self defined modules
import DataLoader_tensor as dl
import SCDataset as ds

### Generate Dummy Data

In [3]:
# Create dummy data
# Set the parameters
n = 180  # Total number of data points
p = 5   # Number of features
M = 13   # Total number of groups

# Create the first array (n x p)
count_matrix = np.random.rand(n, p)  # Filling with random numbers for illustration

# Create the second array (n x 1)
# Ensure each group has at least two members
lineage = np.repeat(np.arange(1, M + 1), repeats=np.ceil(n / M))[:n]
np.random.shuffle(lineage)  # Shuffle to randomize group allocation
lineage = lineage.reshape(n, 1)

# print("Data Array (n x p):\n", count_matrix)
# print("Group Array (n x 1):\n", lineage)


In [4]:
count_matrix.shape, lineage.shape


((180, 5), (180, 1))

### Generate Batches

In [5]:
# step 1 generate designed batches
DLoader = dl.SClineage_DataLoader(count_matrix,lineage,batch_size=3)
batch_all, num_batch = DLoader.batch_generator()
# step 2 generate real dataloader
sc_dataset = ds.SCDataset(batches=batch_all)
data_loader = torch.utils.data.DataLoader(dataset=sc_dataset, batch_size=3, shuffle=False)

### Comparsion between two batches generated by scDataLoader and torch.DataLoader

In [6]:
num_batches = 0
for _ in data_loader:
    num_batches += 1

print(f"Total number of batches generated by scDataLoader: {num_batch}")
print(f"Total number of batches generated by torch.DataLoader: {num_batches}")

Total number of batches generated by scDataLoader: 31
Total number of batches generated by torch.DataLoader: 31


In [7]:
for i, batch in enumerate(data_loader):
    print(f"Batch generated by torch.DataLoader:")
    print(f"Batch {i+1}/{len(data_loader)}:")
    # Assuming each batch is a tuple of two tensors
    for j, cell_pair in enumerate(batch):
        cell_1, cell_2,  cell_3 = cell_pair  # Unpack the tuple
        print(f"  Sample {j+1} in batch:")
        print(f"    Cell 1: {cell_1}")
        print(f"    Cell 2: {cell_2}")
        print(f"    Cell 3: {cell_3}")
    
    # Optionally, break after a few batches to not overload the output
    if i == 0:  # Adjust this number based on how many batches you want to inspect
        break


Batch generated by torch.DataLoader:
Batch 1/31:
  Sample 1 in batch:
    Cell 1: tensor([0.8231, 0.4001, 0.7745, 0.3747, 0.5779])
    Cell 2: tensor([0.6615, 0.2790, 0.8634, 0.6982, 0.0970])
    Cell 3: tensor([0.9173, 0.6221, 0.7738, 0.5678, 0.0665])
  Sample 2 in batch:
    Cell 1: tensor([0.3568, 0.0533, 0.9256, 0.7852, 0.9364])
    Cell 2: tensor([0.6230, 0.2482, 0.2230, 0.5607, 0.5763])
    Cell 3: tensor([0.1637, 0.2104, 0.3233, 0.7858, 0.0229])


In [8]:
print(f"Batch generated by scDataLoader:")
print(f"Batch 1/31:")
batch_all[1]

Batch generated by scDataLoader:
Batch 1/31:


[(tensor([0.8231, 0.4001, 0.7745, 0.3747, 0.5779]),
  tensor([0.3568, 0.0533, 0.9256, 0.7852, 0.9364])),
 (tensor([0.6615, 0.2790, 0.8634, 0.6982, 0.0970]),
  tensor([0.6230, 0.2482, 0.2230, 0.5607, 0.5763])),
 (tensor([0.9173, 0.6221, 0.7738, 0.5678, 0.0665]),
  tensor([0.1637, 0.2104, 0.3233, 0.7858, 0.0229]))]