## ライブラリ定義

In [1]:
from __future__ import print_function
from builtins import input
from builtins import range

#import pyfftw   # See https://github.com/pyFFTW/pyFFTW/issues/40
import numpy as np
import functools
import operator
import matplotlib.pyplot as mplot
mplot.rcParams["axes.grid"] = False
import math
import pprint
import os
import shutil
import time

from scipy.linalg import toeplitz
from sporco.dictlrn import cbpdndl
from sporco.admm import cbpdn
from sporco import util
from sporco import plot
from sporco import cnvrep
import sporco.linalg as sl
import sporco.metric as sm
from sporco.admm import ccmod
from skimage.measure import compare_ssim
from skimage.measure import compare_psnr
plot.config_notebook_plotting()

In [2]:
def l2norm(A):
    l2norm = np.sum( abs(A)*abs(A) )
    return l2norm

def l0norm(A, threshold):
    return np.where(abs(A) < threshold, 0, 1).sum()

def strict_l0norm(A):
    return np.where(A == 0, 0, 1).sum()

def smoothedl0norm(A, sigma):
    N = functools.reduce(operator.mul, A.shape)
    # exp = np.sum( np.exp(-(A*A)/(2*sigma*sigma)) )
    # print(exp)
    # l0_norm = N - exp
    EPS = 0.0000001
    A_ = A.flatten()
    l0_norm = 0
    for a in A_:
        if a > EPS:
            l0_norm += 1
    return l0_norm

def getimages():
    exim = util.ExampleImages(scaled=True, zoom=0.5, gray=True)
    S1 = exim.image('barbara.png', idxexp=np.s_[10:522, 100:612])
    S2 = exim.image('kodim23.png', idxexp=np.s_[:, 60:572])
    S3 = exim.image('monarch.png', idxexp=np.s_[:, 160:672])
    S4 = exim.image('sail.png', idxexp=np.s_[:, 210:722])
    S5 = exim.image('tulips.png', idxexp=np.s_[:, 30:542])
    return np.dstack((S1, S2, S3, S4, S5))

def saveimg(img, filename, title=None):
    fig = plot.figure(figsize=(7, 7))
    plot.imview(img, fig=fig)
    fig.savefig(filename)
    plot.close()
    mplot.close()

# imgs.shape == (R, C, imgR, imgC) or (C, imgR, imgC)
def saveimg2D(imgs, filename, titles=None):
    if imgs.ndim == 3:
        imgs = np.array([imgs])
    if titles is not None and titles.ndim == 3:
        titles = np.array([titles])
    R = imgs.shape[0]
    C = imgs.shape[1]
    fig = plot.figure(figsize=(7*C, 7*R))
    for r in range(R):
        for c in range(C):
            ax = fig.add_subplot(R, C, r*C + c + 1)
            s = None
            if titles is not None:
                s = titles[r][c]
            plot.imview(imgs[r][c], title=s, fig=fig, ax=ax)
    plot.savefig(filename)
    plot.close()
    mplot.close()

# be careful of non-robust implementation
def format_sig(signal):
    return np.transpose(signal, (3, 0, 1, 2, 4)).squeeze()

def saveXimg(cri, Xr, filename):
    # print(Xr.shape)
    X = np.sum(abs(Xr), axis=cri.axisM).squeeze()
    fig = plot.figure(figsize=(7, 7))
    plot.imview(X, cmap=plot.cm.Blues, fig=fig)
    fig.savefig(filename)
    plot.close()
    mplot.close()

def saveXhist(Xr, filename):
    Xr_ = abs(Xr.flatten())
    fig = plot.figure(figsize=(7*10, 7))
    ax = fig.add_subplot(1, 1, 1)
    ax.hist(Xr_, bins=500, density=True)
    fig.savefig(filename)
    plot.close()
    mplot.close()

def save_result(D0, D, X, S, S_reconstructed, filename):
    titles = [[], []]
    r1 = []
    for k in range(S.shape[-1]):
        r1.append(S.T[k].T)
        titles[0].append('')
    r1.append(util.tiledict(D0))
    titles[0].append('')
    r2 = []
    for k in range(S.shape[-1]):
        r2.append(S_reconstructed.T[k].T)
        psnr = sm.psnr(S.T[k].T, S_reconstructed.T[k].T)
        ssim = compare_ssim(S.T[k].T, S_reconstructed.T[k].T)
        l0 = strict_l0norm(np.rollaxis(X, 2)[k])
        titles[1].append("PSNR: %.3fdb\nSSIM: %.4f\nl0norm: %d" % (psnr, ssim, l0))
    r2.append(util.tiledict(D))
    titles[1].append('')
    saveimg2D(np.array([r1, r2]), filename, np.array(titles))

def compressedXk(Xrk, size_rate):
    Xrk = Xrk.copy()
    X_flat = np.ravel(Xrk)
    n = math.ceil(X_flat.size*(1 - size_rate))
    print(str(X_flat.size) + " -> " + str(X_flat.size - n))
    for i in np.argsort(abs(X_flat))[0:n]:
        X_flat[i] = 0
    return Xrk

def to_inative(X, sigma):
    return np.where(X < sigma, 0, X)

# a specific axis to 1-length
# copied
def compress_axis(A, axis, i):
    idx = [slice(None)]*A.ndim
    idx[axis] = slice(i, i + 1)
    return A[tuple(idx)]

def compress_axis_op(A, axis, i):
    idx = [slice(None)]*A.ndim
    idx[axis] = slice(i, i + 1)
    return tuple(idx)

def reconstruct(cri, Dr, Xr):
    Xf = sl.rfftn(Xr, s=cri.Nv, axes=cri.axisN)
    Df = sl.rfftn(Dr, s=cri.Nv, axes=cri.axisN)
    return sl.irfftn(sl.inner(Df, Xf, axis=cri.axisM), s=cri.Nv, axes=cri.axisN)

