In [1]:
import argparse
import json
import os
from glob import glob
import numpy as np

from PIL import Image
import torch
from pathlib import Path
from diffusers.models import AutoencoderKL

from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import CIFAR10
from tqdm import tqdm

from opendit.diffusion import create_diffusion
from opendit.models.mmdit import MMDiT_models
from opendit.utils.data_utils import get_transforms_image

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True

import pandas as pd

from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class ImageCaptionDataset(Dataset):
    def __init__(self, csv_path, root_dir, transform=None):
        """
        Args:
            csv_path (string): Path to the CSV file with annotations
            root_dir (string): Base directory for image paths in CSV
            transform (callable, optional): Optional transform to be applied on images
        """
        self.df = pd.read_csv(csv_path)
        self.root_dir = root_dir
        self.transform = transform

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        
        img_rel_path = Path(row['File Path'].replace("\\", "/")) 
        img_full_path = self.root_dir / img_rel_path

        try:
            if not img_full_path.exists():
                raise FileNotFoundError(f"Image not found at: {img_full_path}")
            
        except FileNotFoundError:
            print(FileNotFoundError)
        image = Image.open(img_full_path).convert('RGB')
        
        caption = row['Caption']
        
        if self.transform:
            image = self.transform(image)

        return image, caption

In [3]:
def center_crop_arr(pil_image, image_size):


    while min(*pil_image.size) >= 2 * image_size:
        pil_image = pil_image.resize(
            tuple(x // 2 for x in pil_image.size), resample=Image.BOX
        )

    scale = image_size / min(*pil_image.size)
    pil_image = pil_image.resize(
        tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
    )

    arr = np.array(pil_image)
    crop_y = (arr.shape[0] - image_size) // 2
    crop_x = (arr.shape[1] - image_size) // 2
    return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])

In [4]:
def requires_grad(model, flag=True):
    """Enable/disable gradients for a model's parameters."""
    for p in model.parameters():
        p.requires_grad = flag

In [5]:
def update_ema(ema, model, decay=0.9999):
    """Update EMA parameters."""
    with torch.no_grad():
        for ema_param, model_param in zip(ema.parameters(), model.parameters()):
            ema_param.data.mul_(decay).add_(model_param.data, alpha=1 - decay)


In [None]:

"""Trains a new MMDiT model."""
assert torch.cuda.is_available(), "Training currently requires at least one GPU."

# Setup directories
os.makedirs(args.outputs, exist_ok=True)
experiment_index = len(glob(f"{args.outputs}/*"))
model_string_name = args.model.replace("/", "-")
experiment_dir = f"{args.outputs}/{experiment_index:03d}-{model_string_name}"
os.makedirs(experiment_dir, exist_ok=True)

# Save configuration
with open(f"{experiment_dir}/config.txt", "w") as f:
    json.dump(args.__dict__, f, indent=4)

# Setup tensorboard
tensorboard_dir = f"{experiment_dir}/tensorboard"
os.makedirs(tensorboard_dir, exist_ok=True)
writer = SummaryWriter(tensorboard_dir)

# Setup device and dtype
device = torch.device('cuda')
if args.mixed_precision == "bf16":
    dtype = torch.bfloat16
elif args.mixed_precision == "fp16":
    dtype = torch.float16
else:
    dtype = torch.float32

# Create VAE encoder
vae = AutoencoderKL.from_pretrained(f"stabilityai/sd-vae-ft-{args.vae}").to(device)

# Configure input size
assert args.image_size % 8 == 0, "Image size must be divisible by 8 (for the VAE encoder)."
input_size = args.image_size // 8

# Create model
model_config = {
    "input_size": input_size,
    "num_classes": args.num_classes,
    "clip_text_encoder": args.text_encoder,
    "t5_text_encoder": args.t5_text_encoder,
}

# Initialize model
model_class = MMDiT_models[args.model]
model = model_class(**model_config).to(device, dtype=dtype)

print(f"Model has {sum(p.numel() for p in model.parameters()):,} parameters")


if args.grad_checkpoint:
    model.enable_gradient_checkpointing()

# Create EMA model
ema = MMDiT_models[args.model](**model_config).to(device)
ema.load_state_dict(model.state_dict())
requires_grad(ema, False)

# Create diffusion
diffusion = create_diffusion(timestep_respacing="")

# Setup optimizer
optimizer = torch.optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=args.lr,
    weight_decay=0
)
# Setup dataset
# dataset = CIFAR10(
#     args.data_path,
#     transform=get_transforms_image(args.image_size),
#     download=True
# )

