In [1]:
# Cell 1: imports & helpers
import json
import numpy as np
from pathlib import Path
import sys
# 直接复用 evaluate_neurlz_correct.py 里的 ROI 指标函数
from evaluate_neurlz_correct import compute_roi_metrics
# 需要 pysz
sys.path.append('/Users/923714256/Data_compression/SZ3/tools/pysz')
from pysz import SZ



In [2]:
# Cell 2: 配置路径（按你的实际情况改这里）
data_path = Path("/Users/923714256/Data_compression/SDRBENCH-EXASKY-NYX-512x512x512/dark_matter_density.f32")

# 选你想看的那次 run，比如 channels_6/run_3
components_dir = Path("/Users/923714256/Data_compression/SZ3_+_2NN/relative_error_bound_4e-4/dual_model/compressed_components/channels_4/run_4")

sz_lib = "/Users/923714256/Data_compression/SZ3/build/lib64/libSZ3c.so"

sz3_path = components_dir / "dark_matter_density.sz3"
meta_path = components_dir / "dark_matter_density_metadata.json"
neurlz_recon_path = components_dir / "dark_matter_density_reconstructed.f32"

print("components_dir:", components_dir)
print("exists sz3:", sz3_path.exists())
print("exists meta:", meta_path.exists())
print("exists neurlz recon:", neurlz_recon_path.exists())

components_dir: /Users/923714256/Data_compression/SZ3_+_2NN/relative_error_bound_4e-4/dual_model/compressed_components/channels_4/run_4
exists sz3: True
exists meta: True
exists neurlz recon: True


In [3]:
# Cell 3: 读取原始数据 + ROI boxes
data = np.fromfile(data_path, dtype=np.float32).reshape(512, 512, 512)

with open(meta_path, "r") as f:
    meta = json.load(f)

roi_boxes = meta.get("roi_boxes_3d") or []
print("ROI boxes:", len(roi_boxes))
print("example ROI[0]:", roi_boxes[1] if roi_boxes else None)

ROI boxes: 100
example ROI[0]: [81, 162, 162, 243, 162, 243]


In [4]:
# Cell 4: SZ3-only 解压（在同一批 ROI boxes 上算 ROI PSNR）
with open(sz3_path, "rb") as f:
    sz_bytes = f.read()
# 关键：转换为 numpy uint8，pysz 才能用 .ctypes
sz_bytes_np = np.frombuffer(sz_bytes, dtype=np.uint8)
sz = SZ(sz_lib)
recon_sz3 = sz.decompress(sz_bytes_np, data.shape, np.float32)

roi_summary_sz3, _ = compute_roi_metrics(data, recon_sz3, roi_boxes, ssim_slice_axis=0)
roi_psnr_sz3 = roi_summary_sz3["roi_psnr"]
print("SZ3 ROI PSNR:", roi_psnr_sz3)

SZ3 ROI PSNR: 80.92042237056938


In [5]:
# Cell 5: NeurLZ 重建（log 里保存的 reconstructed.f32）算 ROI PSNR
recon_neurlz = np.fromfile(neurlz_recon_path, dtype=np.float32).reshape(512, 512, 512)

roi_summary_neurlz, _ = compute_roi_metrics(data, recon_neurlz, roi_boxes, ssim_slice_axis=0)
roi_psnr_neurlz = roi_summary_neurlz["roi_psnr"]
print("NeurLZ ROI PSNR:", roi_psnr_neurlz)

NeurLZ ROI PSNR: 83.24664721107241


In [6]:
# Cell 6: ROI PSNR 提升（你要的 “based on SZ3_roi”）
delta = roi_psnr_neurlz - roi_psnr_sz3
print("ΔROI PSNR (NeurLZ - SZ3 ROI):", delta, "dB")

ΔROI PSNR (NeurLZ - SZ3 ROI): 2.3262248405030306 dB


In [7]:
# Cell 7a: 设置和导入（分步执行以便调试）
import sys
from pathlib import Path
import torch
import gc

# 清理 GPU 内存（如果有）
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

# Add path for compression_function import
sys.path.insert(0, str(Path("/Users/923714256/Data_compression/SZ3_+_2NN")))
sys.path.insert(0, '/Users/923714256/Data_compression/neural_compression')  # For Model imports

try:
    from compression_function import create_model_for_decompress
    print("✓ Successfully imported create_model_for_decompress")
except Exception as e:
    print(f"✗ Import error: {e}")
    import traceback
    traceback.print_exc()
    raise

# Load ROI model weights (separate file)
model_roi_path = components_dir / "dark_matter_density_model_roi.pt"
print(f"Loading model from: {model_roi_path}")
print(f"File exists: {model_roi_path.exists()}")

