# Multistep Consistency Models — Darcy Flow

Implementation of ["Multistep Consistency Models"](https://arxiv.org/abs/2403.06807) (Heek, Hoogeboom, Salimans 2024).

**Task:** Unconditional generation of Darcy Flow PDE solution fields u(x,y).

**Dataset:** PDEBench 2D Darcy Flow (1024 samples, 1x128x128).

**Two-phase pipeline:**
1. Train a VP Diffusion teacher model
2. Distill into a fast Consistency Model (2-8 step generation)

---

## Cell 1: GPU Check & Drive Mount

In [None]:
import torch
assert torch.cuda.is_available(), "Enable GPU: Runtime > Change runtime type > GPU"
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

from google.colab import drive
drive.mount('/content/drive')

## Cell 2: Setup — Install Deps & Load Repo

In [None]:
# Install dependencies (PyTorch is pre-installed in Colab)
!pip install -q omegaconf einops tensorboard h5py

# Clone repo (pull latest on re-run)
import os
REPO_DIR = '/content/GenModeling'
if os.path.exists(REPO_DIR):
    !git -C {REPO_DIR} pull
else:
    !git clone https://github.com/MehdiMHeydari/GenModeling.git {REPO_DIR}

import sys
if REPO_DIR not in sys.path:
    sys.path.insert(0, REPO_DIR)
print("Repo loaded and path configured.")

## Cell 3: Imports

In [None]:
import os
import h5py
import numpy as np
import torch as th
from torch.optim import Adam
from torch.utils.data import DataLoader

from src.models.networks.unet.unet import UNetModelWrapper as UNetModel
from src.models.vp_diffusion import VPDiffusionModel
from src.models.consistency_models import MultistepConsistencyModel
from src.training.trainer import Trainer
from src.training.objectives import VPDiffusionLoss, MultistepCDLoss
from src.inference.samplers import MultistepCMSampler
from src.utils.dataset import VF_FM
from src.utils.logger import configure, log, logkvs, dumpkvs

DEVICE = "cuda"
print("All imports successful.")

## Cell 4: Configuration

In [None]:
# ============================================================
# CONFIGURATION
# ============================================================

# Darcy Flow: 1 channel, 128x128 spatial
DATA_SHAPE = (1, 128, 128)

# Dataset path on Google Drive
DATA_PATH = "/content/drive/MyDrive/2D_DarcyFlow_beta1.0_Train.hdf5"

# UNet config
UNET_CFG = dict(
    dim=list(DATA_SHAPE),
    channel_mult="1, 2, 4, 4",
    num_channels=64,
    num_res_blocks=2,
    num_head_channels=32,
    attention_resolutions="32",
    dropout=0.0,
    use_new_attention_order=True,
    use_scale_shift_norm=True,
    class_cond=False,
    num_classes=None,
)

# Training config — Teacher (VP Diffusion)
TEACHER_EPOCHS = 200
TEACHER_BATCH_SIZE = 16
TEACHER_LR = 1e-4
TEACHER_SAVE_DIR = "/content/drive/MyDrive/cd_darcy/teacher"
TEACHER_SCHEDULE_S = 0.008

# Training config — Student (Consistency Distillation)
CD_EPOCHS = 100
CD_BATCH_SIZE = 16
CD_LR = 1e-4
CD_STUDENT_STEPS = 2
CD_EMA_RATE = 0.9999
CD_X_VAR_FRAC = 0.75
CD_HUBER_EPS = 1e-4
CD_SAVE_DIR = "/content/drive/MyDrive/cd_darcy/student"

# Sampling config
SAMPLE_BATCH_SIZE = 8
NUM_SAMPLES = 24
SAMPLE_SAVE_PATH = "/content/drive/MyDrive/cd_darcy/samples.pt"

os.makedirs(TEACHER_SAVE_DIR, exist_ok=True)
os.makedirs(CD_SAVE_DIR, exist_ok=True)
print("Configuration set.")

## Cell 5: Load & Preprocess Darcy Flow Data

In [None]:
# ============================================================
# LOAD DARCY FLOW DATASET
# ============================================================
# File: 2D_DarcyFlow_beta1.0_Train.hdf5 (PDEBench)
#   "nu":     (N, 128, 128)    — input permeability (NOT USED)
#   "tensor": (N, 1, 128, 128) — output solution u(x,y) (USED)
# ============================================================

with h5py.File(DATA_PATH, 'r') as f:
    print("HDF5 keys:", list(f.keys()))
    outputs = np.array(f['tensor']).astype(np.float32)
    print(f"Raw tensor shape: {outputs.shape}")

# Handle shape: ensure (N, 1, 128, 128)
if outputs.ndim == 3:
    outputs = outputs[:, np.newaxis, :, :]
    print(f"Added channel dim -> {outputs.shape}")

assert outputs.shape[1:] == (1, 128, 128), f"Unexpected shape: {outputs.shape}"
print(f"Raw data range: [{outputs.min():.4f}, {outputs.max():.4f}]")

# --- Min-max normalize to [-1, 1] ---
data_min = float(outputs.min())
data_max = float(outputs.max())
outputs_norm = 2.0 * (outputs - data_min) / (data_max - data_min) - 1.0
print(f"Normalized range: [{outputs_norm.min():.4f}, {outputs_norm.max():.4f}]")

# Save normalization stats for denormalization later
np.save(os.path.join(TEACHER_SAVE_DIR, "data_min.npy"), np.array(data_min))
np.save(os.path.join(TEACHER_SAVE_DIR, "data_max.npy"), np.array(data_max))

# --- Train / Val / Test split ---
train_data = outputs_norm[:800]
val_data   = outputs_norm[800:1000]
test_data  = outputs_norm[1000:]

print(f"Train: {train_data.shape[0]} samples")
print(f"Val:   {val_data.shape[0]} samples")
print(f"Test:  {test_data.shape[0]} samples")

# --- Build dataset and dataloader ---
dataset = VF_FM(train_data, all_vel=True)
train_loader = DataLoader(
    dataset, batch_size=TEACHER_BATCH_SIZE,
    shuffle=True, num_workers=2, pin_memory=True
)

print(f"\nSample shape: {dataset.shape}")
print(f"Batches per epoch: {len(train_loader)}")

## Cell 6: Build UNet

In [None]:
def build_unet():
    return UNetModel(**UNET_CFG)

# Sanity check
net = build_unet()
total_params = sum(p.numel() for p in net.parameters())
print(f"UNet parameters: {total_params:,}")

# Quick forward pass test
with th.no_grad():
    C, H, W = DATA_SHAPE
    x_test = th.randn(2, C, H, W)
    t_test = th.tensor([0.3, 0.7])
    out = net(t=t_test, x=x_test)
    print(f"Forward pass: input {x_test.shape} -> output {out.shape}")
    assert out.shape == x_test.shape
del net, x_test, t_test, out
print("UNet OK.")

## Cell 7: Train VP Diffusion Teacher (Phase A)

In [None]:
# ============================================================
# STAGE 1: Train VP Diffusion Teacher
# ============================================================

teacher_network = build_unet()
teacher_model = VPDiffusionModel(
    network=teacher_network, schedule_s=TEACHER_SCHEDULE_S
)

optimizer = Adam(teacher_model.network.parameters(), lr=TEACHER_LR)
objective = VPDiffusionLoss(class_conditional=False)

logpath_teacher = os.path.join(TEACHER_SAVE_DIR, "logs")
os.makedirs(logpath_teacher, exist_ok=True)

trainer = Trainer(
    model=teacher_model,
    objective=objective,
    dataloader=train_loader,
    optimizer=optimizer,
    scheduler=None,
    logger_dict={
        "dir": logpath_teacher,
        "format_strs": ["log", "tensorboard"],
        "log_print_freq": 5,
    },
    checkpointing_dict={
        "restart": False,
        "restart_epoch": None,
        "save_path": TEACHER_SAVE_DIR,
        "save_epoch_int": 25,
        "log_batch_int": 10,
    },
    device=DEVICE,
)

trainer.train(num_epochs=TEACHER_EPOCHS)
print(f"Teacher training complete. Checkpoints in {TEACHER_SAVE_DIR}")

## Cell 8: Train CD Student (Phase B)

In [None]:
# ============================================================
# STAGE 2: Consistency Distillation
# ============================================================

# Pick teacher checkpoint (last saved epoch)
TEACHER_CKPT = os.path.join(
    TEACHER_SAVE_DIR,
    f"checkpoint_{(TEACHER_EPOCHS - 1) // 25 * 25}.pt"
)
assert os.path.exists(TEACHER_CKPT), f"No teacher checkpoint at {TEACHER_CKPT}"

# --- Load frozen teacher ---
teacher_net = build_unet()
teacher = VPDiffusionModel(
    network=teacher_net, schedule_s=TEACHER_SCHEDULE_S, infer=True
)
state = th.load(TEACHER_CKPT, map_location='cpu', weights_only=True)
teacher.network.load_state_dict(state['model_state_dict'])
teacher.to(DEVICE)
teacher.eval()
for p in teacher.parameters():
    p.requires_grad_(False)
print(f"Loaded teacher from {TEACHER_CKPT}")

# --- Build student (initialize from teacher weights) ---
student_net = build_unet()
student_net.load_state_dict(teacher_net.state_dict())

student_model = MultistepConsistencyModel(
    network=student_net,
    student_steps=CD_STUDENT_STEPS,
    schedule_s=TEACHER_SCHEDULE_S,
    ema_rate=CD_EMA_RATE,
)
student_model.to(DEVICE)

# --- CD Loss ---
cd_loss = MultistepCDLoss(
    class_conditional=False,
    teacher_model=teacher,
    student_steps=CD_STUDENT_STEPS,
    x_var_frac=CD_X_VAR_FRAC,
    huber_epsilon=CD_HUBER_EPS,
    schedule_s=TEACHER_SCHEDULE_S,
)

optimizer_cd = Adam(student_model.network.parameters(), lr=CD_LR)

# Rebuild loader with CD batch size
cd_loader = DataLoader(
    dataset, batch_size=CD_BATCH_SIZE,
    shuffle=True, num_workers=2, pin_memory=True
)

# --- Training loop ---
logpath_cd = os.path.join(CD_SAVE_DIR, "logs")
os.makedirs(logpath_cd, exist_ok=True)
configure(dir=logpath_cd, format_strs=["log", "tensorboard"])

for epoch in range(CD_EPOCHS):
    student_model.network.train()
    total_loss = 0.0
    for i, batch in enumerate(cd_loader):
        loss = cd_loss(student_model, batch, device=DEVICE)
        optimizer_cd.zero_grad()
        loss.backward()
        optimizer_cd.step()
        student_model.update_ema()
        total_loss += loss.item()

    avg_loss = total_loss / len(cd_loader)
    logkvs({"epoch": epoch, "cd_loss": avg_loss})
    dumpkvs()

    if epoch % 5 == 0:
        print(f"Epoch {epoch}: loss={avg_loss:.6f}, "
              f"teacher_steps={cd_loss._teacher_step_schedule()}")

    if epoch % 10 == 0:
        ckpt = {
            'epoch': epoch,
            'model_state_dict': student_model.network.state_dict(),
            'ema_state_dict': student_model.ema_network.state_dict(),
            'optimizer_state_dict': optimizer_cd.state_dict(),
        }
        th.save(ckpt, f"{CD_SAVE_DIR}/checkpoint_{epoch}.pt")

print("CD training complete.")

## Cell 9: Sample from Trained CM

In [None]:
# ============================================================
# STAGE 3: Generate Samples
# ============================================================

# Pick the last saved CD checkpoint
CD_CKPT = os.path.join(
    CD_SAVE_DIR,
    f"checkpoint_{(CD_EPOCHS - 1) // 10 * 10}.pt"
)

# Build and load model
sample_net = build_unet()
cm = MultistepConsistencyModel(
    network=sample_net,
    student_steps=CD_STUDENT_STEPS,
    schedule_s=TEACHER_SCHEDULE_S,
    infer=True,
)
state = th.load(CD_CKPT, map_location='cpu', weights_only=True)
cm.network.load_state_dict(state['model_state_dict'])
if 'ema_state_dict' in state:
    cm.ema_network.load_state_dict(state['ema_state_dict'])
cm.to(DEVICE)
cm.eval()

sampler = MultistepCMSampler(cm)

# Generate
C, H, W = DATA_SHAPE
rounds = (NUM_SAMPLES + SAMPLE_BATCH_SIZE - 1) // SAMPLE_BATCH_SIZE
all_samples = []

print(f"Generating {NUM_SAMPLES} samples with {CD_STUDENT_STEPS} steps...")
with th.no_grad():
    for i in range(rounds):
        n = min(SAMPLE_BATCH_SIZE, NUM_SAMPLES - i * SAMPLE_BATCH_SIZE)
        z = th.randn(n, C, H, W, device=DEVICE)
        samples = sampler.sample(z)
        all_samples.append(samples.cpu())
        print(f"  Batch {i+1}/{rounds} done")

all_samples = th.cat(all_samples, dim=0)[:NUM_SAMPLES]
th.save(all_samples, SAMPLE_SAVE_PATH)
print(f"Saved {NUM_SAMPLES} samples to {SAMPLE_SAVE_PATH}")

## Cell 10: Visualize — Generated vs Real

In [None]:
import matplotlib.pyplot as plt

# Load normalization stats
data_min = float(np.load(os.path.join(TEACHER_SAVE_DIR, "data_min.npy")))
data_max = float(np.load(os.path.join(TEACHER_SAVE_DIR, "data_max.npy")))

def denormalize(x_norm):
    """Map from [-1, 1] back to physical units."""
    return (x_norm + 1.0) / 2.0 * (data_max - data_min) + data_min

# Denormalize generated samples
gen_np = denormalize(all_samples.numpy())  # (N, 1, 128, 128)

# Denormalize test data for comparison
test_denorm = denormalize(test_data)  # (24, 1, 128, 128)

# --- Plot: top row = generated, bottom row = real ---
n_show = 4
fig, axes = plt.subplots(2, n_show, figsize=(4 * n_show, 8))

vmin = min(gen_np[:n_show].min(), test_denorm[:n_show].min())
vmax = max(gen_np[:n_show].max(), test_denorm[:n_show].max())

for i in range(n_show):
    im = axes[0, i].imshow(
        gen_np[i, 0], cmap='viridis', vmin=vmin, vmax=vmax
    )
    axes[0, i].set_title(f"Generated {i+1}")
    axes[0, i].axis('off')

    axes[1, i].imshow(
        test_denorm[i, 0], cmap='viridis', vmin=vmin, vmax=vmax
    )
    axes[1, i].set_title(f"Real {i+1}")
    axes[1, i].axis('off')

fig.colorbar(im, ax=axes, shrink=0.6, label='u(x,y)')
plt.suptitle(
    f"Darcy Flow: CM Samples ({CD_STUDENT_STEPS} steps) vs Real",
    fontsize=14
)
plt.tight_layout()
plt.savefig('/content/darcy_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

# --- Grid of all generated samples ---
n_grid = min(NUM_SAMPLES, 16)
rows = (n_grid + 3) // 4
fig, axes = plt.subplots(rows, 4, figsize=(16, 4 * rows))
axes = axes.flatten()

for i in range(len(axes)):
    if i < n_grid:
        axes[i].imshow(gen_np[i, 0], cmap='viridis')
        axes[i].set_title(f"Sample {i+1}")
    axes[i].axis('off')

plt.suptitle(
    f"All Generated Darcy Flow Fields ({CD_STUDENT_STEPS}-step CM)",
    fontsize=14
)
plt.tight_layout()
plt.savefig('/content/darcy_grid.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"Generated range: [{gen_np.min():.4f}, {gen_np.max():.4f}]")
print(f"Real range:      [{test_denorm.min():.4f}, {test_denorm.max():.4f}]")
print("Visualization complete.")