# Test CD Student — Darcy Flow

Generates samples from the Consistency Model student using **only 2 steps**
and compares to real data + teacher (50-step DDIM) quality.

**Run this in a separate tab while CD training continues.**

In [None]:
import torch
assert torch.cuda.is_available(), "Enable GPU: Runtime > Change runtime type > GPU"
print(f"GPU: {torch.cuda.get_device_name(0)}")

from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install -q omegaconf einops h5py

import os
REPO_DIR = '/content/GenModeling'
if os.path.exists(REPO_DIR):
    !git -C {REPO_DIR} pull
else:
    !git clone https://github.com/MehdiMHeydari/GenModeling.git {REPO_DIR}

import sys
if REPO_DIR not in sys.path:
    sys.path.insert(0, REPO_DIR)
print("Ready.")

In [None]:
import numpy as np
import h5py
import torch as th
import matplotlib.pyplot as plt

from src.models.networks.unet.unet import UNetModelWrapper as UNetModel
from src.models.vp_diffusion import VPDiffusionModel
from src.models.consistency_models import MultistepConsistencyModel
from src.inference.samplers import MultistepCMSampler
from src.models.diffusion_utils import ddim_step

DEVICE = "cuda"
print("Imports OK.")

In [None]:
# ============================================================
# CONFIG
# ============================================================
DATA_SHAPE = (1, 128, 128)
DATA_PATH = "/content/drive/MyDrive/2D_DarcyFlow_beta1.0_Train.hdf5"
TEACHER_SAVE_DIR = "/content/drive/MyDrive/cd_darcy/teacher"
CD_SAVE_DIR = "/content/drive/MyDrive/cd_darcy/student"
SCHEDULE_S = 0.008
STUDENT_STEPS = 2

UNET_CFG = dict(
    dim=list(DATA_SHAPE),
    channel_mult="1, 2, 4, 4",
    num_channels=64,
    num_res_blocks=2,
    num_head_channels=32,
    attention_resolutions="32",
    dropout=0.0,
    use_new_attention_order=True,
    use_scale_shift_norm=True,
    class_cond=False,
    num_classes=None,
)

NUM_SAMPLES = 16
BATCH_SIZE = 8

# --- Pick the latest CD checkpoint ---
# List what's available
cd_ckpts = sorted([f for f in os.listdir(CD_SAVE_DIR) if f.startswith('checkpoint_') and f.endswith('.pt')])
print("Available CD checkpoints:", cd_ckpts)

# Use the latest one
CD_CKPT = os.path.join(CD_SAVE_DIR, cd_ckpts[-1])
print(f"\nUsing: {CD_CKPT}")

In [None]:
# ============================================================
# LOAD CD STUDENT MODEL
# ============================================================
network = UNetModel(**UNET_CFG)
cm = MultistepConsistencyModel(
    network=network,
    student_steps=STUDENT_STEPS,
    schedule_s=SCHEDULE_S,
    infer=True,
)

state = th.load(CD_CKPT, map_location='cpu', weights_only=True)
cm.network.load_state_dict(state['model_state_dict'])
if 'ema_state_dict' in state:
    cm.ema_network.load_state_dict(state['ema_state_dict'])
    print("Loaded EMA weights (will use for sampling)")
cm.to(DEVICE)
cm.eval()
print(f"Loaded student from epoch {state['epoch']}")
print(f"Student steps: {STUDENT_STEPS} (vs teacher's 50 DDIM steps)")

In [None]:
# ============================================================
# GENERATE SAMPLES — CM (2 steps!)
# ============================================================
from tqdm.auto import tqdm

sampler = MultistepCMSampler(cm)
C, H, W = DATA_SHAPE
all_samples = []
rounds = (NUM_SAMPLES + BATCH_SIZE - 1) // BATCH_SIZE

print(f"Generating {NUM_SAMPLES} samples with {STUDENT_STEPS} steps (CM)...")
with th.no_grad():
    for r in range(rounds):
        n = min(BATCH_SIZE, NUM_SAMPLES - r * BATCH_SIZE)
        z = th.randn(n, C, H, W, device=DEVICE)
        samples = sampler.sample(z)
        all_samples.append(samples.cpu())
        print(f"  Batch {r+1}/{rounds} done")