if not model_roi_path.exists():
    raise FileNotFoundError(f"Model file not found: {model_roi_path}")

with open(model_roi_path, "rb") as f:
    model_weights_roi_dict = torch.load(f, map_location="cpu")

print(f"✓ Model weights loaded: {len(model_weights_roi_dict)} parameters")

# Convert to numpy for compatibility (if needed)
if isinstance(model_weights_roi_dict, dict):
    if len(model_weights_roi_dict) > 0:
        first_val = next(iter(model_weights_roi_dict.values()))
        if isinstance(first_val, torch.Tensor):
            model_weights_roi = {k: v.numpy() for k, v in model_weights_roi_dict.items()}
            print("✓ Converted torch tensors to numpy")
        else:
            model_weights_roi = model_weights_roi_dict
            print("✓ Using numpy arrays directly")
    else:
        model_weights_roi = model_weights_roi_dict
else:
    raise ValueError(f"Unexpected model weights format: {type(model_weights_roi_dict)}")


✓ Successfully imported create_model_for_decompress
Loading model from: /Users/923714256/Data_compression/SZ3_+_2NN/relative_error_bound_4e-4/dual_model/compressed_components/channels_4/run_4/dark_matter_density_model_roi.pt
File exists: True
✓ Model weights loaded: 80 parameters
✓ Converted torch tensors to numpy


In [8]:
# Cell 7b: 创建模型
# Get metadata
spatial_dims = meta.get("spatial_dims", 2)
slice_order = meta.get("slice_order", "zxy")
model_type_roi = meta.get("model_type_roi", meta.get("model_type", "tiny_frequency_residual_predictor_7_attn_roi"))

print(f"Model type: {model_type_roi}")
print(f"Spatial dims: {spatial_dims}, Slice order: {slice_order}")

# Create metadata for ROI model
metadata_roi = meta.copy()
if "roi_low_cutoff" in meta:
    metadata_roi["low_cutoff"] = meta["roi_low_cutoff"]
if "roi_mid_cutoff" in meta:
    metadata_roi["mid_cutoff"] = meta["roi_mid_cutoff"]
if "roi_use_phase_sincos" in meta:
    metadata_roi["use_phase_sincos"] = meta["roi_use_phase_sincos"]
if "roi_gn_groups" in meta:
    metadata_roi["gn_groups"] = meta["roi_gn_groups"]

# Create ROI model
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

try:
    model_roi = create_model_for_decompress(model_type_roi, metadata_roi, spatial_dims, device)
    if model_roi is None:
        raise ValueError(f"Failed to create ROI model: {model_type_roi}")
    print("✓ Model created successfully")
except Exception as e:
    print(f"✗ Model creation error: {e}")
    import traceback
    traceback.print_exc()
    raise

# Load weights
try:
    state_dict_roi = {k: torch.from_numpy(v) for k, v in model_weights_roi.items()}
    model_roi.load_state_dict(state_dict_roi)
    model_roi.eval()
    print("✓ Model weights loaded successfully")
except Exception as e:
    print(f"✗ Model loading error: {e}")
    import traceback
    traceback.print_exc()
    raise

# Clean up
del model_weights_roi_dict, model_weights_roi, state_dict_roi
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()


Model type: tiny_frequency_residual_predictor_7_attn_roi
Spatial dims: 2, Slice order: zxy
Using device: cpu
✓ Model created successfully
✓ Model weights loaded successfully


In [9]:
# Cell 7c: 预测 residuals（使用更小的 batch size）
# Get normalization parameters
residual_mean = float(meta["residual_mean"])
residual_std = float(meta["residual_std"])
input_mean = float(meta["input_mean"])
input_std = float(meta["input_std"])

# 使用更小的 batch size 以避免内存问题
batch_size_2d = min(int(meta.get("batch_size_2d", 256)), 64)  # 限制最大为 64
batch_size_3d = min(int(meta.get("batch_size_3d", 256)), 32)  # 限制最大为 32

print(f"Batch sizes: 2D={batch_size_2d}, 3D={batch_size_3d}")

# Prepare x_prime for prediction
if spatial_dims == 2:
    if slice_order == "zxy":
        x_prime_for_pred = recon_sz3.transpose(2, 0, 1)
    elif slice_order == "yxz":
        x_prime_for_pred = recon_sz3.transpose(1, 0, 2)
    else:  # 'xyz'
        x_prime_for_pred = recon_sz3
else:
    x_prime_for_pred = recon_sz3

print(f"x_prime_for_pred shape: {x_prime_for_pred.shape}")

# Predict residuals using ROI model only (on entire volume)
pred_residuals_roi_list = []