def save_reconstructed(cri, Dr, Xr, Sr, filename, Sr_add=None):
    Sr_ = reconstruct(cri, Dr, Xr)
    if Sr_add is None:
        Sr_add = np.zeros_like(Sr)
    img = np.stack((format_sig(Sr + Sr_add), format_sig(Sr_ + Sr_add)), axis=1)
    saveimg2D(img, filename)

def compressedX(cri, Xr, Sr, size_rate):
    Xr_cmp = Xr.copy()
    for k in range(cri.K):
        s = compress_axis_op(Xr_cmp, cri.axisK, k)
        Xr_cmp[s] = compressedXk(Xr_cmp[s], (Sr.size / Xr.size)*size_rate)
    return Xr_cmp

def calcXr(cri, Dr, Sr, lmbda=5e-2):
    opt = cbpdn.ConvBPDN.Options({'Verbose': True, 'MaxMainIter': 200,
                                  'RelStopTol': 5e-3, 'AuxVarObj': False})
    b = cbpdn.ConvBPDN(Dr.squeeze(), Sr.squeeze(), lmbda, opt, dimK=cri.dimK, dimN=cri.dimN)
    Xr = b.solve()
    return Xr

def evaluate_result(cri, Dr0, Dr, Sr, Sr_add=None, lmbda=5e-2, title='result.png'):
    Xr_ = calcXr(cri, Dr, Sr, lmbda)
    print("strict l0 norm", strict_l0norm(Xr_))
    print("l2norm: ", l2norm(Xr_))
    for k in range(cri.K):
        print("image %d: strict l0 norm %f" % (k, strict_l0norm(compress_axis(Xr_, cri.axisK, k))))
    if Sr_add is None:
        Sr_add = np.zeros_like(Sr)
    save_result(Dr0.squeeze(), Dr.squeeze(), Xr_.squeeze(), (Sr + Sr_add).squeeze(), (reconstruct(cri, Dr, Xr_) + Sr_add).squeeze(), title)

def l2norm_minimize(cri, Dr, Sr):
    Df = sl.rfftn(Dr, s=cri.Nv, axes=cri.axisN) # implicitly zero-padding
    Sf = sl.rfftn(Sr, s=cri.Nv, axes=cri.axisN) # implicitly zero-padding
    Xf = np.conj(Df) / sl.inner(Df, np.conj(Df), axis=cri.axisM) * Sf
    Xr = sl.irfftn(Xf, s=cri.Nv, axes=cri.axisN)

    Sr_ = sl.irfftn(sl.inner(Df, Xf, axis=cri.axisM), s=cri.Nv, axes=cri.axisN)
    # print(l2norm(np.random.randn(*Xr.shape)))
    # print(l2norm(Xr))
    # print(l2norm(Sr - Sr_))
    po = np.stack((format_sig(Sr), format_sig(Sr_)), axis=1)
    saveimg2D(po, 'l2norm_minimization_test.png') # the right side is Sr_
    return Xr

def convert_to_Df(D):
    Dr = np.asarray(D.reshape(cri.shpD), dtype=S.dtype)
    Df = sl.rfftn(Dr, cri.Nv, cri.axisN)
    return Df

def convert_to_Sf(S):
    Sr = np.asarray(S.reshape(cri.shpS), dtype=S.dtype)
    Sf = sl.rfftn(Sr, None, cri.axisN)
    return Sf

def convert_to_S(Sf):
    S = sl.irfftn(Sf, cri.Nv, cri.axisN).squeeze()
    return S

def convert_to_Xf(X):
    Xr = np.asarray(X.reshape(cri.shpX), dtype=S.dtype)
    Xf = sl.rfftn(Xr, cri.Nv, cri.axisN)
    return Xf

def convert_to_X(Xf):
    X = sl.irfftn(Xf, cri.Nv, cri.axisN).squeeze()
    return X


def derivD_spdomain(cri, Xr, Sr, Df, Xf, dict_Nv):
    B = sl.irfftn(sl.inner(Df, Xf, axis=cri.axisM), s=cri.Nv, axes=cri.axisN) - Sr
    B = B[np.newaxis, np.newaxis,]
    Xshifted = np.ones(dict_Nv + Xr.shape) * Xr
    
    N1 = 0
    N2 = 1
    I = 2
    J = 3

    print("start shifting")
    for n1 in range(dict_Nv[0]):
        for n2 in range(dict_Nv[1]):
            Xshifted[n1][n2] = np.roll(Xshifted[n1][n2], (n1, n2), axis=(I, J))
            # print("shifted ", (n1, n2))
    ret = np.sum(np.conj(B) * Xshifted, axis=(I, J, 2 + cri.axisK), keepdims=True)
    print(ret.shape)
    ret = ret[:, :, 0, 0]
    print(ret.shape)
    return ret

def goldenRatioSearch(function, rng, cnt):
    # 黄金探索法によるステップ幅の最適化
    gamma = (-1+np.sqrt(5))/2
    a = rng[0]
    b = rng[1]
    p = b-gamma*(b-a)
    q = a+gamma*(b-a)
    Fp = function(p)
    Fq = function(q)
    width = 1e8
    for i in range(cnt):
        if Fp <= Fq:
            b = q
            q = p
            Fq = Fp
            p = b-gamma*(b-a)
            Fp = function(p)
        else:
            a = p
            p = q
            Fp = Fq
            q = a+gamma*(b-a)
            Fq = function(q)
            width = abs(b-a)/2
    alpha = (a+b)/2
    return alpha

# 下に凸
def ternary_search(f, rng, cnt):
    left = rng[0]
    right = rng[1]
    for i in range(cnt):
        if f((left * 2 + right) / 3) > f((left + right * 2) / 3):
            left = (left * 2 + right) / 3
        else:
            right = (left + right * 2) / 3
    return (left + right) / 2

def min_max(x, axis=None):
    min = x.min(axis=axis, keepdims=True)
    max = x.max(axis=axis, keepdims=True)
    return (x - min) / (max - min)