cm_samples = th.cat(all_samples, dim=0)[:NUM_SAMPLES]
print(f"Generated {cm_samples.shape[0]} samples")
print(f"Normalized range: [{cm_samples.min():.4f}, {cm_samples.max():.4f}]")

In [None]:
# ============================================================
# ALSO GENERATE TEACHER SAMPLES FOR COMPARISON
# ============================================================
TEACHER_CKPT = os.path.join(TEACHER_SAVE_DIR, "checkpoint_175.pt")

teacher_net = UNetModel(**UNET_CFG)
teacher = VPDiffusionModel(network=teacher_net, schedule_s=SCHEDULE_S, infer=True)
t_state = th.load(TEACHER_CKPT, map_location='cpu', weights_only=True)
teacher.network.load_state_dict(t_state['model_state_dict'])
teacher.to(DEVICE)
teacher.eval()

DDIM_STEPS = 50
ts = th.linspace(1.0, 0.0, DDIM_STEPS + 1, device=DEVICE)
teacher_samples = []

# Use the SAME initial noise as CM for fair comparison
th.manual_seed(42)
cm_noise = [th.randn(min(BATCH_SIZE, NUM_SAMPLES - r * BATCH_SIZE), C, H, W) for r in range(rounds)]

print(f"Generating {NUM_SAMPLES} teacher samples with {DDIM_STEPS} DDIM steps...")
with th.no_grad():
    for r in range(rounds):
        z = cm_noise[r].to(DEVICE)
        n = z.shape[0]
        for i in tqdm(range(DDIM_STEPS), desc=f"Batch {r+1}", leave=False):
            t_batch = th.full((n,), ts[i].item(), device=DEVICE)
            s_batch = th.full((n,), ts[i+1].item(), device=DEVICE)
            x_hat = teacher.predict_x(z, t_batch)
            z = ddim_step(x_hat, z, t_batch, s_batch, SCHEDULE_S)
        teacher_samples.append(z.cpu())

teacher_samples = th.cat(teacher_samples, dim=0)[:NUM_SAMPLES]
print("Teacher samples done.")

In [None]:
# ============================================================
# LOAD REAL DATA + DENORMALIZE
# ============================================================
with h5py.File(DATA_PATH, 'r') as f:
    outputs = np.array(f['tensor']).astype(np.float32)
if outputs.ndim == 3:
    outputs = outputs[:, np.newaxis, :, :]

data_min = float(np.load(os.path.join(TEACHER_SAVE_DIR, "data_min.npy")))
data_max = float(np.load(os.path.join(TEACHER_SAVE_DIR, "data_max.npy")))

def denormalize(x_norm):
    if isinstance(x_norm, th.Tensor):
        x_norm = x_norm.numpy()
    return (x_norm + 1.0) / 2.0 * (data_max - data_min) + data_min

real_norm = 2.0 * (outputs - data_min) / (data_max - data_min) - 1.0
test_data = real_norm[1000:]

cm_denorm = denormalize(cm_samples)
teacher_denorm = denormalize(teacher_samples)
real_denorm = denormalize(test_data)

print(f"CM student (physical):  [{cm_denorm.min():.4f}, {cm_denorm.max():.4f}]")
print(f"Teacher    (physical):  [{teacher_denorm.min():.4f}, {teacher_denorm.max():.4f}]")
print(f"Real test  (physical):  [{real_denorm.min():.4f}, {real_denorm.max():.4f}]")

In [None]:
# ============================================================
# 3-ROW COMPARISON: CM (2-step) vs Teacher (50-step) vs Real
# ============================================================
n_show = 4
fig, axes = plt.subplots(3, n_show, figsize=(4 * n_show, 12))

vmin = min(cm_denorm[:n_show].min(), teacher_denorm[:n_show].min(), real_denorm[:n_show].min())
vmax = max(cm_denorm[:n_show].max(), teacher_denorm[:n_show].max(), real_denorm[:n_show].max())