try:
    with torch.no_grad():
        x_prime_norm = (x_prime_for_pred - input_mean) / input_std
        print(f"x_prime_norm shape: {x_prime_norm.shape}, dtype: {x_prime_norm.dtype}")
        
        if spatial_dims == 2:
            n_slices = x_prime_norm.shape[0]
            print(f"Processing {n_slices} slices in batches of {batch_size_2d}")
            
            for i, batch_start in enumerate(range(0, n_slices, batch_size_2d)):
                batch_end = min(batch_start + batch_size_2d, n_slices)
                
                if i % 10 == 0:
                    print(f"  Processing batch {i}: slices {batch_start}-{batch_end}")
                
                x_batch = torch.from_numpy(x_prime_norm[batch_start:batch_end]).float().unsqueeze(1).to(device)
                
                pred_roi = model_roi(x_batch)
                if isinstance(pred_roi, tuple):
                    pred_roi = pred_roi[0]
                
                pred_residuals_roi_list.append(pred_roi.cpu().numpy().squeeze(1))
                
                # 清理内存
                del x_batch, pred_roi
                if (i + 1) % 10 == 0:  # 每 10 个 batch 清理一次
                    gc.collect()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
            
            pred_residuals_roi_norm = np.concatenate(pred_residuals_roi_list, axis=0)
            print(f"✓ Concatenated residuals shape: {pred_residuals_roi_norm.shape}")
        else:
            z_dim = x_prime_norm.shape[2]
            print(f"Processing {z_dim} z-slices in batches of {batch_size_3d}")
            
            for i, batch_start in enumerate(range(0, z_dim, batch_size_3d)):
                batch_end = min(batch_start + batch_size_3d, z_dim)
                
                if i % 5 == 0:
                    print(f"  Processing batch {i}: z-slices {batch_start}-{batch_end}")
                
                x_batch = x_prime_norm[:, :, batch_start:batch_end]
                x_batch_tensor = torch.from_numpy(x_batch).float().unsqueeze(0).unsqueeze(0).to(device)
                
                pred_roi = model_roi(x_batch_tensor)
                if isinstance(pred_roi, tuple):
                    pred_roi = pred_roi[0]
                
                pred_residuals_roi_list.append(pred_roi.cpu().numpy().squeeze())
                
                # 清理内存
                del x_batch_tensor, pred_roi
                if (i + 1) % 5 == 0:  # 每 5 个 batch 清理一次
                    gc.collect()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
            
            pred_residuals_roi_norm = np.concatenate(pred_residuals_roi_list, axis=2)
            print(f"✓ Concatenated residuals shape: {pred_residuals_roi_norm.shape}")
        
        # Denormalize
        pred_residuals_roi = pred_residuals_roi_norm * residual_std + residual_mean
        
        # Transpose back if 2D
        if spatial_dims == 2:
            if slice_order == "zxy":
                pred_residuals_roi = pred_residuals_roi.transpose(1, 2, 0)
            elif slice_order == "yxz":
                pred_residuals_roi = pred_residuals_roi.transpose(1, 0, 2)
        
        print(f"✓ Final residuals shape: {pred_residuals_roi.shape}")
        print(f"Residual stats: mean={np.mean(pred_residuals_roi):.3e}, std={np.std(pred_residuals_roi):.3e}")

except Exception as e:
    print(f"✗ Prediction error: {e}")
    import traceback
    traceback.print_exc()
    raise
finally:
    # 清理
    del pred_residuals_roi_list
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()


Batch sizes: 2D=64, 3D=32
x_prime_for_pred shape: (512, 512, 512)
x_prime_norm shape: (512, 512, 512), dtype: float32
Processing 512 slices in batches of 64
  Processing batch 0: slices 0-64
✓ Concatenated residuals shape: (512, 512, 512)
✓ Final residuals shape: (512, 512, 512)
Residual stats: mean=-1.961e-01, std=7.505e-01


In [10]:
# Cell 7d: 计算最终结果
# Compute enhanced reconstruction: x_enhanced = x_prime + pred_residuals_roi (entire volume)
recon_roi_only = recon_sz3 + pred_residuals_roi

print(f"\nROI Model only prediction complete")
print(f"Reconstruction shape: {recon_roi_only.shape}")

# Calculate overall PSNR
def calculate_psnr(original, reconstructed):
    mse = np.mean((original - reconstructed) ** 2)
    data_range = np.max(original) - np.min(original)
    psnr = 20 * np.log10(data_range) - 10 * np.log10(mse) if mse > 0 else float("inf")
    return psnr

psnr_sz3 = calculate_psnr(data, recon_sz3)
psnr_roi_only = calculate_psnr(data, recon_roi_only)
delta_psnr_roi_only = psnr_roi_only - psnr_sz3

