# load packages  
需要安装 cvxpy: https://www.cvxpy.org/install/index.html

In [1]:
import math
import numpy as np
import matplotlib.pyplot as plt
import skimage


import cvxpy as cp
import numpy as np

# 一些函数

In [2]:
def PSNR(ground_truth, predict):
    """
    计算单个图片的PSNR
    """
    ground_truth = (ground_truth - ground_truth.min()) / (ground_truth.max() - ground_truth.min())
    predict = (predict - predict.min()) / (predict.max() - predict.min())
    mse = np.mean((ground_truth - predict) ** 2)
    if (mse == 0):  # MSE is zero means no noise is present in the signal .
        # Therefore PSNR have no importance.
        return 100
    max_pixel = 1.0
    psnr = 20 * np.log10(max_pixel / np.sqrt(mse))
    return np.round(psnr,2)


# radon metrix
M_PI=math.pi
def MAXX(x,y):
    return x if x > y else y

def radon_matrix(nt,nx,ny):
    """
    nt      angle num
    nx,ny = width, height
    """
    xOrigin = int(MAXX(0, math.floor(nx / 2)))
    yOrigin = int(MAXX(0, math.floor(ny / 2)))
    Dr = 1
    Dx = 1
    rsize=math.floor(math.sqrt(float(nx*nx+ny*ny)*Dx)/(2*Dr))+1    # from zhang xiaoqun
    # rsize = int(math.sqrt(2)*MAXX(nx,ny)/2)
    nr=2*rsize+1
    xTable = np.zeros((1,nx))
    yTable = np.zeros((1,ny))
    yTable[0,0] = (-yOrigin - 0.5) * Dx
    xTable[0,0] = (-xOrigin - 0.5) * Dx
    for i in range(1,ny):
        yTable[0,i] = yTable[0,i-1] + Dx
    for ii in range(1,nx):
        xTable[0,ii]=xTable[0,ii-1] + Dx
    Dtheta = M_PI / nt
    percent_sparse = 2/ float(nr)
    nzmax = int(math.ceil(float(nr * nt * nx * ny * percent_sparse)))
    # nr=len(rho)
    # nt=len(theta)
    A= np.zeros((nr * nt,nx * ny))
    weight = np.zeros((1,nzmax))
    irs = np.zeros((1,nzmax))
    jcs =np.zeros((1,A.shape[1]+1))
    k=0
    for m in range(ny):
        for n in range(nx):
            jcs[0,m*nx+n]=k
            for j in range(nt):
                angle=j*Dtheta
                cosine=math.cos(angle)
                sine=math.sin(angle)
                xCos=yTable[0,m]*cosine+rsize*Dr
                ySin=xTable[0,n]*sine
                rldx=(xCos+ySin)/Dr
                rLow=math.floor(rldx)
                pixelLow=1-rldx+rLow
                if 0 <= rLow < (nr - 1):
                    irs[0,k]=nr*j+rLow #irs为元素储存的行号
                    weight[0,k]=pixelLow
                    k=k+1
                    irs[0,k]=nr*j+rLow+1
                    weight[0,k]=1-pixelLow
                    k=k+1
        jcs[0,nx * ny] = k
    for col in range(nx*ny):
        for row in range(2*nt):
            A[int(irs[0,col*2*nt+row]),col]=weight[0,col*2*nt+row]
    return np.flipud(A)

In [3]:
ls mnist/

SyntaxError: invalid syntax (<ipython-input-3-4e079489e19d>, line 1)

# 生成测试数据-图像和sinogram

In [None]:
import cv2,os
height, width = 32, 32
lenth = height * width
angleNum = 30

i = 6
image_name = str(i) + '.jpg'
z_true = cv2.imread(os.path.join("mnist",image_name), cv2.IMREAD_GRAYSCALE) /255
# z_true = np.random.randn(128,1).reshape(1,-1)
# x = GANmoduel(z_true)
image = cv2.resize(z_true, (height, width))

height,width = image.shape
plt.imshow(image, cmap=plt.cm.Greys_r)
plt.colorbar()
plt.title(f"{height,width}")

In [None]:
#----------------------
# angleNum
# signal_noise_ratio
#-----------------
signal_noise_ratio = 0.1

A = radon_matrix(angleNum,width,height)
y_clear = A@image.flatten(order='F')
y_clear_2d = y_clear.reshape(-1,angleNum,order='F')

# add noise
noise_std = y_clear.max()*signal_noise_ratio
y_1d = y_clear+ noise_std*np.random.randn(len(y_clear))
y_1d = np.clip(y_1d,0,100)
y = y_1d.reshape(-1,angleNum,order='F')

fig,ax = plt.subplots(1,2)
ax[0].imshow(y_clear_2d)
ax[0].set_title('clear data')

ax[1].imshow(y)
ax[1].set_title('noisy data')

# 优化  

## 目标函数

$$
\begin{align}
x^* & = \arg \min \frac{1}{2*noise\_std^2} \quad ||Ax-b||_2^2 + \lambda*||x||_{TV}  \\
s.t. \quad & 0 \leq x \leq 1
\end{align}
$$

In [None]:
import time
# Problem data.
n,m = A.shape
A = A
b = y_1d
gamma = cp.Parameter(nonneg=True)

# Construct the problem.
x = cp.Variable(m)
f = cp.sum_squares(A@x - b)/(2*noise_std) + cp.multiply(gamma,cp.tv(x))


constraints = [0.0<=x,x<=1.0]         
objective = cp.Minimize(f)
p = cp.Problem(objective,constraints)

# Assign a value to gamma and find the optimal x.
def get_x(gamma_value):
    start = time.time()

    gamma.value = gamma_value
    result = p.solve(solver=cp.SCS)
    
    end = time.time()
    print(end - start)
    return x.value

## 并行优化

In [None]:
gammas = np.linspace(1e-3,10,12).round(3).tolist()
print(gammas)

# Parallel computation.
from multiprocessing import Pool
pool = Pool(processes = 6)
par_x = pool.map(get_x, gammas)  # 已检查： x.value与gammas的顺序对应

# 展示结果

In [None]:
rec_imgs = [img.reshape(height,width,order='F') for img in par_x]
psnrs= [PSNR(rec_img,image) for rec_img in rec_imgs]

print(np.array(gammas)[np.array(psnrs).argmax()],np.array(psnrs).max())
plt.figure(figsize=(20,5))
plt.plot(gammas,psnrs,'ro-')
plt.xlabel('lambda')
plt.ylabel('PSNR')
plt.grid()

In [None]:
n_col = 5
n_row = int(len(gammas)/n_col)
_, axs = plt.subplots(n_row, n_col, figsize=(n_col*3, n_row*4))
axs = axs.flatten()
for img, ax, lambd,psnr in zip(rec_imgs, axs, gammas,psnrs):
    ax.imshow(img,cmap=plt.cm.Greys_r)
    ax.set_title(f"lambda={lambd} \nPSNR={psnr}")