#Setup and Imports:

In [None]:

import torch, numpy as np, cv2, time, scipy.fftpack
import matplotlib
import matplotlib.pyplot as plt
from skimage.metrics import peak_signal_noise_ratio as psnr, structural_similarity as ssim

# Matplotlib plotting config for clear figures
matplotlib.rcParams.update({
    'font.size': 16,
    'axes.labelsize': 20,
    'axes.labelweight': 'bold',
    'axes.titlesize': 20,
    'axes.titleweight': 'bold',
    'xtick.labelsize': 16,
    'ytick.labelsize': 16,
    'legend.fontsize': 18,
    'lines.linewidth': 3,
    'font.family': 'serif',
    'savefig.dpi': 600,
    'text.usetex': True,
})

# For reproducibility
torch.manual_seed(0); np.random.seed(0)

# Choose GPU if available
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")


#Parameters and Paths:

In [None]:

# === Edit these for your data ===
img_path = '/home/shraddha/.dotnet/Set10C/3.png'  # Path to input image
iterations = 2500
gamma_init = 1.25
lam = 0.4
noise_std = 0 / 255.0    
koopman_window = 30
koopman_every = 10
gamma_min = 0.01 * gamma_init
kernel_size = 25
sigma = 1.0
channels = 3


#Load Image and Preprocessing:

In [None]:

# Load color image, normalize, ensure even shape for downsampling
img = cv2.imread(img_path, cv2.IMREAD_COLOR)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.astype(np.float32) / 255.0

H, W = img.shape[:2]
H_even, W_even = H // 2 * 2, W // 2 * 2
if (H_even != H) or (W_even != W):
    print(f"Input image shape {img.shape} is not even, cropping to ({H_even}, {W_even})")
    img = img[:H_even, :W_even]
H, W = img.shape[:2]
x_true = torch.tensor(img).permute(2,0,1).unsqueeze(0).to(device)
img_np = img
print(f"Loaded image of shape: {img_np.shape}")


#Forward operator:

In [None]:

import torch.nn.functional as F

