In [None]:
# clone the repo
!git clone https://github.com/Sarthak16082/DDT
# cd into DDT
%cd DDT

!pip install lightning==2.5.0 torch torchvision torchaudio pyyaml diffusers timm


!wget https://huggingface.co/MCG-NJU/DDT-XL-22en6de-R512/resolve/main/model.ckpt

In [None]:
import os
import re
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from omegaconf import OmegaConf

from src.diffusion.base.guidance import simple_guidance_fn
from src.diffusion.stateful_flow_matching.scheduling import LinearScheduler
from src.diffusion.stateful_flow_matching.sampling import EulerSampler

# Helper Functions

def instantiate_from_config(config):
    module_path, class_name = config["class_path"].rsplit(".", 1)
    module = __import__(module_path, fromlist=[class_name])
    return getattr(module, class_name)(**config.get("init_args", {}))

def load_weights(model, checkpoint, prefix="ema_denoiser."):
    state_dict = checkpoint["state_dict"]
    loaded, total = 0, len(model.state_dict())
    for name, param in model.state_dict().items():
        full_name = prefix + name
        if full_name in state_dict:
            try:
                param.copy_(state_dict[full_name])
                loaded += 1
            except Exception as e:
                print(f"Failed to load {full_name}: {e}")
        else:
            print(f"Missing key in checkpoint: {full_name}")
    print(f"Loaded {loaded}/{total} weights.")
    return model

def tensor_to_image(x):
    if x is None:
        raise ValueError("Input tensor is None")
    x = torch.clamp((x + 1.0) * 127.5 + 0.5, 0, 255).to(torch.uint8)
    return x

def parse_class_labels(filename="imagenet_classlabels.txt"):
    label_map = {}
    try:
        with open(filename, "r") as f:
            for line in f:
                match = re.match(r'\|\s*(\d+)\s*\|\s*(.*?)\s*\|', line)
                if match:
                    label_map[match.group(2).strip().lower()] = int(match.group(1))
    except FileNotFoundError:
        print("Label file not found.")
        exit()
    return label_map

# Settings

config_path = "configs/repa_improved_ddt_xlen22de6_512.yaml"
checkpoint_path = "model.ckpt"
output_dir = "outputs_improved_v6"
classes = ["tiger cat"]
resolution = 512
num_images = 2
seed = 1234

# Sampler parameters
num_steps = 100
guidance = 8.5
guidance_min, guidance_max = 0.02, 0.98
last_step, timeshift = 0.005, 0.9

# Prepare environment
os.makedirs(output_dir, exist_ok=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Load configuration and models
cfg = OmegaConf.load(config_path)
vae = instantiate_from_config(cfg.model.vae)
denoiser = instantiate_from_config(cfg.model.denoiser)
conditioner = instantiate_from_config(cfg.model.conditioner)

# Load weights
print(f"Loading weights from {checkpoint_path}...")
ckpt = torch.load(checkpoint_path, map_location="cpu")
denoiser = load_weights(denoiser, ckpt).to(device).eval()
vae = vae.to(device).eval()

# Use multiple GPUs if available
if torch.cuda.device_count() > 1:
    denoiser = torch.nn.DataParallel(denoiser)

# Set up sampler
sampler = EulerSampler(
    scheduler=LinearScheduler(),
    w_scheduler=LinearScheduler(),
    guidance_fn=simple_guidance_fn,
    num_steps=num_steps,
    guidance=guidance,
    state_refresh_rate=1,
    guidance_interval_min=guidance_min,
    guidance_interval_max=guidance_max,
    timeshift=timeshift,
    last_step=last_step
)

# Process class names
label_map = parse_class_labels()
valid_classes = [cls for cls in classes if cls.lower() in label_map]

# Image Gen

for class_name in valid_classes:
    class_id = label_map[class_name.lower()]
    print(f"\nGenerating images for: {class_name} (ID: {class_id})")
    for i in range(num_images):
        img_seed = seed + i * 10
        generator = torch.Generator().manual_seed(img_seed)
        noise = torch.randn((1, 4, resolution // 8, resolution // 8), generator=generator).to(device)

        with torch.no_grad():
            cond, uncond = conditioner([class_id])
            output = sampler(denoiser, noise, cond.to(device), uncond.to(device))
            if output is None:
                print(f"Sampler failed for {class_name}, seed {img_seed}")
                continue
            decoded = vae.decode(output.to(device))
            if decoded is None:
                print(f"Decoding failed for {class_name}, seed {img_seed}")
                continue

            img_tensor = tensor_to_image(decoded.cpu())[0].permute(1, 2, 0).numpy()
            img_tensor = img_tensor[:, :, :3] if img_tensor.shape[2] > 3 else img_tensor

            img = Image.fromarray(img_tensor)
            fname = f"{class_name.replace(' ', '_')}_seed{img_seed}.png"
            img.save(os.path.join(output_dir, fname))
            print(f"Saved: {fname}")

            # Optional visualization
            plt.imshow(np.clip(decoded[0].cpu().permute(1, 2, 0).numpy(), 0, 1))
            plt.axis('off')
            plt.title(f"{class_name} (Seed: {img_seed})")
            plt.savefig(os.path.join(output_dir, f"{class_name.replace(' ', '_')}_raw_seed{img_seed}.png"))
            plt.close()

print("\nAll done.")
