In [16]:
import torch
from torch.utils.data import Dataset, DataLoader

# 定義一個簡單的數據集
class SimpleDataset(Dataset):
    def __init__(self):
        # 假設數據是一些數字對 [(1, 2), (3, 4), (5, 6), ...]
        self.data = [(i, i*10) for i in range(1, 11)]
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        print(idx)
        return self.data[idx]

# 定義自訂的 collate_fn 函數
def custom_collate_fn(batch):
    # batch 是一個列表，包含多個 __getitem__ 的返回值
    # 我們將其拆開為兩個張量列表
    print(batch)
    data1 = [item[0] for item in batch]
    data2 = [item[1] for item in batch]
    
    # 將列表轉為張量
    data1 = torch.tensor(data1, dtype=torch.float32)
    data2 = torch.tensor(data2, dtype=torch.float32)
    
    return data1, data2

# 創建數據集和 DataLoader
dataset = SimpleDataset()
dataloader = DataLoader(dataset, batch_size=4, collate_fn=custom_collate_fn)

# 遍歷 DataLoader
for batch_idx, (data1, data2) in enumerate(dataloader):
    print(f'Batch {batch_idx}:')
    print('data1:', data1)
    print('data2:', data2)


0
1
2
3
[(1, 10), (2, 20), (3, 30), (4, 40)]
Batch 0:
data1: tensor([1., 2., 3., 4.])
data2: tensor([10., 20., 30., 40.])
4
5
6
7
[(5, 50), (6, 60), (7, 70), (8, 80)]
Batch 1:
data1: tensor([5., 6., 7., 8.])
data2: tensor([50., 60., 70., 80.])
8
9
[(9, 90), (10, 100)]
Batch 2:
data1: tensor([ 9., 10.])
data2: tensor([ 90., 100.])
