# TV Denoising with Ensemble Regularization (Python version)
This notebook is a translation of the MATLAB script `tv_denoise_remove_huber_highorder_vectorized.m`.
It implements:
- Anisotropic TV
- Isotropic TV
- Exp(sqrt(.)) Regularizer
- Ensemble denoising (mean, weighted, median)

Using ADMM and FFT-based solvers.

In [None]:
# %% [Dependencies]
# Uncomment the following line if packages are missing
# !pip install numpy imageio matplotlib scikit-image scipy
import numpy as np
import imageio.v3 as imageio
import matplotlib.pyplot as plt
from skimage.color import rgb2ycbcr, ycbcr2rgb, rgb2gray
from scipy.fft import fft2, ifft2

## Parameters

In [None]:
# ---------------- Parameters ----------------
noisyFileName = 'noisyballoon2.png'
useColor = True

lambda_tv = 1e-4
lambda_exp = 1e-4
rho = 2.0
maxIter = 200
tol = 1e-5
showEvery = 25

## Load Image

In [None]:
Iin = imageio.imread(noisyFileName).astype(np.float64) / 255.0

if useColor and Iin.ndim == 3:
    I_ycbcr = rgb2ycbcr(Iin)
    Ychan = I_ycbcr[:, :, 0]
    Cb = I_ycbcr[:, :, 1]
    Cr = I_ycbcr[:, :, 2]
    y = Ychan.copy()
else:
    if Iin.ndim == 3:
        y = rgb2gray(Iin)
    else:
        y = Iin.copy()
    Cb, Cr = None, None

Ny, Nx = y.shape

## Operators

In [None]:
def grad(x):
    gx = np.roll(x, -1, axis=1) - x
    gy = np.roll(x, -1, axis=0) - x
    return gx, gy

def divergence(gx, gy):
    return gx - np.roll(gx, 1, axis=1) + gy - np.roll(gy, 1, axis=0)

## FFT Precomputation

In [None]:
ux, uy = np.meshgrid(np.arange(Nx), np.arange(Ny))
wx = 2 * np.pi * ux / Nx
wy = 2 * np.pi * uy / Ny
eigDtD = (2 - 2 * np.cos(wx)) + (2 - 2 * np.cos(wy))
denomFFT_first = 1 + rho * eigDtD

## 1️⃣ Anisotropic TV Denoising

In [None]:
x = y.copy()
zx = np.zeros_like(x)
zy = np.zeros_like(x)
ux_d = np.zeros_like(x)
uy_d = np.zeros_like(x)
prevx = x.copy()

for k in range(1, maxIter + 1):
    rhs = y + rho * divergence(zx - ux_d, zy - uy_d)
    x = np.real(ifft2(fft2(rhs) / denomFFT_first))
    dx, dy = grad(x)
    thresh = lambda_tv / rho
    zx = np.maximum(np.abs(dx + ux_d) - thresh, 0) * np.sign(dx + ux_d)
    zy = np.maximum(np.abs(dy + uy_d) - thresh, 0) * np.sign(dy + uy_d)
    ux_d += dx - zx
    uy_d += dy - zy
    relchg = np.linalg.norm(x - prevx) / max(1e-8, np.linalg.norm(prevx))
    if k % showEvery == 0 or k == 1 or relchg < tol:
        tvTerm = np.sum(np.abs(dx)) + np.sum(np.abs(dy))
        obj = 0.5 * np.sum((x - y) ** 2) + lambda_tv * tvTerm
        print(f"Aniso TV iter {k:4d}: obj={obj:.6f}, relchg={relchg:.3e}")
    if relchg < tol:
        break
    prevx = x.copy()

x_aniso = x.copy()

## 2️⃣ Isotropic TV Denoising

In [None]:
x = y.copy()
zx = np.zeros_like(x)
zy = np.zeros_like(x)
ux_d = np.zeros_like(x)
uy_d = np.zeros_like(x)
prevx = x.copy()

for k in range(1, maxIter + 1):
    rhs = y + rho * divergence(zx - ux_d, zy - uy_d)
    x = np.real(ifft2(fft2(rhs) / denomFFT_first))
    dx, dy = grad(x)
    v1, v2 = dx + ux_d, dy + uy_d
    mag = np.sqrt(v1 ** 2 + v2 ** 2)
    scale = np.maximum(0, 1 - lambda_tv / (rho * (mag + 1e-12)))
    zx = scale * v1
    zy = scale * v2
    ux_d += dx - zx
    uy_d += dy - zy
    relchg = np.linalg.norm(x - prevx) / max(1e-8, np.linalg.norm(prevx))
    if k % showEvery == 0 or k == 1 or relchg < tol:
        tvTerm = np.sum(np.sqrt(dx ** 2 + dy ** 2))
        obj = 0.5 * np.sum((x - y) ** 2) + lambda_tv * tvTerm
        print(f"Iso TV iter {k:4d}: obj={obj:.6f}, relchg={relchg:.3e}")
    if relchg < tol:
        break
    prevx = x.copy()

x_iso = x.copy()

