In [None]:
import numpy as np
import scipy.io
from PIL import Image
from matplotlib import pyplot as plt
from scipy.sparse.linalg import spsolve
from scipy.sparse.linalg import cg, gmres, bicgstab
from scipy import sparse
from scipy.sparse import csr_matrix
import os

def proj(x):
    x = np.where(x > 1, 1.0, x)
    x = np.where(x <0, 0, x)
    return x

def get_D(m, n):
    mn = m*n
    ones_mnm  = -np.append(np.ones(mn-m), np.zeros(m))
    ones_mm = -np.append(np.ones(m-1), [0])
    data = np.vstack((ones_mnm, np.ones(mn)))
    dgs = np.array([0,m])
    Dx  = sparse.spdiags(data, dgs, mn, mn)
    data = np.vstack((ones_mm, np.ones(m)))
    dgs = np.array([0,1])
    Dy_Base  = sparse.spdiags(data, dgs, m, m)
    Dy = sparse.kron(sparse.eye(n), Dy_Base)
    D = sparse.vstack([Dx, Dy])
    return D

In [None]:
class IGD:
    def __init__(self, u, L):
        self.c_last = None
        self.c = None
        self.c_next = None
        self.para = {'lipschitz' : 0.5, 'tolerance' : 1e-6, 'm': None, 'n': None, 'max_iter' : 1000, 
                     'beta' : 0.9, 'mu': 0.01, 'method': 'spsolve', 'rate': 0.05, 'punish': 0.85}
        self.u = u
        self.L = L
        self.mn = None
        self.reconstructed = None
    
    def solver(self, A, b):
        method = self.para['method']
        A=csr_matrix(A)
        if method == 'spsolve':
            return spsolve(A, b).reshape(-1,1)
        elif method == 'bicgstab':
            x, _ = bicgstab(A, b)
            return np.array(x).reshape(-1,1)
        elif method == 'cg':
            x, _ = cg(A, b)
            return np.array(x).reshape(-1,1)
        elif method == 'gmres':
            x, _ = gmres(A, b)
            return np.array(x).reshape(-1,1)
        else:
            raise ValueError('Invalid method')
    
    def set_parameters(self, parameters):
        for key, value in parameters.items():
            self.para[key] = value 
        self.initialize()
    
    def initialize(self):
        self.mn = self.para['m'] * self.para['n']
        self.c_last = None
        self.c = None
        self.c_next = None
        self.reconstructed = None

    def reconstruct(self):
        c = self.c.copy()
        c = c.flatten()
        diag_c = sparse.diags(c)
        Ac = diag_c + (diag_c - sparse.diags([1] * self.mn)) @ self.L
        b = diag_c @ self.u
        x = self.solver(Ac, b)
        self.reconstructed = x
    
    def report(self):
        a = np.dot((self.reconstructed - self.u).T, (self.reconstructed - self.u))
        PSNR = 10*np.log10(self.mn / a)
        rate = np.sum(self.c) / self.mn
        print('PSNR = ', PSNR, 'rate = ', rate)
    
    def barzilai_borwein_step_size(self):
        s = (self.c - self.c_last).flatten()
        y = (self.obj_grad() - self.obj_grad(-1)).flatten() 
        a = s.T @ y
        if a < 1e-10:
            return 0
        return s.T@s / a
        
    def plot(self):
        plt.figure(figsize=(12, 10))
        plt.subplot(1, 3, 1)
        plt.imshow(self.u.reshape((self.para['m'], self.para['n'])).T, cmap='viridis')
        plt.title("Original Image")
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(self.c.reshape((self.para['m'], self.para['n'])).T)
        plt.title("Compressed Point")
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(self.reconstructed.reshape((self.para['m'], self.para['n'])).T, cmap='viridis')
        plt.title("Reconstructed Image")
        plt.axis('off')
        plt.show()

    def f(self, index = 0):
        if index == 0:
            c = self.c.copy()
        elif index == 1:
            c = self.c_next.copy()
        elif index == -1:
            c = self.c_last.copy()
        c = c.flatten()

        diag_c = sparse.diags(c)
        Ac = diag_c + (diag_c - sparse.diags([1] * self.mn)) @ self.L
        b = diag_c @ self.u
        x = self.solver(Ac, b)
        return 0.5 * np.linalg.norm(x - self.u)**2 
    
    def obj(self, index = 0):
        return self.f(index) + self.para['mu'] * sum(self.c) 
    
    def f_grad(self, index = 0):
        if index == 0:
            c = self.c.copy()
        elif index == 1:
            c = self.c_next.copy()
        elif index == -1:
            c = self.c_last.copy()
        c = c.flatten()

        diag_c = sparse.diags(c)
        Ac = diag_c + np.dot(diag_c - sparse.diags([1] * self.mn), self.L)
        b = diag_c @ self.u
        x = self.solver(Ac, b)
        v = self.solver(Ac.T, x - self.u)
        return sparse.diags((-self.L @ x + self.u - x).flatten()) @ v 
    
    def obj_grad(self, index = 0):
        return self.f_grad(index) + self.para['mu'] * np.ones_like(self.c)
    
    def auto_initial(self):
        u = self.u.reshape((self.para['m'], self.para['n']))
        a = np.zeros_like(u).astype(np.float64)
        for i, j in np.ndindex(u.shape):
            try: 
                a[i, j] = (u[i, j] - u[i, j+1]) ** 2 + (u[i, j] - u[i+1, j]) ** 2 + (u[i, j] - u[i, j-1]) ** 2 + (u[i, j] - u[i-1, j]) ** 2
            except:
                pass
        temp = np.percentile(a, 100-self.para['rate']*100)
        result1 = np.zeros_like(a)
        result1[a >= temp] = 1
        for i, j in np.ndindex(u.shape):
            for di in range(-9, 10):
                for dj in range(-9+abs(di), 10-abs(di)):
                    distance = abs(di) + abs(dj)
                    try:
                        if result1[i + di, j + dj] == 1 and (di, dj) != (0, 0):
                            a[i, j] *= self.para['punish'] ** (1 / distance ** 1.5)
                    except:
                        pass
        temp = np.percentile(a, 100-self.para['rate']*100)
        result2 = np.zeros_like(a)
        result2[a >= temp] = 1
        return result2.reshape(-1,1)

    def optimize(self, initial_c = True, simple = False):
        if initial_c is True:
            self.c = self.auto_initial()
        else:
            self.c = initial_c.copy()
        self.c_last= self.c.copy()
        alpha = 1.99 * (1 - self.para['beta']) / self.para['lipschitz']

        for iter in range(self.para['max_iter']):
            f_grad_temp = self.f_grad()
            self.c_next = proj(self.c - alpha * (f_grad_temp + self.para['mu']*np.ones_like(self.c)) + 
                               self.para['beta'] * (self.c - self.c_last))  
            if np.mod(iter, 5) == 0:
                self.para['lipschitz'] *= 0.95
                alpha = 1.99 * (1 - self.para['beta']) / self.para['lipschitz']
            f_temp, obj_grad_temp = self.f(), self.obj_grad()
            while self.f(1) - f_temp >= np.dot(f_grad_temp.T, self.c_next - self.c) + 0.5 * self.para['lipschitz'] * np.linalg.norm(self.c_next - self.c)**2:
                self.para['lipschitz'] *= 2 
                alpha = 1.99 * (1 - self.para['beta']) / self.para['lipschitz']
                self.c_next = proj(self.c - alpha * obj_grad_temp + self.para['beta'] * (self.c - self.c_last)) 
            self.c_last, self.c = self.c, self.c_next
            if np.linalg.norm(self.c - self.c_last) / np.linalg.norm(self.c_last) <= self.para['tolerance']:
                break
            if iter % 10 == 0:
                print("Iterations: {}".format(iter),f'obj = {self.obj()}')

        self.reconstruct()
        if not simple:
            self.plot()
            self.report()
    
    def optimize_BB(self, initial_c, simple = False):
        self.c = initial_c.copy()
        self.c_last= self.c.copy()
        alpha = 1.99 * (1 - self.para['beta']) / self.para['lipschitz']

        for iter in range(self.para['max_iter']):
            alpha = self.barzilai_borwein_step_size()
            self.c_next = proj(self.c - alpha * self.obj_grad() + self.para['beta'] * (self.c - self.c_last))
            self.c_last, self.c = self.c, self.c_next
            if np.linalg.norm(self.c - self.c_last) / np.linalg.norm(self.c_last) <= self.para['tolerance']:
                break
            if iter % 10 == 0:
                print("Iterations: {}".format(iter),f'obj = {self.obj()}')

        self.reconstruct()
        if not simple:
            self.plot()
            self.report()


In [None]:
i = '512_512_lena.png'
image_path = os.path.join('test_images', i)
img = Image.open(image_path).convert('L')
img = np.array(img).astype(np.float64) / 255
m, n = img.shape
u = img.flatten('F').reshape(-1,1)
D = get_D(m,n)
L = - D.T @ D
Opt = IGD(u, L)
Opt.set_parameters({'m': m, 'n': n, 'method': 'bicgstab', 'max_iter': 500})
Opt.optimize()