def gaussian_kernel(size=15, sigma=2.0, channels=3):
    ax = np.arange(-size // 2 + 1., size // 2 + 1.)
    xx, yy = np.meshgrid(ax, ax)
    kernel = np.exp(-(xx**2 + yy**2) / (2. * sigma**2))
    kernel = kernel / np.sum(kernel)
    kernel = torch.tensor(kernel, dtype=torch.float32)
    kernel = kernel.expand(channels, 1, size, size)
    return kernel

kernel = gaussian_kernel(kernel_size, sigma, channels).to(device)

def blur(x):
    return F.conv2d(x, kernel, padding=kernel_size//2, groups=channels)
def blurT(x):
    return blur(x)

def subsample2x(x):
    return x[..., ::2, ::2]
def subsample2x_adjoint(x, out_shape):
    up = torch.zeros(out_shape, device=x.device, dtype=x.dtype)
    up[..., ::2, ::2] = x
    return up

def sr_forward(x):
    return subsample2x(blur(x))
def sr_adjoint(y, out_shape):
    up = subsample2x_adjoint(y, out_shape=out_shape)
    return blurT(up)


#Bicubic Upsampling for Initialization
def upsample_bicubic_aligned(y_lr, scale_factor=2):
    N, C, H, W = y_lr.shape
    up = F.interpolate(y_lr, scale_factor=scale_factor, mode='bicubic', align_corners=True)
    # Subpixel alignment (optional)
    grid_y, grid_x = torch.meshgrid(
        torch.arange(up.shape[2], dtype=up.dtype, device=up.device),
        torch.arange(up.shape[3], dtype=up.dtype, device=up.device),
        indexing='ij'
    )
    grid_y = grid_y.float() - 0.5
    grid_x = grid_x.float() - 0.5
    grid_y = grid_y / (up.shape[2] - 1) * 2 - 1
    grid_x = grid_x / (up.shape[3] - 1) * 2 - 1
    grid = torch.stack((grid_x, grid_y), dim=-1)
    grid = grid.unsqueeze(0).repeat(N,1,1,1)
    aligned = F.grid_sample(up, grid, mode='bilinear', padding_mode='border', align_corners=True)
    return aligned


#Preparing measurement:

In [None]:
#Loading the pretrained denoiser model from deepinv library
from deepinv.models import DRUNet
denoiser = DRUNet(pretrained='download', in_channels=3, out_channels=3).to(device).eval()
sigma_tensor = torch.tensor([noise_std], device=device)

###########-----------##########
y = sr_forward(x_true)
y_noisy = y + noise_std * torch.randn_like(y)
y_bicubic = upsample_bicubic_aligned(y_noisy, scale_factor=2)
y_bicubic_np = np.clip(y_bicubic[0].permute(1,2,0).detach().cpu().numpy(), 0, 1)
psnr_measurement = psnr(img_np, y_bicubic_np, data_range=1.0)
print(f"PSNR of upsampled measurement: {psnr_measurement:.2f} dB")
plt.imshow(y_bicubic_np)
plt.title("Bicubic Upsample")
plt.axis('off')
plt.show()


#SKOOP-RED:

In [None]:

import numpy as np
import scipy.fftpack

def extract_koopman_features(x):
    """
    Feature construction for SKOOP-RED:
    For each channel:
      - Global mean and std (2)
      - 4x4 grid means (16)
      - 2x2 DCT coefficients (4)
    Concatenate for all channels: 22x3=66 features.
    Input: x [1, 3, H, W] or [3, H, W] torch.Tensor
    Returns: np.array of shape (66,)
    """
    if hasattr(x, 'detach'):
        x_np = x.detach().cpu().numpy()
    else:
        x_np = np.asarray(x)
    if x_np.ndim == 4:  # [B, C, H, W]
        x_np = x_np[0]
    C, H, W = x_np.shape
    features = []
    grid_size = 4
    cell_H = H // grid_size
    cell_W = W // grid_size

    for c in range(C):
        ch = x_np[c]
        # (1) Global mean and std
        features.append(ch.mean())
        features.append(ch.std())
        # (2) 4x4 grid means
        for i in range(grid_size):
            for j in range(grid_size):
                patch = ch[
                    i*cell_H:(i+1)*cell_H,
                    j*cell_W:(j+1)*cell_W
                ]
                features.append(patch.mean())
        # (3) 2D DCT, top-left 2x2 block (lowest frequencies)
        dct = scipy.fftpack.dctn(ch, norm='ortho')
        features.extend(dct[:2, :2].flatten())
    return np.array(features)


def koopman_dmd(X, Y, tol=1e-6):
    U, S, Vt = np.linalg.svd(X, full_matrices=False)
    S_inv = np.diag([1/s if s > tol else 0 for s in S])
    return Y @ Vt.T @ S_inv @ U.T



#SKOOP-RED Main Loop
def skoop_red(
    y_noisy, denoiser, sr_forward, sr_adjoint, lam, gamma_init, gamma_min,
    koopman_window, koopman_every, max_iters,
    koopman_pred_threshold=1.05, beta=4, img_np=None
):
    """
    SKOOP-RED:
    - Uses extract_koopman_features(x) for 66-dim vector per iterate.
    - Shrinks gamma only if spectral radius >= koopman_pred_threshold.
    """
    gamma = gamma_init
    koopman_history = []
    gamma_list, radius_list, psnr_list, ssim_list, norm_list = [], [], [], [], []
    x = upsample_bicubic_aligned(y_noisy, scale_factor=2)
    x_prev = x.clone()
    snapshot_dict = {}
    best_psnr = -np.inf
    best_idx = -1
    for k in range(max_iters):
        Ax = sr_forward(x)
        grad_f = sr_adjoint(Ax - y_noisy, out_shape=x.shape)
        with torch.no_grad():
            Dx = denoiser(x, sigma=sigma_tensor)
        x_new = x - gamma * (grad_f + lam * (x - Dx))
        x_new = torch.clamp(x_new, 0, 1)

        # --- Koopman feature extraction  ---
        koopman_feature = extract_koopman_features(x_new)
        koopman_history.append(koopman_feature)
        if len(koopman_history) > koopman_window:
            koopman_history.pop(0)

        # --- Koopman update only if enough history and at interval ---
        radius = 0
        if k >= koopman_window and k % koopman_every == 0:
            Xn = np.stack(koopman_history[:-1], axis=1)  # shape (66, w-1)
            Yn = np.stack(koopman_history[1:], axis=1)   # shape (66, w-1)
            K = koopman_dmd(Xn, Yn)
            eigvals = np.linalg.eigvals(K)
            radius = np.max(np.abs(eigvals))
            if radius >= koopman_pred_threshold:
                eta = float(np.clip(1 - beta * (radius - 1) ** 2, 0.2, 1.0))
                gamma = max(gamma * eta, gamma_min)
            else:
                gamma = max(gamma * 0.995, gamma_min)
        else:
            gamma = max(gamma * 0.995, gamma_min)
        radius_list.append(radius)
        gamma_list.append(gamma)

        # --- Metrics --- #
        x_np = np.clip(x_new[0].permute(1,2,0).detach().cpu().numpy(), 0, 1)
        x_prev_np = np.clip(x_prev[0][0].permute(1,2,0).detach().cpu().numpy(), 0, 1) \
            if x_prev.ndim == 4 else np.clip(x_prev[0].permute(1,2,0).detach().cpu().numpy(), 0, 1)
        psnr_val = psnr(img_np, x_np, data_range=1.0) if img_np is not None else 0
        ssim_val = ssim(img_np, x_np, channel_axis=2, data_range=1.0) if img_np is not None else 0
        psnr_list.append(psnr_val)
        ssim_list.append(ssim_val)
        norm_list.append(np.linalg.norm(x_np - x_prev_np))
        if k in [9, 19, max_iters-1]:  # For illustration
            snapshot_dict[k] = x_np
        if psnr_val > best_psnr:
            best_psnr = psnr_val
            best_idx = k
        x_prev = x.clone()
        x = x_new.clone()
    if best_idx not in snapshot_dict:
        snapshot_dict[best_idx] = np.clip(x_np, 0, 1)
    return dict(
        psnr=psnr_list, ssim=ssim_list, gamma=gamma_list, radius=radius_list, norm=norm_list,
        snapshots=snapshot_dict, best_idx=best_idx
    )


#Vanilla RED:

In [None]:

def vanilla_red(y_noisy, denoiser, sr_forward, sr_adjoint, lam, gamma, max_iters, img_np):
    x = upsample_bicubic_aligned(y_noisy, scale_factor=2)
    psnr_list, ssim_list, norm_list = [], [], []
    x_prev = x.clone()
    snapshot_dict = {}
    best_psnr = -np.inf
    best_idx = -1
    for k in range(max_iters):
        Ax = sr_forward(x)
        grad_f = sr_adjoint(Ax - y_noisy, out_shape=x.shape)
        with torch.no_grad():
            Dx = denoiser(x, sigma=sigma_tensor)
        x_new = x - gamma * (grad_f + lam * (x - Dx))
        x_new = torch.clamp(x_new, 0, 1)
        x_np = np.clip(x_new[0].permute(1,2,0).detach().cpu().numpy(), 0, 1)
        x_prev_np = np.clip(x_prev[0].permute(1,2,0).detach().cpu().numpy(), 0, 1)
        psnr_val = psnr(img_np, x_np, data_range=1.0)
        ssim_val = ssim(img_np, x_np, channel_axis=2, data_range=1.0)
        psnr_list.append(psnr_val)
        ssim_list.append(ssim_val)
        norm_list.append(np.linalg.norm(x_np - x_prev_np))
        if k in [9, 19, max_iters-1]:
            snapshot_dict[k] = x_np
        if psnr_val > best_psnr:
            best_psnr = psnr_val
            best_idx = k
        x_prev = x.clone()
        x = x_new.clone()
    if best_idx not in snapshot_dict:
        snapshot_dict[best_idx] = x_np
    return psnr_list, ssim_list, norm_list, x_np.copy(), snapshot_dict, best_idx


#Run the experiments and plot results:

In [None]:

start_skoop = time.time()
skoop = skoop_red(
    y_noisy, denoiser, sr_forward, sr_adjoint, lam, gamma_init, gamma_min,
    koopman_window, koopman_every, iterations, beta=4, img_np=img_np
)
runtime_skoop = time.time() - start_skoop

start_vanilla = time.time()
van_psnr, van_ssim, van_norm, van_img, vanilla_snapshots, vanilla_peak = vanilla_red(
    y_noisy, denoiser, sr_forward, sr_adjoint, lam, gamma_init, iterations, img_np
)
runtime_vanilla = time.time() - start_vanilla

print(f"Final PSNRs: Vanilla {van_psnr[-1]:.2f}, Skoop-RED {skoop['psnr'][-1]:.2f}")
print(f"Final SSIMs: Vanilla {van_ssim[-1]:.4f}, Skoop-RED {skoop['ssim'][-1]:.4f}")
print(f"Peak Skoop-RED PSNR at iter {skoop['best_idx']}: {max(skoop['psnr']):.2f}")


#Plot Results (PSNR and Residual Norm)
fig, axs = plt.subplots(1, 2, figsize=(12,6))

axs[0].plot(van_psnr, label='Vanilla-RED', color='#D55E00')
axs[0].plot(skoop['psnr'], label=r'SKOOP-RED', color='#0072B2')
axs[0].set_ylabel(r"PSNR (dB)")
axs[0].set_xlabel(r"iteration")
axs[0].grid(True)
axs[0].legend(loc='lower right')

axs[1].plot(van_norm[1:], label='Vanilla-RED', color='#D55E00')
axs[1].plot(skoop['norm'][1:], label=r'SKOOP-RED', color='#0072B2')
axs[1].set_ylabel(r"$\lVert \boldsymbol{x}_t - \boldsymbol{x}_{t-1} \rVert$")
axs[1].set_xlabel(r"iteration ($t$)")
axs[1].set_yscale('log')
axs[1].set_ylim(1e-3, 10)
axs[1].grid(True)
axs[1].legend(loc='upper right')

plt.tight_layout()
plt.savefig('RED_num_SR_G.pdf')
plt.show()
