# Assignment for LNO
---
#### Name - Saksham Maitri
#### Sr Number - 23787

---
Images used Cameraman.png and Lena.jpg

In [None]:
# Imports 
import os, time, math, cv2
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.fftpack import dct, idct
from scipy.sparse.linalg import cg, LinearOperator

def dct2(x_img): # 2D DCT of image
    return dct(dct(x_img, axis=0, norm='ortho'), axis=1, norm='ortho')

def idct2(Y): # 2D inverse DCT of coefficients
    return idct(idct(Y, axis=0, norm='ortho'), axis=1, norm='ortho')

def vec(img): # Flatten image to vector
    return img.flatten()

def im(v, N): # Reshape vector to image (N x N)
    return v.reshape(N, N)

def make_mask(N2, ratio, rng=np.random): # random masking of pixels given ratio
    M = np.zeros(N2, dtype=np.float64)
    ones = int(round(ratio * N2))
    M[:ones] = 1.0
    rng.shuffle(M)
    return M

def add_noise_measured(m, w_mask, SNR_db=30.0, rng=np.random): # Adding noise
    M = int(w_mask.sum())
    if M == 0:
        return m
    signal_norm = np.linalg.norm(m[w_mask == 1])
    sigma = signal_norm / (10 ** (SNR_db / 20) * math.sqrt(M))
    eta = np.zeros_like(m)
    eta[w_mask == 1] = rng.normal(0.0, sigma, M)
    return m + eta

def psnr(x_true, x_hat): # Calculating psnr value of predicted image
    mse = np.mean((x_true - x_hat) ** 2)
    if mse <= 1e-20:
        return 100.0
    return 20.0 * np.log10(1.0 / math.sqrt(mse))

def objective_J(w_mask, x, m, p, eps, lam): # Objective function J(x) for minimization
    data = np.linalg.norm(w_mask * x - m) ** 2
    Xd = dct(x, norm='ortho')
    prior = lam * np.sum((eps + Xd**2) ** p)
    return data + prior

# ---------- MM-CG with diagonal preconditioner + diagnostics ----------
def mm_cg_reconstruct(m, w_mask, p, lam, eps, tol_rel=1e-4, cg_maxiter=200, mm_maxiter=100):
    x = m.copy()  # zero-fill init (W^T m)
    N2 = x.size
    hist_J, hist_rel, hist_cg = [], [], []

    t0 = time.time()
    J_prev = objective_J(w_mask, x, m, p, eps, lam)

    # CG iteration counter via callback
    class CGCounter:
        def __init__(self): self.k = 0
        def __call__(self, _): self.k += 1

    for it in range(mm_maxiter):
        # weights in DCT domain
        Xd = dct(x, norm='ortho')
        wk = p * (eps + Xd**2) ** (p - 1)

        # Define A*v
        def A_mv(v):
            return (w_mask * v) + lam * idct(wk * dct(v, norm='ortho'), norm='ortho')
        A = LinearOperator((N2, N2), matvec=A_mv, dtype=np.float64)
        b = w_mask * m

        # Simple diagonal preconditioner:
        # M ≈ diag(w_mask) + λ * median(wk) * I  (positive diagonal)
        alpha = float(np.median(wk))
        diagM = w_mask + lam * alpha
        inv_diagM = 1.0 / (diagM + 1e-12)
        def M_inv(v):
            return inv_diagM * v
        M = LinearOperator((N2, N2), matvec=M_inv, dtype=np.float64) # preconditioner
        cb = CGCounter()
        x, info = cg(A, b, x0=x, M=M, maxiter=cg_maxiter, callback=cb) # CG solve
        J_cur = objective_J(w_mask, x, m, p, eps, lam)
        rel = abs(J_cur - J_prev) / (J_prev + 1e-16) # relative change
        # store history
        hist_J.append(J_cur)
        hist_rel.append(rel)
        hist_cg.append(cb.k)
        if rel <= tol_rel:
            break
        J_prev = J_cur
    elapsed = time.time() - t0
    hist = dict(J=hist_J, rel=hist_rel, cg_iters=hist_cg, time_sec=elapsed, mm_iters=len(hist_J))
    return x, hist

