In [None]:
# 1. Imports and Config
import torch
from torchvision import transforms
from datasets import load_dataset
from diffusers import DiTModel, DDPMPipeline, DDPMScheduler
from accelerate import Accelerator
from torch.utils.data import DataLoader
from tqdm import tqdm
import wandb
from piq import FID

In [None]:
# 2. Training Configuration
class TrainingConfig:
    image_size = 256
    train_batch_size = 2
    learning_rate = 1e-4
    num_epochs = 10
    output_dir = "dit_celebahq_model"
    mixed_precision = "fp16"  # or "no" for float32
    wandb_project = "diffusion-dit-celebahq"

config = TrainingConfig()

# 3. Init W&B
wandb.init(project=config.wandb_project, config=vars(config))

In [None]:
# 4. Load Dataset (CelebA-HQ stored locally or mounted)
dataset = load_dataset("imagefolder", data_dir="celeba_hq_256", split="train")

# 5. Preprocess Images
transform = transforms.Compose([
    transforms.Resize((config.image_size, config.image_size)),
    transforms.CenterCrop(config.image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

def transform_fn(examples):
    images = [transform(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}

dataset.set_transform(transform_fn)

train_dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)

In [None]:
# 6. Load Pretrained DiT-S Model
model = DiTModel.from_pretrained("facebook/DiT-S-256", subfolder="model")
scheduler = DDPMScheduler(num_train_timesteps=1000)
pipeline = DDPMPipeline(unet=model, scheduler=scheduler)

In [None]:
# 7. Prepare for Training
accelerator = Accelerator(mixed_precision=config.mixed_precision)
model, train_dataloader = accelerator.prepare(model, train_dataloader)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
fid_metric = FID().to(accelerator.device)

In [None]:
# 8. Training Loop
model.train()
for epoch in range(config.num_epochs):
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}")
    for step, batch in enumerate(progress_bar):
        clean_images = batch["images"].to(accelerator.device)
        noise = torch.randn_like(clean_images)
        timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (clean_images.shape[0],), device=clean_images.device).long()

        noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
        noise_pred = model(noisy_images, timesteps).sample

        loss = torch.nn.functional.mse_loss(noise_pred, noise)
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

        progress_bar.set_postfix({"loss": loss.item()})
        wandb.log({"train/loss": loss.item()})

    # Generate samples for FID after each epoch
    model.eval()
    generated_images = []
    real_images = []
    with torch.no_grad():
        for _ in range(10):
            with torch.autocast("cuda"):
                samples = pipeline(num_inference_steps=50).images
            for img in samples:
                tensor_img = transform(img).unsqueeze(0).to(accelerator.device)
                generated_images.append(tensor_img)
        for real_batch in train_dataloader:
            real_images.extend(real_batch["images"].to(accelerator.device).unsqueeze(0))
            if len(real_images) >= len(generated_images):
                break
    fake_tensor = torch.cat(generated_images[:len(real_images)])
    real_tensor = torch.cat(real_images[:len(generated_images)])
    fid_score = fid_metric(fake_tensor, real_tensor).item()
    wandb.log({"eval/fid": fid_score})

    # Save model after each epoch
    if accelerator.is_main_process:
        model.save_pretrained(f"{config.output_dir}_epoch{epoch+1}")

In [None]:
# 9. Inference Example
model.eval()
pipeline = DDPMPipeline(unet=model, scheduler=scheduler).to("cuda")
with torch.autocast("cuda"):
    images = pipeline(num_inference_steps=50).images

images[0].show()