In [1]:
from functools import partial
import os
import argparse
import yaml
# import pandas as pd
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from PIL import Image
import torch.nn.functional as F
from util.img_utils import Blurkernel#, ifft2_m,fft2_m
from guided_diffusion.fastmri_utils import fft2c_new,ifft2c_new
from guided_diffusion.condition_methods import get_conditioning_method
from guided_diffusion.measurements import get_noise, get_operator
from guided_diffusion.unet import create_model
from guided_diffusion.gaussian_diffusion_correct import create_sampler
from data.dataloader import get_dataset, get_dataloader
from util.img_utils import clear_color, mask_generator
from util.logger import get_logger
from common_utils import *
from ddim_sampler import *
import shutil
import lpips
from scheduling_ddpm import DDPMScheduler
from functools import partial

In [2]:
np.random.seed(41)
torch.manual_seed(41)
torch.cuda.manual_seed_all(41)
torch.backends.cudnn.deterministic = True

In [3]:
def load_yaml(file_path: str) -> dict:
    with open(file_path) as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    return config

model_config='/egr/research-slim/liangs16/Measurment_Consistent_Diffusion_Trajectory/configs/model_config.yaml'
diffusion_config='/egr/research-slim/liangs16/Measurment_Consistent_Diffusion_Trajectory/configs/diffusion_config.yaml'
task_config= '/egr/research-slim/liangs16/Measurment_Consistent_Diffusion_Trajectory/configs/phase_retrieval_config.yaml'

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Load configurations
model_config = load_yaml(model_config)
diffusion_config = load_yaml(diffusion_config)
task_config = load_yaml(task_config)

In [4]:
model = create_model(**model_config)
model = model.to(device)


# Prepare Operator and noise
measure_config = task_config['measurement']
operator = get_operator(device=device, **measure_config['operator'])
noiser = get_noise(**measure_config['noise'])
#logger.info(f"Operation: {measure_config['operator']['name']} / Noise: {measure_config['noise']['name']}")

# Prepare conditioning method
cond_config = task_config['conditioning']
cond_method = get_conditioning_method(cond_config['method'], operator, noiser, **cond_config['params'])
measurement_cond_fn = cond_method.conditioning
#logger.info(f"Conditioning method : {task_config['conditioning']['method']}")

# Prepare dataloader
data_config = task_config['data']
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = get_dataset(**data_config, transforms=transform)
loader = get_dataloader(dataset, batch_size=1, num_workers=0, train=False)

# Exception) In case of inpainting, we need to generate a mask 
if measure_config['operator']['name'] == 'inpainting':
    mask_gen = mask_generator(
       **measure_config['mask_opt']
    )



In [5]:

scheduler = DDIMScheduler()

In [6]:
criterion = torch.nn.L1Loss().to(device)

In [7]:
losses = []
psnrs = []
out_visual = []

In [8]:
def compute_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0  # Assuming the image is normalized to [0, 1]
    psnr = 20 * np.log10(max_pixel / (mse**0.5))
    return psnr.item()

In [9]:
def compute_mce(reconstructed_image, measured_magnitude, target_shape=None):
    """Compute L2 error between Fourier magnitude of reconstruction and measured magnitude."""
    # Bring to [0, 1] range if needed
    x_hat = reconstructed_image.clamp(0, 1)

    # Pad to match measurement shape (e.g., 384x384)
    if target_shape is not None:
        _, _, H, W = x_hat.shape
        target_H, target_W = target_shape[-2], target_shape[-1]
        pad_h = (target_H - H) // 2
        pad_w = (target_W - W) // 2
        x_hat = F.pad(x_hat, (pad_w, pad_w, pad_h, pad_h))

    # Ensure complex
    if not torch.is_complex(x_hat):
        x_hat = x_hat.type(torch.complex64)

    # Compute FFT and error
    fft_x_hat = torch.view_as_complex(fft2c_new(torch.view_as_real(x_hat)))
    magnitude_error = (fft_x_hat.abs() - measured_magnitude).pow(2).mean()
    return magnitude_error.item()

In [10]:
# def optimize_input(input,  sqrt_one_minus_alpha_cumprod, sqrt_alpha_cumprod, t, num_steps=100, learning_rate=0.01):
#     input_tensor = torch.randn(1, model.in_channels, 256, 256, requires_grad=True)
#     input_tensor.data = input.clone().to(device)
#     optimizer = torch.optim.Adam([input_tensor], lr=.01)
#     tt = (torch.ones(1) * t).to(device)
#     for step in range(num_steps):
#         optimizer.zero_grad()
    
#         noise_pred = model(input_tensor.to(device), tt)
#         noise_pred = noise_pred[:, :3]
#         pred_x0 = (input_tensor.to(device) -sqrt_one_minus_alpha_cumprod * noise_pred) / sqrt_alpha_cumprod
#         pred_x0= torch.clamp(pred_x0, -1, 1)
#         loss = torch.norm(operator.forward(pred_x0)-y_n)**2 + torch.norm(input_tensor-sqrt_alpha_cumprod*pred_x0)**2/(0.01+sqrt_one_minus_alpha_cumprod)**2 *0.001
#         loss.backward()
#         optimizer.step()

#      #   print(f"Step {step}/{num_steps}, Loss: {loss.item()}")
#     noise = (input_tensor-sqrt_alpha_cumprod*pred_x0)/sqrt_one_minus_alpha_cumprod
#     return input_tensor.detach(), pred_x0.detach(), noise.detach()

In [11]:
import torch
import torch.nn.functional as F

def gerchberg_saxton(
    y_n,
    num_iters=200,
    pad_ratio=2/8.0,
    device="cuda",
    dtype=torch.float32,
):
    """
    Very simple error-reduction / Gerchberg–Saxton solver:
    - object constraints: real, nonnegative
    - measurement: enforce Fourier magnitude = y_n
    """

    B = 1
    H = W = 256
    pad = int(pad_ratio * H)

    # --- init in object domain: random real, [0,1] ---
    x_real = torch.rand(B, 1, H, W, device=device, dtype=dtype)

    # pad to match your measurement geometry
    x_real = F.pad(x_real, (pad, pad, pad, pad))

    # make complex
    x = x_real.to(device)
    if not torch.is_complex(x):
        x = x.type(torch.complex64)

    eps = 1e-8

    for it in range(num_iters):
        # Forward FFT
        X = torch.view_as_complex(fft2c_new(torch.view_as_real(x)))

        # Enforce magnitude constraint: new_X = y_n * exp(i * angle(X))
        mag = torch.abs(X)
        phase = X / (mag + eps)
        X_new = y_n * phase

        # Backward FFT
        x_back = ifft2c_new(torch.view_as_real(X_new.type(torch.complex64)))
        x_back = torch.view_as_complex(x_back)

        # Object-domain constraints: real, nonnegative
        x_real = x_back.real.clamp_min(0.0)

        # Re-pack as complex with zero imaginary part
        x = x_real.type(torch.complex64)

    # Crop back to 256×256
    x_real = x.real[:, :, pad:-pad, pad:-pad]

    # Normalize to [0,1] (optional mild scaling)
    x_min = x_real.amin(dim=(-2, -1), keepdim=True)
    x_max = x_real.amax(dim=(-2, -1), keepdim=True)
    x_real = (x_real - x_min) / (x_max - x_min + 1e-8)

    return x_real  # shape: (1,1,256,256) in [0,1]

In [12]:
def optimize_input(input,  sqrt_one_minus_alpha_cumprod, sqrt_alpha_cumprod, t, num_steps=20, learning_rate=0.01):
    input_tensor = torch.randn(1, model.in_channels, 256, 256, requires_grad=True)
    input_tensor.data = input.clone().to(device)
    optimizer = torch.optim.Adam([input_tensor], lr=0.01)
    tt = (torch.ones(1) * t).to(device)
    for step in range(num_steps):
        optimizer.zero_grad()
        noise_pred = model(input_tensor.to(device), tt)
        noise_pred = noise_pred[:, :3]
        pred_x0 = (input_tensor.to(device) -sqrt_one_minus_alpha_cumprod * noise_pred) / sqrt_alpha_cumprod
        pred_x0= torch.clamp(pred_x0, -1, 1)
        #out =operator.forward(pred_x0)
        pad = int((2 / 8.0) * 256)
        x = pred_x0 * 0.5 + 0.5  # [-1, 1] -> [0, 1]
        x = F.pad(x, (pad, pad, pad, pad))
        if not torch.is_complex(x):
            x = x.type(torch.complex64)
        fft2_m = torch.view_as_complex(fft2c_new(torch.view_as_real(x)))
        out = fft2_m.abs()
        loss = torch.norm(out-y_n)**2
        loss += 0.1*torch.norm(input-input_tensor)**2
        loss.backward()
        optimizer.step()
        back_y = ifft2c_new(torch.view_as_real(fft2_m.type(torch.complex64)))
        back_y = torch.view_as_complex(back_y)
        out_visual.append(back_y)
        #print(f"Step {step}/{num_steps}, Loss: {loss.item()}, break: {1.1*torch.sqrt(torch.tensor(y.numel(), dtype=torch.float32))}")
    noise = (input_tensor-sqrt_alpha_cumprod*pred_x0)/sqrt_one_minus_alpha_cumprod
#     back_y = ifft2c_new(torch.view_as_real(y.type(torch.complex64)))
#     back_y = torch.view_as_complex(back_y)
#     pred_x0 =  
    return input_tensor.detach(), pred_x0.detach(), noise.detach()

In [13]:
out = []
n_step = 20
scheduler.set_timesteps(num_inference_steps=n_step)
step_size = 1000//n_step

In [14]:
dtype = torch.float32

In [25]:
filename = f"{13:05d}.png"            # 00011, 00012, …, 00019 etc... i
filepath = os.path.join("/egr/research-slim/liangs16/Measurment_Consistent_Diffusion_Trajectory/data/demo_remain/", filename)
gt_img = Image.open(filepath).convert("RGB")
#shutil.copy(gt_img_path, os.path.join(logdir, 'gt.png'))
ref_numpy = np.array(gt_img).astype(np.float32) / 255.0
x = ref_numpy * 2 - 1
x = x.transpose(2, 0, 1)
ref_img = torch.Tensor(x).to(dtype).to(device).unsqueeze(0)
ref_img.requires_grad = False
#ref_img = ref_img.to(device)
if measure_config['operator'] ['name'] == 'inpainting':
    mask = mask
    measurement_cond_fn = partial(cond_method.conditioning, mask=mask)
    sample_fn = partial(sample_fn, measurement_cond_fn=measurement_cond_fn)

    # Forward measurement model (Ax + n)
    y = operator.forward(ref_img, mask=mask)
    y_n = noiser(y)

else: 
    # Forward measurement model (Ax + n)
    y = operator.forward(ref_img)
    y_n = noiser(y)
    #y_n = torch.clamp(y_n,-1,1)
#y_n.requires_grad = False

In [26]:
gt = (ref_img/2+0.5)

In [27]:
resize_transform = torchvision.transforms.Resize((256,256))

In [28]:
#input = resize_transform(y_n).clone()
'''input =torch.randn((1, 3, 256, 256), device=device, dtype=dtype).clone().detach().requires_grad_(True)
noise = torch.randn(input.shape)*((1-scheduler.alphas_cumprod[-1])**0.5)
input = torch.tensor(input)*((scheduler.alphas_cumprod[-1])**0.5) + noise.to(device)'''

x0_gs = gerchberg_saxton(
        y_n=y_n,             # your magnitude measurements
        num_iters=200,
        device=device,
        dtype=dtype,
    )  # (1,1,256,256), real, [0,1]
print(x0_gs.shape)

# Repeat across 3 channels to match model.in_channels=3
'''x0_gs_rgb = x0_gs.repeat(1, 3, 1, 1)
B, C, H, W = x0_gs.shape
x0_gs = x0_gs.view(B, 3, 3, H, W).mean(dim=2)'''

# Scale to [-1,1] because your model expects that
x0_gs_rgb = x0_gs * 2.0 - 1.0  # [0,1] -> [-1,1]

# Sample x_T using DDPM forward formula
alpha_T = scheduler.alphas_cumprod[-1]
sigma_T = (1.0 - alpha_T).sqrt()

noise = torch.randn_like(x0_gs_rgb) * sigma_T
input = (alpha_T.sqrt() * x0_gs_rgb + noise).to(device)

input = input.clone().detach().requires_grad_(True)


# Clear output visualization in case didn't rerun
out_visual = []
measurement_errors = []

# Print Initial MCE
print(f"Reconstructing Image {filename}...")
init_mce = compute_mce(input, y_n.abs(), target_shape=y_n.shape)
print(f"Initial MCE:, {init_mce: .4f}")

for i, t in enumerate(scheduler.timesteps):
        prev_timestep = t - step_size

        alpha_prod_t = scheduler.alphas_cumprod[t]
        alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.alphas_cumprod[0]

        beta_prod_t = 1 - alpha_prod_t
        sqrt_one_minus_alpha_cumprod = beta_prod_t**0.5


        #for k in range(1):
        input, pred_original_sample, noise_pred= optimize_input(input.clone(), sqrt_one_minus_alpha_cumprod, alpha_prod_t**0.5, t, num_steps=20, learning_rate=0.1)
        #    input= pred_original_sample * alpha_prod_t**0.5+(1-alpha_prod_t)**0.5*torch.randn(input.size()).to(device)
        
        phase_image = (pred_original_sample / 2 + 0.5).clamp(0, 1)

        # Compute measurement error with correct shape
        mce = compute_mce(phase_image, y_n.abs(), target_shape=y_n.shape)
        measurement_errors.append(mce)
        
        input = pred_original_sample * alpha_prod_t_prev**0.5+(1-alpha_prod_t_prev)**0.5*torch.randn(input.size()).to(device)

        print(f"Time: {t}")
        print(f"Step {i+1}/{n_step}, MCE: {mce:.4f}")
        
input = (input/2+0.5).clamp(0, 1)
phase_image = (pred_original_sample/2+0.5).clamp(0, 1)
#input = (input + 1) / 2
phase_image = phase_image.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()

psnr_value1 = compute_psnr(phase_image, np.array(gt.cpu().detach().numpy()[0].transpose(1,2,0)))
print(f"PSNR: {psnr_value1} dB")

torch.Size([1, 3, 256, 256])
Reconstructing Image 00013.png...
Initial MCE:,  0.0705
Time: 950
Step 1/20, MCE: 0.0088
Time: 900
Step 2/20, MCE: 0.0093
Time: 850
Step 3/20, MCE: 0.0064
Time: 800
Step 4/20, MCE: 0.0057
Time: 750
Step 5/20, MCE: 0.0051
Time: 700
Step 6/20, MCE: 0.0050
Time: 650
Step 7/20, MCE: 0.0044
Time: 600
Step 8/20, MCE: 0.0043
Time: 550
Step 9/20, MCE: 0.0040
Time: 500
Step 10/20, MCE: 0.0037
Time: 450
Step 11/20, MCE: 0.0035
Time: 400
Step 12/20, MCE: 0.0032
Time: 350
Step 13/20, MCE: 0.0031
Time: 300
Step 14/20, MCE: 0.0029
Time: 250
Step 15/20, MCE: 0.0028
Time: 200
Step 16/20, MCE: 0.0027
Time: 150
Step 17/20, MCE: 0.0026
Time: 100
Step 18/20, MCE: 0.0024
Time: 50
Step 19/20, MCE: 0.0022
Time: 0
Step 20/20, MCE: 0.0017
PSNR: 16.191903853773198 dB


## Below is Multi-Run Testing ##

In [None]:
# --- helper: pretty-print an MCE trace for one run ---
def print_mce_trace(run_idx: int, mces, timesteps):
    """
    mces: list[Tensor/float] with [init_mce, mce_at_t_0, mce_at_t_1, ...]
    timesteps: iterable of scheduler timesteps aligned with mces[1:]
    """
    # ensure plain floats for printing
    mces_f = [float(m) for m in mces]

    print(f"\nMCE trace for run {run_idx + 1}:")
    print(f"  init: {mces_f[0]:.4f}")
    for k, (t, m) in enumerate(zip(timesteps, mces_f[1:]), 1):
        # right-aligned columns: step, t, value
        print(f"  step {k:3d} (t={int(t):5d}): {m:8.4f}")
    print()  # blank line for spacing


