In [4]:
# 1. Imports and Config
import torch
from torchvision import transforms
from datasets import load_dataset
from diffusers import DiTPipeline, DDPMScheduler
from accelerate import Accelerator
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import wandb
from piq import FID

In [5]:
# 2. Training Configuration
class TrainingConfig:
    image_size = 256
    train_batch_size = 2
    learning_rate = 1e-4
    num_epochs = 10
    output_dir = "dit_celebahq_model"
    mixed_precision = "fp16"  # or "no" for float32
    wandb_project = "diffusion-dit-celebahq"

config = TrainingConfig()

# 3. Init W&B
wandb.init(project=config.wandb_project, config=vars(config))

[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mrawat-dhruv14[0m ([33mrawat-dhruv[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [9]:
# 4. Load Dataset (CelebA-HQ stored locally or mounted)
dataset = load_dataset("../data/", data_dir="celeba_hq_split/", split="train")

# 5. Preprocess Images
transform = transforms.Compose([
    transforms.Resize((config.image_size, config.image_size)),
    transforms.CenterCrop(config.image_size),
    transforms.ToTensor(),
    transforms.Normalize([0.5], [0.5]),
])

def transform_fn(examples):S
    images = [transform(image.convert("RGB")) for image in examples["image"]]
    return {"images": images}

dataset.set_transform(transform_fn)

train_dataloader = DataLoader(dataset, batch_size=config.train_batch_size, shuffle=True)

Resolving data files:   0%|          | 0/2365 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/300 [00:00<?, ?it/s]

Downloading data:   0%|          | 0/2365 [00:00<?, ?files/s]

Downloading data:   0%|          | 0/300 [00:00<?, ?files/s]

Generating train split: 0 examples [00:00, ? examples/s]

Generating test split: 0 examples [00:00, ? examples/s]

In [11]:
# 6. Load Pretrained DiT-S Model
scheduler = DDPMScheduler(num_train_timesteps=1000)
pipeline = DiTPipeline.from_pretrained(
    "facebook/DiT-XL-2-256",  # Pretrained DiT-S model
    subfolder="model",     # Use the model folder if required
    scheduler=scheduler
)

model_index.json:   0%|          | 0.00/35.8k [00:00<?, ?B/s]

Fetching 5 files:   0%|          | 0/5 [00:00<?, ?it/s]

diffusion_pytorch_model.bin:   0%|          | 0.00/335M [00:00<?, ?B/s]



config.json:   0%|          | 0.00/602 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/601 [00:00<?, ?B/s]

diffusion_pytorch_model.bin:   0%|          | 0.00/3.00G [00:00<?, ?B/s]

OSError: [Errno 28] No space left on device

In [None]:
# 7. Prepare for Training
accelerator = Accelerator(mixed_precision=config.mixed_precision)
model, train_dataloader = accelerator.prepare(model, train_dataloader)
optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)
fid_metric = FID().to(accelerator.device)

In [None]:
# 8. Training Loop
model.train()
for epoch in range(config.num_epochs):
    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch+1}")
    for step, batch in enumerate(progress_bar):
        clean_images = batch["images"].to(accelerator.device)
        noise = torch.randn_like(clean_images)
        timesteps = torch.randint(0, scheduler.config.num_train_timesteps, (clean_images.shape[0],), device=clean_images.device).long()

        noisy_images = scheduler.add_noise(clean_images, noise, timesteps)
        noise_pred = model(noisy_images, timesteps).sample

        loss = torch.nn.functional.mse_loss(noise_pred, noise)
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

        progress_bar.set_postfix({"loss": loss.item()})
        wandb.log({"train/loss": loss.item()})

    # Generate samples for FID after each epoch
    model.eval()
    generated_images = []
    real_images = []
    with torch.no_grad():
        for _ in range(10):
            with torch.autocast("cuda"):
                samples = pipeline(num_inference_steps=50).images
            for img in samples:
                tensor_img = transform(img).unsqueeze(0).to(accelerator.device)
                generated_images.append(tensor_img)
        for real_batch in train_dataloader:
            real_images.extend(real_batch["images"].to(accelerator.device).unsqueeze(0))
            if len(real_images) >= len(generated_images):
                break
    fake_tensor = torch.cat(generated_images[:len(real_images)])
    real_tensor = torch.cat(real_images[:len(generated_images)])
    fid_score = fid_metric(fake_tensor, real_tensor).item()
    wandb.log({"eval/fid": fid_score})

    # Save model after each epoch
    if accelerator.is_main_process:
        model.save_pretrained(f"{config.output_dir}_epoch{epoch+1}")

In [None]:
# 9. Inference Example
model.eval()
pipeline = DDPMPipeline(unet=model, scheduler=scheduler).to("cuda")
with torch.autocast("cuda"):
    images = pipeline(num_inference_steps=50).images

images[0].show()