transform = transforms.Compose([
    transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, args.image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
])

csv_path = "datasets/anime/image_labels.csv"
root_dir = "datasets/anime"  

dataset = ImageCaptionDataset(
csv_path=csv_path,
root_dir=root_dir,
transform=transform
)

dataloader = DataLoader(
    dataset,
    batch_size=args.batch_size,
    shuffle=True,
    drop_last=True,
    pin_memory=True,

)

print(f"Dataset contains {len(dataset):,} images ({args.data_path})")

# Ensure EMA is initialized with synced weights
update_ema(ema, model, decay=0)
model.train()
ema.eval()

print(f"Training for {args.epochs} epochs...")
num_steps_per_epoch = len(dataloader)
global_step = 0

for epoch in range(args.epochs):
    print(f"Beginning epoch {epoch}...")
    
    with tqdm(range(num_steps_per_epoch), desc=f"Epoch {epoch}") as pbar:
        for step in pbar:
            # Get batch
            x, y = next(iter(dataloader))
            x = x.to(device)

            # VAE encode
            with torch.no_grad():
                x = vae.encode(x).latent_dist.sample().mul_(0.18215)

            # print('vae encode:', x.shape)

            # Diffusion training step
            t = torch.randint(0, diffusion.num_timesteps, (x.shape[0],), device=device)
            model_kwargs = dict(c=y)
            loss_dict = diffusion.training_losses(model, x, t, model_kwargs)
            loss = loss_dict["loss"].mean()
            
            # Optimization step
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Update EMA
            update_ema(ema, model)

            # Logging
            global_step = epoch * num_steps_per_epoch + step
            pbar.set_postfix({"loss": loss.item(), "step": step, "global_step": global_step})

            if (global_step + 1) % args.log_every == 0:
                writer.add_scalar("loss", loss.item(), global_step)

            # Save checkpoint
            if args.ckpt_every > 0 and (global_step + 1) % args.ckpt_every == 0:
                checkpoint = {
                    'model': model.state_dict(),
                    'ema': ema.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'epoch': epoch,
                    'global_step': global_step,
                }
                torch.save(
                    checkpoint,
                    f"{experiment_dir}/checkpoint_{global_step:07d}.pt"
                )
                print(f"Saved checkpoint at global step {global_step}")

print("Training finished!")


In [13]:
import argparse

class Config:
    def __init__(self):
        self.parser = argparse.ArgumentParser(description="Training configuration for MMDiT model")
        self._add_arguments()

    def _add_arguments(self):
        """Add all configuration arguments to the parser."""
        self.parser.add_argument("--model", type=str, choices=["MMDiT-XL/2", "MMDiT-L/4"], default="MMDiT-S/8",
                                help="Model architecture to use")
        self.parser.add_argument("--vae", type=str, choices=["ema", "mse"], default="ema",
                                help="VAE type to use for encoding")
        self.parser.add_argument("--outputs", type=str, default="./outputs",
                                help="Directory to save outputs (checkpoints, logs, etc.)")
        self.parser.add_argument("--data_path", type=str, default="./datasets",
                                help="Path to the dataset directory")
        self.parser.add_argument("--image_size", type=int, choices=[256, 512], default=256,
                                help="Size of input images (must be divisible by 8)")
        self.parser.add_argument("--num_classes", type=int, default=1000,
                                help="Number of classes for classification (if applicable)")
        self.parser.add_argument("--epochs", type=int, default=1400,
                                help="Number of training epochs")
        self.parser.add_argument("--batch_size", type=int, default=32,
                                help="Batch size for training")
        self.parser.add_argument("--num_workers", type=int, default=4,
                                help="Number of workers for data loading")
        self.parser.add_argument("--log_every", type=int, default=10,
                                help="Log training metrics every N steps")
        self.parser.add_argument("--ckpt_every", type=int, default=1000,
                                help="Save a checkpoint every N steps")
        self.parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["bf16", "fp16", "fp32"],
                                help="Mixed precision training mode")
        self.parser.add_argument("--lr", type=float, default=1e-4,
                                help="Learning rate for the optimizer")
        self.parser.add_argument("--grad_checkpoint", action="store_true",
                                help="Enable gradient checkpointing to save memory")
        self.parser.add_argument("--text_encoder", type=str, default="openai/clip-vit-base-patch32",
                                help="Text encoder model for CLIP")
        self.parser.add_argument("--t5_text_encoder", type=str, default="google-t5/t5-small",
                                help="Text encoder model for T5")
        self.parser.add_argument("--chkptnumber", type=str, default="None",
                                help="Checkpoint number to resume training from")

    def parse_args(self, args_list=None):
        """
        Parse and return the arguments.
        Args:
            args_list (list): List of arguments to parse (for notebook usage).
        """
        if args_list is None:
            # Parse from command line
            return self.parser.parse_args()
        else:
            # Parse from a list (for notebook usage)
            return self.parser.parse_args(args_list)