print(f"\n{'='*70}")
print(f"Overall PSNR Results (ROI Model Only):")
print(f"{'='*70}")
print(f"SZ3 Baseline PSNR:        {psnr_sz3:.4f} dB")
print(f"ROI Model Only PSNR:      {psnr_roi_only:.4f} dB")
print(f"ΔPSNR (ROI Only - SZ3):   {delta_psnr_roi_only:.4f} dB")
print(f"{'='*70}")



ROI Model only prediction complete
Reconstruction shape: (512, 512, 512)

Overall PSNR Results (ROI Model Only):
SZ3 Baseline PSNR:        82.4817 dB
ROI Model Only PSNR:      84.8987 dB
ΔPSNR (ROI Only - SZ3):   2.4170 dB


### BG Model 

In [11]:
# Cell 8a: BG Model - 设置和导入
import sys
from pathlib import Path
import torch
import gc

# 清理 GPU 内存（如果有）
if torch.cuda.is_available():
    torch.cuda.empty_cache()
gc.collect()

# Add path for compression_function import (如果还没有添加)
if '/Users/923714256/Data_compression/SZ3_+_2NN' not in sys.path:
    sys.path.insert(0, str(Path("/Users/923714256/Data_compression/SZ3_+_2NN")))
if '/Users/923714256/Data_compression/neural_compression' not in sys.path:
    sys.path.insert(0, '/Users/923714256/Data_compression/neural_compression')

try:
    from compression_function import create_model_for_decompress
    print("✓ Successfully imported create_model_for_decompress")
except Exception as e:
    print(f"✗ Import error: {e}")
    import traceback
    traceback.print_exc()
    raise

# Load BG model weights (separate file)
model_bg_path = components_dir / "dark_matter_density_model_bg.pt"
print(f"Loading BG model from: {model_bg_path}")
print(f"File exists: {model_bg_path.exists()}")

if not model_bg_path.exists():
    raise FileNotFoundError(f"BG Model file not found: {model_bg_path}")

with open(model_bg_path, "rb") as f:
    model_weights_bg_dict = torch.load(f, map_location="cpu")

print(f"✓ BG Model weights loaded: {len(model_weights_bg_dict)} parameters")

# Convert to numpy for compatibility (if needed)
if isinstance(model_weights_bg_dict, dict):
    if len(model_weights_bg_dict) > 0:
        first_val = next(iter(model_weights_bg_dict.values()))
        if isinstance(first_val, torch.Tensor):
            model_weights_bg = {k: v.numpy() for k, v in model_weights_bg_dict.items()}
            print("✓ Converted torch tensors to numpy")
        else:
            model_weights_bg = model_weights_bg_dict
            print("✓ Using numpy arrays directly")
    else:
        model_weights_bg = model_weights_bg_dict
else:
    raise ValueError(f"Unexpected model weights format: {type(model_weights_bg_dict)}")


✓ Successfully imported create_model_for_decompress
Loading BG model from: /Users/923714256/Data_compression/SZ3_+_2NN/relative_error_bound_4e-4/dual_model/compressed_components/channels_4/run_4/dark_matter_density_model_bg.pt
File exists: True
✓ BG Model weights loaded: 80 parameters
✓ Converted torch tensors to numpy


In [12]:
# Cell 8b: BG Model - 创建模型
# Get metadata
spatial_dims = meta.get("spatial_dims", 2)
slice_order = meta.get("slice_order", "zxy")
model_type_bg = meta.get("model_type_bg", meta.get("model_type", "tiny_frequency_residual_predictor_7_attn_roi"))

print(f"BG Model type: {model_type_bg}")
print(f"Spatial dims: {spatial_dims}, Slice order: {slice_order}")

# Create metadata for BG model
metadata_bg = meta.copy()
if "bg_low_cutoff" in meta:
    metadata_bg["low_cutoff"] = meta["bg_low_cutoff"]
if "bg_mid_cutoff" in meta:
    metadata_bg["mid_cutoff"] = meta["bg_mid_cutoff"]
if "bg_use_phase_sincos" in meta:
    metadata_bg["use_phase_sincos"] = meta["bg_use_phase_sincos"]
if "bg_gn_groups" in meta:
    metadata_bg["gn_groups"] = meta["bg_gn_groups"]

# Create BG model
device = "cuda:0" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

try:
    model_bg = create_model_for_decompress(model_type_bg, metadata_bg, spatial_dims, device)
    if model_bg is None:
        raise ValueError(f"Failed to create BG model: {model_type_bg}")
    print("✓ BG Model created successfully")
except Exception as e:
    print(f"✗ BG Model creation error: {e}")
    import traceback
    traceback.print_exc()
    raise

