In [1]:
import cupy as np
from cupyx.scipy.sparse.linalg import LinearOperator
import cupyx.scipy.sparse.linalg as linalg

def cov_matern(d, loghyper, x):
    ell = np.exp(loghyper[0])
    sf2 = np.exp(2 * loghyper[1])    
    def f(t):
        if d == 1: return 1
        if d == 3: return 1 + t
        if d == 5: return 1 + t * (1 + t / 3)
        if d == 7: return 1 + t * (1 + t * (6 + t) / 15)       
    def m(t):
        return f(t) * np.exp(-t)   
    dist_sq = ((x[:, None] - x[None, :]) / ell) ** 2
    return sf2 * m(np.sqrt(d * dist_sq))

def bohman(loghyper, x):
    range_ = np.exp(loghyper[0])
    dis = np.abs(x[:, None] - x[None, :])
    r = np.minimum(dis / range_, 1)
    k = (1 - r) * np.cos(np.pi * r) + np.sin(np.pi * r) / np.pi
    k[k < 1e-16] = 0
    k[np.isnan(k)] = 0
    return k

def unfold(tensor, mode):
    return np.reshape(np.moveaxis(tensor, mode, 0), (tensor.shape[mode], -1), order = 'F')

def fold(mat, dim, mode):
    index = list()
    index.append(mode)
    for i in range(dim.shape[0]):
        if i != mode:
            index.append(i)
    return np.moveaxis(np.reshape(mat, list(dim[index]), order = 'F'), 0, mode)

def kroneckerMVM(K3, K2, K1, vec, d1, d2, d3):
    temp1 = (K1 @ vec.reshape(d1, d2, d3, order = 'F').reshape(d1, -1)).reshape(d1, d2, d3)
    temp2 = (K2 @ temp1.transpose(1, 0, 2).reshape(d2, -1)).reshape(d2, d1, d3).transpose(1, 0, 2)
    #temp3 = (K3 @ temp2.transpose(2, 0, 1).reshape(d3, -1)).reshape(d3, d1, d2).transpose(1, 2, 0)
    temp3 = np.tensordot(temp2, K3, axes=([2],[0]))
    return temp3.ravel(order = 'F')

def Ap_operatorT(vec, maskT, KrU, KrU_T, Qu, rho, R, M):
    X = vec.reshape(R, M, order = 'F') 
    temp = KrU @ X      
    temp *= maskT  
    Ap1 = KrU_T @ temp
    Ap2 = rho * (X @ Qu)
    return (Ap1 + Ap2).ravel(order = 'F')

def cg_factorT(Qu, rho, KrU, mask_matrixT, YR_tilde, priorvalue, max_iter):
    R, M = YR_tilde.shape   
    n = R * M
    b = YR_tilde.ravel(order = 'F')    
    KrU_T = KrU.T   
    def matvec(v):
        # v is a cupy.ndarray of shape (n,)
        return Ap_operatorT(v, mask_matrixT, KrU, KrU_T, Qu, rho, R, M)
    # Build a LinearOperator that represents the R*M × R*M matrix
    A_op = LinearOperator((n, n), matvec    = matvec,  dtype     = b.dtype)
    x0 = priorvalue.copy()  # should be a cupy array of length n
    x, info = linalg.cg(A_op, b, x0 = x0, atol  = 1e-4, maxiter = max_iter )    
    return x, info

def Ap_operatorL(vec, pos_obs, Kd, Kt, Ks, gamma, d1, d2, d3):
    x = np.zeros(d1 * d2 * d3)
    x[pos_obs] = vec
    Ap1 = kroneckerMVM(Kd, Kt, Ks, x, d1, d2, d3)
    return Ap1[pos_obs] + gamma * vec

def cg_local(gamma, Kd, Kt, Ks, pos_obs, YR_tilde, priorvalue, max_iter):
    d1, d2, d3  = YR_tilde.shape
    n_obs = pos_obs[0].shape[0]
    b = (YR_tilde.ravel(order = 'F'))[pos_obs]
    def matvec(v):
        # v is cupy.ndarray, shape = (n_obs,)
        return Ap_operatorL(v, pos_obs, Kd, Kt, Ks, gamma, d1, d2, d3)

    A_op = LinearOperator((n_obs, n_obs), matvec = matvec, dtype  = b.dtype)
    x0 = priorvalue.copy()
    x_gpu, info = linalg.cg(A_op, b, x0 = x0, atol     = 1e-4,  maxiter = max_iter)
    return x_gpu, info

In [2]:
# from scipy.sparse import eye, csr_matrix
# from scipy.linalg import inv, khatri_rao

from cupyx.scipy.linalg import khatri_rao
from cupyx.scipy.sparse import eye, csr_matrix
import numpy