# Usage in Jupyter Notebook
if __name__ == "__main__":
    config = Config()
    args = config.parse_args()
    print(args)

usage: ipykernel_launcher.py [-h] [--model {MMDiT-XL/2,MMDiT-L/4}]
                             [--vae {ema,mse}] [--outputs OUTPUTS]
                             [--data_path DATA_PATH] [--image_size {256,512}]
                             [--num_classes NUM_CLASSES] [--epochs EPOCHS]
                             [--batch_size BATCH_SIZE]
                             [--num_workers NUM_WORKERS]
                             [--log_every LOG_EVERY] [--ckpt_every CKPT_EVERY]
                             [--mixed_precision {bf16,fp16,fp32}] [--lr LR]
                             [--grad_checkpoint] [--text_encoder TEXT_ENCODER]
                             [--t5_text_encoder T5_TEXT_ENCODER]
                             [--chkptnumber CHKPTNUMBER]
ipykernel_launcher.py: error: unrecognized arguments: --f=/home/zyro/.local/share/jupyter/runtime/kernel-v3ec2d9e305a0ccd0b3fe1f6b84ec5ba37bd8a8e05.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


In [1]:
pip install datasets

Note: you may need to restart the kernel to use updated packages.




Defaulting to user installation because normal site-packages is not writeable
Collecting datasets
  Downloading datasets-3.2.0-py3-none-any.whl.metadata (20 kB)