def zscore(x, axis = None):
    xmean = x.mean(axis=axis, keepdims=True)
    xstd  = np.std(x, axis=axis, keepdims=True)
    zscore = (x-xmean)/xstd
    return zscore

def normalize(v, axis=-1, order=2):
    l2 = np.linalg.norm(v, ord=order, axis=axis, keepdims=True)
    l2[l2==0] = 1
    return v/l2

def to_frequency(cri, Ar):
    return sl.rfftn(Ar, s=cri.Nv, axes=cri.axisN)

def to_spatial(cri, Af):
    return sl.irfftn(Af, s=cri.Nv, axes=cri.axisN)

def update_dict(cri, Pcn, crop_op, Xr, Gr, Hr, Sf, param_rho):
    # D step
    Xf = to_frequency(cri, Xr)
    Gf = to_frequency(cri, Gr)
    Hf = to_frequency(cri, Hr)
    XSf = sl.inner(np.conj(Xf), Sf, cri.axisK)
    b = XSf + param_rho * (Gf - Hf)
    Df = sl.solvemdbi_ism(Xf, param_rho, b, cri.axisM, cri.axisK)
    Dr = to_spatial(cri, Df)
    # G step
    Gr = Pcn(Dr + Hr)
    # H step
    Hr = Hr + Dr - Gr
    return Gr[crop_op], Hr

def nakashizuka_solve(
    cri, Dr0, Xr, Sr,
    final_sigma,
    maxitr = 40,
    param_mu = 1,
    param_lambda = 1e-2,
    debug_dir = None
):
    
    param_rho = 0.5

    Xr = Xr.copy()
    Sr = Sr.copy()
    Dr = Dr0.copy()
    Hr = np.zeros_like(cnvrep.zpad(Dr, cri.Nv))

    Sf = to_frequency(cri, Sr)

    # sigma set
    # sigma_list = []
    # sigma_list.append(Xr.max()*4)
    # for i in range(7):
    #     sigma_list.append(sigma_list[i]*0.5)
    first_sigma = Xr.max()*4
    c = (final_sigma / first_sigma) ** (1/(maxitr - 1))
    print("c = %.8f" % c)
    sigma_list = []
    sigma_list.append(first_sigma)
    for i in range(maxitr - 1):
        sigma_list.append(sigma_list[i]*c)
    
    crop_op = []
    for l in Dr.shape:
        crop_op.append(slice(0, l))
    crop_op = tuple(crop_op)
    Pcn = cnvrep.getPcn(Dr.shape, cri.Nv, cri.dimN, cri.dimCd, zm=False)

    updcnt = 0
    dictcnt = 0
    for sigma in sigma_list:
        print("sigma = %.8f" % sigma)
        # Xf_old = sl.rfftn(Xr, cri.Nv, cri.axisN)
        for l in range(1):
            # print("l0norm: %f" % l0norm(Xr, sigma_list[-1]))
            # print('error1: ', l2norm(Sr - reconstruct(cri, Dr, Xr)))
            # print("l2(Xr): %.6f, l2(delta): %.6f" % (l2norm(Xr), l2norm(delta)))
            delta = Xr * np.exp(-(Xr*Xr) / (2*sigma*sigma))
            Xr = Xr - param_mu*delta# + np.random.randn(*Xr.shape)*sigma*1e-1
            Xf = to_frequency(cri, Xr)

            # print('error2: ', l2norm(Sr - reconstruct(cri, Dr, Xr)))

            Df = to_frequency(cri, Dr)
            b = Xf / param_lambda + np.conj(Df) * Sf
            Xf = sl.solvedbi_sm(Df, 1/param_lambda, b, axis=cri.axisM)
            Xr = to_spatial(cri, Xf).astype(np.float32)
            
            # print('error3: ', l2norm(Sr - reconstruct(cri, Dr, Xr)))

            # save_reconstructed(cri, Dr, Xr, Sr, "./rec/%da.png" % reccnt)
            # saveXhist(Xr, "./hist/%da.png" % reccnt)

            Dr, Hr = update_dict(cri, Pcn, crop_op, Xr, Dr, Hr, Sf, param_rho)
            Df = to_frequency(cri, Dr)
            
            # print('error4: ', l2norm(Sr - reconstruct(cri, Dr, Xr)))

            # # project X to solution space
            # b = sl.inner(Df, Xf, axis=cri.axisM) - Sf
            # c = sl.inner(Df, np.conj(Df), axis=cri.axisM)
            # Xf = Xf - np.conj(Df) / c * b
            # Xr = sl.irfftn(Xf, s=cri.Nv, axes=cri.axisN)

            # print('error5: ', l2norm(Sr - reconstruct(cri, Dr, Xr)))
            
            if debug_dir is not None:
                saveimg(util.tiledict(Dr.squeeze()), debug_dir + "/dict/%d.png" % updcnt)

            updcnt += 1

        # saveXhist(Xr, "Xhist_sigma=" + str(sigma) + ".png")
    
    # print("l0 norm of final X: %d" % smoothedl0norm(Xr, 0.00001))
    plot.close()
    mplot.close()
    return Dr

