# Consistency Distillation — MicroDiT (EDM)

Distill MicroDiT 0.5B teacher (30 steps) → student (4 steps) via Consistency Distillation.

In [None]:
# !pip install -r requirements.txt
# micro_diffusion уже добавлен в sys.path в следующей ячейке

In [None]:
import os, sys

SCRIPTS_DIR = '/kaggle/input/datasets/albertdavletshin/consistency-distillation-micro-diffusion'
sys.path.insert(0, SCRIPTS_DIR)
os.chdir(SCRIPTS_DIR)

In [None]:
import sys, os, copy, gc
import torch
import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
from accelerate import Accelerator
from accelerate.utils import set_seed
from diffusers.optimization import get_scheduler

# micro_diffusion находится в /kaggle/input/datasets/albertdavletshin/micro-diffusers
sys.path.insert(0, '/kaggle/input/datasets/albertdavletshin/micro-diffusers')
from micro_diffusion.models.model import create_latent_diffusion

from cd_utils import get_sigmas, cd_loss
from dataset import get_dataloader
from sampler import generate_images, run_parti_prompts_benchmark

print(f"torch {torch.__version__}, CUDA: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}, {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

## Config

In [None]:
TEACHER_CHECKPOINT = "dit_4_channel_0.5B_synthetic_data.pt"
DATA_DIR = "data/journeydb_subset"
OUTPUT_DIR = "output"

LATENT_RES = 64
IN_CHANNELS = 4
POS_INTERP_SCALE = 2.0

NUM_TIMESTEPS = 50
GUIDANCE_SCALE = 5.0
LOSS_TYPE = "huber"
HUBER_C = 0.001

LEARNING_RATE = 1e-5
TRAIN_BATCH_SIZE = 2
GRADIENT_ACCUMULATION_STEPS = 4
MAX_TRAIN_STEPS = 5000
MIXED_PRECISION = "fp16"
SEED = 42
MAX_GRAD_NORM = 1.0
LR_WARMUP_STEPS = 100
RESOLUTION = 512

VALIDATION_STEPS = 500
SAVE_STEPS = 1000

VALIDATION_PROMPTS = [
    "A beautiful sunset over mountains with a clear lake, highly detailed",
    "Portrait of a girl with golden hair, 8k, masterpiece",
    "Astronaut floating in space with Earth in background",
    "A cute corgi puppy playing in autumn leaves, photograph",
]

os.makedirs(OUTPUT_DIR, exist_ok=True)

## Data Preparation

Download 5k images from JourneyDB (streaming, no full dataset needed).
May require `huggingface-cli login` if gated.

In [None]:
import json
from datasets import load_dataset
from PIL import Image

NUM_SAMPLES = 5000

if os.path.exists(DATA_DIR) and os.path.exists(os.path.join(DATA_DIR, "captions.json")):
    n = len(json.load(open(os.path.join(DATA_DIR, "captions.json"))))
    print(f"Dataset already exists: {n} samples in {DATA_DIR}")
else:
    images_dir = os.path.join(DATA_DIR, "images")
    os.makedirs(images_dir, exist_ok=True)

    print(f"Downloading {NUM_SAMPLES} samples from JourneyDB...")
    ds = load_dataset("JourneyDB/JourneyDB", split="train", streaming=True)

    captions = {}
    count = 0
    for sample in tqdm(ds, total=NUM_SAMPLES):
        if count >= NUM_SAMPLES:
            break
        try:
            image = sample['image']
            prompt = sample.get('prompt', sample.get('text', ''))
            if not prompt or not isinstance(prompt, str):
                continue
            fname = f"{count:06d}.jpg"
            if isinstance(image, Image.Image):
                image = image.convert("RGB")
                if max(image.size) > 1024:
                    image.thumbnail((1024, 1024), Image.LANCZOS)
                image.save(os.path.join(images_dir, fname), quality=95)
            else:
                continue
            captions[fname] = prompt
            count += 1
        except Exception as e:
            continue

    with open(os.path.join(DATA_DIR, "captions.json"), 'w') as f:
        json.dump(captions, f, indent=2, ensure_ascii=False)
    print(f"Saved {count} samples to {DATA_DIR}")

## Model Init

In [None]:
if not os.path.exists(TEACHER_CHECKPOINT):
    !wget -q https://huggingface.co/VSehwag24/MicroDiT/resolve/main/ckpts/dit_4_channel_0.5B_synthetic_data.pt

model = create_latent_diffusion(latent_res=LATENT_RES, in_channels=IN_CHANNELS, pos_interp_scale=POS_INTERP_SCALE)
model.dit.load_state_dict(torch.load(TEACHER_CHECKPOINT, map_location='cpu'))
print(f"DiT: {sum(p.numel() for p in model.dit.parameters()) / 1e6:.1f}M params")

In [None]:
# teacher: frozen copy
teacher_dit = copy.deepcopy(model.dit)
teacher_dit.requires_grad_(False)
teacher_dit.eval()

# student: trainable (same init)
student_dit = model.dit
student_dit.train()

model.vae.requires_grad_(False)
model.text_encoder.requires_grad_(False)

vae = model.vae
text_encoder = model.text_encoder
tokenizer = model.tokenizer
edm_config = model.edm_config
latent_scale = model.latent_scale

print(f"sigma_min={edm_config.sigma_min}, sigma_max={edm_config.sigma_max}, sigma_data={edm_config.sigma_data}")

## Training

In [None]:
set_seed(SEED)

accelerator = Accelerator(
    mixed_precision=MIXED_PRECISION,
    gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS,
)
weight_dtype = torch.float16 if MIXED_PRECISION == "fp16" else torch.bfloat16

optimizer = torch.optim.AdamW(student_dit.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999), eps=1e-8)
lr_scheduler = get_scheduler("constant_with_warmup", optimizer=optimizer,
                              num_warmup_steps=LR_WARMUP_STEPS, num_training_steps=MAX_TRAIN_STEPS)

