# Test VP Diffusion Teacher — Darcy Flow

Generates samples from the trained teacher model using DDIM sampling
and compares them to real Darcy Flow fields.

**Run this while Stage 2 (CD) is training to verify the teacher works.**

In [None]:
import torch
assert torch.cuda.is_available(), "Enable GPU: Runtime > Change runtime type > GPU"
print(f"GPU: {torch.cuda.get_device_name(0)}")

from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install -q omegaconf einops h5py

import os
REPO_DIR = '/content/GenModeling'
if os.path.exists(REPO_DIR):
    !git -C {REPO_DIR} pull
else:
    !git clone https://github.com/MehdiMHeydari/GenModeling.git {REPO_DIR}

import sys
if REPO_DIR not in sys.path:
    sys.path.insert(0, REPO_DIR)
print("Ready.")

In [None]:
import numpy as np
import h5py
import torch as th
import matplotlib.pyplot as plt

from src.models.networks.unet.unet import UNetModelWrapper as UNetModel
from src.models.vp_diffusion import VPDiffusionModel
from src.models.diffusion_utils import ddim_step

DEVICE = "cuda"
print("Imports OK.")

In [None]:
# ============================================================
# CONFIG — must match what you used for training
# ============================================================
DATA_SHAPE = (1, 128, 128)
DATA_PATH = "/content/drive/MyDrive/2D_DarcyFlow_beta1.0_Train.hdf5"
TEACHER_SAVE_DIR = "/content/drive/MyDrive/cd_darcy/teacher"
SCHEDULE_S = 0.008

UNET_CFG = dict(
    dim=list(DATA_SHAPE),
    channel_mult="1, 2, 4, 4",
    num_channels=64,
    num_res_blocks=2,
    num_head_channels=32,
    attention_resolutions="32",
    dropout=0.0,
    use_new_attention_order=True,
    use_scale_shift_norm=True,
    class_cond=False,
    num_classes=None,
)

# Sampling config
DDIM_STEPS = 50          # number of DDIM steps (more = better quality, slower)
NUM_SAMPLES = 16         # how many to generate
BATCH_SIZE = 8

# Pick checkpoint (use the last one)
TEACHER_CKPT = os.path.join(TEACHER_SAVE_DIR, "checkpoint_175.pt")
print(f"Will load: {TEACHER_CKPT}")
print(f"Exists: {os.path.exists(TEACHER_CKPT)}")

In [None]:
# ============================================================
# LOAD TEACHER MODEL
# ============================================================
network = UNetModel(**UNET_CFG)
teacher = VPDiffusionModel(network=network, schedule_s=SCHEDULE_S, infer=True)

state = th.load(TEACHER_CKPT, map_location='cpu', weights_only=True)
teacher.network.load_state_dict(state['model_state_dict'])
teacher.to(DEVICE)
teacher.eval()
print(f"Loaded teacher from epoch {state['epoch']}")
print(f"Parameters: {sum(p.numel() for p in teacher.parameters()):,}")

In [None]:
# ============================================================
# GENERATE SAMPLES VIA DDIM
# ============================================================
# DDIM: start from z_1 ~ N(0,I), step from t=1 down to t=0
# At each step: predict x_hat, then take DDIM step to next t
# ============================================================
from tqdm.auto import tqdm

C, H, W = DATA_SHAPE
all_samples = []
rounds = (NUM_SAMPLES + BATCH_SIZE - 1) // BATCH_SIZE

# Time steps: linearly spaced from 1.0 down to ~0
ts = th.linspace(1.0, 0.0, DDIM_STEPS + 1, device=DEVICE)

print(f"Generating {NUM_SAMPLES} samples with {DDIM_STEPS} DDIM steps...")

with th.no_grad():
    for r in range(rounds):
        n = min(BATCH_SIZE, NUM_SAMPLES - r * BATCH_SIZE)
        z = th.randn(n, C, H, W, device=DEVICE)

        for i in tqdm(range(DDIM_STEPS), desc=f"Batch {r+1}/{rounds}", leave=False):
            t_now = ts[i]
            t_next = ts[i + 1]

            t_batch = th.full((n,), t_now.item(), device=DEVICE)
            s_batch = th.full((n,), t_next.item(), device=DEVICE)

            x_hat = teacher.predict_x(z, t_batch)
            z = ddim_step(x_hat, z, t_batch, s_batch, SCHEDULE_S)

        all_samples.append(z.cpu())
        print(f"  Batch {r+1}/{rounds} done")

all_samples = th.cat(all_samples, dim=0)[:NUM_SAMPLES]
print(f"Generated {all_samples.shape[0]} samples, shape: {all_samples.shape}")
print(f"Normalized range: [{all_samples.min():.4f}, {all_samples.max():.4f}]")

