In [None]:
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import scipy.optimize as spopt
import scipy.fftpack as spfft
import scipy.ndimage as spimg
import cvxpy as cvx
import imageio
import os
from scipy import linalg
from scipy import sparse
import pywt
import scipy.optimize as spopt
import scipy.fftpack as spfft
import scipy.ndimage as spimg
import imageio
from skimage.metrics import peak_signal_noise_ratio
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="7"  # specify which GPU(s) to be used
import brainpy as bp
import brainpy.math as bm
bm.disable_gpu_memory_preallocation()
bm.set_platform('gpu')

In [None]:
def imshowgray(im, vmin=None, vmax=None):
    plt.imshow(im, cmap=plt.get_cmap('gray'), vmin=vmin, vmax=vmax)

    
def wavMask(dims, scale):
    sx, sy = dims
    res = np.ones(dims)
    NM = np.round(np.log2(dims))
    for n in range(int(np.min(NM)-scale+2)//2):
        res[:int(np.round(2**(NM[0]-n))), :int(np.round(2**(NM[1]-n)))] = \
            res[:int(np.round(2**(NM[0]-n))), :int(np.round(2**(NM[1]-n)))]/2
    return res


def imshowWAV(Wim, scale=1):
    plt.imshow(np.abs(Wim)*wavMask(Wim.shape, scale), cmap = plt.get_cmap('gray'))

    
def coeffs2img(LL, coeffs):
    LH, HL, HH = coeffs
    return np.vstack((np.hstack((LL, LH)), np.hstack((HL, HH))))


def unstack_coeffs(Wim):
        L1, L2  = np.hsplit(Wim, 2) 
        LL, HL = np.vsplit(L1, 2)
        LH, HH = np.vsplit(L2, 2)
        return LL, [LH, HL, HH]

    
def img2coeffs(Wim, levels=4):
    LL, c = unstack_coeffs(Wim)
    coeffs = [c]
    for i in range(levels-1):
        LL, c = unstack_coeffs(LL)
        coeffs.insert(0,c)
    coeffs.insert(0, LL)
    return coeffs
    
    
def dwt2(im):
    coeffs = pywt.wavedec2(im, wavelet='db4', mode='per', level=4)
    Wim, rest = coeffs[0], coeffs[1:]
    for levels in rest:
        Wim = coeffs2img(Wim, levels)
    return Wim


def idwt2(Wim):
    coeffs = img2coeffs(Wim, levels=4)
    return pywt.waverec2(coeffs, wavelet='db4', mode='per')

In [None]:
im  = imageio.imread('brain.bmp', mode='F')

Wim = dwt2(im)
f = 0.1
m = np.sort(abs(Wim.ravel()))[::-1]
ndx = int(len(m) * f)
thr = m[ndx]
Wim_thr = Wim * (abs(Wim) > thr)

im3 = idwt2(Wim_thr)
Wim_plus  = np.maximum(Wim_thr,0)
Wim_minus = np.minimum(Wim_thr,0)

X_plus_normalized  = Wim_plus/np.max(Wim_plus)
X_minus_normalized = Wim_minus/np.min(Wim_minus)

In [None]:
X_plus_reshape  = X_plus_normalized.ravel()[:, np.newaxis]
# X_plus_reshape  = abs(X_minus_normalized.ravel()[:, np.newaxis])
rng = np.random.RandomState(3)
# rng = np.random.RandomState(30)
matrix_size = X_plus_reshape.shape[0]
A = rng.randn(int(matrix_size* 0.2), matrix_size)  # random design
A_norm = np.linalg.norm(A,ord=2,axis = 0,keepdims =True)
phi = A /A_norm

proj   =  phi @ X_plus_reshape
factor = 1.

b = bm.array((phi.T @ proj.flatten())/ factor)
w = bm.array(phi.T @ phi)
w[bm.diag_indices_from(w)] = 0

In [None]:
import Neuron_models as nm
# net_double_rk2 = nm.SLCA_rk2(w.shape[0], w, b, 0.0001)
net_double_rk2 = nm.SLCA_rk2(w.shape[0], w, b, 0.001)
total_period = 200
runner = bp.DSRunner(net_double_rk2,monitors=['N.spike'], dt = 0.1)
runner.run(total_period)

In [None]:
size_num = runner.mon['N.spike'].shape[0]
spike_calculate = runner.mon['N.spike']
lca = np.sum(spike_calculate, axis=0)/total_period
X_plus_recov = (np.max(Wim_plus)*lca.reshape(Wim_plus.shape[0],Wim_plus.shape[1]))
# X_plus_recov = (np.max(Wim_minus)*lca.reshape(Wim_plus.shape[0],Wim_plus.shape[1]))

In [None]:
Wim_recov = X_plus_recov + Wim_minus
im_recov = idwt2(Wim_recov)

In [None]:
np.save('W_plus.npy',  X_plus_recov)
# np.save('W_minus.npy', X_plus_recov)

In [None]:
# Wim_plus  = np.load('W_plus.npy')
# Wim_minus = np.load('W_minus.npy')
# Wim_recov = Wim_plus - Wim_minus
# im_recov = idwt2(Wim_recov)

In [None]:
plt.rcParams['figure.figsize'] = (16, 16)
plt.subplot(1,3,1)
imshowgray(np.abs(im))
plt.title('Original')

plt.subplot(1,3,2)
imshowWAV(Wim)
plt.title('DWT')

plt.subplot(1,3,3)
imshowgray(np.abs(im_recov))
plt.title('Reconstruction')

print('PSNR:', peak_signal_noise_ratio(im, im_recov, data_range=im.max() - im.min()))

In [None]:
plt.rcParams['figure.figsize'] = (16, 16)
plt.subplot(1,3,1)
imshowgray(np.abs(im))
plt.title('Original')

plt.subplot(1,3,2)
imshowWAV(Wim)
plt.title('DWT')

plt.subplot(1,3,3)
imshowgray(np.abs(im3))
plt.title('Reconstruction')

print('PSNR:', peak_signal_noise_ratio(im, im3, data_range=im.max() - im.min()))