# ---------- plotting ----------
def save_psnr_vs_lambda(fig_dir, r, p_vals, lam_vals, psnr_grid):
    """
    psnr_grid shape: (len(p_vals), len(lam_vals))
    """
    plt.figure()
    for i, p in enumerate(p_vals):
        plt.plot(lam_vals, psnr_grid[i, :], marker='o', label=f"p={p}")
    plt.xscale('log')
    plt.xlabel(r"$\lambda$ (log scale)")
    plt.ylabel("PSNR (dB)")
    plt.title(f"PSNR vs $\\lambda$ @ sampling r={r}")
    plt.grid(True, which='both', ls=':')
    plt.legend()
    path = os.path.join(fig_dir, f"psnr_vs_lambda_r{r}.png")
    plt.savefig(path, dpi=200, bbox_inches='tight')
    plt.close()

def save_convergence_plots(fig_dir, r, p, lam, hist):
    # J vs MM iter
    plt.figure()
    plt.plot(np.arange(1, len(hist['J'])+1), hist['J'], marker='o')
    plt.xlabel("MM iteration")
    plt.ylabel("Objective J(x)")
    plt.title(f"Convergence: J vs iter (r={r}, p={p}, λ={lam})")
    plt.grid(True, ls=':')
    plt.savefig(os.path.join(fig_dir, f"conv_J_r{r}_p{p}_lam{lam}.png"), dpi=200, bbox_inches='tight')
    plt.close()

    # rel change vs MM iter (semilogy)
    plt.figure()
    plt.semilogy(np.arange(1, len(hist['rel'])+1), hist['rel'], marker='o')
    plt.xlabel("MM iteration")
    plt.ylabel("Rel. change")
    plt.title(f"Convergence: rel-change (r={r}, p={p}, λ={lam})")
    plt.grid(True, ls=':')
    plt.savefig(os.path.join(fig_dir, f"conv_rel_r{r}_p{p}_lam{lam}.png"), dpi=200, bbox_inches='tight')
    plt.close()

    # CG iterations per MM step
    plt.figure()
    plt.plot(np.arange(1, len(hist['cg_iters'])+1), hist['cg_iters'], marker='o')
    plt.xlabel("MM iteration")
    plt.ylabel("# CG iters")
    plt.title(f"CG iters per MM step (r={r}, p={p}, λ={lam})")
    plt.grid(True, ls=':')
    plt.savefig(os.path.join(fig_dir, f"cg_iters_r{r}_p{p}_lam{lam}.png"), dpi=200, bbox_inches='tight')
    plt.close()

def save_image(path, img01):
    img_clamped = np.clip(img01, 0, 1)
    cv2.imwrite(path, (img_clamped * 255).astype(np.uint8))

def save_dct_histograms(fig_dir, name_prefix, x_true_img, x_rec_img):
    Y_true = dct2(x_true_img)
    Y_rec  = dct2(x_rec_img)
    plt.figure()
    plt.hist(vec(np.abs(Y_true)).astype(np.float64), bins=100, alpha=0.7, label='|DCT(true)|')
    plt.hist(vec(np.abs(Y_rec)).astype(np.float64),  bins=100, alpha=0.7, label='|DCT(recon)|')
    plt.yscale('log')
    plt.xlabel("|coeff|")
    plt.ylabel("count (log)")
    plt.title("DCT magnitude histograms")
    plt.legend()
    plt.grid(True, ls=':')
    plt.savefig(os.path.join(fig_dir, f"{name_prefix}_dct_hist.png"), dpi=200, bbox_inches='tight')
    plt.close()

