In [None]:
# 1. Imports and Config
import torch
from torchvision import transforms
from datasets import load_dataset
from diffusers import DiTModel, DDPMPipeline, DDPMScheduler
from accelerate import Accelerator
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb
from piq import FID

In [None]:
# 2. Training Configuration
class TrainingConfig:
    image_size = 256
    train_batch_size = 2
    learning_rate = 1e-4
    num_epochs = 10
    output_dir = "dit_celebahq_model"
    mixed_precision = "fp16"  # or "no" for float32
    wandb_project = "diffusion-dit-celebahq"

config = TrainingConfig()

# 3. Init W&B
wandb.init(project=config.wandb_project, config=vars(config))

In [None]:
# 4. Load Dataset (CelebA-HQ stored locally or mounted)
dataset = load_dataset("imagefolder", data_dir="celeba_hq_256", split="train")

# 5. Preprocess Images
transform = transforms.Compose([
    transforms.Resize((config.image_size, config.image_size)),
    transforms.CenterCrop(config.image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

def transform_fn(examples):
    images = [transform(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}

dataset.set_transform(transform_fn)

train_dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)

In [None]:
# 6. Load Pretrained DiT-S Model
model = DiTModel.from_pretrained("facebook/DiT-S-256", subfolder="model")
scheduler = DDPMScheduler(num_train_timesteps=1000)
pipeline = DDPMPipeline(unet=model, scheduler=scheduler)