# Load weights
try:
    state_dict_bg = {k: torch.from_numpy(v) for k, v in model_weights_bg.items()}
    model_bg.load_state_dict(state_dict_bg)
    model_bg.eval()
    print("✓ BG Model weights loaded successfully")
except Exception as e:
    print(f"✗ BG Model loading error: {e}")
    import traceback
    traceback.print_exc()
    raise

# Clean up
del model_weights_bg_dict, model_weights_bg, state_dict_bg
gc.collect()
if torch.cuda.is_available():
    torch.cuda.empty_cache()


BG Model type: tiny_frequency_residual_predictor_7_attn_roi
Spatial dims: 2, Slice order: zxy
Using device: cpu
✓ BG Model created successfully
✓ BG Model weights loaded successfully


In [13]:
# Cell 8c: BG Model - 预测 residuals（使用更小的 batch size）
# Get normalization parameters
residual_mean = float(meta["residual_mean"])
residual_std = float(meta["residual_std"])
input_mean = float(meta["input_mean"])
input_std = float(meta["input_std"])

# 使用更小的 batch size 以避免内存问题
batch_size_2d = min(int(meta.get("batch_size_2d", 256)), 64)  # 限制最大为 64
batch_size_3d = min(int(meta.get("batch_size_3d", 256)), 32)  # 限制最大为 32

print(f"Batch sizes: 2D={batch_size_2d}, 3D={batch_size_3d}")

# Prepare x_prime for prediction (reuse from Cell 7c if available, otherwise recompute)
if spatial_dims == 2:
    if slice_order == "zxy":
        x_prime_for_pred_bg = recon_sz3.transpose(2, 0, 1)
    elif slice_order == "yxz":
        x_prime_for_pred_bg = recon_sz3.transpose(1, 0, 2)
    else:  # 'xyz'
        x_prime_for_pred_bg = recon_sz3
else:
    x_prime_for_pred_bg = recon_sz3

print(f"x_prime_for_pred shape: {x_prime_for_pred_bg.shape}")

# Initialize pred_residuals_bg outside try block to ensure it's accessible
pred_residuals_bg = None

# Predict residuals using BG model only (on entire volume)
pred_residuals_bg_list = []

try:
    with torch.no_grad():
        x_prime_norm_bg = (x_prime_for_pred_bg - input_mean) / input_std
        print(f"x_prime_norm shape: {x_prime_norm_bg.shape}, dtype: {x_prime_norm_bg.dtype}")
        
        if spatial_dims == 2:
            n_slices = x_prime_norm_bg.shape[0]
            print(f"Processing {n_slices} slices in batches of {batch_size_2d}")
            
            for i, batch_start in enumerate(range(0, n_slices, batch_size_2d)):
                batch_end = min(batch_start + batch_size_2d, n_slices)
                
                if i % 10 == 0:
                    print(f"  Processing batch {i}: slices {batch_start}-{batch_end}")
                
                x_batch = torch.from_numpy(x_prime_norm_bg[batch_start:batch_end]).float().unsqueeze(1).to(device)
                
                pred_bg = model_bg(x_batch)
                if isinstance(pred_bg, tuple):
                    pred_bg = pred_bg[0]
                
                pred_residuals_bg_list.append(pred_bg.cpu().numpy().squeeze(1))
                
                # 清理内存
                del x_batch, pred_bg
                if (i + 1) % 10 == 0:  # 每 10 个 batch 清理一次
                    gc.collect()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
            
            pred_residuals_bg_norm = np.concatenate(pred_residuals_bg_list, axis=0)
            print(f"✓ Concatenated residuals shape: {pred_residuals_bg_norm.shape}")
        else:
            z_dim = x_prime_norm_bg.shape[2]
            print(f"Processing {z_dim} z-slices in batches of {batch_size_3d}")
            
            for i, batch_start in enumerate(range(0, z_dim, batch_size_3d)):
                batch_end = min(batch_start + batch_size_3d, z_dim)
                
                if i % 5 == 0:
                    print(f"  Processing batch {i}: z-slices {batch_start}-{batch_end}")
                
                x_batch = x_prime_norm_bg[:, :, batch_start:batch_end]
                x_batch_tensor = torch.from_numpy(x_batch).float().unsqueeze(0).unsqueeze(0).to(device)
                
                pred_bg = model_bg(x_batch_tensor)
                if isinstance(pred_bg, tuple):
                    pred_bg = pred_bg[0]
                
                pred_residuals_bg_list.append(pred_bg.cpu().numpy().squeeze())
                
                # 清理内存
                del x_batch_tensor, pred_bg
                if (i + 1) % 5 == 0:  # 每 5 个 batch 清理一次
                    gc.collect()
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
            
            pred_residuals_bg_norm = np.concatenate(pred_residuals_bg_list, axis=2)
            print(f"✓ Concatenated residuals shape: {pred_residuals_bg_norm.shape}")
        
        # Denormalize
        pred_residuals_bg = pred_residuals_bg_norm * residual_std + residual_mean
        
        # Transpose back if 2D
        if spatial_dims == 2:
            if slice_order == "zxy":
                pred_residuals_bg = pred_residuals_bg.transpose(1, 2, 0)
            elif slice_order == "yxz":
                pred_residuals_bg = pred_residuals_bg.transpose(1, 0, 2)
        
        print(f"✓ Final residuals shape: {pred_residuals_bg.shape}")
        print(f"Residual stats: mean={np.mean(pred_residuals_bg):.3e}, std={np.std(pred_residuals_bg):.3e}")
        
        # 确保变量在全局作用域中可用
        globals()['pred_residuals_bg'] = pred_residuals_bg