In [None]:
# Multi-Run Testing

num_runs = 1
run_results = []
mce_results = []

print(f"Reconstructing Image {filename}...")
for run_iter in range(num_runs):
    print('Run Number:', run_iter + 1)
    
    # Clear output visualization in case didn't rerun
    out_visual = []
    measurement_errors = []
    
    input =torch.randn((1, 3, 256, 256), device=device, dtype=dtype).clone().detach().requires_grad_(True)
    noise = torch.randn(input.shape)*((1-scheduler.alphas_cumprod[-1])**0.5)
    input = torch.tensor(input)*((scheduler.alphas_cumprod[-1])**0.5) + noise.to(device)
    
    # Print Initial MCE
    init_mce = compute_mce(input, y_n.abs(), target_shape=y_n.shape) # In practice we don't have access to GT during DMs, otherwise we just know GT
    print(f"Initial MCE:, {init_mce: .4f}")
    measurement_errors.append(init_mce)

    for i, t in enumerate(scheduler.timesteps):
        prev_timestep = t - step_size

        alpha_prod_t = scheduler.alphas_cumprod[t]
        alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.alphas_cumprod[0]

        beta_prod_t = 1 - alpha_prod_t
        sqrt_one_minus_alpha_cumprod = beta_prod_t**0.5


        #for k in range(1):
        input, pred_original_sample, noise_pred= optimize_input(input.clone(), sqrt_one_minus_alpha_cumprod, alpha_prod_t**0.5, t, num_steps=20, learning_rate=0.075)
        #    input= pred_original_sample * alpha_prod_t**0.5+(1-alpha_prod_t)**0.5*torch.randn(input.size()).to(device)
        
        phase_image = (pred_original_sample / 2 + 0.5).clamp(0, 1)

        # Compute measurement error with correct shape
        mce = compute_mce(phase_image, y_n.abs(), target_shape=y_n.shape)
        measurement_errors.append(mce)
        
        input = pred_original_sample * alpha_prod_t_prev**0.5+(1-alpha_prod_t_prev)**0.5*torch.randn(input.size()).to(device)

        print(f"Time: {t}, , MCE: {mce:.4f}")
    input = (input/2+0.5).clamp(0, 1)
    phase_image = (pred_original_sample/2+0.5).clamp(0, 1)
    #input = (input + 1) / 2
    phase_image = phase_image.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
    
    # (A) show the full MCE trace for this run in a neat block
    print_mce_trace(run_iter, measurement_errors, scheduler.timesteps)

    # (B) store results
    run_results.append(phase_image)
    mce_results.append(mce)  # final mce
    per_image_results.append({
        "run": run_iter + 1,
        "init_mce": float(measurement_errors[0]),
        "final_mce": float(measurement_errors[-1]),
        "mce_seq": [float(v) for v in measurement_errors],  # full trace, including init
    })

In [None]:
psnr_values = []
for i in range(num_runs):
    #plt.imshow(run_results[i])
    psnr_value = compute_psnr(np.array(run_results[i]), np.array(gt.cpu().detach().numpy()[0].transpose(1,2,0)))
    psnr_values.append(psnr_value)
    print(f"After diffusion PSNR: {psnr_value} dB, MCE: {mce_results[i]: .4f}")

In [None]:
plt.hist(psnr_values, bins = 30, alpha = 0.8, edgecolor = 'black')
plt.xlabel("PSNR (dB)")
plt.ylabel("Frequency")
plt.title(f"Distribution of PSNR Values on {num_runs} Runs of Image {filename} with Random Input")
plt.grid(True)
plt.show()

## Make one file for multi run testing ##

In [None]:
import os, time, json, math
import numpy as np
import pandas as pd
from collections import defaultdict

def _make_step_labels(scheduler, include_init=True):
    """
    Build labels like: ['initial', f"t{t}", ..., 't0'] matching scheduler.timesteps.
    """
    labels = []
    if include_init:
        labels.append("initial")
    # timesteps is typically descending (e.g., [999, 979, ...])
    for t in list(scheduler.timesteps):
        labels.append(f"t{int(t)}")
    return labels

class DMRunRecorder:
    """
    Accumulates per-run traces and PSNRs for one image,
    then computes per-image aggregates and returns tidy + wide DataFrames.
    """
    def __init__(self, image_id, scheduler):
        self.image_id = image_id
        self.scheduler = scheduler
        self.step_labels = _make_step_labels(scheduler, include_init=True)
        self.num_steps = len(self.step_labels)  # initial + len(timesteps)
        self._per_run_mce_traces = []  # list of list[float] length == num_steps
        self._per_run_final_psnr = []  # list[float]
        self._per_run_final_mce  = []  # list[float]

    def record_run(self, init_mce, mce_trace_per_step, final_psnr):
        """
        init_mce: float
        mce_trace_per_step: list of floats for *each diffusion step in order of scheduler.timesteps*.
            We'll prepend init_mce so final length is num_steps.
        final_psnr: float
        """
        assert isinstance(mce_trace_per_step, (list, tuple)), "mce_trace_per_step must be a list"
        trace = [float(init_mce)] + [float(x) for x in mce_trace_per_step]
        if len(trace) != self.num_steps:
            raise ValueError(
                f"MCE trace length mismatch for image {self.image_id}: "
                f"got {len(trace)}, expected {self.num_steps} "
                f"(= 1 init + {len(self.step_labels)-1} steps)"
            )
        self._per_run_mce_traces.append(trace)
        self._per_run_final_psnr.append(float(final_psnr))
        self._per_run_final_mce.append(float(trace[-1]))

    def to_dataframes(self):
        """
        Returns:
          tidy_df: long table with (image, run, step_idx, step_label, mce)
          per_image_summary_df: one row with per-step averages and changes, plus psnr_avg
        """
        n_runs = len(self._per_run_mce_traces)
        # --- Tidy long table ---
        tidy_rows = []
        for r, trace in enumerate(self._per_run_mce_traces, start=1):
            for k, (label, mce_val) in enumerate(zip(self.step_labels, trace)):
                tidy_rows.append({
                    "image": self.image_id,
                    "run": r,
                    "step_idx": k,
                    "step_label": label,
                    "mce": float(mce_val),
                })
        tidy_df = pd.DataFrame(tidy_rows)

        # --- Per-image averages per step ---
        traces = np.array(self._per_run_mce_traces, dtype=float)  # shape (n_runs, num_steps)
        step_means = traces.mean(axis=0)                # avg MCE per step
        change_from_init = step_means - step_means[0]   # Δ vs initial per step
        # Δ vs previous (NaN at initial)
        change_from_prev = np.empty_like(step_means)
        change_from_prev[:] = np.nan
        change_from_prev[1:] = step_means[1:] - step_means[:-1]

        psnr_avg = float(np.mean(self._per_run_final_psnr)) if len(self._per_run_final_psnr) else np.nan

        # Build a single-row wide summary with:
        #   MCE_<label>, dInit_<label>, dPrev_<label>, and final_psnr_avg
        summary = {"image": self.image_id, "final_psnr_avg": psnr_avg}
        for lbl, m, dI, dP in zip(self.step_labels, step_means, change_from_init, change_from_prev):
            summary[f"MCE_{lbl}"]   = float(m)
            summary[f"dInit_{lbl}"] = float(dI)
            summary[f"dPrev_{lbl}"] = float(dP) if not math.isnan(dP) else np.nan

        per_image_summary_df = pd.DataFrame([summary])
        return tidy_df, per_image_summary_df


In [None]:
# --- Config (as you already have) ---
out = []
n_step = 20
scheduler.set_timesteps(num_inference_steps=n_step)
step_size = 1000 // n_step
dtype = torch.float32

# Where to save the outputs:
export_dir = "./dm_results_exports"
os.makedirs(export_dir, exist_ok=True)

# Collect all images’ tidy/wide in memory to write one combined CSV per type
all_tidy = []
all_summary = []

for j in range(11, 12):
    filename = f"{j:05d}.png"   # 00011 ... 00020
    filepath = os.path.join("/egr/research-slim/liangs16/Measurment_Consistent_Diffusion_Trajectory/data/demo_remain/", filename)

    gt_img = Image.open(filepath).convert("RGB")
    ref_numpy = np.array(gt_img).astype(np.float32) / 255.0
    x = ref_numpy * 2 - 1
    x = x.transpose(2, 0, 1)
    ref_img = torch.tensor(x, dtype=dtype, device=device).unsqueeze(0)
    ref_img.requires_grad = False

    # Forward measurement model (Ax + n)
    if measure_config['operator']['name'] == 'inpainting':
        measurement_cond_fn = partial(cond_method.conditioning, mask=mask)
        sample_fn = partial(sample_fn, measurement_cond_fn=measurement_cond_fn)
        y = operator.forward(ref_img, mask=mask)
        y_n = noiser(y)
    else:
        y = operator.forward(ref_img)
        y_n = noiser(y)

    gt = (ref_img / 2 + 0.5)  # in [0,1] torch
    gt_np = gt.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()  # (H,W,C) in [0,1]
    
    gt = (ref_img/2+0.5)
    # resize_transform = torchvision.transforms.Resize((256,256))

    num_runs = 2
    print(f"Reconstructing Image {filename}...")

    # Create per-image recorder
    rec = DMRunRecorder(image_id=f"{j:05d}", scheduler=scheduler)

    for run_iter in range(num_runs):
        # fresh init
        input =torch.randn((1, 3, 256, 256), device=device, dtype=dtype).clone().detach().requires_grad_(True)
        noise = torch.randn(input.shape)*((1-scheduler.alphas_cumprod[-1])**0.5)
        input = torch.tensor(input)*((scheduler.alphas_cumprod[-1])**0.5) + noise.to(device)

        # Initial MCE (using the *current sample*; for pure theory you might prefer pred_original_sample @ init)
        init_mce = compute_mce(input, y_n.abs(), target_shape=y_n.shape)

        # Will store MCE at each diffusion step
        mce_trace = []

        # Diffusion sweep
        pred_original_sample = None  # keep last
        for i, t in enumerate(scheduler.timesteps):
            prev_timestep = t - step_size
            alpha_prod_t = scheduler.alphas_cumprod[t]
            alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.alphas_cumprod[0]
            beta_prod_t = 1 - alpha_prod_t
            sqrt_one_minus_alpha_cumprod = beta_prod_t ** 0.5

            # Optimize current latent toward measurement consistency (your routine)
            input, pred_original_sample, noise_pred = optimize_input(
                input.clone(),
                sqrt_one_minus_alpha_cumprod,
                alpha_prod_t ** 0.5,
                t,
                num_steps=20,
                learning_rate=0.075
            )
            

            # Compute MCE on denoised image proxy
            phase_image = (pred_original_sample / 2 + 0.5).clamp(0, 1)
            mce_val = compute_mce(phase_image, y_n.abs(), target_shape=y_n.shape)
            mce_trace.append(float(mce_val))

            # Prepare next step x_{t-1}
            input = pred_original_sample * (alpha_prod_t_prev ** 0.5) + (1 - alpha_prod_t_prev) ** 0.5 * torch.randn_like(input)

        # Final recon for PSNR (use last pred_original_sample)
        phase_image = (pred_original_sample / 2 + 0.5).clamp(0, 1)
        recon_np = phase_image.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
        final_psnr = float(compute_psnr(recon_np, gt_np))

        # Record the run (this stores [init] + per-step MCE trace under the hood)
        rec.record_run(init_mce=float(init_mce), mce_trace_per_step=mce_trace, final_psnr=final_psnr)

    # After all runs for this image: convert to DataFrames
    tidy_df, summary_df = rec.to_dataframes()

    # Save per-image (optional)
    ts = time.strftime("%Y%m%d-%H%M%S")
    tidy_path    = os.path.join(export_dir, f"{j:05d}_runs_tidy_{ts}.csv")
    summary_path = os.path.join(export_dir, f"{j:05d}_summary_{ts}.csv")
    tidy_df.to_csv(tidy_path, index=False)
    summary_df.to_csv(summary_path, index=False)

    # Accumulate for global combined files
    all_tidy.append(tidy_df)
    all_summary.append(summary_df)

# Write global combined CSVs (all images)
tidy_all = pd.concat(all_tidy, ignore_index=True)
summary_all = pd.concat(all_summary, ignore_index=True)
tidy_all_path = os.path.join(export_dir, f"ALL_images_runs_tidy_{n_step}.csv")
summary_all_path = os.path.join(export_dir, f"ALL_images_summary_wide_{n_step}.csv")
tidy_all.to_csv(tidy_all_path, index=False)
summary_all.to_csv(summary_all_path, index=False)

# Optional: write one consolidated Excel with two sheets
# xlsx_path = os.path.join(export_dir, "ALL_images_results.xlsx")
# with pd.ExcelWriter(xlsx_path, engine="xlsxwriter") as writer:
#    tidy_all.to_excel(writer, sheet_name="runs_tidy", index=False)
#    summary_all.to_excel(writer, sheet_name="summary_wide", index=False)

print("Saved:", tidy_all_path, summary_all_path, xlsx_path)

In [None]:
# --- Long-format exporters (place this right after you have tidy_all & summary_all) ---
import os
import re
import pandas as pd
import numpy as np

# 0) Make sure your export dir exists
export_dir = "./dm_results_exports"
os.makedirs(export_dir, exist_ok=True)

# 1) Save per-run tidy exactly as-is (already long)
tidy_all_path = os.path.join(export_dir, f"ALL_images_runs_tidy_{n_step}.csv")
tidy_all.to_csv(tidy_all_path, index=False)

# 2) Convert per-image "wide" summary to long, with clean step ordering
def _step_sort_key(lbl: str):
    # order: initial first, then t<...> descending (t1000, t950, ..., t0)
    if lbl == "initial":
        return (0, 0)  # smallest
    m = re.fullmatch(r"t(\d+)", lbl)
    if m:
        return (1, -int(m.group(1)))  # descending by numeric
    # put anything unknown last, stable order
    return (2, lbl)

def summary_wide_to_long(summary_wide_df: pd.DataFrame) -> pd.DataFrame:
    """
    Input columns include:
      - image
      - final_psnr_avg
      - MCE_<label>, dInit_<label>, dPrev_<label> for labels in {'initial','tXXXX',...}
    Output long DF (one row per image, per step label):
      image, step_idx, step_label, avg_mce, dInit, dPrev
    Plus an extra row with step_label='final_psnr' containing final_psnr_avg.
    """
    rows = []
    for _, row in summary_wide_df.iterrows():
        img = row["image"]
        # discover labels present by scanning MCE_* keys
        labels = [k.replace("MCE_", "") for k in row.index if str(k).startswith("MCE_")]
        labels = sorted(labels, key=_step_sort_key)

        for k, lbl in enumerate(labels):
            mce   = float(row[f"MCE_{lbl}"])
            dInit = float(row.get(f"dInit_{lbl}", np.nan))
            dPrev = float(row.get(f"dPrev_{lbl}", np.nan))
            rows.append({
                "image": img,
                "step_idx": k,
                "step_label": lbl,
                "avg_mce": mce,
                "dInit": dInit,
                "dPrev": dPrev,
            })

        # append PSNR as one special row (no MCE/diffs there)
        rows.append({
            "image": img,
            "step_idx": len(labels),          # after the last step
            "step_label": "final_psnr",
            "avg_mce": np.nan,
            "dInit": np.nan,
            "dPrev": np.nan,
            "final_psnr_avg": float(row["final_psnr_avg"]),
        })

    return pd.DataFrame(rows)

# 5 decimal places everywhere; no scientific notation
step_averages_long = summary_wide_to_long(summary_all)
step_averages_long_path = os.path.join(export_dir, f"ALL_images_step_averages_long_{n_step}.csv")
step_averages_long.to_csv(step_averages_long_path, index=False)