In [None]:
# ============================================================
# LOAD REAL DATA FOR COMPARISON
# ============================================================
with h5py.File(DATA_PATH, 'r') as f:
    outputs = np.array(f['tensor']).astype(np.float32)
if outputs.ndim == 3:
    outputs = outputs[:, np.newaxis, :, :]

# Load normalization stats
data_min = float(np.load(os.path.join(TEACHER_SAVE_DIR, "data_min.npy")))
data_max = float(np.load(os.path.join(TEACHER_SAVE_DIR, "data_max.npy")))

def denormalize(x_norm):
    return (x_norm + 1.0) / 2.0 * (data_max - data_min) + data_min

# Normalize real data the same way
real_norm = 2.0 * (outputs - data_min) / (data_max - data_min) - 1.0
test_data = real_norm[1000:]  # test set

# Denormalize everything
gen_denorm = denormalize(all_samples.numpy())
real_denorm = denormalize(test_data)

print(f"Generated (physical): [{gen_denorm.min():.4f}, {gen_denorm.max():.4f}]")
print(f"Real test  (physical): [{real_denorm.min():.4f}, {real_denorm.max():.4f}]")
print(f"Real full   (physical): [{outputs.min():.4f}, {outputs.max():.4f}]")

In [None]:
# ============================================================
# VISUAL COMPARISON: Generated vs Real
# ============================================================
n_show = 4
fig, axes = plt.subplots(2, n_show, figsize=(4 * n_show, 8))

vmin = min(gen_denorm[:n_show].min(), real_denorm[:n_show].min())
vmax = max(gen_denorm[:n_show].max(), real_denorm[:n_show].max())

for i in range(n_show):
    im = axes[0, i].imshow(gen_denorm[i, 0], cmap='viridis', vmin=vmin, vmax=vmax)
    axes[0, i].set_title(f'Generated {i+1}')
    axes[0, i].axis('off')

    axes[1, i].imshow(real_denorm[i, 0], cmap='viridis', vmin=vmin, vmax=vmax)
    axes[1, i].set_title(f'Real {i+1}')
    axes[1, i].axis('off')

fig.colorbar(im, ax=axes, shrink=0.6, label='u(x,y)')
plt.suptitle(f'VP Diffusion Teacher ({DDIM_STEPS}-step DDIM) vs Real Darcy Flow', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# ============================================================
# GRID: All Generated Samples
# ============================================================
n_grid = min(NUM_SAMPLES, 16)
rows = (n_grid + 3) // 4
fig, axes = plt.subplots(rows, 4, figsize=(16, 4 * rows))
axes = axes.flatten()

for i in range(len(axes)):
    if i < n_grid:
        axes[i].imshow(gen_denorm[i, 0], cmap='viridis')
        axes[i].set_title(f'Sample {i+1}')
    axes[i].axis('off')

plt.suptitle(f'All Teacher Samples ({DDIM_STEPS}-step DDIM)', fontsize=14)
plt.tight_layout()
plt.show()

In [None]:
# ============================================================
# STATISTICS COMPARISON
# ============================================================
print("=" * 50)
print("STATISTICS COMPARISON (physical units)")
print("=" * 50)
print(f"{'':20s} {'Generated':>12s} {'Real (test)':>12s}")
print("-" * 50)
print(f"{'Mean':20s} {gen_denorm.mean():12.6f} {real_denorm.mean():12.6f}")
print(f"{'Std':20s} {gen_denorm.std():12.6f} {real_denorm.std():12.6f}")
print(f"{'Min':20s} {gen_denorm.min():12.6f} {real_denorm.min():12.6f}")
print(f"{'Max':20s} {gen_denorm.max():12.6f} {real_denorm.max():12.6f}")
print(f"{'Median':20s} {np.median(gen_denorm):12.6f} {np.median(real_denorm):12.6f}")

# Per-sample mean comparison
gen_means = gen_denorm.mean(axis=(1, 2, 3))
real_means = real_denorm.mean(axis=(1, 2, 3))
print(f"\n{'Per-sample mean':20s}")
print(f"  Generated:  {gen_means.mean():.6f} +/- {gen_means.std():.6f}")
print(f"  Real:       {real_means.mean():.6f} +/- {real_means.std():.6f}")

In [None]:
# ============================================================
# HISTOGRAM: Pixel Value Distributions
# ============================================================
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
ax.hist(gen_denorm.flatten(), bins=100, alpha=0.6, density=True, label='Generated', color='tab:blue')
ax.hist(real_denorm.flatten(), bins=100, alpha=0.6, density=True, label='Real (test)', color='tab:orange')
ax.set_xlabel('u(x,y)')
ax.set_ylabel('Density')
ax.set_title('Pixel Value Distribution: Generated vs Real')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()