except Exception as e:
    print(f"✗ Prediction error: {e}")
    import traceback
    traceback.print_exc()
    # 如果出错，确保变量被设置为 None
    pred_residuals_bg = None
    globals()['pred_residuals_bg'] = None
    raise
finally:
    # 清理临时变量，但保留 pred_residuals_bg
    if 'pred_residuals_bg_list' in locals():
        del pred_residuals_bg_list
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

# 验证 pred_residuals_bg 是否成功创建
if 'pred_residuals_bg' not in globals() or pred_residuals_bg is None:
    raise RuntimeError("pred_residuals_bg was not created successfully. Please check the error messages above.")
else:
    print(f"\n✓ pred_residuals_bg successfully created and available in global scope")
    print(f"  Shape: {pred_residuals_bg.shape}, dtype: {pred_residuals_bg.dtype}")


Batch sizes: 2D=64, 3D=32
x_prime_for_pred shape: (512, 512, 512)


x_prime_norm shape: (512, 512, 512), dtype: float32
Processing 512 slices in batches of 64
  Processing batch 0: slices 0-64
✓ Concatenated residuals shape: (512, 512, 512)
✓ Final residuals shape: (512, 512, 512)
Residual stats: mean=-1.931e-01, std=7.273e-01

✓ pred_residuals_bg successfully created and available in global scope
  Shape: (512, 512, 512), dtype: float32


In [14]:
# Cell 8d: BG Model - 计算最终结果
# 检查 pred_residuals_bg 是否存在（使用安全的检查方式）
if 'pred_residuals_bg' not in globals():
    raise NameError(
        "pred_residuals_bg is not defined. Please execute Cell 8c first to predict BG residuals.\n"
        "Cell 8c should create pred_residuals_bg by running BG model prediction."
    )

# 检查变量是否为 None
if globals()['pred_residuals_bg'] is None:
    raise ValueError(
        "pred_residuals_bg is None. Cell 8c may have failed. Please check Cell 8c output for errors."
    )

# 获取变量（现在可以安全访问）
pred_residuals_bg = globals()['pred_residuals_bg']
print(f"✓ Using pred_residuals_bg: shape={pred_residuals_bg.shape}, dtype={pred_residuals_bg.dtype}")

# Compute enhanced reconstruction: x_enhanced = x_prime + pred_residuals_bg (entire volume)
recon_bg_only = recon_sz3 + pred_residuals_bg

print(f"\nBG Model only prediction complete")
print(f"Reconstruction shape: {recon_bg_only.shape}")

# Calculate overall PSNR (reuse function from Cell 7d if available, otherwise define it)
if 'calculate_psnr' not in globals():
    def calculate_psnr(original, reconstructed):
        mse = np.mean((original - reconstructed) ** 2)
        data_range = np.max(original) - np.min(original)
        psnr = 20 * np.log10(data_range) - 10 * np.log10(mse) if mse > 0 else float("inf")
        return psnr

psnr_sz3_bg = calculate_psnr(data, recon_sz3)
psnr_bg_only = calculate_psnr(data, recon_bg_only)
delta_psnr_bg_only = psnr_bg_only - psnr_sz3_bg

print(f"\n{'='*70}")
print(f"Overall PSNR Results (BG Model Only):")
print(f"{'='*70}")
print(f"SZ3 Baseline PSNR:        {psnr_sz3_bg:.4f} dB")
print(f"BG Model Only PSNR:       {psnr_bg_only:.4f} dB")
print(f"ΔPSNR (BG Only - SZ3):    {delta_psnr_bg_only:.4f} dB")
print(f"{'='*70}")