# 3) (Optional) Ultra-tidy molten table: one value per row (handy for seaborn/plotnine)
#    Produces rows like: (image, step_idx, step_label, metric, value)
def melt_step_averages_long(df: pd.DataFrame) -> pd.DataFrame:
    part_mce = df.loc[df["step_label"] != "final_psnr",
                      ["image","step_idx","step_label","avg_mce","dInit","dPrev"]].copy()
    melted_mce = part_mce.melt(id_vars=["image","step_idx","step_label"],
                               value_vars=["avg_mce","dInit","dPrev"],
                               var_name="metric",
                               value_name="value")
    part_psnr = df.loc[df["step_label"] == "final_psnr",
                       ["image","step_idx","step_label","final_psnr_avg"]].copy()
    part_psnr = part_psnr.rename(columns={"final_psnr_avg":"value"})
    part_psnr["metric"] = "final_psnr_avg"
    part_psnr = part_psnr[["image","step_idx","step_label","metric","value"]]
    return pd.concat([melted_mce, part_psnr], ignore_index=True)

metrics_long = melt_step_averages_long(step_averages_long)
metrics_long_path = os.path.join(export_dir, "ALL_images_metrics_long.csv")
metrics_long.to_csv(metrics_long_path, index=False)

tidy_all.to_csv(tidy_all_path, index=False, float_format="%.5f")
step_averages_long.to_csv(step_averages_long_path, index=False, float_format="%.5f")
metrics_long.to_csv(metrics_long_path, index=False, float_format="%.5f")

print("Saved long-format files:")
print(" • Per-run tidy:", tidy_all_path)
print(" • Per-image step averages (long):", step_averages_long_path)
print(" • (Optional) Ultra-tidy metric/value table:", metrics_long_path)


No Pandas -- Try This

In [12]:
# ======== NO-PANDAS RECORDER + EXPORTS ========
import os, time, json, math, csv
import numpy as np

# ----- formatting helpers -----
# --- helper: pretty-print an MCE trace for one run ---
def print_mce_trace(run_idx: int, mces, timesteps):
    """
    mces: list[Tensor/float] with [init_mce, mce_at_t_0, mce_at_t_1, ...]
    timesteps: iterable of scheduler timesteps aligned with mces[1:]
    """
    # ensure plain floats for printing
    mces_f = [float(m) for m in mces]

    print(f"\nMCE trace for run {run_idx + 1}:")
    print(f"  init: {mces_f[0]:.4f}")
    for k, (t, m) in enumerate(zip(timesteps, mces_f[1:]), 1):
        # right-aligned columns: step, t, value
        print(f"  step {k:3d} (t={int(t):5d}): {m:8.4f}")
    print()  # blank line for spacing
    
def is_nan(x):
    try:
        return np.isnan(x)
    except Exception:
        return False

def fmt5(x):
    """Format floats to 5 decimals; return '' for NaN/None."""
    if x is None or is_nan(x):
        return ""
    return f"{float(x):.5f}"

def fmt2(x):
    if x is None or is_nan(x):
        return ""
    return f"{float(x):.2f}"

# ----- step label helpers -----
def _make_step_labels(scheduler, include_init=True):
    labels = []
    if include_init:
        labels.append("initial")
    for t in list(scheduler.timesteps):  # typically descending
        labels.append(f"t{int(t)}")
    return labels

def _step_sort_key(lbl: str):
    # order: initial first, then t<...> descending (t1000, t950, ..., t0)
    if lbl == "initial":
        return (0, 0)
    if lbl.startswith("t") and lbl[1:].isdigit():
        return (1, -int(lbl[1:]))
    return (2, lbl)

# ----- recorder without pandas -----
class DMRunRecorderNP:
    """
    Accumulates per-run MCE traces and PSNRs for one image,
    then produces tidy rows and a single wide summary dict.
    """
    def __init__(self, image_id, scheduler):
        self.image_id = image_id
        self.scheduler = scheduler
        self.step_labels = _make_step_labels(scheduler, include_init=True)
        self.num_steps = len(self.step_labels)
        self._per_run_mce_traces = []  # list[list[float]] length = num_steps
        self._per_run_final_psnr = []  # list[float]

    def record_run(self, init_mce, mce_trace_per_step, final_psnr):
        assert isinstance(mce_trace_per_step, (list, tuple)), "mce_trace_per_step must be list/tuple"
        trace = [float(init_mce)] + [float(x) for x in mce_trace_per_step]
        if len(trace) != self.num_steps:
            raise ValueError(
                f"MCE trace length mismatch for image {self.image_id}: "
                f"got {len(trace)}, expected {self.num_steps} "
                f"(= 1 init + {len(self.step_labels)-1} steps)"
            )
        self._per_run_mce_traces.append(trace)
        self._per_run_final_psnr.append(float(final_psnr))

    def to_records(self):
        """
        Returns:
          tidy_rows: list of dicts: {image, run, step_idx, step_label, mce}
          summary_row: dict with keys:
            - 'image', 'final_psnr_avg'
            - 'MCE_<label>', 'dInit_<label>', 'dPrev_<label>' for each step label
        """
        # --- tidy rows (long per-run) ---
        tidy_rows = []
        for r, trace in enumerate(self._per_run_mce_traces, start=1):
            for k, (label, mce_val) in enumerate(zip(self.step_labels, trace)):
                tidy_rows.append({
                    "image": self.image_id,
                    "run": r,
                    "step_idx": k,
                    "step_label": label,
                    "mce": float(mce_val),
                })

        # --- per-image averages per step ---
        traces = np.array(self._per_run_mce_traces, dtype=float)  # (n_runs, num_steps)
        step_means = traces.mean(axis=0) if traces.size > 0 else np.full(self.num_steps, np.nan)
        dInit = step_means - step_means[0] if traces.size > 0 else np.full(self.num_steps, np.nan)
        dPrev = np.full(self.num_steps, np.nan)
        if traces.size > 0:
            dPrev[1:] = step_means[1:] - step_means[:-1]

        psnr_avg = float(np.mean(self._per_run_final_psnr)) if len(self._per_run_final_psnr) else np.nan

        summary_row = {"image": self.image_id, "final_psnr_avg": psnr_avg}
        for lbl, m, di, dp in zip(self.step_labels, step_means, dInit, dPrev):
            summary_row[f"MCE_{lbl}"]   = float(m)
            summary_row[f"dInit_{lbl}"] = float(di)
            summary_row[f"dPrev_{lbl}"] = float(dp)
        return tidy_rows, summary_row