student_dit, optimizer, lr_scheduler = accelerator.prepare(student_dit, optimizer, lr_scheduler)

device = accelerator.device
teacher_dit.to(device, dtype=weight_dtype)
vae.to(device)
text_encoder.to(device, dtype=weight_dtype)

sigmas = get_sigmas(NUM_TIMESTEPS, sigma_min=edm_config.sigma_min,
                    sigma_max=edm_config.sigma_max, rho=edm_config.rho, device=device)

print(f"device={device}, effective_bs={TRAIN_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS}")
print(f"sigmas: {NUM_TIMESTEPS} steps, [{sigmas[0]:.2f} -> {sigmas[-1]:.4f}]")

In [None]:
from torchvision.utils import save_image

@torch.no_grad()
def validate(step):
    student = accelerator.unwrap_model(student_dit)
    student.eval()
    val_dir = os.path.join(OUTPUT_DIR, f"val_{step}")
    os.makedirs(val_dir, exist_ok=True)
    for i, prompt in enumerate(VALIDATION_PROMPTS):
        imgs = generate_images(student, model, [prompt]*2, num_steps=4, seed=SEED+i, device=device)
        save_image(imgs, os.path.join(val_dir, f"p{i}.png"), nrow=2)
    student.train()
    print(f"  val images -> {val_dir}")

In [None]:
train_dataloader = get_dataloader(DATA_DIR, resolution=RESOLUTION, batch_size=TRAIN_BATCH_SIZE)
train_iter = iter(train_dataloader)

In [None]:
losses = []
progress_bar = tqdm(range(MAX_TRAIN_STEPS), desc="CD Training")

for step in progress_bar:
    batch = next(train_iter)
    images = batch['image'].to(device)
    captions = batch['caption']

    with torch.no_grad():
        latents = vae.encode(images.to(weight_dtype))['latent_dist'].sample().data * latent_scale
        tokens = tokenizer.tokenize(captions)['input_ids'].to(device)
        text_emb = text_encoder.encode(tokens)[0]

    with accelerator.accumulate(student_dit):
        student = accelerator.unwrap_model(student_dit)
        loss = cd_loss(
            student, teacher_dit, edm_config,
            latents.float(), text_emb.float(),
            sigmas, GUIDANCE_SCALE,
            loss_type=LOSS_TYPE, huber_c=HUBER_C, weight_dtype=weight_dtype
        )
        accelerator.backward(loss)
        if accelerator.sync_gradients:
            accelerator.clip_grad_norm_(student_dit.parameters(), MAX_GRAD_NORM)
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad(set_to_none=True)

    loss_val = loss.detach().item()
    losses.append(loss_val)
    progress_bar.set_postfix(loss=f"{loss_val:.4f}", lr=f"{lr_scheduler.get_last_lr()[0]:.2e}")

    if (step + 1) % VALIDATION_STEPS == 0:
        print(f"\nstep {step+1}, avg_loss={np.mean(losses[-VALIDATION_STEPS:]):.4f}")
        validate(step + 1)

    if (step + 1) % SAVE_STEPS == 0:
        ckpt = os.path.join(OUTPUT_DIR, f"student_step_{step+1}.pt")
        torch.save(accelerator.unwrap_model(student_dit).state_dict(), ckpt)
        print(f"  saved {ckpt}")