def mysolve(
    cri, Dr0, Xr, Sr,
    final_sigma,
    maxitr = 40,
    param_mu = 1,
    debug_dir = None
):
    Dr = Dr0.copy()
    Xr = Xr.copy()
    Sr = Sr.copy()

    #離散フーリエ変換
    Df = sl.rfftn(Dr, s=cri.Nv, axes=cri.axisN)
    Sf = sl.rfftn(Sr, s=cri.Nv, axes=cri.axisN)
    Xf = sl.rfftn(Xr, s=cri.Nv, axes=cri.axisN)
    alpha = 1e0

    # sigma set
    first_sigma = Xr.max()*4
    c = (final_sigma / first_sigma) ** (1/(maxitr - 1))
    print("c = %.8f" % c)
    sigma_list = []
    sigma_list.append(first_sigma)
    for i in range(maxitr - 1):
        sigma_list.append(sigma_list[i]*c)
        print(sigma_list[-1])
    
    # 辞書をクロップする領域を添え字で指定
    crop_op = []
    for l in Dr.shape:
        crop_op.append(slice(0, l))
    crop_op = tuple(crop_op)
    print(crop_op)
    
    # 射影関数のインスタンス化
    Pcn = cnvrep.getPcn(Dr.shape, cri.Nv, cri.dimN, cri.dimCd, zm=False)

    updcnt = 0
    for sigma in sigma_list:
        print("sigma = %.8f" % sigma)
        # Xf_old = sl.rfftn(Xr, cri.Nv, cri.axisN)
        # print("l0norm: %f" % l0norm(Xr, sigma_list[-1]))
        # print('error1: ', l2norm(Sr - reconstruct(cri, Dr, Xr)))
        delta = Xr * np.exp(-(Xr*Xr) / (2*sigma*sigma))
        # print("l2(Xr): %.6f, l2(delta): %.6f" % (l2norm(Xr), l2norm(delta)))
        Xr = Xr - param_mu*delta# + np.random.randn(*Xr.shape)*sigma*1e-1
        Xf = sl.rfftn(Xr, cri.Nv, cri.axisN)
        # saveXhist(Xr, "./hist/%db.png" % reccnt)

        # print('error2: ', l2norm(Sr - reconstruct(cri, Dr, Xr)))

        # if debug_dir is not None:
        #     save_reconstructed(cri, Dr, Xr, Sr, debug_dir + '/%drecA.png' % updcnt)

        # DXf = sl.inner(Df, Xf, axis=cri.axisM)
        # gamma = (np.sum(np.conj(DXf) * Sf, axis=cri.axisN, keepdims=True) + np.sum(DXf * np.conj(Sf), axis=cri.axisN, keepdims=True)) / 2 / np.sum(np.conj(DXf) * DXf, axis=cri.axisN, keepdims=True)
        # print(gamma)
        # print(gamma.shape, ' * ', Xr.shape)
        # gamma = np.real(gamma)
        # Xr = Xr * gamma
        # Xf = to_frequency(cri, Xr)

        # if debug_dir is not None:
        #     save_reconstructed(cri, Dr, Xr, Sr, debug_dir + '/%drecB.png' % updcnt)
        
        # print('error3: ', l2norm(Sr - reconstruct(cri, Dr, Xr)))
        # print("max: ", np.max(Xr))

        B = sl.inner(Xf, Df, axis=cri.axisM) - Sf
        derivDf = sl.inner(np.conj(Xf), B, axis=cri.axisK)
        # derivDr = sl.irfftn(derivDf, s=cri.Nv, axes=cri.axisN)[crop_op]
        def func(alpha):
            Df_ = Df - alpha * derivDf
            Dr_ = sl.irfftn(Df_, s=cri.Nv, axes=cri.axisN)[crop_op]
            Df_ = sl.rfftn(Dr_, s=cri.Nv, axes=cri.axisN)
            Sf_ = sl.inner(Df_, Xf, axis=cri.axisM)
            return l2norm(Sr - sl.irfftn(Sf_, s=cri.Nv, axes=cri.axisN))
        choice = np.array([func(alpha / 2), func(alpha), func(alpha * 2)]).argmin()
        alpha *= [0.5, 1, 2][choice]
        print("alpha: ", alpha)
        Df = Df - alpha * derivDf
        #print(type(Df))
        #print(type(Df[0]))
        print(Df.shape)
        Dr = sl.irfftn(Df, s=cri.Nv, axes=cri.axisN)
        print(Dr.shape)
        Dr = Pcn(Dr)
        print(Dr.shape)
        Dr = Dr[crop_op]
        print(Dr.shape)
        # print(l2norm(Dr.T[0]))
        # Dr = normalize(Dr, axis=cri.axisN)
        print(l2norm(Dr.T[0]))
        Df = sl.rfftn(Dr, s=cri.Nv, axes=cri.axisN)

        if debug_dir is not None:
            saveimg(util.tiledict(Dr.squeeze()), debug_dir + "/dict/%d.png" % updcnt)
        # if debug_dir is not None:
        #     save_reconstructed(cri, Dr, Xr, Sr, debug_dir + '/%drecC.png' % updcnt)
        # dictcnt += 1

        # print('error4: ', l2norm(Sr - reconstruct(cri, Dr, Xr)))

        # save_reconstructed(cri, Dr, Xr, Sr, debug_dir + "/rec/%dc.png" % updcnt)

        # project X to solution space
        b = sl.inner(Df, Xf, axis=cri.axisM) - Sf
        c = sl.inner(Df, np.conj(Df), axis=cri.axisM)
        Xf = Xf - np.conj(Df) / c * b
        Xr = sl.irfftn(Xf, s=cri.Nv, axes=cri.axisN)
        
        # save_reconstructed(cri, Dr, Xr, Sr, debug_dir + "rec/%dd.png" % updcnt)
        # saveXhist(Xr, debug_dir + "hist/%db.png" % updcnt)
        updcnt += 1
    
    # print("l0 norm of final X: %d" % smoothedl0norm(Xr, 0.00001))
    plot.close()
    mplot.close()
    return Dr

def sporcosolve(cri, Dr0, Sr, maxitr=200):
    Dr0 = Dr0.copy()
    Sr = Sr.copy()
    lmbda = 0.2
    opt = cbpdndl.ConvBPDNDictLearn.Options({'Verbose': True, 'MaxMainIter': maxitr,
                            'CBPDN': {'rho': 50.0*lmbda + 0.5},
                            'CCMOD': {'rho': 10.0, 'ZeroMean': True}},
                            dmethod='cns')
    d = cbpdndl.ConvBPDNDictLearn(Dr0.squeeze(), Sr.squeeze(), lmbda, opt, dmethod='cns')
    Dr = d.solve()
    print("ConvBPDNDictLearn solve time: %.2fs" % d.timer.elapsed('solve'))
    return Dr

