In [None]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"

from src.admm_model import ADMM_Net
from src.fn import *
import torch
import warnings

# 屏蔽所有 UserWarning
warnings.filterwarnings("ignore", category=UserWarning)

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

In [None]:
psf_file = r'.\data\psf.tiff'
measurement = r".\data\measurement.png"

psf_resized = load_and_downsample_normal2_image(
    psf_file,
    downsample=4,
    mode="gray",
    remove_bg=True,
    normalize=True,
    visualize=False
)
measurement_resized = load_and_downsample_normal2_image(
    measurement,
    downsample=1,
    mode="rgb",
    remove_bg=False,
    normalize=True,
    visualize=False
)
measurement_resized = measurement_resized.transpose((2, 0, 1))  # 转为 (C, H, W)

print(f"psf_resized: shape={psf_resized.shape}, dtype={psf_resized.dtype}, min={psf_resized.min():.3f}, max={psf_resized.max():.3f}")
print(f"measurement_resized: shape={measurement_resized.shape}, dtype={measurement_resized.dtype}, min={measurement_resized.min():.3f}, max={measurement_resized.max():.3f}")


In [None]:
iterations  = 100
tau_scale   = 1000

input_tensor = (
    torch.from_numpy(measurement_resized.copy())
    .float()
    .unsqueeze(0)
    .to(device)
)

admm_net = ADMM_Net(psf_resized, iterations, device)
admm_net.tau.data *= tau_scale
admm_net.to(device)

admm_net.eval()
with torch.no_grad():
    output_tensor = admm_net(input_tensor)

print(f"input Tensor: shape={input_tensor.shape}, dtype={input_tensor.dtype}, min={input_tensor.min():.3f}, max={input_tensor.max():.3f}")
print(f"output Tensor: shape={output_tensor.shape}, dtype={output_tensor.dtype}, min={output_tensor.min():.3f}, max={output_tensor.max():.3f}")

reconstruction = output_tensor[0].cpu().numpy().transpose(1, 2, 0)

In [None]:
ground_truth_file = r".\data\ground_truth.png"  # 如果没有可设为 None

visualize_reconstruction(
    psf_resized=psf_resized,
    measurement_resized=np.clip(measurement_resized.transpose(1, 2, 0) / measurement_resized.max(), 0, 1),
    reconstruction=reconstruction,
    ground_truth_file=ground_truth_file,
    iterations=iterations
)