# ---------- experiment runner ----------
def run_experiments(image_path, image_name, r_list=(0.1, 0.2, 0.3, 0.5),
                    p_list=(0.3, 0.4, 0.5),
                    lam_list=(1e-4, 1e-3, 1e-2, 1e-1, 1.0),
                    eps=1e-6, snr_db=30.0,
                    tol_rel=1e-4, cg_maxiter=200, mm_maxiter=100,
                    out_root="results"):

    # loading grayscale, normalise it by dividing by 255.0
    img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE)
    if img is None:
        raise FileNotFoundError(f"Could not open {image_path}")
    img = img.astype(np.float64) / 255.0
    # make square by center-cropping (not needed as both the images are already square)
    if img.shape[0] != img.shape[1]:
        N = min(img.shape[:2])
        img = cv2.resize(img, (N, N), interpolation=cv2.INTER_AREA)
    # dim of image
    N = img.shape[0]
    x_true = vec(img) # ground truth vector (actual image)
    N2 = x_true.size # N*N for this square image
    
    # setting up directory paths
    root = os.path.join(out_root, image_name)
    os.makedirs(root, exist_ok=True)
    fig_dir = os.path.join(root, "figs")
    os.makedirs(fig_dir, exist_ok=True)
    # save the original image
    save_image(os.path.join(root, f"{image_name}_orig.png"), img)

    # results storage
    rows = []
    rng = np.random.default_rng(0)

    for r in r_list:
        # mask + measurements
        w_mask = make_mask(N2, r, rng)
        m_clean = w_mask * x_true
        m_noisy = add_noise_measured(m_clean, w_mask, SNR_db=snr_db, rng=rng)
        # save mask, noisy observation
        mask_img = im(w_mask, N)
        obs_img  = im(m_noisy, N)
        save_image(os.path.join(root, f"mask_r{r}.png"), mask_img)
        save_image(os.path.join(root, f"noisy_r{r}.png"), obs_img)
        # PSNR vs lambda per p
        psnr_grid = np.zeros((len(p_list), len(lam_list)), dtype=np.float64)
        best_records = {}  # (p) -> dict of best over lambda
        for i, p in enumerate(p_list):
            best = {"psnr": -1.0, "lam": None, "x": None, "hist": None}
            for j, lam in enumerate(lam_list):
                x_hat, hist = mm_cg_reconstruct(m_noisy, w_mask, p, lam, eps,
                                                tol_rel=tol_rel, cg_maxiter=cg_maxiter, mm_maxiter=mm_maxiter)
                Ximg = im(x_hat, N)
                cur_psnr = psnr(img, Ximg)
                rel_l2 = np.linalg.norm(x_true - x_hat) / np.linalg.norm(x_true)
                psnr_grid[i, j] = cur_psnr
                rows.append({
                    "image": image_name,
                    "r": r, "p": p, "lambda": lam,
                    "PSNR": cur_psnr,
                    "rel_l2": rel_l2,
                    "mm_iters": hist["mm_iters"],
                    "time_sec": hist["time_sec"],
                    "final_J": hist["J"][-1],
                    "last_rel_change": hist["rel"][-1] if len(hist["rel"]) else np.nan,
                    "avg_cg_iters": float(np.mean(hist["cg_iters"])) if len(hist["cg_iters"]) else np.nan
                })
                if cur_psnr > best["psnr"]:
                    best = {"psnr": cur_psnr, "lam": lam, "x": x_hat.copy(), "hist": hist}
            best_records[p] = best
        # Save PSNR-vs-lambda plot for all p at this r
        save_psnr_vs_lambda(fig_dir, r, p_list, lam_list, psnr_grid)
        # For each p: save best reconstruction images + residual + convergence + DCT hists
        for p in p_list:
            rec = best_records[p]
            Xbest = im(rec["x"], N)
            save_image(os.path.join(root, f"recon_best_r{r}_p{p}_lam{rec['lam']}.png"), Xbest)
            # residual
            residual = img - Xbest
            # rescale residual to 0..1 for visualization (optional: center to 0.5)
            res_vis = 0.5 + residual / (2 * max(1e-8, np.max(np.abs(residual))))
            save_image(os.path.join(root, f"residual_best_r{r}_p{p}_lam{rec['lam']}.png"), res_vis)
            # convergence plots
            save_convergence_plots(fig_dir, r, p, rec["lam"], rec["hist"])
            # DCT histograms (true vs recon)
            save_dct_histograms(fig_dir, f"best_r{r}_p{p}_lam{rec['lam']}", img, Xbest)
    # writing CSV log
    df = pd.DataFrame(rows)
    csv_path = os.path.join(root, f"{image_name}_reconstruction_log.csv")
    df.to_csv(csv_path, index=False)
    print(f"[{image_name}] results saved to: {root}")
    return root

# Run experiments on both the images
_ = run_experiments(image_path="Images/Cameraman.png", image_name="cameraman")
_ = run_experiments(image_path="Images/Lena.jpg", image_name="lena",)


[cameraman] results saved to: results/cameraman
[lena] results saved to: results/lena