final_ckpt = os.path.join(OUTPUT_DIR, "student_dit_final.pt")
torch.save(accelerator.unwrap_model(student_dit).state_dict(), final_ckpt)
print(f"Done. Final checkpoint: {final_ckpt}")

In [None]:
plt.figure(figsize=(10, 4))
plt.plot(losses, alpha=0.3, label='Raw')
w = min(100, len(losses) // 5)
if w > 1:
    sm = np.convolve(losses, np.ones(w)/w, mode='valid')
    plt.plot(range(w-1, len(losses)), sm, label=f'Smooth (w={w})', color='red')
plt.xlabel('Step'); plt.ylabel('Loss'); plt.title('CD Training Loss')
plt.legend(); plt.grid(alpha=0.3); plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'loss.png'), dpi=150)
plt.show()

## Inference

In [None]:
# uncomment to load from checkpoint:
# model.dit.load_state_dict(torch.load(os.path.join(OUTPUT_DIR, "student_dit_final.pt"), map_location='cpu'))
# student_dit = model.dit

student = accelerator.unwrap_model(student_dit)
student.eval()
student.to(device)

In [None]:
test_prompts = [
    "A beautiful sunset over mountains with a clear lake",
    "An elegant squirrel pirate on a ship",
    "A photo of an astronaut riding a horse",
    "A cute corgi puppy playing in autumn leaves",
]

fig, axes = plt.subplots(len(test_prompts), 4, figsize=(16, 4*len(test_prompts)))
for row, prompt in enumerate(test_prompts):
    for col, n_steps in enumerate([1, 2, 4, 8]):
        imgs = generate_images(student, model, [prompt], num_steps=n_steps, seed=SEED, device=device)
        axes[row, col].imshow(imgs[0].cpu().permute(1,2,0).numpy())
        axes[row, col].set_title(f"{n_steps} step{'s' if n_steps>1 else ''}")
        axes[row, col].axis('off')
plt.suptitle('Steps comparison'); plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'steps.png'), dpi=150, bbox_inches='tight')
plt.show()

## PartiPrompts Benchmark

In [None]:
results = run_parti_prompts_benchmark(
    student, model,
    tsv_path="PartiPrompts.tsv",
    output_dir=os.path.join(OUTPUT_DIR, "parti_prompts_4steps"),
    num_steps=4, num_prompts=100, batch_size=4, seed=2024, device=device,
)
print(f"Generated {len(results)} images")

In [None]:
from PIL import Image

fig, axes = plt.subplots(4, 4, figsize=(16, 16))
for idx, (prompt, img_path) in enumerate(results[:16]):
    r, c = idx // 4, idx % 4
    axes[r, c].imshow(Image.open(img_path))
    axes[r, c].set_title(prompt[:50] + '...', fontsize=8)
    axes[r, c].axis('off')
plt.suptitle('PartiPrompts (4 steps)'); plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'parti_grid.png'), dpi=150, bbox_inches='tight')
plt.show()

## Teacher vs Student

In [None]:
cmp_prompts = [
    "An elegant squirrel pirate on a ship",
    "A photo of a cat wearing sunglasses",
    "A beautiful mountain landscape at sunset",
    "A robot couple fine dining with Eiffel Tower in the background",
]

fig, axes = plt.subplots(len(cmp_prompts), 2, figsize=(10, 5*len(cmp_prompts)))
for row, prompt in enumerate(cmp_prompts):
    t_img = model.generate(prompt=[prompt], num_inference_steps=30, guidance_scale=GUIDANCE_SCALE, seed=SEED)
    s_img = generate_images(student, model, [prompt], num_steps=4, seed=SEED, device=device)

    axes[row, 0].imshow(t_img[0].cpu().permute(1,2,0).numpy())
    axes[row, 0].set_title('Teacher (30 steps)'); axes[row, 0].axis('off')
    axes[row, 1].imshow(s_img[0].cpu().permute(1,2,0).numpy())
    axes[row, 1].set_title('Student CD (4 steps)'); axes[row, 1].axis('off')

plt.suptitle('Teacher (30) vs Student (4)'); plt.tight_layout()
plt.savefig(os.path.join(OUTPUT_DIR, 'comparison.png'), dpi=150, bbox_inches='tight')
plt.show()