In [2]:
import numpy as np
import imageio.v2 as imageio
import matplotlib.pyplot as plt
import scipy.fftpack as spfft
from skimage.transform import resize
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import mean_squared_error as mse
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from ipywidgets import interact, IntSlider, FloatSlider
import os

# ---------------------------------------------------
# 2D DCT / IDCT
# ---------------------------------------------------
def dct2(x):
    return spfft.dct(spfft.dct(x.T, norm='ortho', axis=0).T, norm='ortho', axis=0)

def idct2(x):
    return spfft.idct(spfft.idct(x.T, norm='ortho', axis=0).T, norm='ortho', axis=0)

# ---------------------------------------------------
# FISTA solver
# ---------------------------------------------------
def fista(b, ri, nx, ny, lam=0.01, max_iter=200):
    x = np.zeros((nx, ny))
    y = x.copy()
    t = 1
    L = 1.0
    for _ in range(max_iter):
        Ax = idct2(y).T.flat[ri]
        residual = Ax - b

        Axb2 = np.zeros((ny, nx))
        Axb2.T.flat[ri] = residual
        grad = dct2(Axb2)

        z = y - (1/L) * grad
        x_new = np.sign(z) * np.maximum(np.abs(z) - lam/L, 0)

        t_new = (1 + np.sqrt(1 + 4*t*t)) / 2
        y = x_new + ((t - 1)/t_new) * (x_new - x)
        x, t = x_new, t_new
    return x

# ---------------------------------------------------
# Interactive function
# ---------------------------------------------------
def run_cs(nx=64, sampling_percent=20, lam=0.05):
    # Load image
    fileName = "portrait.jpg"
    if not os.path.exists(fileName):
        raise FileNotFoundError(f"Image file '{fileName}' not found.")
    Xorig = imageio.imread(fileName)
    if Xorig.ndim == 3:
        Xorig = np.mean(Xorig, axis=2).astype(np.uint8)

    # Resize image to (ny, nx)
    ny = nx
    Xorig_resized = resize(Xorig, (ny, nx), anti_aliasing=True)
    Xorig_resized = (255*Xorig_resized).astype(np.uint8)

    # Sampling
    n = nx * ny
    k = round(n * (sampling_percent/100))
    ri = np.random.choice(n, k, replace=False)

    Xm = 255*np.ones_like(Xorig_resized)
    Xm.T.flat[ri] = Xorig_resized.T.flat[ri]
    b = Xorig_resized.T.flat[ri].astype(float)

    # Reconstruction with chosen λ
    print(f"Running FISTA: nx={nx}, samples={sampling_percent}%, λ={lam}")
    Xdct = fista(b, ri, nx, ny, lam=lam, max_iter=200)
    Xrec = idct2(Xdct)
    Xrec = np.clip(Xrec, 0, 255).astype("uint8")

    # Metrics
    psnr_val = psnr(Xorig_resized, Xrec, data_range=255)
    ssim_val = ssim(Xorig_resized, Xrec, data_range=255)
    mse_val = mse(Xorig_resized, Xrec)

    # Plotly figure
    fig = make_plotly_display(Xorig_resized, Xm, Xrec, nx, sampling_percent,
                              psnr_val, ssim_val, mse_val)
    fig.show()

# ---------------------------------------------------
# Helper: make Plotly figure
# ---------------------------------------------------
def make_plotly_display(orig, mask, rec, nx, sampling_percent, psnr_val, ssim_val, mse_val):
    fig = make_subplots(
        rows=1, cols=3,
        subplot_titles=[
            "Original (resized)",
            f"Masked ({sampling_percent}%)",
            "Reconstructed (FISTA)"
        ]
    )

    images = [orig, mask, rec]

    for i, im in enumerate(images):
        fig.add_trace(
            go.Heatmap(
                z=im,
                colorscale="gray",
                showscale=False,
                zmin=0, zmax=255
            ),
            row=1, col=i+1
        )
        # Fix orientation per subplot
        fig.update_yaxes(autorange="reversed", row=1, col=i+1)

    # Global font set to Times New Roman
    fig.update_layout(
        title=f"CS Reconstruction | size={nx}×{nx}, samples={sampling_percent}% "
              f"| PSNR={psnr_val:.2f} dB, SSIM={ssim_val:.3f}, MSE={mse_val:.2f}",
        margin=dict(l=10, r=10, t=50, b=10),
        height=400,
        width=1200,
        font=dict(family="Times New Roman", size=14),
        title_font=dict(family="Times New Roman", size=18)
    )
    return fig

# ---------------------------------------------------
# Interactive sliders
# ---------------------------------------------------
interact(
    run_cs,
    nx=IntSlider(min=16, max=128, step=16, value=64, description="Image size"),
    sampling_percent=IntSlider(min=5, max=80, step=5, value=20, description="% samples"),
    lam=FloatSlider(min=0.0, max=0.5, step=0.01, value=0.05, description="λ (sparsity)")
)


interactive(children=(IntSlider(value=64, description='Image size', max=128, min=16, step=16), IntSlider(value…

<function __main__.run_cs(nx=64, sampling_percent=20, lam=0.05)>