In [3]:
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

In [4]:
# load lena

lena_im = Image.open('data/lena_gray.bmp')
lena_arr = np.array(lena_im)
lena_vec = lena_arr.flatten()
lena_hist = np.unique(lena_arr, return_counts=True)

In [7]:
def get_f_from_rep(fhat,m):
    f = np.ones((m,m))
    N = len(fhat)
    for i in range(N):
        for j in range(N):
            ml, mh = i*m//N,(i+1)*m//N
            f[ml:mh,ml:mh] *= fhat

    return f

In [8]:
def Lp_mse(f,fhat,p,w):
    return np.sum(np.abs(f - get_f_from_rep(fhat,f.shape[0]))**p*w)

In [9]:
def Lp_solver(f, w, p, N, eps=1e-3, delta=1e-10):
    m,n = f.shape
    fhat = np.ones((N,N))
    wtag = w

    for i in range(N):
        for j in range(N):
            ml, mh = i*m//N,(i+1)*m//N
            nl, nh = j*n//N,(j+1)*n//N

            def step():
                # get w from prev fhat
                wtag[ml:mh,nl:nh] = (np.minimum(1/delta, np.abs(f - fhat[i,j])**(p-2) * w))[ml:mh,nl:nh]

                # get new fhat
                fhat[i,j] = np.sum(f*wtag[ml:mh,nl:nh])/np.sum(wtag)[ml:mh,nl:nh]
            
            prev_error = Lp_mse(f, fhat, p, w)
            step()
            error = Lp_mse(f, fhat, p, w)
            while np.abs(error - prev_error) > eps:
                prev_error = error
                step()
                error = Lp_mse(f, fhat, p, w)

    return fhat