# Compare with ROI Model results (if available)
if 'psnr_roi_only' in globals():
    print(f"\n{'='*70}")
    print(f"Comparison: ROI vs BG Model")
    print(f"{'='*70}")
    print(f"ROI Model Only PSNR:    {psnr_roi_only:.4f} dB")
    print(f"BG Model Only PSNR:     {psnr_bg_only:.4f} dB")
    print(f"Difference (ROI - BG):  {psnr_roi_only - psnr_bg_only:.4f} dB")
    print(f"{'='*70}")

✓ Using pred_residuals_bg: shape=(512, 512, 512), dtype=float32



BG Model only prediction complete
Reconstruction shape: (512, 512, 512)

Overall PSNR Results (BG Model Only):
SZ3 Baseline PSNR:        82.4817 dB
BG Model Only PSNR:       85.0123 dB
ΔPSNR (BG Only - SZ3):    2.5306 dB

Comparison: ROI vs BG Model
ROI Model Only PSNR:    84.8987 dB
BG Model Only PSNR:     85.0123 dB
Difference (ROI - BG):  -0.1136 dB


### BG + ROI Model Combination 

In [15]:
# Cell 9a: BG + ROI Model - 创建 ROI mask 并组合 residuals
import gc

# 检查必要的变量是否存在
required_vars = ['pred_residuals_bg', 'pred_residuals_roi', 'roi_boxes', 'recon_sz3', 'data']
missing_vars = [var for var in required_vars if var not in globals()]

if missing_vars:
    raise NameError(
        f"Missing required variables: {missing_vars}\n"
        "Please execute the following cells first:\n"
        "- Cell 7c: for pred_residuals_roi\n"
        "- Cell 8c: for pred_residuals_bg\n"
        "- Cell 3: for roi_boxes\n"
        "- Cell 4: for recon_sz3 and data"
    )

# 获取变量
pred_residuals_bg = globals()['pred_residuals_bg']
pred_residuals_roi = globals()['pred_residuals_roi']
roi_boxes = globals()['roi_boxes']

print("✓ All required variables are available")
print(f"  pred_residuals_bg shape: {pred_residuals_bg.shape}")
print(f"  pred_residuals_roi shape: {pred_residuals_roi.shape}")
print(f"  Number of ROI boxes: {len(roi_boxes)}")

# 创建 ROI mask
from compression_function import create_roi_mask

volume_shape = data.shape  # (512, 512, 512)
roi_mask = create_roi_mask(volume_shape, roi_boxes)
roi_mask_float = roi_mask.astype(np.float32)

roi_voxels = np.sum(roi_mask)
total_voxels = roi_mask.size
roi_percentage = (roi_voxels / total_voxels) * 100

print(f"\nROI Mask Statistics:")
print(f"  Total voxels: {total_voxels:,}")
print(f"  ROI voxels: {roi_voxels:,}")
print(f"  ROI percentage: {roi_percentage:.2f}%")

# 组合 residuals: R_combined = R_bg * (1 - M_roi) + R_roi * M_roi
# 在 ROI 区域使用 ROI model，在非 ROI 区域使用 BG model
pred_residuals_combined = pred_residuals_bg * (1 - roi_mask_float) + pred_residuals_roi * roi_mask_float

print(f"\n✓ Combined residuals created")
print(f"  Shape: {pred_residuals_combined.shape}")
print(f"  Mean: {np.mean(pred_residuals_combined):.3e}")
print(f"  Std: {np.std(pred_residuals_combined):.3e}")

# 确保变量在全局作用域中可用
globals()['pred_residuals_combined'] = pred_residuals_combined
globals()['roi_mask'] = roi_mask
globals()['roi_mask_float'] = roi_mask_float

print(f"\n✓ Variables saved to global scope")


✓ All required variables are available
  pred_residuals_bg shape: (512, 512, 512)
  pred_residuals_roi shape: (512, 512, 512)
  Number of ROI boxes: 100

ROI Mask Statistics:
  Total voxels: 134,217,728
  ROI voxels: 53,144,100
  ROI percentage: 39.60%

✓ Combined residuals created
  Shape: (512, 512, 512)
  Mean: -1.953e-01
  Std: 7.344e-01

✓ Variables saved to global scope


In [16]:
# Cell 9b: BG + ROI Model - 计算最终重建和 PSNR
# 检查组合 residuals 是否存在
if 'pred_residuals_combined' not in globals() or pred_residuals_combined is None:
    raise NameError(
        "pred_residuals_combined is not defined. Please execute Cell 9a first."
    )

pred_residuals_combined = globals()['pred_residuals_combined']

# 计算增强重建: x_enhanced = x_prime + pred_residuals_combined
recon_combined = recon_sz3 + pred_residuals_combined

print(f"✓ Combined reconstruction created")
print(f"  Shape: {recon_combined.shape}")
print(f"  Data range: [{np.min(recon_combined):.3e}, {np.max(recon_combined):.3e}]")