Collecting pyarrow>=15.0.0 (from datasets)
  Downloading pyarrow-19.0.0-cp312-cp312-win_amd64.whl.metadata (3.4 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp312-cp312-win_amd64.whl.metadata (13 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py312-none-any.whl.metadata (7.2 kB)
Collecting huggingface-hub>=0.23.0 (from datasets)
  Downloading huggingface_hub-0.27.1-py3-none-any.whl.metadata (13 kB)
Downloading datasets-3.2.0-py3-none-any.whl (480 kB)
   ---------------------------------------- 0.0/480.6 kB ? eta -:--:--
   - -------------------------------------- 20.5/480.6 kB ? eta -:--:--
   ---- ---------------------------------- 51.2/480.6 kB 871.5 kB/s eta 0:00:01
   ------ -------------------------------- 81.9/480.6 kB 762.6 kB/s eta 0:00:01
   ---------- -------

In [2]:
help(load_dataset)

Help on function load_dataset in module datasets.load:

load_dataset(path: str, name: Optional[str] = None, data_dir: Optional[str] = None, data_files: Union[str, Sequence[str], Mapping[str, Union[str, Sequence[str]]], NoneType] = None, split: Union[str, datasets.splits.Split, NoneType] = None, cache_dir: Optional[str] = None, features: Optional[datasets.features.features.Features] = None, download_config: Optional[datasets.download.download_config.DownloadConfig] = None, download_mode: Union[datasets.download.download_manager.DownloadMode, str, NoneType] = None, verification_mode: Union[datasets.utils.info_utils.VerificationMode, str, NoneType] = None, keep_in_memory: Optional[bool] = None, save_infos: bool = False, revision: Union[str, datasets.utils.version.Version, NoneType] = None, token: Union[bool, str, NoneType] = None, streaming: bool = False, num_proc: Optional[int] = None, storage_options: Optional[Dict] = None, trust_remote_code: bool = None, **config_kwargs) -> Union[datas

In [None]:
from datasets import load_dataset

ds = load_dataset("Spawning/PD12M" )

README.md:   0%|          | 0.00/3.62k [00:00<?, ?B/s]

Resolving data files:   0%|          | 0/126 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/125 [00:00<?, ?files/s]

pd12m.013.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.000.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.007.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.011.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.014.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.008.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.015.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.001.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.010.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.009.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.004.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.006.parquet:   0%|          | 0.00/18.8M [00:00<?, ?B/s]

pd12m.002.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.005.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.012.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.003.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.016.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.017.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.019.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.018.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.021.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.020.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.022.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.023.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.025.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.024.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.026.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.027.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.028.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.029.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.030.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.031.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.032.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

Error while downloading from https://cdn-lfs-us-1.hf.co/repos/bb/d1/bbd10a40cbf91b41e63bac21b8ca45db58c9dc8029fa3012009760dc9472b148/0cccae3d3864b3f81d4abb5d0bd1712447454db9129b7daba7f9e4be23c8f257?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27pd12m.032.parquet%3B+filename%3D%22pd12m.032.parquet%22%3B&Expires=1738039830&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczODAzOTgzMH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmhmLmNvL3JlcG9zL2JiL2QxL2JiZDEwYTQwY2JmOTFiNDFlNjNiYWMyMWI4Y2E0NWRiNThjOWRjODAyOWZhMzAxMjAwOTc2MGRjOTQ3MmIxNDgvMGNjY2FlM2QzODY0YjNmODFkNGFiYjVkMGJkMTcxMjQ0NzQ1NGRiOTEyOWI3ZGFiYTdmOWU0YmUyM2M4ZjI1Nz9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=j3TsVPtrwn6z1SqfhtV9Evb046027mgD6TugtxOviKUSbCAkhV4G6vR7ThVeo1yHuHSaIpCQ6tD8xA%7EpHv9dOL2vJX5zqSBcBtXADCmkMneXw4wacaGpjkX%7Eabrq5Yboc8zKf6CBqmVgWfpd6Ufp6A1XWs1zUgKrAV7KvY75GAxBzIwybruYY2-FcFI9mhHrEZmi0c-tzJAhWXMiS%7EI9jmtl8rQpp9BQXgdO5qrl-AHE0kRM4i6LDOoZZUzuzD

pd12m.028.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.027.parquet:  56%|#####5    | 10.5M/18.9M [00:00<?, ?B/s]

pd12m.032.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.031.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.018.parquet:  56%|#####5    | 10.5M/18.9M [00:00<?, ?B/s]

pd12m.025.parquet:  56%|#####5    | 10.5M/18.9M [00:00<?, ?B/s]

pd12m.016.parquet:  56%|#####5    | 10.5M/18.9M [00:00<?, ?B/s]

pd12m.023.parquet:  56%|#####5    | 10.5M/18.9M [00:00<?, ?B/s]

pd12m.019.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

Error while downloading from https://cdn-lfs-us-1.hf.co/repos/bb/d1/bbd10a40cbf91b41e63bac21b8ca45db58c9dc8029fa3012009760dc9472b148/f73ecec628c26af28c53f250f02232fdcfa318dce60b39e644fc175e5dd04a21?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27pd12m.031.parquet%3B+filename%3D%22pd12m.031.parquet%22%3B&Expires=1738039780&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczODAzOTc4MH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmhmLmNvL3JlcG9zL2JiL2QxL2JiZDEwYTQwY2JmOTFiNDFlNjNiYWMyMWI4Y2E0NWRiNThjOWRjODAyOWZhMzAxMjAwOTc2MGRjOTQ3MmIxNDgvZjczZWNlYzYyOGMyNmFmMjhjNTNmMjUwZjAyMjMyZmRjZmEzMThkY2U2MGIzOWU2NDRmYzE3NWU1ZGQwNGEyMT9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=mKLqsHXbAjUlYNshjN0U0YIWv2WbLl3EwlHtOWPoNRl0zQn9JI4KfZsXB%7EFUBK2J2MYZgAV3tNFHvPpZPH0JzoocHIrR5U36jBBex86s4LZtWVWeShKA7aaM%7EHzqql%7EZIk9JnqmU3e%7EwcZ9VgMpAgYMTojJrudlGSQXWqnnV6uNsg0f1cUnQCTGO4NlQdyCULtmHit1x3GKbRqqRQ7g3%7ECiATKosh1IZPN6FeRZphCtC67wb3QavOdKFQQ

pd12m.031.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.023.parquet:  56%|#####5    | 10.5M/18.9M [00:00<?, ?B/s]

pd12m.032.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.021.parquet:  55%|#####5    | 10.5M/18.9M [00:00<?, ?B/s]

pd12m.019.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.020.parquet:  56%|#####5    | 10.5M/18.9M [00:00<?, ?B/s]

pd12m.030.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.018.parquet:  56%|#####5    | 10.5M/18.9M [00:00<?, ?B/s]

pd12m.016.parquet:  56%|#####5    | 10.5M/18.9M [00:00<?, ?B/s]

pd12m.024.parquet:  56%|#####5    | 10.5M/18.9M [00:00<?, ?B/s]

pd12m.026.parquet:  55%|#####5    | 10.5M/18.9M [00:00<?, ?B/s]

pd12m.029.parquet:  55%|#####5    | 10.5M/18.9M [00:00<?, ?B/s]

pd12m.022.parquet:  56%|#####5    | 10.5M/18.9M [00:00<?, ?B/s]

pd12m.028.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.025.parquet:  56%|#####5    | 10.5M/18.9M [00:00<?, ?B/s]

pd12m.027.parquet:  56%|#####5    | 10.5M/18.9M [00:00<?, ?B/s]

pd12m.038.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.033.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.037.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.034.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.036.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.035.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.039.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.042.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.040.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.041.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.043.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.044.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.045.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.048.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.046.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.047.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.050.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.049.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.051.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.052.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.053.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.054.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.055.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.056.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.057.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.058.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

Error while downloading from https://cdn-lfs-us-1.hf.co/repos/bb/d1/bbd10a40cbf91b41e63bac21b8ca45db58c9dc8029fa3012009760dc9472b148/d4b420b7db4ad826e34ba1451e7ead728c3701725559a9dda050a17938e3cfcd?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27pd12m.049.parquet%3B+filename%3D%22pd12m.049.parquet%22%3B&Expires=1738040071&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTczODA0MDA3MX19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmhmLmNvL3JlcG9zL2JiL2QxL2JiZDEwYTQwY2JmOTFiNDFlNjNiYWMyMWI4Y2E0NWRiNThjOWRjODAyOWZhMzAxMjAwOTc2MGRjOTQ3MmIxNDgvZDRiNDIwYjdkYjRhZDgyNmUzNGJhMTQ1MWU3ZWFkNzI4YzM3MDE3MjU1NTlhOWRkYTA1MGExNzkzOGUzY2ZjZD9yZXNwb25zZS1jb250ZW50LWRpc3Bvc2l0aW9uPSoifV19&Signature=ceBVxG7MTgo8T98cW-DeFUJEYwUEclj9nYrSXbwNyWGgEtc%7EuQRWawgTLKjiQ7MCYzp0o9r9UE1DmLDEic7wWPtMY5U9BMZQzN3crFDCCbCpz8KikomKFkPOIgKs6wQ2MjFLr%7EZxI9InibRulnbyMMlOuwSqvvKf41ZOOFaR819LjPPa8PfY6UGFgTjHo0r8BGAQZWmgvGtAkPBI3hFuANqZfSJpYNtB92JaS617nQhNoJ%7EAHtR581jWZ1LSCD

pd12m.053.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.057.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.049.parquet:  56%|#####5    | 10.5M/18.9M [00:00<?, ?B/s]

pd12m.059.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.060.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.062.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.061.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.063.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.064.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.067.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.066.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.065.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.069.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.068.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.072.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.071.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.070.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.073.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.074.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.075.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.076.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.077.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.078.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.079.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.080.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.081.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.082.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.083.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.084.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.085.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.086.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.087.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.088.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.089.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.090.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.092.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.093.parquet:   0%|          | 0.00/18.8M [00:00<?, ?B/s]

pd12m.091.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.094.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.095.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.096.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.097.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.098.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.099.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.100.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.101.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.102.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.103.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.104.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.105.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.106.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.107.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.108.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.109.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.110.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.111.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.112.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.113.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.114.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.115.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.116.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.117.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.118.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.119.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.120.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.121.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.122.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.123.parquet:   0%|          | 0.00/18.9M [00:00<?, ?B/s]

pd12m.124.parquet:   0%|          | 0.00/29.5k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/12400094 [00:00<?, ? examples/s]

DatasetGenerationError: An error occurred while generating the dataset

In [None]:
from torch.utils.data import Dataset, DataLoader

: 

In [3]:
class SpawningPD12Dataset(Dataset):
    def __init__(self, transform=None, max_images=60000):
        """
        Args:
            transform (callable, optional): Optional transform to be applied on images
            max_images (int): Maximum number of images to include in the dataset
        """
        self.dataset = load_dataset("spawning/pd12", split="train")
        self.transform = transform
        self.max_images = min(max_images, len(self.dataset))
        
    def __len__(self):
        return self.max_images

    def __getitem__(self, idx):
        item = self.dataset[idx]
        
        # Load and convert image
        image = Image.open(item['image'].convert('RGB'))
        caption = item['text']
        
        if self.transform:
            image = self.transform(image)

        return image, caption

NameError: name 'Dataset' is not defined

In [None]:
dataset = SpawningPD12Dataset(transform=transform, max_images=60000)