In [1]:
%matplotlib inline

import numpy as np
# import cupy as cp
import matplotlib.pyplot as plt
from scipy.signal import convolve2d
from skimage import data, io, color
from skimage.transform import resize
from pandas import DataFrame

plt.rcParams['figure.figsize'] = [10, 10]

def show_image(image, title, flip_x_axis=False):
    if flip_x_axis:
        image = np.fliplr(image)
    plt.imshow(image, cmap=plt.get_cmap("gray"))
    plt.title(title)
    plt.colorbar()
    plt.show()
    
def normalise(data):
    return (data - np.min(data)) / (np.max(data) - np.min(data))

def rrmse(observed, ideal, decimal=6):
    return "{:.{}f}".format(np.sqrt((1 / observed.shape[0]**2) * np.sum((observed-ideal)**2) / np.sum(ideal**2)) * 100.0, decimal)

def decimation_matrix(l, m):
    d_matrix = np.zeros((m**2, l**2), dtype=np.float32)

    tile = np.repeat((1, 0, 1), (2, l - 2, 2)) # assuming taking 2 neighbours per dimension
    t_len = tile.shape[0]
    d = l // m
    r_offset = m**2 // 2
    c_offset = l**2 // 2

    for p in np.arange(l//4): # divide by 4 as 4 neighbours total
        p_offset = p * l
        for q in np.arange(m):
            d_matrix[q+ p_offset//2, q*d + p_offset*2 : q*d+t_len + p_offset*2] = tile # top-left quadrant
            d_matrix[q+r_offset + p_offset//2, q*d+c_offset + p_offset*2: q*d+t_len+c_offset + p_offset*2] = tile # bottom-right quadrant
    return d_matrix

# produces convolution matrix of size l**2 by l**2, where each row is populated by the convolution kernel values at the appropriate neighbours
# note: assumes kernel is a two-dimensonal numpy array of some size n by n
def convolution_matrix(l, kernel):
    
    conv = np.zeros((l**2, l**2), dtype=np.float32)
    full_supp = kernel.shape[0] # assumed square
    half_supp = (full_supp - 1) // 2

    for conv_row in np.arange(l**2):

        row, col = (conv_row // l, conv_row % l)

        for k_row in np.arange(-(half_supp), half_supp + 1):
            # map "kernel row" to rows in conv
            mapped_row = row + k_row
            # ignore any out of bounds rows
            if mapped_row >= 0 and mapped_row < l:
                linear_col = col - half_supp
                # truncate negative columns
                mapped_col_start = max(linear_col, 0)
                # truncate columns which exceed the l dimension
                mapped_col_end = min(linear_col + full_supp, l)
                # left trimming for kernels when overlapping out of bounds region in conv (col < 0)
                left = np.absolute(col - half_supp) if linear_col < 0 else 0
                # right trimming for kernels when overlapping out of bounds region in conv (col >= l)
                right = linear_col + full_supp - l if linear_col + full_supp >= l else 0 
                # copy over kernel row for current k_row, possibly including trimming for out of bounds coordinates
                conv[conv_row][mapped_row * l + mapped_col_start : mapped_row * l + mapped_col_end] = kernel[k_row + half_supp][left: left + full_supp - right]
    return conv

#### Configuration and data set up...

In [None]:
dataframe = DataFrame()

timesteps = 30 # total timesteps
timesteps_per_y = 5
l = 100
m = 50
n = timesteps // timesteps_per_y
w = np.ones(n)

betas = [0.203091762, 0.203091762, 0.203091762, 0.366524124, 0.366524124, 0.412462638,
         0.366524124, 0.366524124, 0.366524124, 0.325702066, 0.289426612, 0.289426612, 
         0.289426612, 0.142510267, 0.289426612, 0.257191381, 0.257191381, 0.289426612, 
         0.289426612, 0.289426612, 0.289426612, 0.289426612, 0.289426612, 0.257191381, 
         0.257191381]
half_supports = np.arange(26, 51)

beta_half_supp_pairs = list(zip(betas, half_supports))

dataframe["Betas"] = betas
dataframe["Supports"] = half_supports * 2 - 1

filename = "../data/direct_image_ts_0_29_800x800.bin"
x_true = np.fromfile(filename, dtype=np.float32)
x_true = resize(x_true.reshape(800, 800), (l, l), anti_aliasing=False, order=1)
x_true = normalise(x_true)

filename = "../data/direct_psf_ts_0_29_800x800.bin"
x_psf = np.fromfile(filename, dtype=np.float32).reshape(800, 800)[1:, 1:]
x_psf = resize(x_psf, (l-1, l-1), anti_aliasing=False, order=1)
x_psf = np.pad(x_psf, ((1, 0), (1, 0))) # pad with new 0th row/col to ensure trimming from centre

# Storing all low-res images as layered stack
y = np.zeros((n, m, m))

# batched time steps direct images
for i in np.arange(n):
    filename = f"../data/direct_image_ts_{i * timesteps_per_y}_{i * timesteps_per_y + timesteps_per_y - 1}.bin"
    y[i] = np.fromfile(filename, dtype=np.float32).reshape(m, m)
    y[i] = normalise(y[i])

# Decimation matrix
d = decimation_matrix(l, m)

# Sharpening matrix (laplacian)
laplacian = np.array([[0, -1,  0], [-1,  4, -1], [0, -1,  0]], dtype=np.float32)
s = convolution_matrix(l, laplacian)

errors = []
a_lap_l1_norms = []
a_dec_l1_norms = []

for beta_supp_pair in beta_half_supp_pairs:
    β = beta_supp_pair[0]
    half_supp = beta_supp_pair[1]
    
    print(f"β = {β}, half supp = {half_supp} (or full = {half_supp * 2 - 1})") 

    psf_min = l//2 - (half_supp - 1)
    psf_max = l//2 + half_supp
    x_psf_trim = x_psf[psf_min:psf_max, psf_min:psf_max]
    x_psf_trim /= np.sum(x_psf_trim)

    # Blur matrix (psf)
    h = convolution_matrix(l, x_psf_trim)

    b = np.zeros(l**2, dtype=np.float32)

    for i in np.arange(n):
        b += np.matmul(w[i] * h.T, np.matmul(d.T, y[i].flatten()))

    lhs = β * np.matmul(s.T, s)
    rhs = (h.T @ d.T @ d @ h) * np.sum(w)
    a = lhs + rhs

    x = np.linalg.solve(a, b)
    x = x.reshape(100, 100)
    
    errors.append(rrmse(normalise(x), normalise(x_true)))
    
    a_laplacian = lhs
    a_lap_l1 = np.max(np.sum(np.absolute(a_laplacian), axis=0))
    a_deci_blur = rhs
    a_dec_l1 = np.max(np.sum(np.absolute(a_deci_blur), axis=0))

    a_lap_l1_norms.append(a_lap_l1)
    a_dec_l1_norms.append(a_dec_l1)
    
    if beta_supp_pair == beta_half_supp_pairs[-1]:
        dataframe["RRMSE"] = errors
        dataframe["A Lap L1 Norm"] = a_lap_l1_norms
        dataframe["A Dec L1 Norm"] = a_dec_l1_norms

β = 0.203091762, half supp = 26 (or full = 51)
β = 0.203091762, half supp = 27 (or full = 53)
β = 0.203091762, half supp = 28 (or full = 55)
β = 0.366524124, half supp = 29 (or full = 57)
β = 0.366524124, half supp = 30 (or full = 59)
β = 0.412462638, half supp = 31 (or full = 61)
β = 0.366524124, half supp = 32 (or full = 63)
β = 0.366524124, half supp = 33 (or full = 65)


In [None]:
dataframe