# 计算整体 PSNR
if 'calculate_psnr' not in globals():
    def calculate_psnr(original, reconstructed):
        mse = np.mean((original - reconstructed) ** 2)
        data_range = np.max(original) - np.min(original)
        psnr = 20 * np.log10(data_range) - 10 * np.log10(mse) if mse > 0 else float("inf")
        return psnr

psnr_sz3_baseline = calculate_psnr(data, recon_sz3)
psnr_combined = calculate_psnr(data, recon_combined)
delta_psnr_combined = psnr_combined - psnr_sz3_baseline

print(f"\n{'='*70}")
print(f"Overall PSNR Results (BG + ROI Model Combination):")
print(f"{'='*70}")
print(f"SZ3 Baseline PSNR:        {psnr_sz3_baseline:.4f} dB")
print(f"BG + ROI Combined PSNR:    {psnr_combined:.4f} dB")
print(f"ΔPSNR (Combined - SZ3):    {delta_psnr_combined:.4f} dB")
print(f"{'='*70}")

# 保存结果到全局作用域
globals()['recon_combined'] = recon_combined
globals()['psnr_combined'] = psnr_combined
globals()['delta_psnr_combined'] = delta_psnr_combined


✓ Combined reconstruction created
  Shape: (512, 512, 512)
  Data range: [-6.536e+00, 1.378e+04]

Overall PSNR Results (BG + ROI Model Combination):
SZ3 Baseline PSNR:        82.4817 dB
BG + ROI Combined PSNR:    84.9475 dB
ΔPSNR (Combined - SZ3):    2.4657 dB


In [17]:
# Cell 9c: BG + ROI Model - 完整对比总结
print(f"\n{'='*70}")
print(f"Complete Comparison: All Methods vs SZ3 Baseline")
print(f"{'='*70}")

# 获取所有可用的 PSNR 结果
results = {
    "SZ3 Baseline": psnr_sz3_baseline
}

if 'psnr_roi_only' in globals():
    results["ROI Model Only"] = psnr_roi_only
    results["ROI ΔPSNR"] = psnr_roi_only - psnr_sz3_baseline

if 'psnr_bg_only' in globals():
    results["BG Model Only"] = psnr_bg_only
    results["BG ΔPSNR"] = psnr_bg_only - psnr_sz3_baseline

if 'psnr_combined' in globals():
    results["BG + ROI Combined"] = psnr_combined
    results["Combined ΔPSNR"] = delta_psnr_combined

# 打印结果表格
print(f"\n{'Method':<25} {'PSNR (dB)':<15} {'ΔPSNR (dB)':<15}")
print(f"{'-'*55}")

for method, value in results.items():
    if "ΔPSNR" in method:
        continue
    delta_key = method.replace("Model Only", "ΔPSNR").replace("Combined", "Combined ΔPSNR")
    delta = results.get(delta_key, 0.0)
    print(f"{method:<25} {value:>14.4f} {delta:>14.4f}")

print(f"\n{'='*70}")

# 计算相对提升
if 'psnr_combined' in globals():
    print(f"\nImprovement Analysis:")
    print(f"  Combined vs SZ3:        {delta_psnr_combined:.4f} dB ({delta_psnr_combined/psnr_sz3_baseline*100:.2f}% relative)")
    
    if 'psnr_roi_only' in globals() and 'psnr_bg_only' in globals():
        improvement_over_roi = psnr_combined - psnr_roi_only
        improvement_over_bg = psnr_combined - psnr_bg_only
        print(f"  Combined vs ROI Only:   {improvement_over_roi:.4f} dB")
        print(f"  Combined vs BG Only:    {improvement_over_bg:.4f} dB")
        
        if improvement_over_roi > 0 and improvement_over_bg > 0:
            print(f"\n  ✓ Combined method outperforms both individual models!")
        elif improvement_over_roi > 0:
            print(f"\n  ✓ Combined method outperforms ROI model only")
        elif improvement_over_bg > 0:
            print(f"\n  ✓ Combined method outperforms BG model only")
        else:
            print(f"\n  ⚠ Combined method does not outperform individual models")

print(f"{'='*70}\n")



Complete Comparison: All Methods vs SZ3 Baseline

Method                    PSNR (dB)       ΔPSNR (dB)     
-------------------------------------------------------
SZ3 Baseline                     82.4817        82.4817
ROI Model Only                   84.8987         2.4170
BG Model Only                    85.0123         2.5306
BG + ROI Combined                84.9475         0.0000


Improvement Analysis:
  Combined vs SZ3:        2.4657 dB (2.99% relative)
  Combined vs ROI Only:   0.0487 dB
  Combined vs BG Only:    -0.0649 dB

  ✓ Combined method outperforms ROI model only