def testdict(cri, Dr0, Dr, Slr, Shr, dir,
    lambdas = [
        '1e-2',
        '2e-2',
        '5e-2',
        '1e-1',
        '2e-1',
        '5e-1',
    ]):
    S = (Slr + Shr).squeeze()
    ret = [[] for k in range(cri.K)]
    for s in lambdas:
        print('==========================================')
        print('test dictionary (lambda = %s)' % s)
        print('==========================================')
        lmbda = float(s)
        # ADMMによる係数最適化
        Xr = calcXr(cri, Dr, Shr, lmbda=lmbda)
        X = Xr.squeeze()
        S_ = (reconstruct(cri, Dr, Xr) + Slr).squeeze()
        for k in range(cri.K):
            d = {
                'lambda': lmbda,
                'psnr': sm.psnr(S.T[k].T, S_.T[k].T),
                'ssim': compare_ssim(S.T[k].T, S_.T[k].T),
                'l0norm': strict_l0norm(np.rollaxis(X, 2)[k]),
            }
            pprint.pprint(d)
            ret[k].append(d)

        save_result(Dr0.squeeze(), Dr.squeeze(), X, (Slr + Shr).squeeze(), S_, dir + '/result_lambda=%s.png' % s)
    return ret

def test_mysolve(cri_train, Dr0, Shr_train, cri_test, Slr_test, Shr_test, outdir='.'):
    itrs = [5, 10, 20, 30, 40, 50]
    # itrs = [50]
    data = [[] for k in range(cri_test.K)]
    times = []

    # dummy (for memory allocate on google colab)
    Xr = l2norm_minimize(cri_train, Dr0, Shr_train)
    Dr = mysolve(cri_train, Dr0, Xr, Shr_train, 1e-4, maxitr=2)

    for maxitr in itrs:
        start = time.time()
        # 辞書学習
        # 初期係数の決定
        Xr = l2norm_minimize(cri_train, Dr0, Shr_train) #信号とのL2誤差を最小にする係数を返す
        Dr = mysolve(cri_train, Dr0, Xr, Shr_train, 1e-4, maxitr=maxitr) #更新式
        end = time.time()
        times.append({'maxitr': maxitr, 'time': end - start})

        dir = outdir + '/mysolve_itr=%d' % maxitr
        if os.path.isdir(dir):
            shutil.rmtree(dir)
        os.makedirs(dir)
        
        # 学習した辞書の評価
        res = testdict(cri_test, Dr0, Dr, Slr_test, Shr_test, dir,
            lambdas = [
                '1e-3',
                '3e-3',
                '1e-2',
                '3e-2',
                '1e-1',
                '3e-1',])
        for k in range(cri_test.K):
            for d in res[k]:
                d['maxitr'] = maxitr
            data[k] += res[k]
    return data, times

def test_nakashizuka_solve(cri_train, Dr0, Shr_train, cri_test, Slr_test, Shr_test, outdir='.'):
    itrs = [20, 40, 60, 100, 200, 300]
    data = [[] for k in range(cri_test.K)]
    times = []
    
    # dummy (for memory allocate on google colab)
    Xr = l2norm_minimize(cri_train, Dr0, Shr_train)
    Dr = nakashizuka_solve(cri_train, Dr0, Xr, Shr_train, 1e-4, maxitr=2)

    for maxitr in itrs:
        start = time.time()
        Xr = l2norm_minimize(cri_train, Dr0, Shr_train) #L2誤差を最小にする係数を返す
        Dr = nakashizuka_solve(cri_train, Dr0, Xr, Shr_train, 1e-4, maxitr=maxitr)
        end = time.time()
        times.append({'maxitr': maxitr, 'time': end - start})

        dir = outdir + '/nakashizuka_solve_itr=%d' % maxitr
        if os.path.isdir(dir):
            shutil.rmtree(dir)
        os.makedirs(dir)
        res = testdict(cri_test, Dr0, Dr, Slr_test, Shr_test, dir)
        for k in range(cri_test.K):
            for d in res[k]:
                d['maxitr'] = maxitr
            data[k] += res[k]
    return data, times

def test_sporcosolve(cri_train, Dr0, Shr_train, cri_test, Slr_test, Shr_test, outdir='.'):
    itrs = [20, 40, 80, 120, 160, 200]
    # itrs = [200]
    data = [[] for k in range(cri_test.K)]
    times = []

    # dummy (for memory allocate on google colab)
    Dr = sporcosolve(cri_train, Dr0, Shr_train, maxitr=2)

    for maxitr in itrs:
        start = time.time()
        Dr = sporcosolve(cri_train, Dr0, Shr_train, maxitr=maxitr)
        end = time.time()
        times.append({'maxitr': maxitr, 'time': end - start})

        dir = outdir + '/sporcosolve_itr=%d' % maxitr
        if os.path.isdir(dir):
            shutil.rmtree(dir)
        os.makedirs(dir)
        res = testdict(cri_test, Dr0, Dr, Slr_test, Shr_test, dir)
        for k in range(cri_test.K):
            for d in res[k]:
                d['maxitr'] = maxitr
            data[k] += res[k]
    return data, times


## 学習用画像をダウンロード

In [3]:
S = getimages().astype(np.float32)
print("%d images, each is %dx%d." % (S.shape[2], S.shape[0], S.shape[1]))

5 images, each is 256x256.
0.85105884


## 学習用画像の前処理

In [4]:
Sl = np.zeros_like(S)
Smean = np.mean(S*2, axis=(0, 1))
Sh = S*2 - Smean

## 学習用辞書、係数の初期化

In [5]:
#TODO: explicitly zero-padding (for me, foolish)
D = np.random.randn(12, 12, 256)

#畳み込み型辞書学習に必要な変数の次元を参照
cri = cnvrep.CSC_ConvRepIndexing(D, S)
# shpS(N0,  N1, ...,  C,   K,   1) C is the number of channels in S
# shpD(N0,  N1, ...,  1,   1,   M) K is the number of signals in S
# shpX(N0,  N1, ...,  C,   K,   M) M is the number of filters in D
print("cri_test"+"\n"+str(cri))
Dr0 = np.asarray(D.reshape(cri.shpD), dtype=S.dtype)
Slr = np.asarray(Sl.reshape(cri.shpS), dtype=S.dtype)
Shr = np.asarray(Sh.reshape(cri.shpS), dtype=S.dtype)
#Shf = sl.rfftn(Shr, s=cri.Nv, axes=cri.axisN) # implicitly zero-padding