# ----- CSV writing helpers (no pandas) -----
def write_runs_tidy_csv(path, rows):
    # rows: list of dicts with keys ["image","run","step_idx","step_label","mce"]
    header = ["image","run","step_idx","step_label","mce"]
    with open(path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(header)
        for r in rows:
            w.writerow([
                r["image"], int(r["run"]), int(r["step_idx"]), r["step_label"], fmt5(r["mce"])
            ])

def build_summary_wide_header(step_labels):
    header = ["image", "final_psnr_avg"]
    for lbl in step_labels:
        header += [f"MCE_{lbl}", f"dInit_{lbl}", f"dPrev_{lbl}"]
    return header

def write_summary_wide_csv(path, rows, step_labels):
    # rows: list of summary_row dicts from recorder
    header = build_summary_wide_header(step_labels)
    with open(path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(header)
        for r in rows:
            row_out = [r["image"], fmt5(r.get("final_psnr_avg", np.nan))]
            for lbl in step_labels:
                row_out.append(fmt5(r.get(f"MCE_{lbl}", np.nan)))
                row_out.append(fmt5(r.get(f"dInit_{lbl}", np.nan)))
                row_out.append(fmt5(r.get(f"dPrev_{lbl}", np.nan)))
            w.writerow(row_out)

def summary_wide_to_step_averages_long(summary_rows, step_labels):
    """
    summary_rows: list of summary_row dicts
    Returns list of dict rows:
      - for MCE steps: {image, step_idx, step_label, avg_mce, dInit, dPrev}
      - for PSNR:      {image, step_idx, step_label='final_psnr', final_psnr_avg}
    """
    out = []
    # ensure sorted step order
    ordered = sorted(step_labels, key=_step_sort_key)
    for r in summary_rows:
        img = r["image"]
        for k, lbl in enumerate(ordered):
            out.append({
                "image": img,
                "step_idx": k,
                "step_label": lbl,
                "avg_mce": r.get(f"MCE_{lbl}", np.nan),
                "dInit":   r.get(f"dInit_{lbl}", np.nan),
                "dPrev":   r.get(f"dPrev_{lbl}", np.nan),
            })
        out.append({
            "image": img,
            "step_idx": len(ordered),
            "step_label": "final_psnr",
            "final_psnr_avg": r.get("final_psnr_avg", np.nan),
        })
    return out

def write_step_averages_long_csv(path, rows):
    # rows: from summary_wide_to_step_averages_long
    # Columns: image, step_idx, step_label, avg_mce, dInit, dPrev, (optional) final_psnr_avg
    header = ["image","step_idx","step_label","avg_mce","dInit","dPrev","final_psnr_avg"]
    with open(path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(header)
        for r in rows:
            if r["step_label"] == "final_psnr":
                w.writerow([r["image"], int(r["step_idx"]), "final_psnr", "", "", "", fmt5(r["final_psnr_avg"])])
            else:
                w.writerow([
                    r["image"], int(r["step_idx"]), r["step_label"],
                    fmt5(r["avg_mce"]), fmt5(r["dInit"]), fmt5(r["dPrev"]), ""
                ])

def melt_metrics_long(step_avg_long_rows):
    """
    Input: rows from summary_wide_to_step_averages_long
    Output: list of dicts with columns: image, step_idx, step_label, metric, value
    Where metric ∈ {'avg_mce','dInit','dPrev'} for MCE steps, or 'final_psnr_avg' for PSNR row.
    """
    out = []
    for r in step_avg_long_rows:
        if r["step_label"] == "final_psnr":
            out.append({
                "image": r["image"], "step_idx": r["step_idx"], "step_label": "final_psnr",
                "metric": "final_psnr_avg", "value": r.get("final_psnr_avg", np.nan)
            })
        else:
            for mkey in ["avg_mce","dInit","dPrev"]:
                out.append({
                    "image": r["image"], "step_idx": r["step_idx"], "step_label": r["step_label"],
                    "metric": mkey, "value": r.get(mkey, np.nan)
                })
    return out

def write_metrics_long_csv(path, rows):
    header = ["image","step_idx","step_label","metric","value"]
    with open(path, "w", newline="") as f:
        w = csv.writer(f)
        w.writerow(header)
        for r in rows:
            w.writerow([r["image"], int(r["step_idx"]), r["step_label"], r["metric"], fmt5(r["value"])])

# ======== INTEGRATE WITH YOUR MAIN LOOP (replace pandas bits) ========

# --- Config (as you already have) ---
out = []
n_step = 20
scheduler.set_timesteps(num_inference_steps=n_step)
step_size = 1000 // n_step
dtype = torch.float32

# Save dir
export_dir = "./dm_results_exports"
os.makedirs(export_dir, exist_ok=True)

all_tidy_rows = []     # accumulate all images' tidy rows
all_summary_rows = []  # accumulate all images' summary rows
global_step_labels = _make_step_labels(scheduler, include_init=True)

for j in range(19, 20):  # adjust range as needed
    filename = f"{j:05d}.png"
    filepath = os.path.join("/egr/research-slim/liangs16/Measurment_Consistent_Diffusion_Trajectory/data/demo_remain/", filename)

    gt_img = Image.open(filepath).convert("RGB")
    ref_numpy = np.array(gt_img).astype(np.float32) / 255.0
    x = ref_numpy * 2 - 1
    x = x.transpose(2, 0, 1)
    ref_img = torch.tensor(x, dtype=dtype, device=device).unsqueeze(0)
    ref_img.requires_grad = False

    # Forward measurement model (Ax + n)
    if measure_config['operator']['name'] == 'inpainting':
        measurement_cond_fn = partial(cond_method.conditioning, mask=mask)
        sample_fn = partial(sample_fn, measurement_cond_fn=measurement_cond_fn)
        y = operator.forward(ref_img, mask=mask)
        y_n = noiser(y)
    else:
        y = operator.forward(ref_img)
        y_n = noiser(y)

    gt = (ref_img / 2 + 0.5)
    gt_np = gt.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()

    num_runs = 10
    print(f"Reconstructing Image {filename}...")

    rec = DMRunRecorderNP(image_id=f"{j:05d}", scheduler=scheduler)

    for run_iter in range(num_runs):
        # fresh init
        # input = torch.randn((1, 3, 256, 256), device=device, dtype=dtype).clone().detach().requires_grad_(True)
        # noise = torch.randn_like(input) * ((1 - scheduler.alphas_cumprod[-1]) ** 0.5)
        # input = input * (scheduler.alphas_cumprod[-1] ** 0.5) + noise  # avoid torch.tensor(input)
        
        input =torch.randn((1, 3, 256, 256), device=device, dtype=dtype).clone().detach().requires_grad_(True)
        noise = torch.randn(input.shape)*((1-scheduler.alphas_cumprod[-1])**0.5)
        input = torch.tensor(input)*((scheduler.alphas_cumprod[-1])**0.5) + noise.to(device)


        # Initial MCE
        init_mce = compute_mce(input, y_n.abs(), target_shape=y_n.shape)

        mce_trace = []
        pred_original_sample = None

        for i, t in enumerate(scheduler.timesteps):
            prev_timestep = t - step_size
            alpha_prod_t = scheduler.alphas_cumprod[t]
            alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.alphas_cumprod[0]
            beta_prod_t = 1 - alpha_prod_t
            sqrt_one_minus_alpha_cumprod = beta_prod_t ** 0.5

            # Your inner optimizer (ensure it builds a fresh graph per iter, as discussed earlier)
            input, pred_original_sample, noise_pred = optimize_input(
                input.clone(),
                sqrt_one_minus_alpha_cumprod,
                alpha_prod_t ** 0.5,
                t,
                num_steps=20,
                learning_rate=0.075
            )

            phase_image = (pred_original_sample / 2 + 0.5).clamp(0, 1)
            mce_val = compute_mce(phase_image, y_n.abs(), target_shape=y_n.shape)
            mce_trace.append(float(mce_val))

            input = pred_original_sample * (alpha_prod_t_prev ** 0.5) + (1 - alpha_prod_t_prev) ** 0.5 * torch.randn_like(input)

        # Final PSNR
        phase_image = (pred_original_sample / 2 + 0.5).clamp(0, 1)
        recon_np = phase_image.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
        final_psnr = float(compute_psnr(recon_np, gt_np))

        rec.record_run(init_mce=float(init_mce), mce_trace_per_step=mce_trace, final_psnr=final_psnr)

    # Convert to tidy rows + per-image summary row (no pandas)
    tidy_rows, summary_row = rec.to_records()
    # Write per-image (optional)
    ts = time.strftime("%Y%m%d-%H%M%S")
    per_image_tidy_path = os.path.join(export_dir, f"{j:05d}_runs_tidy_{ts}_{j}.csv")
    per_image_summary_path = os.path.join(export_dir, f"{j:05d}_summary_{ts}_{j}.csv")
    write_runs_tidy_csv(per_image_tidy_path, tidy_rows)
    write_summary_wide_csv(per_image_summary_path, [summary_row], rec.step_labels)

    # Accumulate
    all_tidy_rows.extend(tidy_rows)
    all_summary_rows.append(summary_row)

# ---- Write combined CSVs (ALL images) ----
tidy_all_path = os.path.join(export_dir, f"ALL_images_runs_tidy_{n_step}_{j}.csv")
summary_wide_all_path = os.path.join(export_dir, f"ALL_images_summary_wide_{n_step}_{j}.csv")
write_runs_tidy_csv(tidy_all_path, all_tidy_rows)
write_summary_wide_csv(summary_wide_all_path, all_summary_rows, global_step_labels)

# ---- Long per-image step averages (avg_mce, dInit, dPrev) + PSNR row ----
step_averages_long_rows = summary_wide_to_step_averages_long(all_summary_rows, global_step_labels)
step_averages_long_path = os.path.join(export_dir, f"ALL_images_step_averages_long_{n_step}_{j}.csv")
write_step_averages_long_csv(step_averages_long_path, step_averages_long_rows)

# ---- Ultra-tidy metric/value table (optional) ----
metrics_long_rows = melt_metrics_long(step_averages_long_rows)
metrics_long_path = os.path.join(export_dir, f"ALL_images_metrics_long_{n_step}_{j}.csv")
write_metrics_long_csv(metrics_long_path, metrics_long_rows)

print("Saved:")
print(" • Per-run tidy:", tidy_all_path)
print(" • Per-image summary wide:", summary_wide_all_path)
print(" • Per-image step averages (long):", step_averages_long_path)
print(" • Ultra-tidy metric/value table:", metrics_long_path)


Reconstructing Image 00019.png...


  input = torch.tensor(input)*((scheduler.alphas_cumprod[-1])**0.5) + noise.to(device)


Saved:
 • Per-run tidy: ./dm_results_exports/ALL_images_runs_tidy_20_19.csv
 • Per-image summary wide: ./dm_results_exports/ALL_images_summary_wide_20_19.csv
 • Per-image step averages (long): ./dm_results_exports/ALL_images_step_averages_long_20_19.csv
 • Ultra-tidy metric/value table: ./dm_results_exports/ALL_images_metrics_long_20_19.csv


Old one

In [None]:
## Make one file for multi run testing 

out = []
n_step = 20
scheduler.set_timesteps(num_inference_steps=n_step)
step_size = 1000//n_step
dtype = torch.float32

# Loop through images 00011, 00012, …, 00019 etc... i

for j in range(11, 12):
    filename = f"{j:05d}.png"            # 00011, 00012, …, 00019 etc... i
    filepath = os.path.join("/egr/research-slim/liangs16/Measurment_Consistent_Diffusion_Trajectory/data/demo_remain/", filename)
    gt_img = Image.open(filepath).convert("RGB")
    #shutil.copy(gt_img_path, os.path.join(logdir, 'gt.png'))
    ref_numpy = np.array(gt_img).astype(np.float32) / 255.0
    x = ref_numpy * 2 - 1
    x = x.transpose(2, 0, 1)
    ref_img = torch.Tensor(x).to(dtype).to(device).unsqueeze(0)
    ref_img.requires_grad = False
    #ref_img = ref_img.to(device)
    if measure_config['operator'] ['name'] == 'inpainting':
        mask = mask
        measurement_cond_fn = partial(cond_method.conditioning, mask=mask)
        sample_fn = partial(sample_fn, measurement_cond_fn=measurement_cond_fn)

        # Forward measurement model (Ax + n)
        y = operator.forward(ref_img, mask=mask)
        y_n = noiser(y)

    else: 
        # Forward measurement model (Ax + n)
        y = operator.forward(ref_img)
        y_n = noiser(y)
        #y_n = torch.clamp(y_n,-1,1)
    #y_n.requires_grad = False

    gt = (ref_img/2+0.5)
    resize_transform = torchvision.transforms.Resize((256,256))

    # Multi-Run Testing

    num_runs = 2
    run_results = []
    mce_results = []
    per_image_results = []        # optional: collect per-run records for this image

    print(f"Reconstructing Image {filename}...")
    for run_iter in range(num_runs):
        # print('Run Number:', run_iter + 1)
    
        # Clear output visualization in case didn't rerun
        out_visual = []
        measurement_errors = []
    
        input =torch.randn((1, 3, 256, 256), device=device, dtype=dtype).clone().detach().requires_grad_(True)
        noise = torch.randn(input.shape)*((1-scheduler.alphas_cumprod[-1])**0.5)
        input = torch.tensor(input)*((scheduler.alphas_cumprod[-1])**0.5) + noise.to(device)
    
        # Print Initial MCE
        init_mce = compute_mce(input, y_n.abs(), target_shape=y_n.shape) # In practice we don't have access to GT during DMs, otherwise we just know GT
        # print(f"Initial MCE:, {init_mce: .4f}")

        for i, t in enumerate(scheduler.timesteps):
            prev_timestep = t - step_size

            alpha_prod_t = scheduler.alphas_cumprod[t]
            alpha_prod_t_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.alphas_cumprod[0]

            beta_prod_t = 1 - alpha_prod_t
            sqrt_one_minus_alpha_cumprod = beta_prod_t**0.5


            #for k in range(1):
            input, pred_original_sample, noise_pred= optimize_input(input.clone(), sqrt_one_minus_alpha_cumprod, alpha_prod_t**0.5, t, num_steps=20, learning_rate=0.075)
            #    input= pred_original_sample * alpha_prod_t**0.5+(1-alpha_prod_t)**0.5*torch.randn(input.size()).to(device)
        
            phase_image = (pred_original_sample / 2 + 0.5).clamp(0, 1)

            # Compute measurement error with correct shape
            mce = compute_mce(phase_image, y_n.abs(), target_shape=y_n.shape)
            measurement_errors.append(mce)
        
            input = pred_original_sample * alpha_prod_t_prev**0.5+(1-alpha_prod_t_prev)**0.5*torch.randn(input.size()).to(device)

            # print(f"Time: {t}, , MCE: {mce:.4f}")
        input = (input/2+0.5).clamp(0, 1)
        phase_image = (pred_original_sample/2+0.5).clamp(0, 1)
        #input = (input + 1) / 2
        phase_image = phase_image.squeeze(0).permute(1, 2, 0).detach().cpu().numpy()
    
        run_results.append(phase_image)
        mce_results.append(mce)
        
        # (A) show the full MCE trace for this run in a neat block
        print_mce_trace(run_iter, measurement_errors, scheduler.timesteps)

        # (B) store results
        run_results.append(phase_image)
        mce_results.append(mce)  # final mce
        per_image_results.append({
            "run": run_iter + 1,
            "init_mce": float(measurement_errors[0]),
            "final_mce": float(measurement_errors[-1]),
            "mce_seq": [float(v) for v in measurement_errors],  # full trace, including init
        })

        psnr_values = []
    for i in range(num_runs):
        #plt.imshow(run_results[i])
        psnr_value = compute_psnr(np.array(run_results[i]), np.array(gt.cpu().detach().numpy()[0].transpose(1,2,0)))
        psnr_values.append(psnr_value)
        print(f"After diffusion PSNR: {psnr_value} dB, MCE: {mce_results[i]: .4f}")



Reconstructing Image 00011.png...
Run Number: 1
/tmp/ipykernel_435532/1188489366.py:57: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  input = torch.tensor(input)*((scheduler.alphas_cumprod[-1])**0.5) + noise.to(device)
Run Number: 2
Run Number: 3
Run Number: 4
Run Number: 5
Run Number: 6
Run Number: 7
Run Number: 8
Run Number: 9
Run Number: 10
After diffusion PSNR: 28.100968456509925 dB, MCE:  0.0015
After diffusion PSNR: 28.270696040092986 dB, MCE:  0.0015
After diffusion PSNR: 11.455024601889816 dB, MCE:  0.0018
After diffusion PSNR: 28.094056601340395 dB, MCE:  0.0015
After diffusion PSNR: 28.05229052845054 dB, MCE:  0.0015
After diffusion PSNR: 11.66003553083227 dB, MCE:  0.0017
After diffusion PSNR: 11.559687522342362 dB, MCE:  0.0017
After diffusion PSNR: 28.188196534392603 dB, MCE:  0.0015
After diffusion PSNR: 28.174921946261616 dB, MCE:  0.0015
After diffusion PSNR: 27.94670647299117 dB, MCE:  0.0015
Reconstructing Image 00012.png...
Run Number: 1
Run Number: 2
Run Number: 3
Run Number: 4
Run Number: 5
Run Number: 6
Run Number: 7
Run Number: 8
Run Number: 9
Run Number: 10
After diffusion PSNR: 26.940155056594772 dB, MCE:  0.0015
After diffusion PSNR: 26.731576885850366 dB, MCE:  0.0015
After diffusion PSNR: 26.66404642307755 dB, MCE:  0.0015
After diffusion PSNR: 27.3286891149528 dB, MCE:  0.0015
After diffusion PSNR: 27.5034707883839 dB, MCE:  0.0015
After diffusion PSNR: 27.063347006811128 dB, MCE:  0.0015
After diffusion PSNR: 27.474818271877446 dB, MCE:  0.0015
After diffusion PSNR: 27.39552779143446 dB, MCE:  0.0015
After diffusion PSNR: 26.921228538103055 dB, MCE:  0.0015
After diffusion PSNR: 26.60586533247784 dB, MCE:  0.0015
Reconstructing Image 00013.png...
Run Number: 1
Run Number: 2
Run Number: 3
Run Number: 4
Run Number: 5
Run Number: 6
Run Number: 7
Run Number: 8
Run Number: 9
Run Number: 10
After diffusion PSNR: 11.363108193156277 dB, MCE:  0.0017
After diffusion PSNR: 12.083036982990214 dB, MCE:  0.0017
After diffusion PSNR: 16.64904020110488 dB, MCE:  0.0017
After diffusion PSNR: 9.415043187755677 dB, MCE:  0.0017
After diffusion PSNR: 8.866917820733708 dB, MCE:  0.0017
After diffusion PSNR: 9.646573669323892 dB, MCE:  0.0018
After diffusion PSNR: 20.269110522060153 dB, MCE:  0.0016
After diffusion PSNR: 8.729081822900417 dB, MCE:  0.0017
After diffusion PSNR: 18.484052998158578 dB, MCE:  0.0016
After diffusion PSNR: 19.98795395736829 dB, MCE:  0.0016
Reconstructing Image 00014.png...
Run Number: 1
Run Number: 2
Run Number: 3
Run Number: 4
Run Number: 5
Run Number: 6
Run Number: 7
Run Number: 8
Run Number: 9
Run Number: 10
After diffusion PSNR: 26.929716242264906 dB, MCE:  0.0015
After diffusion PSNR: 26.985969714366604 dB, MCE:  0.0015
After diffusion PSNR: 27.1693873825237 dB, MCE:  0.0015
After diffusion PSNR: 27.03242907426711 dB, MCE:  0.0015
After diffusion PSNR: 26.689538557373606 dB, MCE:  0.0015
After diffusion PSNR: 26.8395733024615 dB, MCE:  0.0015
After diffusion PSNR: 14.015525037617174 dB, MCE:  0.0016
After diffusion PSNR: 13.540458112686235 dB, MCE:  0.0016
After diffusion PSNR: 16.99490836422439 dB, MCE:  0.0016
After diffusion PSNR: 13.832507705817783 dB, MCE:  0.0016
Reconstructing Image 00015.png...
Run Number: 1
Run Number: 2
Run Number: 3
Run Number: 4
Run Number: 5
Run Number: 6
Run Number: 7
Run Number: 8
Run Number: 9
Run Number: 10
After diffusion PSNR: 26.792820120253072 dB, MCE:  0.0015
After diffusion PSNR: 27.61752523171412 dB, MCE:  0.0015
After diffusion PSNR: 27.2816830756511 dB, MCE:  0.0015
After diffusion PSNR: 27.251475007559453 dB, MCE:  0.0015
After diffusion PSNR: 23.259608292515747 dB, MCE:  0.0015
After diffusion PSNR: 13.441754281690955 dB, MCE:  0.0016
After diffusion PSNR: 13.142495755413277 dB, MCE:  0.0016
After diffusion PSNR: 27.65922756139282 dB, MCE:  0.0015
After diffusion PSNR: 12.838730532741472 dB, MCE:  0.0015
After diffusion PSNR: 13.055441963791116 dB, MCE:  0.0016
Reconstructing Image 00016.png...
Run Number: 1
Run Number: 2
Run Number: 3
Run Number: 4
Run Number: 5
Run Number: 6
Run Number: 7
Run Number: 8
Run Number: 9
Run Number: 10
After diffusion PSNR: 26.840034383038983 dB, MCE:  0.0015
After diffusion PSNR: 27.54649289923041 dB, MCE:  0.0015
After diffusion PSNR: 27.72653724075265 dB, MCE:  0.0015
After diffusion PSNR: 27.97125742699196 dB, MCE:  0.0015
After diffusion PSNR: 27.875255593484916 dB, MCE:  0.0015
After diffusion PSNR: 27.738405159443943 dB, MCE:  0.0015
After diffusion PSNR: 27.66921234342825 dB, MCE:  0.0015
After diffusion PSNR: 27.64957848482655 dB, MCE:  0.0015
After diffusion PSNR: 27.793223880376967 dB, MCE:  0.0015
After diffusion PSNR: 27.743999570460307 dB, MCE:  0.0015
Reconstructing Image 00017.png...
Run Number: 1
Run Number: 2
Run Number: 3
Run Number: 4
Run Number: 5
Run Number: 6
Run Number: 7
Run Number: 8
Run Number: 9
Run Number: 10
After diffusion PSNR: 5.743272710826254 dB, MCE:  0.0018
After diffusion PSNR: 28.798906000257546 dB, MCE:  0.0016
After diffusion PSNR: 6.072536987315037 dB, MCE:  0.0019
After diffusion PSNR: 29.798010102832784 dB, MCE:  0.0016
After diffusion PSNR: 6.301213807972318 dB, MCE:  0.0019
After diffusion PSNR: 29.87369556376093 dB, MCE:  0.0016
After diffusion PSNR: 6.1885579968220625 dB, MCE:  0.0019
After diffusion PSNR: 5.940199347231086 dB, MCE:  0.0019
After diffusion PSNR: 29.82425631453917 dB, MCE:  0.0016
After diffusion PSNR: 5.7959923188265705 dB, MCE:  0.0018
Reconstructing Image 00018.png...
Run Number: 1
Run Number: 2
Run Number: 3
Run Number: 4
Run Number: 5
Run Number: 6
Run Number: 7
Run Number: 8
Run Number: 9
Run Number: 10
After diffusion PSNR: 24.562092359496937 dB, MCE:  0.0015
After diffusion PSNR: 23.784467321121895 dB, MCE:  0.0015
After diffusion PSNR: 25.996529788705093 dB, MCE:  0.0015
After diffusion PSNR: 24.70376010475503 dB, MCE:  0.0015
After diffusion PSNR: 21.411937894085934 dB, MCE:  0.0015
After diffusion PSNR: 24.390761619050757 dB, MCE:  0.0015
After diffusion PSNR: 24.411338219158726 dB, MCE:  0.0015
After diffusion PSNR: 23.61230814272552 dB, MCE:  0.0015
After diffusion PSNR: 15.035636875019094 dB, MCE:  0.0015
After diffusion PSNR: 26.055713929914674 dB, MCE:  0.0015
Reconstructing Image 00019.png...
Run Number: 1
Run Number: 2
Run Number: 3
Run Number: 4
Run Number: 5
Run Number: 6
Run Number: 7
Run Number: 8
Run Number: 9
Run Number: 10
After diffusion PSNR: 28.03935703206749 dB, MCE:  0.0015
After diffusion PSNR: 8.808278893364239 dB, MCE:  0.0015
After diffusion PSNR: 28.248407550365442 dB, MCE:  0.0015
After diffusion PSNR: 28.072686061773776 dB, MCE:  0.0015
After diffusion PSNR: 28.225377001623496 dB, MCE:  0.0015
After diffusion PSNR: 27.993134702380956 dB, MCE:  0.0015
After diffusion PSNR: 28.604166988025618 dB, MCE:  0.0015
After diffusion PSNR: 28.248771017924174 dB, MCE:  0.0015
After diffusion PSNR: 9.373544984918373 dB, MCE:  0.0016
After diffusion PSNR: 28.186135946278633 dB, MCE:  0.0015
Reconstructing Image 00020.png...
Run Number: 1
Run Number: 2
Run Number: 3
Run Number: 4
Run Number: 5
Run Number: 6
Run Number: 7
Run Number: 8
Run Number: 9
Run Number: 10
After diffusion PSNR: 27.631227521988883 dB, MCE:  0.0015
After diffusion PSNR: 27.329056739101052 dB, MCE:  0.0015
After diffusion PSNR: 27.61261121107919 dB, MCE:  0.0015
After diffusion PSNR: 14.489585813764085 dB, MCE:  0.0016
After diffusion PSNR: 27.75191148168183 dB, MCE:  0.0015
After diffusion PSNR: 17.78303706664373 dB, MCE:  0.0015
After diffusion PSNR: 27.783757170769945 dB, MCE:  0.0015
After diffusion PSNR: 15.750993342903797 dB, MCE:  0.0016
After diffusion PSNR: 27.142005428689583 dB, MCE:  0.0015
After diffusion PSNR: 27.940152676962033 dB, MCE:  0.0015
Reconstructing Image 00021.png...
Run Number: 1
Run Number: 2
Run Number: 3
Run Number: 4
Run Number: 5
Run Number: 6
Run Number: 7
Run Number: 8
Run Number: 9
Run Number: 10
After diffusion PSNR: 27.633408238307045 dB, MCE:  0.0015
After diffusion PSNR: 27.86082722449575 dB, MCE:  0.0015
After diffusion PSNR: 27.493037341698866 dB, MCE:  0.0015
After diffusion PSNR: 25.957219039690766 dB, MCE:  0.0015
After diffusion PSNR: 27.46857552061432 dB, MCE:  0.0015
After diffusion PSNR: 28.08782169024601 dB, MCE:  0.0015
After diffusion PSNR: 27.702768071998 dB, MCE:  0.0015
After diffusion PSNR: 27.42008259849309 dB, MCE:  0.0015
After diffusion PSNR: 27.52073730703134 dB, MCE:  0.0015
After diffusion PSNR: 27.579953984064403 dB, MCE:  0.0015
Reconstructing Image 00022.png...
Run Number: 1
Run Number: 2
Run Number: 3
Run Number: 4
Run Number: 5
Run Number: 6
Run Number: 7
Run Number: 8
Run Number: 9
Run Number: 10
After diffusion PSNR: 14.901366177258575 dB, MCE:  0.0015
After diffusion PSNR: 12.71349172138037 dB, MCE:  0.0016
After diffusion PSNR: 15.440604824776132 dB, MCE:  0.0015
After diffusion PSNR: 12.82896302130289 dB, MCE:  0.0016
After diffusion PSNR: 20.573017083742855 dB, MCE:  0.0015
After diffusion PSNR: 13.546121083431686 dB, MCE:  0.0016
After diffusion PSNR: 14.784194750770382 dB, MCE:  0.0016
After diffusion PSNR: 13.367328713870377 dB, MCE:  0.0016
After diffusion PSNR: 21.3134021394038 dB, MCE:  0.0015
After diffusion PSNR: 21.251997337143337 dB, MCE:  0.0015
Reconstructing Image 00023.png...
Run Number: 1
Run Number: 2
Run Number: 3
Run Number: 4
Run Number: 5
Run Number: 6
Run Number: 7
Run Number: 8
Run Number: 9
Run Number: 10
After diffusion PSNR: 27.621637755416153 dB, MCE:  0.0015
After diffusion PSNR: 27.811321954763145 dB, MCE:  0.0015
After diffusion PSNR: 27.895848858924044 dB, MCE:  0.0015
After diffusion PSNR: 27.534319962033855 dB, MCE:  0.0015
After diffusion PSNR: 27.61300921728704 dB, MCE:  0.0015
After diffusion PSNR: 27.86293750311696 dB, MCE:  0.0015
After diffusion PSNR: 27.912073257060776 dB, MCE:  0.0015
After diffusion PSNR: 28.11110064973633 dB, MCE:  0.0015
After diffusion PSNR: 27.768674458919946 dB, MCE:  0.0015
After diffusion PSNR: 28.09685512216681 dB, MCE:  0.0015
Reconstructing Image 00024.png...
Run Number: 1
Run Number: 2
Run Number: 3
Run Number: 4
Run Number: 5
Run Number: 6
Run Number: 7
Run Number: 8
Run Number: 9
Run Number: 10
After diffusion PSNR: 17.48211655066524 dB, MCE:  0.0015
After diffusion PSNR: 23.277887829279514 dB, MCE:  0.0014
After diffusion PSNR: 23.51545800558304 dB, MCE:  0.0014
After diffusion PSNR: 11.037143934600639 dB, MCE:  0.0016
After diffusion PSNR: 24.949964874942697 dB, MCE:  0.0014
After diffusion PSNR: 13.931928073919355 dB, MCE:  0.0016
After diffusion PSNR: 21.19443340334766 dB, MCE:  0.0014
After diffusion PSNR: 13.478402535412243 dB, MCE:  0.0016
After diffusion PSNR: 16.017040692484013 dB, MCE:  0.0016
After diffusion PSNR: 14.512906661177137 dB, MCE:  0.0015

## Trying: Enforce MCE Thresehold and Restrarting ##

Note: Both dont work

In [None]:
import os, torch, numpy as np, torchvision
import torch.nn.functional as F
from functools import partial
from PIL import Image

# --- assumes you already have: model, device, operator, noiser, measure_config, fft2c_new, ifft2c_new, DDIMScheduler ---

# ========== helpers ==========

def compute_mce(reconstructed_image, measured_magnitude, target_shape=None):
    """L2 error between Fourier magnitude of reconstruction (in [0,1]) and measured magnitude."""
    x_hat = reconstructed_image.clamp(0, 1)
    if target_shape is not None:
        _, _, H, W = x_hat.shape
        target_H, target_W = target_shape[-2], target_shape[-1]
        pad_h = (target_H - H) // 2
        pad_w = (target_W - W) // 2
        x_hat = F.pad(x_hat, (pad_w, pad_w, pad_h, pad_h))
    if not torch.is_complex(x_hat):
        x_hat = x_hat.type(torch.complex64)
    fft_x_hat = torch.view_as_complex(fft2c_new(torch.view_as_real(x_hat)))
    return (fft_x_hat.abs() - measured_magnitude).pow(2).mean().item()

def compute_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse <= 0:
        return float('inf')
    return float(20 * np.log10(1.0 / (mse**0.5)))

@torch.no_grad()
def to_start_latent_from_x0(x0, sqrt_alpha_bar_T, sqrt_one_minus_alpha_bar_T, noise_scale=1.0):
    """
    Given x0 in [-1,1], sample a start latent x_T = sqrt(alpha_bar_T)*x0 + sqrt(1-alpha_bar_T)*noise.
    """
    eps = torch.randn_like(x0) * noise_scale
    return sqrt_alpha_bar_T * x0 + sqrt_one_minus_alpha_bar_T * eps

def optimize_input(
    input_tensor, sqrt_one_minus_alpha_cumprod, sqrt_alpha_cumprod, t, y_n,
    num_steps=20, learning_rate=0.01
):
    """
    One inner optimization block producing (x_t_next, pred_x0, noise_pred).
    Uses your model(x_t, t) -> eps_hat, and a magnitude data term vs y_n.
    """
    x_t = input_tensor.clone().detach().requires_grad_(True)
    opt = torch.optim.Adam([x_t], lr=learning_rate)
    tt = (torch.ones(1, device=x_t.device, dtype=torch.long) * int(t))

    for _ in range(num_steps):
        opt.zero_grad()
        noise_pred = model(x_t, tt)[:, :3]
        pred_x0 = (x_t - sqrt_one_minus_alpha_cumprod * noise_pred) / sqrt_alpha_cumprod
        pred_x0 = torch.clamp(pred_x0, -1, 1)

        # pad and FFT magnitude for the data term
        pad = int((2 / 8.0) * 256)
        x = pred_x0 * 0.5 + 0.5
        x = F.pad(x, (pad, pad, pad, pad))
        if not torch.is_complex(x):
            x = x.type(torch.complex64)
        fft2_m = torch.view_as_complex(fft2c_new(torch.view_as_real(x)))
        out_mag = fft2_m.abs()

        data_loss = torch.norm(out_mag - y_n)**2
        # small tether to keep x_t from drifting too wildly (optional)
        reg_loss = 0.1 * torch.norm(input_tensor.detach() - x_t)**2
        loss = data_loss + reg_loss
        loss.backward()
        opt.step()

    # recompute noise from the final pair (x_t, pred_x0)
    noise = (x_t - sqrt_alpha_cumprod * pred_x0) / sqrt_one_minus_alpha_cumprod
    return x_t.detach(), pred_x0.detach(), noise.detach()

# ========== main loop with threshold-triggered restarts ==========

def run_threshold_restart_experiment(
    image_indices=range(11, 21),
    num_runs=1,
    n_step=20,
    mce_threshold=0.00148,
    max_restarts=5,
    stop_on_hit=False,              # if True, stop as soon as threshold hit once
    restart_noise_scale=0.5,       # 0 = deterministic reuse, 1 = full start noise
    patience_restarts=None,         # e.g. 2: stop if no best-MCE improvement for 2 restarts
    inner_opt_steps=20,
    inner_lr=0.075,
    seed=41
):
    # reproducibility
    torch.manual_seed(seed); np.random.seed(seed)

    scheduler = DDIMScheduler()
    scheduler.set_timesteps(num_inference_steps=n_step)
    step_size = 1000 // n_step

    dtype = torch.float32
    results_table = []  # for stats / correlation

    # Precompute start-time alphas
    T = int(scheduler.timesteps[0].item())  # largest t (close to 1000)
    alpha_bar_T = scheduler.alphas_cumprod[T].item()
    sqrt_alpha_bar_T = alpha_bar_T**0.5
    sqrt_one_minus_alpha_bar_T = (1.0 - alpha_bar_T)**0.5

    for j in image_indices:
        filename = f"{j:05d}.png"
        filepath = os.path.join("/egr/research-slim/liangs16/Measurment_Consistent_Diffusion_Trajectory/data/demo_remain/", filename)
        gt_img = Image.open(filepath).convert("RGB")
        ref_numpy = np.array(gt_img).astype(np.float32) / 255.0
        x = ref_numpy * 2 - 1
        x = x.transpose(2, 0, 1)
        ref_img = torch.tensor(x, dtype=dtype, device=device).unsqueeze(0)
        gt = (ref_img / 2 + 0.5)

        # set up measurement
        if measure_config['operator']['name'] == 'inpainting':
            # you can keep your partial(...) pathway if used elsewhere
            y = operator.forward(ref_img, mask=mask)
        else:
            y = operator.forward(ref_img)
        y_n = noiser(y)  # measured magnitude domain used in losses/MCE

        print(f"\nReconstructing Image {filename}...")
        for run_iter in range(num_runs):
            print(f"  Run {run_iter+1}/{num_runs}")

            # ----- initialize x_T -----
            x0_init = torch.randn((1, 3, 256, 256), device=device, dtype=dtype)
            x_T = to_start_latent_from_x0(
                x0_init, sqrt_alpha_bar_T, sqrt_one_minus_alpha_bar_T, noise_scale=1.0
            )

            best_mce_overall = float('inf')
            best_psnr_overall = -float('inf')
            best_final_image = None

            restart = 0
            no_improve = 0
            while True:
                # one full pass over all timesteps
                measurement_errors = []
                input_t = x_T.clone()

                for i, t in enumerate(scheduler.timesteps):
                    prev_timestep = int(t) - step_size
                    alpha_prod_t = scheduler.alphas_cumprod[int(t)]
                    alpha_prod_t_prev = (
                        scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0
                        else scheduler.alphas_cumprod[0]
                    )
                    beta_prod_t = 1 - alpha_prod_t
                    sqrt_one_minus_alpha_cumprod = beta_prod_t**0.5
                    sqrt_alpha_cumprod = alpha_prod_t**0.5
                    sqrt_alpha_cumprod_prev = alpha_prod_t_prev**0.5

                    input_t, pred_x0, noise_pred = optimize_input(
                        input_t, sqrt_one_minus_alpha_cumprod, sqrt_alpha_cumprod,
                        int(t), y_n, num_steps=inner_opt_steps, learning_rate=inner_lr
                    )

                    # phase image in [0,1]
                    phase_image = (pred_x0 / 2 + 0.5).clamp(0, 1)
                    mce = compute_mce(phase_image, y_n.abs(), target_shape=y_n.shape)
                    measurement_errors.append(mce)
                    print(f"    t={int(t):4d}  MCE={mce:.6f}")

                    # DDIM/ODE next-step latent (ancestral)
                    eps = torch.randn_like(input_t)
                    input_t = sqrt_alpha_cumprod_prev * pred_x0 + (1 - alpha_prod_t_prev)**0.5 * eps

                    # threshold trigger: restart ASAP with warm-start if met
                    if mce >= mce_threshold:
                        print(f"    *** threshold hit ({mce:.6f} ≤ {mce_threshold}) → scheduling restart")
                        # prepare new x_T from current x0 (pred_x0)
                        x_T = to_start_latent_from_x0(
                            pred_x0, sqrt_alpha_bar_T, sqrt_one_minus_alpha_bar_T,
                            noise_scale=restart_noise_scale
                        )
                        break  # break this pass early to restart

                # end of pass: compute final image & stats of this pass
                final_x0 = pred_x0.detach()
                final_img = (final_x0 / 2 + 0.5).clamp(0, 1).squeeze(0).permute(1, 2, 0).cpu().numpy()
                gt_img_np = gt[0].permute(1, 2, 0).detach().cpu().numpy()
                final_psnr = compute_psnr(final_img, gt_img_np)
                final_mce = measurement_errors[-1] if len(measurement_errors) else float('inf')
                min_mce_pass = min(measurement_errors) if len(measurement_errors) else float('inf')

                print(f"    Pass done: final MCE={final_mce:.6f}  min-pass MCE={min_mce_pass:.6f}  PSNR={final_psnr:.3f} dB")

                # track best across restarts (by min MCE in the pass)
                improved = min_mce_pass < best_mce_overall - 1e-9
                if improved:
                    best_mce_overall = min_mce_pass
                    best_psnr_overall = final_psnr
                    best_final_image = final_img
                    no_improve = 0
                else:
                    no_improve += 1

                # stopping conditions for restarts
                hit_threshold = (min_mce_pass <= mce_threshold)
                if stop_on_hit and hit_threshold:
                    print("    Stopping on first hit.")
                    break
                if restart >= max_restarts:
                    print("    Max restarts reached.")
                    break
                if patience_restarts is not None and no_improve >= patience_restarts:
                    print("    Early stop: no improvement across restarts.")
                    break

                # otherwise: prepare next restart if threshold was hit inside the pass
                restart += 1
                if not hit_threshold:
                    # even if not hit, we can still warm-start next pass from best x0 of this pass
                    x_T = to_start_latent_from_x0(
                        final_x0, sqrt_alpha_bar_T, sqrt_one_minus_alpha_bar_T,
                        noise_scale=restart_noise_scale
                    )
                    print(f"    Restarting (no threshold hit). Restart #{restart}")

            # record final stats for this run (use the best of the restarts)
            results_table.append({
                "image": filename,
                "run": run_iter,
                "best_mce_overall": best_mce_overall,
                "best_psnr_overall": best_psnr_overall,
                "max_restarts": max_restarts,
                "restart_noise_scale": restart_noise_scale,
                "threshold": mce_threshold,
                "n_steps": n_step,
            })
            print(f"  ==> Best-overall for run: MCE={best_mce_overall:.6f}, PSNR={best_psnr_overall:.3f} dB")

    # correlation check across all images/runs
    mces = np.array([r["best_mce_overall"] for r in results_table], dtype=float)
    psnrs = np.array([r["best_psnr_overall"] for r in results_table], dtype=float)
    if len(mces) >= 2 and np.std(mces) > 0 and np.std(psnrs) > 0:
        corr = float(np.corrcoef(mces, psnrs)[0,1])
    else:
        corr = float('nan')

    print("\n=== Summary ===")
    print(f"Runs (total): {len(results_table)}")
    print(f"Pearson corr(best_MCE, best_PSNR): {corr:.4f} (expect negative if smaller MCE → higher PSNR)")
    return results_table, corr


In [None]:
results, corr = run_threshold_restart_experiment()

In [None]:
# --- Add these helpers somewhere below the experiment function ---

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

def save_results_to_csv(results_table, csv_path="threshold_restart_results_con.csv"):
    """
    Save the per-run summary rows returned by run_threshold_restart_experiment()
    to a CSV. Each row contains: image, run, best_mce_overall, best_psnr_overall, etc.
    """
    df = pd.DataFrame(results_table)
    df.to_csv(csv_path, index=False)
    print(f"Saved {len(df)} rows to {csv_path}")
    return csv_path

def plot_psnr_vs_mce(results_table, png_path="psnr_vs_mce_con.png"):
    """
    Scatter plot of best PSNR [dB] vs best MCE (lower is better).
    Annotates Pearson correlation r in the title.
    """
    mces  = np.array([r["best_mce_overall"]  for r in results_table], dtype=float)
    psnrs = np.array([r["best_psnr_overall"] for r in results_table], dtype=float)

    # guard against degenerate cases
    if len(mces) < 2 or np.std(mces) == 0 or np.std(psnrs) == 0:
        r = float('nan')
    else:
        r = float(np.corrcoef(mces, psnrs)[0, 1])

    plt.figure()
    plt.scatter(mces, psnrs, s=16)
    plt.xlabel("Best MCE (lower is better)")
    plt.ylabel("Best PSNR [dB] (higher is better)")
    plt.title(f"PSNR vs MCE (Pearson r = {r:.3f}, n = {len(mces)})")
    plt.grid(True, linestyle="--", linewidth=0.5, alpha=0.6)
    plt.tight_layout()
    plt.savefig(png_path, dpi=150)
    plt.close()
    print(f"Saved scatter plot to {png_path} (r = {r:.3f})")
    return png_path, r


In [None]:
csv_path = save_results_to_csv(results)          # -> "threshold_restart_results.csv"
png_path, r = plot_psnr_vs_mce(results)          # -> "psnr_vs_mce.png", correlation value


In [None]:
import os, torch, numpy as np, torchvision
import torch.nn.functional as F
from PIL import Image

# ===== Helpers (same API as before) =====

def compute_mce(reconstructed_image, measured_magnitude, target_shape=None):
    x_hat = reconstructed_image.clamp(0, 1)
    if target_shape is not None:
        _, _, H, W = x_hat.shape
        target_H, target_W = target_shape[-2], target_shape[-1]
        pad_h = (target_H - H) // 2
        pad_w = (target_W - W) // 2
        x_hat = F.pad(x_hat, (pad_w, pad_w, pad_h, pad_h))
    if not torch.is_complex(x_hat):
        x_hat = x_hat.type(torch.complex64)
    fft_x_hat = torch.view_as_complex(fft2c_new(torch.view_as_real(x_hat)))
    return (fft_x_hat.abs() - measured_magnitude).pow(2).mean().item()

def compute_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse <= 0:
        return float('inf')
    return float(20 * np.log10(1.0 / (mse**0.5)))

@torch.no_grad()
def to_start_latent_from_x0(x0, sqrt_alpha_bar_T, sqrt_one_minus_alpha_bar_T, noise_scale=1.0):
    eps = torch.randn_like(x0) * noise_scale
    return sqrt_alpha_bar_T * x0 + sqrt_one_minus_alpha_bar_T * eps

def optimize_input(
    input_tensor, sqrt_one_minus_alpha_cumprod, sqrt_alpha_cumprod, t, y_n,
    num_steps=20, learning_rate=0.075
):
    """
    One inner optimization block producing (x_t_next, pred_x0, noise_pred).
    Assumes global: `model` on correct device, expects (x_t, t_long).
    """
    x_t = input_tensor.clone().detach().requires_grad_(True)
    opt = torch.optim.Adam([x_t], lr=learning_rate)
    tt = (torch.ones(1, device=x_t.device, dtype=torch.long) * int(t))

    for _ in range(num_steps):
        opt.zero_grad()
        noise_pred = model(x_t, tt)[:, :3]
        pred_x0 = (x_t - sqrt_one_minus_alpha_cumprod * noise_pred) / sqrt_alpha_cumprod
        pred_x0 = torch.clamp(pred_x0, -1, 1)

        # data term in magnitude space (mirror the inner-loop forward exactly)
        pad = int((2 / 8.0) * 256)
        x = pred_x0 * 0.5 + 0.5
        x = F.pad(x, (pad, pad, pad, pad))
        if not torch.is_complex(x):
            x = x.type(torch.complex64)
        fft2_m = torch.view_as_complex(fft2c_new(torch.view_as_real(x)))
        out_mag = fft2_m.abs()

        data_loss = torch.norm(out_mag - y_n)**2
        reg_loss = 0.1 * torch.norm(input_tensor.detach() - x_t)**2
        loss = data_loss + reg_loss
        loss.backward()
        opt.step()

    noise = (x_t - sqrt_alpha_cumprod * pred_x0) / sqrt_one_minus_alpha_cumprod
    return x_t.detach(), pred_x0.detach(), noise.detach()

# ===== Main: end-of-pass restart logic (Path A) =====

def run_threshold_restart_experiment(
    image_indices=range(11, 12),
    num_runs=1,
    n_step=20,
    mce_threshold=0.0015,
    restart_policy="bad_warm",     # "good_refine" | "bad_warm" | "bad_random"
    max_restarts=10,
    patience_restarts=None,        # e.g., 2: stop if best MCE doesn't improve across this many restarts
    restart_noise_scale=0.5,      # used for warm starts
    inner_opt_steps=20,
    inner_lr=0.075,
    seed=0,
):
    """
    End-of-pass decision only. No step-level restarts.
    Requires globals: model, device, operator, noiser, measure_config, fft2c_new, ifft2c_new, DDIMScheduler
    """
    # reproducibility
    torch.manual_seed(seed); np.random.seed(seed)

    scheduler = DDIMScheduler()
    scheduler.set_timesteps(num_inference_steps=n_step)
    step_size = 1000 // n_step

    dtype = torch.float32
    results_table = []

    # Precompute start-time alphas (largest t in schedule)
    T = int(scheduler.timesteps[0].item())
    alpha_bar_T = scheduler.alphas_cumprod[T].item()
    sqrt_alpha_bar_T = alpha_bar_T**0.5
    sqrt_one_minus_alpha_bar_T = (1.0 - alpha_bar_T)**0.5

    for j in image_indices:
        filename = f"{j:05d}.png"
        filepath = os.path.join(
            "/egr/research-slim/liangs16/Measurment_Consistent_Diffusion_Trajectory/data/demo_remain/",
            filename
        )
        gt_img = Image.open(filepath).convert("RGB")
        ref_numpy = np.array(gt_img).astype(np.float32) / 255.0
        x = ref_numpy * 2 - 1
        x = x.transpose(2, 0, 1)
        ref_img = torch.tensor(x, dtype=dtype, device=device).unsqueeze(0)
        gt = (ref_img / 2 + 0.5)

        # forward/measurement
        if measure_config['operator']['name'] == 'inpainting':
            y = operator.forward(ref_img, mask=mask)
        else:
            y = operator.forward(ref_img)
        y_n = noiser(y)

        print(f"\nReconstructing Image {filename}...")
        for run_iter in range(num_runs):
            print(f"  Run {run_iter+1}/{num_runs}")

            # initial latent x_T (random)
            x0_init = torch.randn((1, 3, 256, 256), device=device, dtype=dtype)
            x_T = to_start_latent_from_x0(x0_init, sqrt_alpha_bar_T, sqrt_one_minus_alpha_bar_T, noise_scale=1.0)

            best_mce_overall = float('inf')
            best_psnr_overall = -float('inf')
            best_final_image = None
            no_improve = 0
            restart = 0
            passes_used = 0

            while True:
                measurement_errors = []
                input_t = x_T.clone()
                passes_used += 1

                # ===== Full pass over the schedule; NO step-level restarts =====
                for i, t in enumerate(scheduler.timesteps):
                    prev_timestep = int(t) - step_size
                    alpha_prod_t = scheduler.alphas_cumprod[int(t)]
                    alpha_prod_t_prev = (
                        scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0
                        else scheduler.alphas_cumprod[0]
                    )
                    beta_prod_t = 1 - alpha_prod_t
                    sqrt_one_minus_alpha_cumprod = beta_prod_t**0.5
                    sqrt_alpha_cumprod = alpha_prod_t**0.5
                    sqrt_alpha_cumprod_prev = alpha_prod_t_prev**0.5

                    input_t, pred_x0, noise_pred = optimize_input(
                        input_t, sqrt_one_minus_alpha_cumprod, sqrt_alpha_cumprod,
                        int(t), y_n, num_steps=inner_opt_steps, learning_rate=inner_lr
                    )

                    phase_image = (pred_x0 / 2 + 0.5).clamp(0, 1)
                    mce = compute_mce(phase_image, y_n.abs(), target_shape=y_n.shape)
                    measurement_errors.append(mce)
                    print(f"    t={int(t):4d}  MCE={mce:.6f}")

                    # ancestral move to previous timestep latent
                    eps = torch.randn_like(input_t)
                    input_t = sqrt_alpha_cumprod_prev * pred_x0 + (1 - alpha_prod_t_prev)**0.5 * eps

                # ===== End-of-pass statistics & decision =====
                final_x0 = pred_x0.detach()
                final_img = (final_x0 / 2 + 0.5).clamp(0, 1).squeeze(0).permute(1, 2, 0).cpu().numpy()
                gt_img_np = gt[0].permute(1, 2, 0).detach().cpu().numpy()
                final_psnr = compute_psnr(final_img, gt_img_np)
                final_mce  = measurement_errors[-1] if measurement_errors else float('inf')
                min_mce_pass = min(measurement_errors) if measurement_errors else float('inf')

                print(f"    Pass done: final MCE={final_mce:.6f}  min-pass MCE={min_mce_pass:.6f}  PSNR={final_psnr:.3f} dB")

                # track best across passes
                improved = min_mce_pass < best_mce_overall - 1e-12
                if improved:
                    best_mce_overall = min_mce_pass
                    best_psnr_overall = final_psnr
                    best_final_image = final_img
                    no_improve = 0
                else:
                    no_improve += 1

                # decide whether to restart another pass (END-OF-PASS ONLY)
                trigger = False
                if restart_policy == "good_refine":
                    trigger = (min_mce_pass <= mce_threshold)
                elif restart_policy in ("bad_warm", "bad_random"):
                    trigger = (min_mce_pass >= mce_threshold)

                # patience & caps
                if patience_restarts is not None and no_improve >= patience_restarts:
                    print("    Early stop: no improvement across passes.")
                    break
                if restart >= max_restarts:
                    print("    Max restarts reached.")
                    break

                if trigger:
                    # prepare x_T for the next pass
                    if restart_policy in ("good_refine", "bad_warm"):
                        # warm start from this pass's x0
                        x_T = to_start_latent_from_x0(
                            final_x0, sqrt_alpha_bar_T, sqrt_one_minus_alpha_bar_T,
                            noise_scale=restart_noise_scale
                        )
                        print(f"    Restarting (policy={restart_policy}).")
                    else:  # bad_random
                        x_T = torch.randn_like(x_T)
                        print("    Restarting (policy=bad_random).")
                    restart += 1
                    continue  # run another full pass
                else:
                    # stop passes for this run
                    break

            # record best-overall for this (image, run)
            results_table.append({
                "image": filename,
                "run": run_iter,
                "best_mce_overall": best_mce_overall,
                "best_psnr_overall": best_psnr_overall,
                "max_restarts": max_restarts,
                "restart_noise_scale": restart_noise_scale,
                "threshold": mce_threshold,
                "n_steps": n_step,
                "restart_policy": restart_policy,
                "passes_used": passes_used,   
            })
            print(f"  ==> Best-overall for run: MCE={best_mce_overall:.6f}, PSNR={best_psnr_overall:.3f} dB")

    # correlation across all (image, run)
    mces  = np.array([r["best_mce_overall"]  for r in results_table], dtype=float)
    psnrs = np.array([r["best_psnr_overall"] for r in results_table], dtype=float)
    if len(mces) >= 2 and np.std(mces) > 0 and np.std(psnrs) > 0:
        corr = float(np.corrcoef(mces, psnrs)[0,1])
    else:
        corr = float('nan')

    print("\n=== Summary ===")
    print(f"Runs (total): {len(results_table)}")
    print(f"Pearson corr(best_MCE, best_PSNR): {corr:.4f} (expect negative if smaller MCE → higher PSNR)")
    return results_table, corr


In [None]:
run_threshold_restart_experiment()

In [None]:
def _summarize_runs(results):
    import numpy as np
    if not results:
        return {"n": 0, "mean_best_mce": float('nan'), "median_best_mce": float('nan'),
                "mean_best_psnr": float('nan'), "median_best_psnr": float('nan'),
                "pct_hits_threshold": float('nan'), "mean_passes_used": float('nan')}
    mces   = np.array([r["best_mce_overall"] for r in results], dtype=float)
    psnrs  = np.array([r["best_psnr_overall"] for r in results], dtype=float)
    thr    = np.array([r["threshold"] for r in results], dtype=float)
    hits   = (mces <= thr)
    passes = np.array([r.get("passes_used", np.nan) for r in results], dtype=float)
    return {
        "n": len(results),
        "mean_best_mce": float(np.nanmean(mces)),
        "median_best_mce": float(np.nanmedian(mces)),
        "mean_best_psnr": float(np.nanmean(psnrs)),
        "median_best_psnr": float(np.nanmedian(psnrs)),
        "pct_hits_threshold": float(100.0 * hits.mean()),
        "mean_passes_used": float(np.nanmean(passes)),
    }

def ab_compare_parity_10_passes(
    image_indices=range(11,14),
    n_step=20,
    # inner loop config
    inner_opt_steps=20,
    inner_lr=0.075,
    # asymmetric thresholds
    good_refine_threshold=0.0016,  # looser: refine “decent” passes
    bad_warm_threshold=0.0013,     # tighter: only salvage truly-bad passes
    # warm-start noise
    good_refine_noise=0.5,
    bad_warm_noise=0.5,
    seed=123,
):
    """
    Apples-to-apples: ~10 full passes total per variant.

    - baseline_no_restarts: num_runs=10, max_restarts=0         -> 10 passes
    - good_refine:         num_runs=1,  max_restarts=9          -> up to 10 passes
    - bad_warm:            num_runs=1,  max_restarts=9          -> up to 10 passes
    """
    shared_base = dict(
        image_indices=image_indices,
        n_step=n_step,
        inner_opt_steps=inner_opt_steps,
        inner_lr=inner_lr,
        seed=seed,
        patience_restarts=None,  # keep simple for parity
    )

    # --- baseline: 10 independent inits (10 passes) ---
    print("\n=== baseline_no_restarts (10 inits) ===")
    base_res, base_corr = run_threshold_restart_experiment(
        **shared_base,
        num_runs=10,
        restart_policy="bad_warm",        # irrelevant when max_restarts=0
        mce_threshold=bad_warm_threshold, # for consistent logging
        restart_noise_scale=0.0,
        max_restarts=0
    )
    base_sum = _summarize_runs(base_res); base_sum["name"] = "baseline_no_restarts"; base_sum["corr"] = base_corr

    # --- good_refine: 1 run, ≤10 passes ---
    print("\n=== good_refine (1 run, ≤10 passes) ===")
    gr_res, gr_corr = run_threshold_restart_experiment(
        **shared_base,
        num_runs=1,
        restart_policy="good_refine",
        mce_threshold=good_refine_threshold,
        restart_noise_scale=good_refine_noise,
        max_restarts=9
    )
    gr_sum = _summarize_runs(gr_res); gr_sum["name"] = "good_refine"; gr_sum["corr"] = gr_corr

    # --- bad_warm: 1 run, ≤10 passes ---
    print("\n=== bad_warm (1 run, ≤10 passes) ===")
    bw_res, bw_corr = run_threshold_restart_experiment(
        **shared_base,
        num_runs=1,
        restart_policy="bad_warm",
        mce_threshold=bad_warm_threshold,
        restart_noise_scale=bad_warm_noise,
        max_restarts=9
    )
    bw_sum = _summarize_runs(bw_res); bw_sum["name"] = "bad_warm"; bw_sum["corr"] = bw_corr

    # --- summary table ---
    print("\n=== A/B Summary (≈10 passes per variant) ===")
    for s in (base_sum, gr_sum, bw_sum):
        print(f"{s['name']:>20} | n={s['n']:>3} | mean MCE={s['mean_best_mce']:.6f} "
              f"| mean PSNR={s['mean_best_psnr']:.3f} dB | hits≤thr={s['pct_hits_threshold']:.1f}% "
              f"| mean passes={s['mean_passes_used']:.2f} | corr(MCE,PSNR)={s['corr']:.3f}")

    return {
        "baseline": (base_res, base_sum, base_corr),
        "good_refine": (gr_res, gr_sum, gr_corr),
        "bad_warm": (bw_res, bw_sum, bw_corr),
    }



In [None]:
ab_out = ab_compare_parity_10_passes(
    image_indices=range(11,14),
    n_step=20,
    good_refine_threshold=0.0016,
    bad_warm_threshold=0.0013,
    good_refine_noise=0.5,
    bad_warm_noise=0.5,
)



## A/B Testing ##

In [None]:
def run_fixed_pass_experiment(
    image_indices=range(11,14),
    num_runs=1,
    n_step=20,
    total_passes=10,                 # <- enforces exactly this many passes per run
    policy="good_refine",            # "baseline_random" | "good_refine" | "bad_warm"
    mce_threshold=0.0015,
    restart_noise_scale=0.5,         # warm-start noise
    inner_opt_steps=20,
    inner_lr=0.075,
    seed=0,
):
    """
    EXACTLY `total_passes` per run. No early exits.
    Policies:
      - baseline_random: every pass is fresh random init (equiv. to 10 random inits).
      - good_refine:     if pass hits (min_mce_pass <= thr)  -> warm-start next; else random.
      - bad_warm:        if pass NOT hit (min_mce_pass > thr)-> warm-start next; else random.
    """
    torch.manual_seed(seed); np.random.seed(seed)

    scheduler = DDIMScheduler()
    scheduler.set_timesteps(num_inference_steps=n_step)
    step_size = 1000 // n_step
    dtype = torch.float32

    # Start-time alphas (largest t)
    T = int(scheduler.timesteps[0].item())
    alpha_bar_T = scheduler.alphas_cumprod[T].item()
    sqrt_alpha_bar_T = alpha_bar_T**0.5
    sqrt_one_minus_alpha_bar_T = (1.0 - alpha_bar_T)**0.5

    def _random_xT_like(shape, device, dtype):
        x0_init = torch.randn(shape, device=device, dtype=dtype)
        return to_start_latent_from_x0(x0_init, sqrt_alpha_bar_T, sqrt_one_minus_alpha_bar_T, noise_scale=1.0)

    results_table = []

    for j in image_indices:
        filename = f"{j:05d}.png"
        filepath = os.path.join(
            "/egr/research-slim/liangs16/Measurment_Consistent_Diffusion_Trajectory/data/demo_remain/",
            filename
        )
        gt_img = Image.open(filepath).convert("RGB")
        ref_numpy = np.array(gt_img).astype(np.float32) / 255.0
        x = (ref_numpy * 2 - 1).transpose(2, 0, 1)
        ref_img = torch.tensor(x, dtype=dtype, device=device).unsqueeze(0)
        gt = (ref_img/2+0.5)

        # forward/measurement (no GT leakage)
        if measure_config['operator']['name'] == 'inpainting':
            y = operator.forward(ref_img, mask=mask)
        else:
            y = operator.forward(ref_img)
        y_n = noiser(y)

        print(f"\n[FIXED-PASSES:{policy}] Image {filename}")
        for run_iter in range(num_runs):
            print(f"  Run {run_iter+1}/{num_runs}")

            # initialize starting x_T depending on policy
            x_T = _random_xT_like((1,3,256,256), device, dtype)

            best_mce_overall = float('inf')
            best_psnr_overall = -float('inf')
            best_final_image = None

            passes_used = 0

            for pass_idx in range(total_passes):
                passes_used += 1
                measurement_errors = []
                input_t = x_T.clone()

                # --- full reverse pass ---
                for t in scheduler.timesteps:
                    t = int(t)
                    prev_timestep = t - step_size
                    alpha_t = scheduler.alphas_cumprod[t]
                    alpha_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.alphas_cumprod[0]
                    sqrt_one_minus = (1 - alpha_t)**0.5
                    sqrt_alpha = alpha_t**0.5
                    sqrt_alpha_prev = alpha_prev**0.5

                    input_t, pred_x0, _ = optimize_input(
                        input_t, sqrt_one_minus, sqrt_alpha, t, y_n,
                        num_steps=inner_opt_steps, learning_rate=inner_lr
                    )

                    phase_image = (pred_x0/2+0.5).clamp(0,1)
                    mce = compute_mce(phase_image, y_n.abs(), target_shape=y_n.shape)
                    measurement_errors.append(mce)

                    # ancestral step
                    eps = torch.randn_like(input_t)
                    input_t = sqrt_alpha_prev * pred_x0 + (1 - alpha_prev)**0.5 * eps

                # --- end-of-pass stats ---
                final_x0 = pred_x0.detach()
                final_img = (final_x0/2+0.5).clamp(0,1).squeeze(0).permute(1,2,0).cpu().numpy()
                gt_img_np = gt[0].permute(1,2,0).detach().cpu().numpy()
                final_psnr = compute_psnr(final_img, gt_img_np)
                min_mce_pass = min(measurement_errors) if measurement_errors else float('inf')

                print(f"    Pass {pass_idx+1}/{total_passes}: min-pass MCE={min_mce_pass:.6f}  PSNR={final_psnr:.3f} dB")

                # track best across all passes
                if min_mce_pass < best_mce_overall - 1e-12:
                    best_mce_overall = min_mce_pass
                    best_psnr_overall = final_psnr
                    best_final_image = final_img

                # --- choose next x_T based on policy (ALWAYS do next pass unless last) ---
                if pass_idx < total_passes - 1:
                    if policy == "baseline_random":
                        # always random
                        x_T = _random_xT_like((1,3,256,256), device, dtype)

                    elif policy == "good_refine":
                        if min_mce_pass <= mce_threshold:
                            # threshold HIT => warm-start from current x0
                            x_T = to_start_latent_from_x0(final_x0, sqrt_alpha_bar_T, sqrt_one_minus_alpha_bar_T,
                                                          noise_scale=restart_noise_scale)
                            print("      -> HIT thr → warm-start next pass")
                        else:
                            # threshold NOT hit => random
                            x_T = _random_xT_like((1,3,256,256), device, dtype)
                            print("      -> NOT hit thr → random next pass")

                    elif policy == "bad_warm":
                        if min_mce_pass > mce_threshold:
                            # threshold NOT hit => warm start
                            x_T = to_start_latent_from_x0(final_x0, sqrt_alpha_bar_T, sqrt_one_minus_alpha_bar_T,
                                                          noise_scale=restart_noise_scale)
                            print("      -> NOT hit thr → warm-start next pass")
                        else:
                            # threshold HIT => random
                            x_T = _random_xT_like((1,3,256,256), device, dtype)
                            print("      -> HIT thr → random next pass")

                    else:
                        raise ValueError(f"Unknown policy: {policy}")

            # append per (image, run)
            results_table.append({
                "image": filename,
                "run": run_iter,
                "best_mce_overall": float(best_mce_overall),
                "best_psnr_overall": float(best_psnr_overall),
                "policy": policy,
                "threshold": float(mce_threshold),
                "n_steps": int(n_step),
                "passes_used": int(passes_used),   # should equal total_passes
                "total_passes": int(total_passes),
                "restart_noise_scale": float(restart_noise_scale),
            })
            print(f"  ==> Best-overall: MCE={best_mce_overall:.6f}, PSNR={best_psnr_overall:.3f} dB "
                  f"(passes={passes_used})")

    # global correlation across all rows
    mces  = np.array([r["best_mce_overall"]  for r in results_table], dtype=float)
    psnrs = np.array([r["best_psnr_overall"] for r in results_table], dtype=float)
    corr = float('nan') if len(mces)<2 or np.std(mces)==0 or np.std(psnrs)==0 else float(np.corrcoef(mces, psnrs)[0,1])

    print("\n=== Fixed-Passes Summary ===")
    print(f"Rows (image×run): {len(results_table)}")
    print(f"Pearson corr(best_MCE, best_PSNR): {corr:.4f}")
    return results_table, corr

def compute_mce(reconstructed_image, measured_magnitude, target_shape=None):
    x_hat = reconstructed_image.clamp(0, 1)
    if target_shape is not None:
        _, _, H, W = x_hat.shape
        target_H, target_W = target_shape[-2], target_shape[-1]
        pad_h = (target_H - H) // 2
        pad_w = (target_W - W) // 2
        x_hat = F.pad(x_hat, (pad_w, pad_w, pad_h, pad_h))
    if not torch.is_complex(x_hat):
        x_hat = x_hat.type(torch.complex64)
    fft_x_hat = torch.view_as_complex(fft2c_new(torch.view_as_real(x_hat)))
    return (fft_x_hat.abs() - measured_magnitude).pow(2).mean().item()

def compute_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse <= 0:
        return float('inf')
    return float(20 * np.log10(1.0 / (mse**0.5)))

@torch.no_grad()
def to_start_latent_from_x0(x0, sqrt_alpha_bar_T, sqrt_one_minus_alpha_bar_T, noise_scale=1.0):
    eps = torch.randn_like(x0) * noise_scale
    return sqrt_alpha_bar_T * x0 + sqrt_one_minus_alpha_bar_T * eps

def optimize_input(
    input_tensor, sqrt_one_minus_alpha_cumprod, sqrt_alpha_cumprod, t, y_n,
    num_steps=20, learning_rate=0.075
):
    """
    One inner optimization block producing (x_t_next, pred_x0, noise_pred).
    Assumes global: `model` on correct device, expects (x_t, t_long).
    """
    x_t = input_tensor.clone().detach().requires_grad_(True)
    opt = torch.optim.Adam([x_t], lr=learning_rate)
    tt = (torch.ones(1, device=x_t.device, dtype=torch.long) * int(t))

    for _ in range(num_steps):
        opt.zero_grad()
        noise_pred = model(x_t, tt)[:, :3]
        pred_x0 = (x_t - sqrt_one_minus_alpha_cumprod * noise_pred) / sqrt_alpha_cumprod
        pred_x0 = torch.clamp(pred_x0, -1, 1)

        # data term in magnitude space (mirror the inner-loop forward exactly)
        pad = int((2 / 8.0) * 256)
        x = pred_x0 * 0.5 + 0.5
        x = F.pad(x, (pad, pad, pad, pad))
        if not torch.is_complex(x):
            x = x.type(torch.complex64)
        fft2_m = torch.view_as_complex(fft2c_new(torch.view_as_real(x)))
        out_mag = fft2_m.abs()

        data_loss = torch.norm(out_mag - y_n)**2
        reg_loss = 0.1 * torch.norm(input_tensor.detach() - x_t)**2
        loss = data_loss + reg_loss
        loss.backward()
        opt.step()

    noise = (x_t - sqrt_alpha_cumprod * pred_x0) / sqrt_one_minus_alpha_cumprod
    return x_t.detach(), pred_x0.detach(), noise.detach()

def ab_fixed_passes_10(
    image_indices=range(11,14),
    n_step=20,
    total_passes=10,
    good_refine_threshold=0.0016,   # looser
    bad_warm_threshold=0.0013,      # tighter
    restart_noise_scale=0.5,
    inner_opt_steps=20,
    inner_lr=0.075,
    seed=123,
):
    def summarize(rows, name, corr):
        import numpy as np
        mces  = np.array([r["best_mce_overall"] for r in rows], dtype=float)
        psnrs = np.array([r["best_psnr_overall"] for r in rows], dtype=float)
        passes= np.array([r["passes_used"] for r in rows], dtype=float)
        print(f"{name:>18} | n={len(rows):>3} | mean MCE={np.nanmean(mces):.6f} | mean PSNR={np.nanmean(psnrs):.3f} dB "
              f"| mean passes={np.nanmean(passes):.2f} | corr={corr:.3f}")

    print("\n=== baseline_random (10 passes) ===")
    base_res, base_corr = run_fixed_pass_experiment(
        image_indices=image_indices, num_runs=1, n_step=n_step, total_passes=total_passes,
        policy="baseline_random", mce_threshold=bad_warm_threshold,  # threshold unused; kept for log symmetry
        restart_noise_scale=0.0, inner_opt_steps=inner_opt_steps, inner_lr=inner_lr, seed=seed
    )
    print("\n=== good_refine (10 passes) ===")
    gr_res, gr_corr = run_fixed_pass_experiment(
        image_indices=image_indices, num_runs=1, n_step=n_step, total_passes=total_passes,
        policy="good_refine", mce_threshold=good_refine_threshold,
        restart_noise_scale=restart_noise_scale, inner_opt_steps=inner_opt_steps, inner_lr=inner_lr, seed=seed
    )
    print("\n=== bad_warm (10 passes) ===")
    bw_res, bw_corr = run_fixed_pass_experiment(
        image_indices=image_indices, num_runs=1, n_step=n_step, total_passes=total_passes,
        policy="bad_warm", mce_threshold=bad_warm_threshold,
        restart_noise_scale=restart_noise_scale, inner_opt_steps=inner_opt_steps, inner_lr=inner_lr, seed=seed
    )

    print("\n=== A/B Summary (exactly 10 passes each) ===")
    summarize(base_res, "baseline_random", base_corr)
    summarize(gr_res,   "good_refine",     gr_corr)
    summarize(bw_res,   "bad_warm",        bw_corr)

    return {
        "baseline_random": base_res,
        "good_refine": gr_res,
        "bad_warm": bw_res,
    }


In [None]:
ab_fixed = ab_fixed_passes_10(
    image_indices=range(11,14),
    n_step=20,
    total_passes=10,
    good_refine_threshold=0.0016,
    bad_warm_threshold=0.00147,
    restart_noise_scale=0.5,
)


In [None]:
def make_policy_table_from_dict(ab_results_dict):
    """
    ab_results_dict is what ab_fixed_passes_10(...) returns:
      {
        "baseline_random": [...rows...],
        "good_refine":     [...rows...],
        "bad_warm":        [...rows...],
      }
    Each row must have: 'image', 'best_psnr_overall', 'best_mce_overall'.
    """
    return make_policy_table(
        ab_results_dict["baseline_random"],
        ab_results_dict["good_refine"],
        ab_results_dict["bad_warm"]
    )

def make_policy_table(base_rows, good_rows, bad_rows):
    """
    Build a table: rows = images, cols = policies (baseline_random, good_refine, bad_warm)
    Cell value: "<best_psnr_overall:.3f} dB (MCE={best_mce_overall:.6f})"
    If multiple runs per image exist, picks the run with max PSNR for that policy.
    """
    by_policy = {
        "baseline_random": base_rows,
        "good_refine":     good_rows,
        "bad_warm":        bad_rows,
    }

    # collect all images seen
    images = sorted(set(
        [r["image"] for r in by_policy["baseline_random"]] +
        [r["image"] for r in by_policy["good_refine"]] +
        [r["image"] for r in by_policy["bad_warm"]]
    ))

    # build pretty table + numeric companions (optional/useful)
    pretty = pd.DataFrame(index=images, columns=list(by_policy.keys()))
    psnr_numeric = pd.DataFrame(index=images, columns=list(by_policy.keys()), dtype=float)
    mce_numeric  = pd.DataFrame(index=images, columns=list(by_policy.keys()), dtype=float)

    for policy, rows in by_policy.items():
        # group rows by image
        rows_by_img = {}
        for r in rows:
            rows_by_img.setdefault(r["image"], []).append(r)

        for img in images:
            if img not in rows_by_img or len(rows_by_img[img]) == 0:
                pretty.loc[img, policy] = "—"
                psnr_numeric.loc[img, policy] = np.nan
                mce_numeric.loc[img, policy]  = np.nan
                continue

            # choose the run with the highest PSNR for this image & policy
            best_row = max(rows_by_img[img], key=lambda rr: rr.get("best_psnr_overall", -np.inf))
            best_psnr = float(best_row["best_psnr_overall"])
            best_mce  = float(best_row["best_mce_overall"])

            pretty.loc[img, policy] = f"{best_psnr:.3f} dB (MCE={best_mce:.6f})"
            psnr_numeric.loc[img, policy] = best_psnr
            mce_numeric.loc[img, policy]  = best_mce

    return pretty, psnr_numeric, mce_numeric

In [None]:
ab_fixed = ab_fixed_passes_10(
    image_indices=range(11,14),
    n_step=20,
    total_passes=10,
    good_refine_threshold=0.0016,
    bad_warm_threshold=0.0013,
    restart_noise_scale=0.5,
)
pretty, psnr_tbl, mce_tbl = make_policy_table_from_dict(ab_fixed)
print(pretty)


## Reburn ##

In [None]:
def run_fixed_pass_experiment_reburn(
    image_indices=range(11,14),
    num_runs=1,
    n_step=20,
    total_passes=10,
    policy="good_refine",            # "baseline_random" | "good_refine" | "bad_warm"
    mce_threshold=0.0015,
    # reburn controls
    reburn=True,                     # if False, behavior matches your previous fixed-pass runner
    reburn_frac_hit=0.5,             # where to restart (fraction of schedule) when "hit" event
    reburn_frac_miss=0.7,            # where to restart when "miss" event
    restart_noise_scale=0.5,         # noise injected at the chosen reburn timestep
    # inner optimizer
    inner_opt_steps=20,
    inner_lr=0.075,
    seed=0,
    save_dir=None,                   # optional: save best images like before
):
    """
    EXACTLY `total_passes` per run. If `reburn=True`, warm-start restarts from an intermediate
    timestep t* (chosen by reburn_frac_*), not from the noisiest x_T.

    Fractions are in [0,1]:
      - 0.0 → start at the VERY END (almost no steps left)
      - 1.0 → start at the VERY BEGINNING (most steps left; equivalent to x_T when noise_scale=1)

    Practical tips:
      - good_refine: set reburn_frac_hit ~ 0.4-0.6, reburn_frac_miss ~ 0.9 (fallback almost random)
      - bad_warm:    set reburn_frac_miss ~ 0.5-0.7, reburn_frac_hit ~ 0.9 (fallback almost random)
    """
    import os
    torch.manual_seed(seed); np.random.seed(seed)
    if save_dir is not None:
        os.makedirs(save_dir, exist_ok=True)

    scheduler = DDIMScheduler()
    scheduler.set_timesteps(num_inference_steps=n_step)
    step_size = 1000 // n_step
    dtype = torch.float32

    timesteps = list(map(int, scheduler.timesteps))           # length = n_step, descending
    n_steps_sched = len(timesteps)                             # == n_step

    def _start_idx_from_frac(frac: float) -> int:
        # frac in [0,1]; 1.0 -> index 0 (most steps), 0.0 -> index n_step-1 (fewest steps)
        frac = max(0.0, min(1.0, float(frac)))
        # convert to index (round toward giving *more* steps when ambiguous)
        # e.g., frac=0.5 with n_step=20 -> idx=10 (do ~half the schedule)
        idx = int(round((1.0 - frac) * (n_steps_sched - 1)))
        return max(0, min(n_steps_sched - 1, idx))

    def _q_xt_given_x0(x0, t_idx, noise_scale=1.0):
        """Sample q(x_t|x0) at scheduler index t_idx using alphas_cumprod[t] and isotropic noise."""
        t_val = timesteps[t_idx]
        alpha_bar = scheduler.alphas_cumprod[t_val].item()
        sqrt_a = alpha_bar**0.5
        sqrt_1ma = (1.0 - alpha_bar)**0.5
        return sqrt_a * x0 + sqrt_1ma * torch.randn_like(x0) * noise_scale

    def _random_xT_like(shape, device, dtype):
        # If you want pure x_T (max noise), just draw a fresh x0 and reburn at idx=0 with noise=1
        x0_init = torch.randn(shape, device=device, dtype=dtype)
        return _q_xt_given_x0(x0_init, t_idx=0, noise_scale=1.0), 0  # start_idx=0

    results_table = []

    for j in image_indices:
        filename = f"{j:05d}.png"
        filepath = os.path.join(
            "/egr/research-slim/liangs16/Measurment_Consistent_Diffusion_Trajectory/data/demo_remain/",
            filename
        )
        gt_img = Image.open(filepath).convert("RGB")
        ref_numpy = np.array(gt_img).astype(np.float32) / 255.0
        x = (ref_numpy * 2 - 1).transpose(2, 0, 1)
        ref_img = torch.tensor(x, dtype=dtype, device=device).unsqueeze(0)
        gt = (ref_img/2+0.5)

        # forward measurement (no GT leakage)
        if measure_config['operator']['name'] == 'inpainting':
            y = operator.forward(ref_img, mask=mask)
        else:
            y = operator.forward(ref_img)
        y_n = noiser(y)

        print(f"\n[REBURN:{policy}] Image {filename}")
        for run_iter in range(num_runs):
            print(f"  Run {run_iter+1}/{num_runs}")

            # initialize: random at the *beginning* of schedule
            x_T, start_idx = _random_xT_like((1,3,256,256), device, dtype)

            best_mce_overall = float('inf')
            best_psnr_overall = -float('inf')
            best_final_image = None
            best_image_path = ""

            passes_used = 0

            for pass_idx in range(total_passes):
                passes_used += 1
                measurement_errors = []
                input_t = x_T.clone()

                # run from start_idx to the end of schedule
                for si in range(start_idx, n_steps_sched):
                    t = timesteps[si]
                    prev_timestep = t - step_size
                    alpha_t = scheduler.alphas_cumprod[t]
                    alpha_prev = scheduler.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else scheduler.alphas_cumprod[0]
                    sqrt_one_minus = (1 - alpha_t)**0.5
                    sqrt_alpha = alpha_t**0.5
                    sqrt_alpha_prev = alpha_prev**0.5

                    input_t, pred_x0, _ = optimize_input(
                        input_t, sqrt_one_minus, sqrt_alpha, t, y_n,
                        num_steps=inner_opt_steps, learning_rate=inner_lr
                    )

                    phase_image = (pred_x0/2+0.5).clamp(0,1)
                    mce = compute_mce(phase_image, y_n.abs(), target_shape=y_n.shape)
                    measurement_errors.append(mce)

                    # ancestral update (unless at the very last step)
                    if si < n_steps_sched - 1:
                        eps = torch.randn_like(input_t)
                        input_t = sqrt_alpha_prev * pred_x0 + (1 - alpha_prev)**0.5 * eps

                # end-of-pass stats from this start_idx
                final_x0 = pred_x0.detach()
                final_img = (final_x0/2+0.5).clamp(0,1).squeeze(0).permute(1,2,0).cpu().numpy()
                gt_img_np = gt[0].permute(1,2,0).detach().cpu().numpy()
                final_psnr = compute_psnr(final_img, gt_img_np)
                min_mce_pass = min(measurement_errors) if measurement_errors else float('inf')

                print(f"    Pass {pass_idx+1}/{total_passes} (start_idx={start_idx:02d}/{n_steps_sched-1}): "
                      f"min-pass MCE={min_mce_pass:.6f}  PSNR={final_psnr:.3f} dB")

                # track best across passes
                if min_mce_pass < best_mce_overall - 1e-12:
                    best_mce_overall = min_mce_pass
                    best_psnr_overall = final_psnr
                    best_final_image = final_img

                # choose next start and latent (always do next unless last pass)
                if pass_idx < total_passes - 1:
                    if policy == "baseline_random":
                        # Always random, and start at beginning (index 0)
                        x_T, start_idx = _random_xT_like((1,3,256,256), device, dtype)

                    elif policy == "good_refine":
                        if min_mce_pass <= mce_threshold:
                            # HIT → reburn warm-start around current x0 at mid-ish steps
                            if reburn:
                                idx = _start_idx_from_frac(reburn_frac_hit)
                                x_T = _q_xt_given_x0(final_x0, idx, noise_scale=restart_noise_scale)
                                start_idx = idx
                                print(f"      -> HIT thr → REBURN warm-start at idx={idx} (frac≈{reburn_frac_hit:.2f})")
                            else:
                                x_T, start_idx = _random_xT_like((1,3,256,256), device, dtype)
                        else:
                            # NOT hit → random
                            x_T, start_idx = _random_xT_like((1,3,256,256), device, dtype)
                            print("      -> NOT hit thr → random next pass")

                    elif policy == "bad_warm":
                        if min_mce_pass > mce_threshold:
                            # NOT hit → reburn warm-start
                            if reburn:
                                idx = _start_idx_from_frac(reburn_frac_miss)
                                x_T = _q_xt_given_x0(final_x0, idx, noise_scale=restart_noise_scale)
                                start_idx = idx
                                print(f"      -> NOT hit thr → REBURN warm-start at idx={idx} (frac≈{reburn_frac_miss:.2f})")
                            else:
                                x_T, start_idx = _random_xT_like((1,3,256,256), device, dtype)
                        else:
                            # HIT → random
                            x_T, start_idx = _random_xT_like((1,3,256,256), device, dtype)
                            print("      -> HIT thr → random next pass")

                    else:
                        raise ValueError(f"Unknown policy: {policy}")

            # (optional) save best image
            if save_dir is not None and best_final_image is not None:
                out_path = os.path.join(save_dir, f"reburn_{policy}_{filename}_run{run_iter}_best.png")
                Image.fromarray((np.clip(best_final_image,0,1)*255).astype(np.uint8)).save(out_path)
                best_image_path = out_path
            else:
                best_image_path = ""

            results_table.append({
                "image": filename,
                "run": run_iter,
                "best_mce_overall": float(best_mce_overall),
                "best_psnr_overall": float(best_psnr_overall),
                "policy": policy,
                "threshold": float(mce_threshold),
                "n_steps": int(n_step),
                "passes_used": int(passes_used),
                "total_passes": int(total_passes),
                "restart_noise_scale": float(restart_noise_scale),
                "reburn": bool(reburn),
                "reburn_frac_hit": float(reburn_frac_hit),
                "reburn_frac_miss": float(reburn_frac_miss),
                "best_image_path": best_image_path,
            })
            print(f"  ==> Best-overall: MCE={best_mce_overall:.6f}, PSNR={best_psnr_overall:.3f} dB "
                  f"(passes={passes_used})")

    # correlation across all rows
    mces  = np.array([r["best_mce_overall"]  for r in results_table], dtype=float)
    psnrs = np.array([r["best_psnr_overall"] for r in results_table], dtype=float)
    corr = float('nan') if len(mces)<2 or np.std(mces)==0 or np.std(psnrs)==0 else float(np.corrcoef(mces, psnrs)[0,1])

    print("\n=== Reburn Summary ===")
    print(f"Rows (image×run): {len(results_table)}")
    print(f"Pearson corr(best_MCE, best_PSNR): {corr:.4f}")
    return results_table, corr


In [None]:
# Baseline (unchanged, always random starts)
'''base_res, _ = run_fixed_pass_experiment_reburn(
    image_indices=range(11,14),
    n_step=20, total_passes=10,
    policy="baseline_random",
    mce_threshold=0.0014,      # unused logically for baseline
    reburn=False,              # ensure it's pure baseline
    restart_noise_scale=0.0,
)'''

# good_refine with reburn from mid schedule when threshold is HIT
gr_res, _ = run_fixed_pass_experiment_reburn(
    image_indices=range(11,14),
    n_step=20, total_passes=10,
    policy="good_refine",
    mce_threshold=0.0016,      # looser
    reburn=True,
    reburn_frac_hit=0.5,       # start around the middle
    reburn_frac_miss=0.9,      # if ever used for miss, nearly random
    restart_noise_scale=0.5,
)

# bad_warm with reburn when threshold is NOT HIT
bw_res, _ = run_fixed_pass_experiment_reburn(
    image_indices=range(11,14),
    n_step=20, total_passes=10,
    policy="bad_warm",
    mce_threshold=0.00147,      # tighter
    reburn=True,
    reburn_frac_hit=0.9,       # if hit → random-like (we won't warm when hit)
    reburn_frac_miss=0.8,      # when miss → mid/high noise to escape
    restart_noise_scale=0.5,
)



# Loop over images and show them step by step
for i, img_tensor in enumerate(out_visual):
    image_np = clear_color(img_tensor)
    
    # Display the image
    plt.imshow(image_np)
    plt.axis('off')  # Hide axis for better display
    plt.title(f"Image {i}")  # Add title with image number
    plt.show()

In [None]:
plt.imshow(run_results[1])
# plt.axis('off')
#plt.savefig("ffhq0000_db2.png", dpi=300,bbox_inches='tight', pad_inches=0)
plt.colorbar()

In [None]:
gt = (ref_img/2+0.5)

In [None]:
plt.imshow(np.array(gt.cpu().detach().numpy()[0].transpose(1,2,0)))
plt.axis('off')
plt.colorbar()

In [None]:
def compute_psnr(img1, img2):
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return float('inf')
    max_pixel = 1.0  # Assuming the image is normalized to [0, 1]
    psnr = 20 * np.log10(max_pixel / (mse**0.5))
    return psnr.item()
psnr_value = compute_psnr(np.array(phase_image), np.array(gt.cpu().detach().numpy()[0].transpose(1,2,0)))
print(f"After diffusion PSNR: {psnr_value} dB")


In [None]:
print(input.shape)

In [None]:
from piq import psnr, ssim, LPIPS

In [None]:
def norm( x):
    return (x * 0.5 + 0.5).clip(0, 1)

In [None]:
print(psnr(norm(ref_img), norm(input), 1.0, reduction='mean'))

In [None]:
fname = str(i).zfill(5) + 'test'

# New: Save image to explicit folder: SITCOMOUT at Cheng_Hans

# Method 1: Use PIL 
# !!NotWorking!!
def save_image(image, folder_path, filename="output.png"):
    # Ensure the folder exists
    os.makedirs(folder_path, exist_ok=True)

    # Define the full file path
    file_path = os.path.join(folder_path, filename)

    # Save the image
    image.save(file_path)
    print(f"Image saved at: {file_path}")

# Usage
save_image(phase_image, "/egr/research-slim/liangs16/Cheng_Hans/SITCOMOUTPUT", "demo_remain_00024_1.png")

In [None]:
#Method 2: numpy

def save_image(image_array, folder_path, filename="output.png"):
    # Ensure the folder exists
    os.makedirs(folder_path, exist_ok=True)

    # Define the full file path
    file_path = os.path.join(folder_path, filename)

    # Save the image
    # plt.imsave(file_path, image_array, cmap = "gray")  # Use `gray` for grayscale images
    plt.imsave(file_path, image_array)
    print(f"Image saved at: {file_path}")

# Example Usage (creating a dummy NumPy array)
save_image(phase_image, "/egr/research-slim/liangs16/Cheng_Hans/SITCOMOUTPUT", "demo_remain_00026_1.png")

gt_img = Image.open('/egr/research-slim/liangs16/Cheng_Hans/SITCOMOUTPUT/demo_remain_00024_1.png').convert("RGB")
plt.imshow(gt_img)