def GLSKF(I, Omega, lengthscaleU: list, lengthscaleR: list, varianceU: list, varianceR: list, tapering_range, d_MaternU, d_MaternR, R, rho, gamma, maxiter, K0, epsilon):
    N = I.shape
    N = numpy.array(N)
    
    D = I.ndim
    maxP = float(np.max(I))

    Omega = Omega.astype(bool)
    pos_miss = np.where(Omega == 0)
    num_obser = np.sum(Omega)
    mask_matrix = [unfold(Omega, d) for d in range(D)]
    idx = np.sum(mask_matrix[2], axis = 0) > 0
    train_matrix = I * Omega
    train_matrix = train_matrix[train_matrix > 0]
    Isubmean = I - np.mean(train_matrix)
    T = Isubmean * Omega
    mask_matrixT = [mask_matrix[d].T for d in range(D)]
    mask_flat = [mask_matrix[d].ravel(order = 'F') for d in range(D)]
    pos_obs = [np.where(mask_flat[d] == 1) for d in range(D)]

    hyper_Ku = [None] * D
    hyper_Ku[0] = [np.log(lengthscaleU[0]), np.log(varianceU[0])]
    hyper_Ku[1] = [np.log(lengthscaleU[1]), np.log(varianceU[1])]   
    hyper_Kr = [None] * D
    hyper_Kr[0] = [np.log(lengthscaleR[0]), np.log(varianceR[0]), np.log(tapering_range)]
    hyper_Kr[1] = [np.log(lengthscaleR[1]), np.log(varianceR[1]), np.log(tapering_range)]

    Ku, Kr = [None] * D, [None] * D
    invKu = [None] * D

    x = np.arange(1, N[0] + 1)
    Ku[0] = cov_matern(d_MaternU, hyper_Ku[0], x)
    invKu[0] = np.linalg.inv(Ku[0])
    TaperM = bohman([hyper_Kr[0][2]], x)
    Kr[0] = csr_matrix(cov_matern(d_MaternR, hyper_Kr[0][:2], x) * TaperM)

    x = np.arange(1, N[1] + 1)
    Ku[1] = cov_matern(d_MaternU, hyper_Ku[1], x)
    invKu[1] = np.linalg.inv(Ku[1])
    TaperM = bohman([hyper_Kr[1][2]], x)
    Kr[1] = csr_matrix(cov_matern(d_MaternR, hyper_Kr[1][:2], x) * TaperM)

    invKu[2] = np.eye(N[2])
    Kr[2] = np.eye(N[2])

    X = T
    X[pos_miss] = T.sum() / num_obser
    U = [0.1 * np.random.randn(N[d], R) for d in range(D)]
    M_unfold1 = U[0] @ khatri_rao(U[2], U[1]).T
    M = fold(M_unfold1, N, 0)
    Uvector = [U[d].ravel(order = 'F') for d in range(D)]
    UTvector = [U[d].T.ravel(order = 'F') for d in range(D)]
    Rtensor = np.zeros(N)
    Rvector = Rtensor.ravel(order = 'F')   
    Rvector_temp = Rtensor.ravel(order = 'F')
    X[pos_miss] = M[pos_miss] + Rtensor[pos_miss]

    d_all = np.arange(0, D)
    train_norm = np.linalg.norm(T)
    last_ten = T.copy()
    psnrf = np.zeros(maxiter)
    approxU = [None] * D
    iter = 0
    while True:
        Gtensor = X - Rtensor
        Gtensor_mask = Gtensor * Omega
        for d in range(D):
            dsub = np.delete(d_all, d)
            dsub = numpy.array(dsub.get())
            KrU = khatri_rao(U[dsub[1]], U[dsub[0]])          
            HG = KrU.T @ unfold(Gtensor_mask, d).T
            UTvector[d], approxU[d] = cg_factorT(invKu[d], rho, KrU, mask_matrixT[d], HG, UTvector[d], 100)           
            U[d] = (UTvector[d].reshape(R, N[d], order = 'F')).T          
        M_unfold1 = U[0] @ (khatri_rao(U[2], U[1]).T)
        M = fold(M_unfold1, N, 0)
        X[pos_miss] = M[pos_miss] + Rtensor[pos_miss]
        if iter >= K0:
            Ltensor = X - M
            Ltensor_mask = Ltensor * Omega
            Rvector_temp[pos_obs[0]], approxE = cg_local(gamma, Kr[2], Kr[1], Kr[0], pos_obs[0], \
                                                         Ltensor_mask, Rvector_temp[pos_obs[0]], 100)
            Rvector = kroneckerMVM(Kr[2], Kr[1], Kr[0], Rvector_temp, N[0], N[1], N[2])
            Rtensor = Rvector.reshape(N, order = 'F')
            Rtensor_unfold3 = unfold(Rtensor, 2)
            Rtensor_unfold3_obs = Rtensor_unfold3[:, idx]
            Kr[2] = np.cov(Rtensor_unfold3_obs)          
        else:
            Rtensor = np.zeros_like(Rtensor)    
        X[pos_miss] = M[pos_miss] + Rtensor[pos_miss]
        Xori = X + np.mean(train_matrix)
        Xrecovery = np.maximum(0, Xori)
        Xrecovery = np.minimum(maxP, Xrecovery)
        mseC1 = np.linalg.norm(I[:, :, 0].astype(float) - Xrecovery[:, :, 0], 'fro') ** 2 / (N[0] * N[1])
        psnrC1 = 10 * np.log10(maxP**2 / mseC1)
        mseC2 = np.linalg.norm(I[:, :, 1].astype(float) - Xrecovery[:, :, 1], 'fro') ** 2 / (N[0] * N[1])
        psnrC2 = 10 * np.log10(maxP**2 / mseC2)
        mseC3 = np.linalg.norm(I[:, :, 2].astype(float) - Xrecovery[:, :, 2], 'fro') ** 2 / (N[0] * N[1])
        psnrC3 = 10 * np.log10(maxP**2 / mseC3)
        psnrf[iter] = (psnrC1 + psnrC2 + psnrC3)/3
        iter += 1
        print(f"Epoch = {iter}, PSNR = {psnrf[iter-1]}")
        tol = np.linalg.norm((X - last_ten)) / train_norm
        last_ten = X.copy()       
        if (tol < epsilon) or (iter >= maxiter):
            break
    return Xori, Rtensor + np.mean(train_matrix), M + np.mean(train_matrix)