# crop_op = []
# for l in Dr0.shape:
#     crop_op.append(slice(0, l))
# crop_op = tuple(crop_op)
# print(Dr0.shape)

#辞書Dr0を正規化
#Dr0 = cnvrep.getPcn(Dr0.shape, cri.Nv, cri.dimN, cri.dimCd, zm=False)(cnvrep.zpad(Dr0, cri.Nv))
Dr0 = cnvrep.normalise(Dr0 , cri.dimN + cri.dimC)
print(Dr0.shape)
# Dr0 = Dr0[crop_op]
# print(Dr0.shape)

cri_test
{'C': 1,
 'Cd': 1,
 'K': 5,
 'M': 256,
 'N': 65536,
 'Nv': (256, 256),
 'axisC': 2,
 'axisK': 3,
 'axisM': 4,
 'axisN': (0, 1),
 'dimC': 0,
 'dimCd': 0,
 'dimK': 1,
 'dimN': 2,
 'shpD': (12, 12, 1, 1, 256),
 'shpS': (256, 256, 1, 5, 1),
 'shpX': (256, 256, 1, 5, 256)}
(12, 12, 1, 1, 256)


## テスト用画像

In [6]:
#テスト用画像をダウンロード

# scales=Trueのとき、画像を正規化
# zoom:画像のスケールを調整
exim1 = util.ExampleImages(scaled=True, zoom=0.5, pth='./')
S1_test = exim1.image('couple.tiff') #(256, 256)
#print(S1_test.shape)
exim2 = util.ExampleImages(scaled=True, zoom=1, pth='./')
S2_test = exim2.image('LENNA.bmp') #(256, 256)
#画像2枚を1つにまとめる
S_test = np.dstack((S1_test, S2_test)) #(x,y,枚数)

#畳み込み型辞書学習に必要な変数の次元を参照
cri_test = cnvrep.CSC_ConvRepIndexing(D, S_test)
# shpS(N0,  N1, ...,  C,   K,   1) C is the number of channels in S
# shpD(N0,  N1, ...,  1,   1,   M) K is the number of signals in S
# shpX(N0,  N1, ...,  C,   K,   M) M is the number of filters in D
print("cri_test"+"\n"+str(cri_test))

#Tikhonov正則化
Sl_test, Sh_test = util.tikhonov_filter(S_test, 5, 16) #ローパス、ハイパス
Slr_test = np.asarray(Sl_test.reshape(cri_test.shpS), dtype=S_test.dtype)
Shr_test = np.asarray(Sh_test.reshape(cri_test.shpS), dtype=S_test.dtype)

cri_test
{'C': 1,
 'Cd': 1,
 'K': 2,
 'M': 256,
 'N': 65536,
 'Nv': (256, 256),
 'axisC': 2,
 'axisK': 3,
 'axisM': 4,
 'axisN': (0, 1),
 'dimC': 0,
 'dimCd': 0,
 'dimK': 1,
 'dimN': 2,
 'shpD': (12, 12, 1, 1, 256),
 'shpS': (256, 256, 1, 2, 1),
 'shpX': (256, 256, 1, 2, 256)}


In [7]:
outdir = './no_low-pass'
#--------実験手法を指定-------
# prefix = 'nakashizuka_solve' # 中静先生の論文
prefix = 'mysolve' # 提案手法
# prefix = 'sporcosolve' # sporcoライブラリに実装された方法（B. Wohlbergによる）
#----------------------------------
if prefix == 'nakashizuka_solve':
    data, times = test_nakashizuka_solve(cri, Dr0, Shr, cri_test, Slr_test, Shr_test, outdir=outdir)
if prefix == 'mysolve':
    data, times = test_mysolve(cri, Dr0, Shr, cri_test, Slr_test, Shr_test, outdir=outdir)
if prefix == 'sporcosolve':
    data, times = test_sporcosolve(cri, Dr0, Shr, cri_test, Slr_test, Shr_test, outdir=outdir)