for i in range(n_show):
    axes[0, i].imshow(cm_denorm[i, 0], cmap='viridis', vmin=vmin, vmax=vmax)
    axes[0, i].set_title(f'CM Student {i+1}')
    axes[0, i].axis('off')

    axes[1, i].imshow(teacher_denorm[i, 0], cmap='viridis', vmin=vmin, vmax=vmax)
    axes[1, i].set_title(f'Teacher {i+1}')
    axes[1, i].axis('off')

    im = axes[2, i].imshow(real_denorm[i, 0], cmap='viridis', vmin=vmin, vmax=vmax)
    axes[2, i].set_title(f'Real {i+1}')
    axes[2, i].axis('off')

axes[0, 0].set_ylabel('CM (2 steps)', fontsize=12)
axes[1, 0].set_ylabel('Teacher (50 steps)', fontsize=12)
axes[2, 0].set_ylabel('Real', fontsize=12)

fig.colorbar(im, ax=axes, shrink=0.5, label='u(x,y)')
plt.suptitle('Consistency Model vs Teacher vs Real Darcy Flow', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# ============================================================
# GRID: All CM Student Samples
# ============================================================
n_grid = min(NUM_SAMPLES, 16)
rows = (n_grid + 3) // 4
fig, axes = plt.subplots(rows, 4, figsize=(16, 4 * rows))
axes = axes.flatten()

for i in range(len(axes)):
    if i < n_grid:
        axes[i].imshow(cm_denorm[i, 0], cmap='viridis')
        axes[i].set_title(f'CM Sample {i+1}')
    axes[i].axis('off')

plt.suptitle(f'All CM Student Samples ({STUDENT_STEPS}-step generation)', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# ============================================================
# STATISTICS COMPARISON
# ============================================================
print("=" * 60)
print("STATISTICS COMPARISON (physical units)")
print("=" * 60)
print(f"{'':20s} {'CM (2-step)':>12s} {'Teacher (50)':>12s} {'Real (test)':>12s}")
print("-" * 60)
print(f"{'Mean':20s} {cm_denorm.mean():12.6f} {teacher_denorm.mean():12.6f} {real_denorm.mean():12.6f}")
print(f"{'Std':20s} {cm_denorm.std():12.6f} {teacher_denorm.std():12.6f} {real_denorm.std():12.6f}")
print(f"{'Min':20s} {cm_denorm.min():12.6f} {teacher_denorm.min():12.6f} {real_denorm.min():12.6f}")
print(f"{'Max':20s} {cm_denorm.max():12.6f} {teacher_denorm.max():12.6f} {real_denorm.max():12.6f}")
print(f"{'Median':20s} {np.median(cm_denorm):12.6f} {np.median(teacher_denorm):12.6f} {np.median(real_denorm):12.6f}")

print(f"\nPer-sample mean:")
cm_means = cm_denorm.mean(axis=(1,2,3))
t_means = teacher_denorm.mean(axis=(1,2,3))
r_means = real_denorm.mean(axis=(1,2,3))
print(f"  CM:       {cm_means.mean():.6f} +/- {cm_means.std():.6f}")
print(f"  Teacher:  {t_means.mean():.6f} +/- {t_means.std():.6f}")
print(f"  Real:     {r_means.mean():.6f} +/- {r_means.std():.6f}")

In [None]:
# ============================================================
# HISTOGRAM: All Three Distributions
# ============================================================
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.hist(cm_denorm.flatten(), bins=100, alpha=0.5, density=True, label=f'CM ({STUDENT_STEPS}-step)', color='tab:green')
ax.hist(teacher_denorm.flatten(), bins=100, alpha=0.5, density=True, label='Teacher (50-step)', color='tab:blue')
ax.hist(real_denorm.flatten(), bins=100, alpha=0.5, density=True, label='Real (test)', color='tab:orange')
ax.set_xlabel('u(x,y)')
ax.set_ylabel('Density')
ax.set_title('Pixel Value Distribution: CM vs Teacher vs Real')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print(f"\nKey question: Does the CM (2-step) distribution match the Teacher (50-step)?")
print(f"If yes -> distillation is working, quality is limited by teacher.")
print(f"If no  -> student needs more training or different hyperparameters.")