## 3️⃣ Exp(sqrt(.)) Gradient Regularizer

In [None]:
def phi(t):
    return np.exp(np.sqrt(t)) - 1

x = y.copy()
vx = np.zeros_like(x)
vy = np.zeros_like(x)
bx = np.zeros_like(x)
by = np.zeros_like(x)
prevx = x.copy()

for k in range(1, maxIter + 1):
    rhs = y + rho * divergence(vx - bx, vy - by)
    x = np.real(ifft2(fft2(rhs) / denomFFT_first))

    dx, dy = grad(x)
    qx, qy = dx + bx, dy + by
    r = np.sqrt(qx ** 2 + qy ** 2)

    t = np.maximum(r - lambda_exp / rho, 0)
    tol_newton = 1e-12
    max_newton = 20

    for _ in range(max_newton):
        t_safe = np.maximum(t, 1e-12)
        dphi = 0.5 / np.sqrt(t_safe) * np.exp(np.sqrt(t_safe))
        g = lambda_exp * dphi + rho * (t - r)
        ddphi = (-0.25 / (t_safe ** 1.5)) * np.exp(np.sqrt(t_safe)) + 0.25 / t_safe * np.exp(np.sqrt(t_safe))
        H = lambda_exp * ddphi + rho
        t_new = np.maximum(t - g / (H + np.finfo(float).eps), 0)
        if np.max(np.abs(t_new - t)) < tol_newton:
            t = t_new
            break
        t = t_new

    vx = (t / np.maximum(r, 1e-12)) * qx
    vy = (t / np.maximum(r, 1e-12)) * qy
    bx += dx - vx
    by += dy - vy

    relchg = np.linalg.norm(x - prevx) / max(1e-8, np.linalg.norm(prevx))
    if k % showEvery == 0 or k == 1 or relchg < tol:
        gradmag = np.sqrt(dx ** 2 + dy ** 2)
        obj = 0.5 * np.sum((x - y) ** 2) + lambda_exp * np.sum(phi(gradmag))
        print(f"ExpReg iter {k:4d}: obj={obj:.6f}, relchg={relchg:.3e}")
    if relchg < tol:
        break
    prevx = x.copy()

x_exp = x.copy()

## 4️⃣ Ensemble Combination

In [None]:
def grad_magnitude(x):
    gx, gy = grad(x)
    return np.sqrt(gx**2 + gy**2)

reg_a = np.sum(np.abs(grad(x_aniso)[0]) + np.abs(grad(x_aniso)[1]))
reg_i = np.sum(grad_magnitude(x_iso))
reg_e = np.sum(np.exp(np.sqrt(grad_magnitude(x_exp))) - 1)
reg_values = np.array([reg_a, reg_i, reg_e])
weights = 1.0 / (reg_values + np.finfo(float).eps)
weights /= np.sum(weights)

x_list = np.stack([x_aniso, x_iso, x_exp], axis=-1)
x_ensemble_mean = np.mean(x_list, axis=-1)
x_ensemble_weighted = weights[0]*x_aniso + weights[1]*x_iso + weights[2]*x_exp
x_ensemble_median = np.median(x_list, axis=-1)

## 5️⃣ Color Reconstruction and Display

In [None]:
def clip_img(img):
    return np.clip(img, 0, 1)

from skimage.transform import resize

if useColor and Cb is not None and Cr is not None:
    if Cb.shape != (Ny, Nx):
        Cb = resize(Cb, (Ny, Nx))
        Cr = resize(Cr, (Ny, Nx))

    def combine(Y):
        return clip_img(ycbcr2rgb(np.stack([Y, Cb, Cr], axis=-1)))

    out_imgs = {
        'Aniso TV': combine(x_aniso),
        'Iso TV': combine(x_iso),
        'Exp(sqrt(.)) Reg': combine(x_exp),
        'Ensemble Mean': combine(x_ensemble_mean),
        'Ensemble Weighted': combine(x_ensemble_weighted),
        'Ensemble Median': combine(x_ensemble_median)
    }
else:
    out_imgs = {
        'Aniso TV': clip_img(x_aniso),
        'Iso TV': clip_img(x_iso),
        'Exp(sqrt(.)) Reg': clip_img(x_exp),
        'Ensemble Mean': clip_img(x_ensemble_mean),
        'Ensemble Weighted': clip_img(x_ensemble_weighted),
        'Ensemble Median': clip_img(x_ensemble_median)
    }

plt.figure(figsize=(16, 8))
plt.subplot(2, 4, 1)
plt.imshow(Iin)
plt.title('Noisy Input')
for i, (name, img) in enumerate(out_imgs.items(), start=2):
    plt.subplot(2, 4, i)
    plt.imshow(img)
    plt.title(name)
plt.tight_layout()
plt.show()

## 6️⃣ Save Outputs

In [None]:
for name, img in out_imgs.items():
    fname = 'denoised_' + name.lower().replace(' ', '_').replace('(', '').replace(')', '').replace('.', '') + '.png'
    imageio.imwrite(fname, (clip_img(img) * 255).astype(np.uint8))

print('All outputs saved.')