c = 0.00171715
0.0001
(slice(0, 12, None), slice(0, 12, None), slice(0, 1, None), slice(0, 1, None), slice(0, 256, None))
sigma = 0.05823601
alpha:  2.0
(256, 129, 1, 1, 256)
(256, 256, 1, 1, 256)
(256, 256, 1, 1, 256)
(12, 12, 1, 1, 256)
0.9999999
sigma = 0.00010000
alpha:  1.0
(256, 129, 1, 1, 256)
(256, 256, 1, 1, 256)
(256, 256, 1, 1, 256)
(12, 12, 1, 1, 256)
0.99999964
c = 0.20356453
0.011854785793677214
0.0024132138399370804
0.0004912447292274066
9.999999999999999e-05
(slice(0, 12, None), slice(0, 12, None), slice(0, 1, None), slice(0, 1, None), slice(0, 256, None))
sigma = 0.05823601
alpha:  2.0
(256, 129, 1, 1, 256)
(256, 256, 1, 1, 256)
(256, 256, 1, 1, 256)
(12, 12, 1, 1, 256)
0.9999999
sigma = 0.01185479
alpha:  4.0
(256, 129, 1, 1, 256)
(256, 256, 1, 1, 256)
(256, 256, 1, 1, 256)
(12, 12, 1, 1, 256)
1.0000001
sigma = 0.00241321
alpha:  2.0
(256, 129, 1, 1, 256)
(256, 256, 1, 1, 256)
(256, 256, 1, 1, 256)
(12, 12, 1, 1, 256)
1.0
sigma = 0.00049124
alpha:  1.0
(256, 129, 1, 1

test dictionary (lambda = 3e-3)
Itn   Fnc       DFid      Regℓ1     r         s         ρ       
----------------------------------------------------------------
   0  5.44e+01  8.05e+00  1.55e+04  6.05e-01  9.52e-01  1.15e+00
   1  4.24e+01  8.30e+00  1.14e+04  2.82e-01  7.30e-01  1.15e+00
   2  4.32e+01  4.28e+00  1.30e+04  2.36e-01  4.43e-01  7.11e-01
   3  4.14e+01  2.99e+00  1.28e+04  2.13e-01  3.04e-01  5.15e-01
   4  3.72e+01  2.31e+00  1.16e+04  1.89e-01  2.25e-01  4.29e-01
   5  3.35e+01  2.06e+00  1.05e+04  1.61e-01  1.74e-01  3.91e-01
   6  2.99e+01  1.94e+00  9.33e+03  1.28e-01  1.44e-01  3.91e-01
   7  2.75e+01  1.86e+00  8.53e+03  1.03e-01  1.25e-01  3.91e-01
   8  2.60e+01  1.80e+00  8.07e+03  9.14e-02  1.09e-01  3.52e-01
   9  2.51e+01  1.75e+00  7.77e+03  8.14e-02  9.50e-02  3.20e-01
  10  2.39e+01  1.71e+00  7.38e+03  6.78e-02  8.36e-02  3.20e-01
  11  2.31e+01  1.68e+00  7.15e+03  6.21e-02  7.48e-02  2.86e-01
  12  2.25e+01  1.66e+00  6.94e+03  5.70e-02  6.68e-02  2.

  38  4.66e+01  8.23e+00  3.84e+03  1.04e-02  1.14e-02  5.71e-01
  39  4.64e+01  8.22e+00  3.82e+03  1.00e-02  1.10e-02  5.71e-01
  40  4.62e+01  8.22e+00  3.80e+03  9.68e-03  1.06e-02  5.71e-01
  41  4.60e+01  8.22e+00  3.78e+03  9.36e-03  1.03e-02  5.71e-01
  42  4.58e+01  8.21e+00  3.76e+03  9.05e-03  9.94e-03  5.71e-01
  43  4.57e+01  8.21e+00  3.75e+03  8.75e-03  9.62e-03  5.71e-01
  44  4.55e+01  8.21e+00  3.73e+03  8.47e-03  9.32e-03  5.71e-01
  45  4.54e+01  8.21e+00  3.72e+03  8.20e-03  9.02e-03  5.71e-01
  46  4.52e+01  8.20e+00  3.70e+03  7.95e-03  8.74e-03  5.71e-01
  47  4.51e+01  8.20e+00  3.69e+03  7.71e-03  8.48e-03  5.71e-01
  48  4.50e+01  8.20e+00  3.68e+03  7.48e-03  8.24e-03  5.71e-01
  49  4.49e+01  8.20e+00  3.67e+03  7.26e-03  8.00e-03  5.71e-01
  50  4.47e+01  8.19e+00  3.65e+03  7.05e-03  7.77e-03  5.71e-01
  51  4.46e+01  8.19e+00  3.64e+03  6.85e-03  7.55e-03  5.71e-01
  52  4.45e+01  8.19e+00  3.63e+03  6.66e-03  7.32e-03  5.71e-01
  53  4.44e+01  8.19e+00 

  15  2.66e+02  9.89e+01  1.67e+03  4.91e-02  2.87e-02  7.30e+00
  16  2.60e+02  9.85e+01  1.61e+03  4.34e-02  2.56e-02  7.30e+00
  17  2.54e+02  9.82e+01  1.56e+03  3.87e-02  2.31e-02  7.30e+00
  18  2.49e+02  9.79e+01  1.51e+03  3.48e-02  2.11e-02  7.30e+00
  19  2.45e+02  9.77e+01  1.48e+03  3.32e-02  1.94e-02  6.64e+00
  20  2.42e+02  9.76e+01  1.44e+03  3.03e-02  1.80e-02  6.64e+00
  21  2.39e+02  9.74e+01  1.41e+03  2.78e-02  1.68e-02  6.64e+00
  22  2.36e+02  9.73e+01  1.39e+03  2.70e-02  1.57e-02  6.04e+00
  23  2.34e+02  9.71e+01  1.37e+03  2.51e-02  1.46e-02  6.04e+00
  24  2.32e+02  9.70e+01  1.35e+03  2.34e-02  1.37e-02  6.04e+00
  25  2.31e+02  9.69e+01  1.34e+03  2.20e-02  1.28e-02  6.04e+00
  26  2.29e+02  9.68e+01  1.32e+03  2.06e-02  1.21e-02  6.04e+00
  27  2.27e+02  9.66e+01  1.31e+03  1.95e-02  1.14e-02  6.04e+00
  28  2.26e+02  9.65e+01  1.29e+03  1.84e-02  1.09e-02  6.04e+00
  29  2.24e+02  9.65e+01  1.27e+03  1.75e-02  1.04e-02  6.04e+00
  30  2.22e+02  9.64e+01 

  56  3.23e+02  1.94e+02  4.30e+02  1.04e-02  2.20e-03  1.17e+01
  57  3.23e+02  1.94e+02  4.27e+02  1.02e-02  2.14e-03  1.17e+01
  58  3.22e+02  1.94e+02  4.25e+02  1.00e-02  2.08e-03  1.17e+01
  59  3.21e+02  1.94e+02  4.23e+02  9.87e-03  2.03e-03  1.17e+01
  60  3.21e+02  1.94e+02  4.21e+02  9.71e-03  1.98e-03  1.17e+01
  61  3.20e+02  1.94e+02  4.19e+02  9.55e-03  1.94e-03  1.17e+01
  62  3.20e+02  1.94e+02  4.17e+02  9.40e-03  1.90e-03  1.17e+01
  63  3.19e+02  1.94e+02  4.16e+02  9.27e-03  1.85e-03  1.17e+01
  64  3.19e+02  1.94e+02  4.14e+02  9.13e-03  1.81e-03  1.17e+01
  65  3.18e+02  1.94e+02  4.12e+02  9.00e-03  1.77e-03  1.17e+01
  66  3.18e+02  1.94e+02  4.10e+02  8.87e-03  1.74e-03  1.17e+01
  67  3.17e+02  1.94e+02  4.09e+02  8.74e-03  1.70e-03  1.17e+01
  68  3.17e+02  1.95e+02  4.07e+02  8.63e-03  1.66e-03  1.17e+01
  69  3.16e+02  1.95e+02  4.06e+02  8.52e-03  1.63e-03  1.17e+01
  70  3.16e+02  1.95e+02  4.04e+02  8.41e-03  1.60e-03  1.17e+01
  71  3.15e+02  1.95e+02 

  19  6.93e+00  2.31e-01  6.70e+03  3.86e-02  4.62e-02  8.56e-02
  20  6.79e+00  2.36e-01  6.56e+03  3.67e-02  4.28e-02  7.81e-02
  21  6.62e+00  2.35e-01  6.39e+03  3.33e-02  3.97e-02  7.81e-02
  22  6.49e+00  2.51e-01  6.24e+03  3.05e-02  3.73e-02  7.81e-02
  23  6.44e+00  2.88e-01  6.15e+03  2.99e-02  3.52e-02  7.06e-02
  24  6.34e+00  2.83e-01  6.06e+03  2.79e-02  3.30e-02  7.06e-02
  25  6.25e+00  2.76e-01  5.97e+03  2.61e-02  3.12e-02  7.06e-02
  26  6.16e+00  2.68e-01  5.89e+03  2.46e-02  2.97e-02  7.06e-02
  27  6.15e+00  3.13e-01  5.84e+03  2.44e-02  2.83e-02  6.41e-02
  28  6.10e+00  3.07e-01  5.80e+03  2.33e-02  2.67e-02  6.41e-02
  29  6.06e+00  3.08e-01  5.75e+03  2.22e-02  2.53e-02  6.41e-02
  30  6.03e+00  3.21e-01  5.71e+03  2.12e-02  2.42e-02  6.41e-02
  31  5.96e+00  3.00e-01  5.66e+03  2.02e-02  2.33e-02  6.41e-02
  32  5.92e+00  3.09e-01  5.61e+03  1.94e-02  2.25e-02  6.41e-02
  33  5.88e+00  3.12e-01  5.57e+03  1.86e-02  2.17e-02  6.41e-02
  34  5.84e+00  3.10e-01 

  37  1.51e+01  9.90e-01  4.71e+03  1.21e-02  1.41e-02  1.87e-01
  38  1.51e+01  9.91e-01  4.69e+03  1.17e-02  1.37e-02  1.87e-01
  39  1.50e+01  9.94e-01  4.67e+03  1.13e-02  1.32e-02  1.87e-01
  40  1.49e+01  9.93e-01  4.64e+03  1.09e-02  1.28e-02  1.87e-01
  41  1.49e+01  9.91e-01  4.62e+03  1.05e-02  1.24e-02  1.87e-01
  42  1.48e+01  9.91e-01  4.60e+03  1.02e-02  1.20e-02  1.87e-01
  43  1.47e+01  9.93e-01  4.58e+03  9.84e-03  1.16e-02  1.87e-01
  44  1.47e+01  9.91e-01  4.57e+03  9.53e-03  1.13e-02  1.87e-01
  45  1.46e+01  9.93e-01  4.55e+03  9.24e-03  1.09e-02  1.87e-01
  46  1.46e+01  9.90e-01  4.53e+03  8.96e-03  1.06e-02  1.87e-01
  47  1.45e+01  9.91e-01  4.52e+03  8.70e-03  1.03e-02  1.87e-01
  48  1.45e+01  9.91e-01  4.50e+03  8.45e-03  1.00e-02  1.87e-01
  49  1.45e+01  9.92e-01  4.49e+03  8.21e-03  9.74e-03  1.87e-01
  50  1.44e+01  9.94e-01  4.48e+03  8.33e-03  9.47e-03  1.70e-01
  51  1.44e+01  9.96e-01  4.47e+03  8.12e-03  9.17e-03  1.70e-01
  52  1.44e+01  9.94e-01 

   3  2.00e+02  3.70e+01  5.43e+03  1.71e-01  1.89e-01  2.79e+00
   4  1.89e+02  3.36e+01  5.18e+03  1.37e-01  1.44e-01  2.40e+00
   5  1.82e+02  3.13e+01  5.04e+03  1.16e-01  1.13e-01  2.12e+00
   6  1.73e+02  2.95e+01  4.79e+03  9.64e-02  9.13e-02  2.12e+00
   7  1.63e+02  2.84e+01  4.48e+03  8.17e-02  7.76e-02  2.12e+00
   8  1.53e+02  2.78e+01  4.19e+03  7.05e-02  6.67e-02  2.12e+00
   9  1.46e+02  2.75e+01  3.96e+03  6.18e-02  5.73e-02  2.12e+00
  10  1.41e+02  2.74e+01  3.78e+03  5.47e-02  4.99e-02  2.12e+00
  11  1.36e+02  2.73e+01  3.62e+03  4.88e-02  4.47e-02  2.12e+00
  12  1.32e+02  2.73e+01  3.48e+03  4.39e-02  4.06e-02  2.12e+00
  13  1.29e+02  2.72e+01  3.38e+03  3.98e-02  3.73e-02  2.12e+00
  14  1.26e+02  2.71e+01  3.29e+03  3.63e-02  3.41e-02  2.12e+00
  15  1.24e+02  2.71e+01  3.22e+03  3.34e-02  3.14e-02  2.12e+00
  16  1.22e+02  2.70e+01  3.16e+03  3.09e-02  2.90e-02  2.12e+00
  17  1.20e+02  2.69e+01  3.11e+03  2.88e-02  2.69e-02  2.12e+00
  18  1.18e+02  2.69e+01 

KeyboardInterrupt: 