In [44]:
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from torch.utils.data.distributed import DistributedSampler

from datasets import load_dataset
import torch

class ImageDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform
        
    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        (image, label) = self.dataset[idx].values()
        if image.mode == "L":
            image = image.convert("RGB")
        if self.transform is not None:
            image = self.transform(image)
        return image, label
    
def get_transform(image_size=256):
    transform = transforms.Compose([
        transforms.Resize((image_size, image_size)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
    ])
    return transform

def get_ddp_sampler_loader(dataset,
        num_replicas,
        rank,
        sample_shuffle,
        seed,
        batch_size,
        num_workers,
        pin_memory,
        drop_last):
    
    sampler = DistributedSampler(
        dataset,
        num_replicas=num_replicas,
        rank=rank,
        shuffle=sample_shuffle,
        seed=seed
    )
    loader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=False,
        sampler=sampler,
        num_workers=num_workers,
        pin_memory=pin_memory,
        drop_last=drop_last
    )
    return sampler, loader


In [45]:
dataset = load_dataset("imagenet-1k",cache_dir='./dataset/imagenet1k')
transform = get_transform(image_size=256)
train_dataset = ImageDataset(dataset['train'], transform)

batch_size=int(32)
train_sampler, train_loader = get_ddp_sampler_loader(dataset=train_dataset,
                            num_replicas=1,
                            rank=0,
                            sample_shuffle=True,
                            seed=0,
                            batch_size=batch_size,
                            num_workers=4,
                            pin_memory=True,
                            drop_last=True)

Found cached dataset imagenet-1k (/media/data/xiw136/workspace/local-refine-diffuser/dataset/imagenet1k/imagenet-1k/default/1.0.0/a1e9bfc56c3a7350165007d1176b15e9128fcaf9ab972147840529aed3ae52bc)
100%|██████████| 3/3 [00:00<00:00, 22.56it/s]


In [46]:
for batch in train_loader:
    print(batch[0].shape)
    break

torch.Size([32, 3, 256, 256])


In [6]:
batch[0].shape

NameError: name 'batch' is not defined

In [44]:
dataset[0]

(<PIL.Image.Image image mode=RGB size=1024x1024>, 0)

In [51]:
from torchvision.datasets import ImageFolder
class ImageDataset(Dataset):
    def __init__(self, dataset, transform=None):
        self.dataset = dataset
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        (image, label) = self.dataset[idx]
        image = self.transform(image)
        return image, label
    
transform = get_transform(image_size=256)
dataset = ImageFolder('./dataset/images/ffhq1k/')
train_dataset = ImageDataset(dataset, transform)


batch_size=int(32)
train_sampler, train_loader = get_ddp_sampler_loader(dataset=train_dataset,
                            num_replicas=4,
                            rank=1,
                            sample_shuffle=True,
                            seed=0,
                            batch_size=batch_size,
                            num_workers=4,
                            pin_memory=True,
                            drop_last=True)

In [52]:
for batch in train_loader:
    print(batch[0].shape)
    break

torch.Size([32, 3, 256, 256])


In [64]:
dataset.map()

Dataset({
    features: ['image', 'label'],
    num_rows: 1281167
})

In [62]:
transform = transforms.Compose([
        transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size=image_size)),
        transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, 256)),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
    ])
    
dataset = load_dataset("imagenet-1k",cache_dir='./dataset/imagenet1k')['train']
sampler = DistributedSampler(
    dataset,
    num_replicas=4,
    rank=0,
    shuffle=True,
    seed=0 # 似乎应该是rank specific seed
)

loader = DataLoader(
    dataset,
    batch_size=int(32),
    shuffle=False,
    sampler=sampler,
    num_workers=4,
    pin_memory=True,
    drop_last=True
)

Found cached dataset imagenet-1k (/media/data/xiw136/workspace/local-refine-diffuser/dataset/imagenet1k/imagenet-1k/default/1.0.0/a1e9bfc56c3a7350165007d1176b15e9128fcaf9ab972147840529aed3ae52bc)
100%|██████████| 3/3 [00:00<00:00, 33.51it/s]


In [63]:
for batch in loader:
    print(batch)
    break

TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/AD/xiw136/anaconda3/envs/DiT/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 127, in collate
    return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/AD/xiw136/anaconda3/envs/DiT/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 127, in <dictcomp>
    return elem_type({key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem})
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/AD/xiw136/anaconda3/envs/DiT/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 150, in collate
    raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.JpegImagePlugin.JpegImageFile'>

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/AD/xiw136/anaconda3/envs/DiT/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 308, in _worker_loop
    data = fetcher.fetch(index)
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/AD/xiw136/anaconda3/envs/DiT/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/AD/xiw136/anaconda3/envs/DiT/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 265, in default_collate
    return collate(batch, collate_fn_map=default_collate_fn_map)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/AD/xiw136/anaconda3/envs/DiT/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 130, in collate
    return {key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/AD/xiw136/anaconda3/envs/DiT/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 130, in <dictcomp>
    return {key: collate([d[key] for d in batch], collate_fn_map=collate_fn_map) for key in elem}
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/AD/xiw136/anaconda3/envs/DiT/lib/python3.11/site-packages/torch/utils/data/_utils/collate.py", line 150, in collate
    raise TypeError(default_collate_err_msg_format.format(elem_type))
TypeError: default_collate: batch must contain tensors, numpy arrays, numbers, dicts or lists; found <class 'PIL.JpegImagePlugin.JpegImageFile'>