In [3]:
from PIL import Image
import scipy.io
seedr = 6
np.random.seed(seedr)
I = np.array(Image.open('./data/original/airplane.tiff'))
Omega = scipy.io.loadmat('./data/mask/airplane_90RM.mat')['Omega']

Omega = np.array(Omega)
lengthscaleU = np.ones(2) * 30
varianceU = np.ones(2)
lengthscaleR = np.ones(2) * 4
varianceR = np.ones(2)
tapering_range = 20
d_MaternU, d_MaternR = 3, 3
R = 20
rho, gamma = 20, 5
maxiter, K0 = 100, 70
epsilon = 1e-4
X, Rtensor, Mtensor = GLSKF(I, Omega, lengthscaleU, lengthscaleR, varianceU, varianceR, tapering_range, d_MaternU, d_MaternR, R, rho, gamma, maxiter, K0, epsilon)

Epoch = 1, PSNR = 15.256277296023853
Epoch = 2, PSNR = 19.313385559926797
Epoch = 3, PSNR = 21.731999296772486
Epoch = 4, PSNR = 22.310086864995473
Epoch = 5, PSNR = 22.50268845136583
Epoch = 6, PSNR = 22.621624824397816
Epoch = 7, PSNR = 22.71009957098217
Epoch = 8, PSNR = 22.77496573942173
Epoch = 9, PSNR = 22.821805077054137
Epoch = 10, PSNR = 22.856970634425483
Epoch = 11, PSNR = 22.88504630996594
Epoch = 12, PSNR = 22.90857176969148
Epoch = 13, PSNR = 22.928856140257665
Epoch = 14, PSNR = 22.946868721215697
Epoch = 15, PSNR = 22.96314882023564
Epoch = 16, PSNR = 22.977940959766304
Epoch = 17, PSNR = 22.99156371768808
Epoch = 18, PSNR = 23.004171612351314
Epoch = 19, PSNR = 23.015976934176184
Epoch = 20, PSNR = 23.02710892199489
Epoch = 21, PSNR = 23.03767961594926
Epoch = 22, PSNR = 23.047762931971636
Epoch = 23, PSNR = 23.057321427558147
Epoch = 24, PSNR = 23.066364008904685
Epoch = 25, PSNR = 23.074862101575025
Epoch = 26, PSNR = 23.082925068790917
Epoch = 27, PSNR = 23.09055035

In [None]:
%load_ext line_profiler
maxiter, K0 = 4, 1

%lprun -f linalg.cg GLSKF(I, Omega, lengthscaleU, lengthscaleR, varianceU, varianceR, tapering_range, d_MaternU, d_MaternR, R, rho, gamma, maxiter, K0, epsilon)

Epoch = 1, PSNR = 15.255557053642285
Epoch = 2, PSNR = 22.096866601103056


In [None]:
import matplotlib.pyplot as plt
import numpy as npp

I = npp.array(Image.open('./data/original/airplane.tiff'))
Omega = scipy.io.loadmat('./data/mask/airplane_90RM.mat')['Omega']

fig, axes = plt.subplots(2, 3, figsize=(12, 8))
fig.set_tight_layout(True)

# Top-left subplot
axes[0, 0].imshow(npp.uint8(npp.minimum(255, npp.abs(I * Omega))))
axes[0, 0].set_title('Observed data, 90% RM')
axes[0, 0].axis("off")

axes[0, 1].imshow(npp.uint8(npp.minimum(255, npp.abs(X.get()))))
axes[0, 1].set_title('Completed data')
axes[0, 1].axis("off")

axes[1, 0].imshow(npp.uint8(npp.minimum(255, npp.abs(Mtensor.get()))))
axes[1, 0].set_title('Global component')
axes[1, 0].axis("off")

axes[1, 1].imshow(npp.uint8(npp.minimum(255, npp.abs(Rtensor.get()))))
axes[1, 1].set_title('Local component')
axes[1, 1].axis("off")

axes[0, 2].imshow(npp.uint8(npp.minimum(255, I)))
axes[0, 2].set_title('Truth')
axes[0, 2].axis("off")


axes[1, 2].axis("off")


fig.savefig('Result.pdf')