### GPU setting

In [None]:
print(torch.cuda.is_available())
if torch.cuda.is_available():
    torch.cuda.current_device()
    torch.cuda.device(0)
    torch.cuda.device_count()
    torch.cuda.get_device_name(0)

    # setting device on GPU if available, else CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
!nvidia-smi

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from torch.autograd import Variable
from torch.distributions import normal
import torch.nn.init as init
import torch.nn.functional as F
import torchvision
from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, models
import time
import cv2
from skimage import io
import math
from scipy.linalg import orth


# Basic Linear Problem

## Condition number

In [None]:
cond = []
for e in [1,0.1,0.01,0.001,0.0001]:
    A_epsilon = np.array([[1,1],[1,1+e]])
    condition = np.linalg.cond(A_epsilon)
    cond.append(condition)
    print('epsilon: ',e,' conditon number: ',condition)

## Tikhonov Regularization

### Optimal parameter

In [None]:
# training
np.random.seed(1)
N = 10000 # number of data size
LAMBDA = []
X = np.random.randn(2,N)
for k in range(5):
    MSE = []
    epsilon = 10**(-k)
    A_epsilon = np.array([[1,1],[1,1+epsilon]]) # (2,2)
    eta = np.random.randn(2,N) # (2,N)
    Y = np.dot(A_epsilon,X) + 0.01*eta # (2,N)
    
    for a in range(1,100001):
        scale = 1e6
        a = a/scale
        # Tikhonov solution
        X_tik = np.linalg.inv(np.dot(A_epsilon.T,A_epsilon) + a * np.eye(2)).dot(A_epsilon.T).dot(Y)
        
        # error
        error = np.linalg.norm(X_tik-X)**2/X.shape[1]
        MSE.append(error)
        
    Lambda = (MSE.index(min(MSE))+1)/scale    
    LAMBDA.append(Lambda)
    plt.plot((np.array(range(len(MSE)))+1)/scale,MSE,'-',label='epsilon = '+ str(epsilon))
    plt.legend()
    plt.ylabel(r'$MSE=\frac{1}{n}\Vert X_{tik}-X \Vert$')
    plt.xlabel(r'$\lambda$')
    plt.title('Optimal')
    #plt.ylim(0)
    plt.show()
    print('epsilon = ', str(epsilon),' MSE:',min(MSE),' lambda:',Lambda ,'\n')

In [2]:
# test
np.random.seed(2)
N = 2000 # number of data size

X = np.random.randn(2,N)
MSE = []
for k,a in zip(range(5),LAMBDA):
    epsilon = 10**(-k)
    A_epsilon = np.array([[1,1],[1,1+epsilon]]) 
    eta = np.random.randn(2,N) 
    Y = np.dot(A_epsilon,X) + 0.01 * eta 

    # Tikhonov solution
    X_tik = np.linalg.inv(np.dot(A_epsilon.T,A_epsilon) + a * np.eye(2)).dot(A_epsilon.T).dot(Y)
        
    # error
    error = np.linalg.norm(X_tik-X)**2/X.shape[1]
    MSE.append(error)
print(MSE)

[0.0006812162305501012, 0.040074347857498874, 0.8146861803659274, 1.03274110047104, 1.0341222194233362]


### L-curve

In [None]:
# curvature
def curvature(x,y):
    x_t = np.gradient(x)
    y_t = np.gradient(y)
    xx_t = np.gradient(x_t)
    yy_t = np.gradient(y_t)
    curvature_val = np.abs(xx_t * y_t - x_t * yy_t) / (x_t * x_t + y_t * y_t)**1.5
    curvature_val = curvature_val.tolist()
    index = curvature_val.index(max(curvature_val))
    return index

In [None]:
# training
np.random.seed(1)
N = 10000 # number of data size
LAMBDA = []
X = np.random.randn(2,N)
for k in range(5):
    epsilon = 10**(-k)
    A_epsilon = np.array([[1,1],[1,1+epsilon]]) # (2,2)
    eta = np.random.randn(2,N) # (2,N)
    Y = np.dot(A_epsilon,X) + 0.01*eta # (2,N)
    solution = []
    residual = []
    
    for a in range(1,100001):
        scale = 1e6
        a = a/scale
        # Tikhonov solution
        X_tik = np.dot(np.dot(np.linalg.inv(np.dot(A_epsilon.T,A_epsilon) + a * np.eye(2)),A_epsilon.T),Y)
        
        # norm
        X_tik_norm = np.linalg.norm(X_tik)
        residual_norm = np.linalg.norm(np.dot(A_epsilon,X_tik) - Y)
        solution.append(X_tik_norm)
        residual.append(residual_norm)
        
    # L-curve figure
    
    plt.plot(residual,solution,'.-',label='$\epsilon$ = '+ str(epsilon))
    plt.legend()
    plt.ylabel(r'$\log\Vert X_{Tik} \Vert$')
    plt.xlabel(r'$\log\Vert AX_{Tik}-Y^{\eta} \Vert$')
    plt.title('Discrete L-curve for Tikhonov regularization')
    plt.yscale('log')
    plt.xscale('log')
    plt.show()
    
    Lambda = (curvature(residual,solution)+1)/scale
    LAMBDA.append(Lambda)
    X_tik_Lambda = np.dot(np.dot(np.linalg.inv(np.dot(A_epsilon.T,A_epsilon) + Lambda * np.eye(2)),A_epsilon.T),Y)
    MSE = np.linalg.norm(X_tik_Lambda-X)**2 / Y.shape[1]
    
    print('epsilon: ',epsilon,'lambda: ',Lambda, ' MSE: ',MSE)

In [None]:
# test
np.random.seed(2)
N = 2000 # number of data size

X = np.random.randn(2,N)
MSE = []
for k,a in zip(range(5),LAMBDA):
    epsilon = 10**(-k)
    A_epsilon = np.array([[1,1],[1,1+epsilon]]) # (2,2)
    eta = np.random.randn(2,N) # (2,N)
    Y = np.dot(A_epsilon,X) + 0.01*eta # (2,N)

    # Tikhonov solution
    X_tik = np.linalg.inv(np.dot(A_epsilon.T,A_epsilon) + a * np.eye(2)).dot(A_epsilon.T).dot(Y)
        
    # error
    error = np.linalg.norm(X_tik-X)**2/X.shape[1]
    MSE.append(error)
print(MSE)

### GCV

In [None]:
# training
np.random.seed(1)
N = 10000 # number of data size
LAMBDA = []
X = np.random.randn(2,N)
for k in range(5):
    GCV = []
    epsilon = 10**(-k)
    A_epsilon = np.array([[1,1],[1,1+epsilon]]) # (2,2)
    eta = np.random.randn(2,N) # (2,N)
    Y = np.dot(A_epsilon,X) + 0.01*eta # (2,N)
    
    for a in range(1,100001):
        scale = 1e6
        a = a/scale
        # Tikhonov solution
        X_tik = np.dot(np.dot(np.linalg.inv(np.dot(A_epsilon.T,A_epsilon) + a * np.eye(2)),A_epsilon.T),Y)
        
        # GCV
        upper = np.linalg.norm(np.dot(A_epsilon,X_tik)-Y)**2
        lower = np.trace(np.eye(2) - np.dot(np.dot(A_epsilon,np.linalg.inv(np.dot(A_epsilon.T,A_epsilon)
                                                                           + a * np.eye(2))),A_epsilon.T))**2
        gcv = upper/lower
        GCV.append(gcv)

    plt.plot(np.array(range(len(GCV)))/scale,GCV,'-',label='$\epsilon$ = '+ str(epsilon))
    plt.legend()
    plt.ylabel(r'$GCV(\lambda)$')
    plt.xlabel(r'$\lambda$')
    plt.title('GCV')
    plt.ylim(0)
    
    Lambda = (GCV.index(min(GCV))+1)/scale
    LAMBDA.append(Lambda)
    X_tik_Lambda = np.dot(np.dot(np.linalg.inv(np.dot(A_epsilon.T,A_epsilon) + Lambda * np.eye(2)),A_epsilon.T),Y)
    MSE = np.linalg.norm(X_tik_Lambda-X)**2 / Y.shape[1]
    
    print('epsilon = ', str(epsilon),' GCV:',min(GCV),' lambda:',Lambda ,' MSE: ',MSE,'\n')

In [None]:
# test
np.random.seed(2)
N = 2000 # number of data size
X = np.random.randn(2,N)
MSE = []
for k,a in zip(range(5),LAMBDA):
    epsilon = 10**(-k)
    A_epsilon = np.array([[1,1],[1,1+epsilon]]) # (2,2)
    eta = np.random.randn(2,N) # (2,N)
    Y = np.dot(A_epsilon,X) + 0.01*eta # (2,N)

    # Tikhonov solution
    X_tik = np.linalg.inv(np.dot(A_epsilon.T,A_epsilon) + a * np.eye(2)).dot(A_epsilon.T).dot(Y)
        
    # error
    error = np.linalg.norm(X_tik-X)**2/X.shape[1]
    MSE.append(error)
print(MSE)

## Neural Network

### Forward problem

In [None]:
# torch version (without weight restriction)

EPOCH = 500
LEARNING_RATE = 0.05
N = 10000
MSE = []

# data generation
for k in range(5):
    torch.manual_seed(1)
    X = torch.randn(2,N)
    epsilon = 10**(-k)
    A_epsilon = torch.Tensor([[1,1],[1,1+epsilon]]) 
    eta = torch.randn(2,N)
    Y = torch.mm(A_epsilon,X) + 0.01*eta

    net = torch.nn.Sequential(
        torch.nn.Linear(2,4),
        torch.nn.ReLU(),
        torch.nn.Linear(4,2)
    )

    criterion = torch.nn.MSELoss()
    optimzer = torch.optim.Adam(net.parameters(),lr = LEARNING_RATE)
    #optimzer = torch.optim.SGD(net.parameters(),lr = LEARNING_RATE)

# training
    for epoch in range(EPOCH):
        Y_pred = net(X.T)
        loss = criterion(Y_pred,Y.T)
        optimzer.zero_grad()
        loss.backward()
        optimzer.step()
        MSE.append(loss.data.numpy())
        # test
        if epoch == EPOCH-1:
            torch.manual_seed(2)
            n = 2000
            x = torch.randn(2,n)
            e = torch.randn(2,n)
            y = torch.mm(A_epsilon,x) + 0.01 * e
            y_pred = net(x.T)
            mse = criterion(y_pred,y.T)
            print('epsilon: ',epsilon, 'MSE: ',mse)        
    plt.plot(range(EPOCH),MSE[EPOCH*k:EPOCH*(k+1)],'-',label='$\epsilon$ = '+ str(epsilon))
    #plt.ylim(0,3)
    #plt.xlim(0)
    plt.legend()
    plt.ylabel(r'$MSE=\frac{1}{n}\Vert W_{x}X-Y\Vert ^{2} $')
    plt.xlabel('Number of Epoch')
    plt.title('Forward Problem (Basic Matrix)')
plt.savefig('nn(DPBMnoOutliers).pdf', format='pdf')

In [None]:
# numpy version (with weight restriction)
np.random.seed(1)
EPOCH = 1000
LEARNING_RATE = 0.001
N , D_in , H , D_out = 10000 , 2, 4 , 2 
# N is the number of the samples，D_in is the dim of input，
# H is the numeber of the notes in hidden layer,D_out is the dim of the output

X = np.random.randn(D_in,N) # input(D_in,N)
MSE = []
for k in range(5):
    epsilon = 10**(-k)
    A_epsilon = np.array([[1,1],[1,1+epsilon]]) # (2,2)
    eta = np.random.randn(D_in,N) # (D_in,N)
    Y = np.dot(A_epsilon,X) + 0.01*eta # (D_out,N)
    
    # Randomly initializes the parameter matrix W
    W = np.random.rand(2,2)
    w0 = np.array([[1,0],[-1,0],[0,1],[0,-1]]) # (H,D_out)
    w1 = np.dot(w0,W) # (H,D_in)
    w2 = np.array([[1,0],[-1,0],[0,1],[0,-1]]) # (H,D_out)

    for step in range(EPOCH): 
        #train_n = int(np.random.rand(1) * N)
        train_n = int(0.7 * N) # the number of samples for training
        idx = random.sample(list(range(N)),train_n) 
        X_training = X[:,idx] #(D_in,train_n)
        Y_training = Y[:,idx] #(D_in,train_n)
        
        # forward propagation
        z1 = np.dot(w1,X_training) #(H,train_n)
        z2 = z1 #(H,train_n)
        z2[z2<0] = 0
        Y_pred_training = np.dot(w2.T,z2) #(D_out,train_n)

        # the loss function
        loss = np.linalg.norm((Y_pred_training - Y_training))**2/Y_pred_training.shape[1]
        MSE.append(loss)

        # back propagation
        dl_dy_pred = (2/Y_training.shape[1]) * (Y_pred_training - Y_training) # (D_out,train_n)
        dy_pred_dz2 = w2
        dz2_dW = np.dot(w0,X_training)
        dz2_dW[z1<0] = 0

        dl_dW = np.dot(np.dot(dl_dy_pred,dz2_dW.T),dy_pred_dz2)
        
        # update the weight matrix
        W -= LEARNING_RATE * dl_dW
        w1 = np.dot(w0,W) 
        if step % 500 == 0 or step == EPOCH-1:
            print('epsilon: 1e{}\nepoch: {}\n{}'.format(-k,step, W))
        # test
        if step == EPOCH-1:
            n = 2000
            np.random.seed(2)
            x = np.random.randn(D_in,n)
            e = np.random.randn(D_in,n)
            y = np.dot(A_epsilon,x) + 0.01*e
            
            # forward propagation
            z1 = np.dot(w1,x) #(H,train_n)
            z2 = z1 #(H,train_n)
            z2[z2<0] = 0
            y_pred = np.dot(w2.T,z2) #(D_out,train_n)

            # the loss function
            loss = np.linalg.norm((y_pred - y))**2/y_pred.shape[1]
            print('====='*5,'MSE: ')
            print(loss)
            
    # plot the figure
    plt.plot(range(EPOCH),MSE[EPOCH*k:EPOCH*(k+1)],'-',label='$\epsilon$ = '+ str(epsilon))
    #plt.ylim(0)
    #plt.xlim(0)
    plt.legend()
    plt.ylabel(r'$MSE=\frac{1}{n}\Vert W_{x}X-Y\Vert ^{2} $')
    plt.xlabel('Number of Epoch')
    plt.title('Forward Problem')
plt.savefig('nn(DP).pdf', format='pdf')

### Inverse problem

In [None]:
# torch version (without weight restriction)

EPOCH = 2000
LEARNING_RATE = 0.05
N = 10000
MSE = []

# data generation
for k in range(5):
    torch.manual_seed(1)
    X = torch.randn(2,N)
    epsilon = 10**(-k)
    A_epsilon = torch.Tensor([[1,1],[1,1+epsilon]]) 
    eta = torch.randn(2,N)
    Y = torch.mm(A_epsilon,X) + 0.01*eta

    net = torch.nn.Sequential(
        torch.nn.Linear(2,4),
        torch.nn.ReLU(),
        torch.nn.Linear(4,2)
    )

    criterion = torch.nn.MSELoss()
    optimzer = torch.optim.Adam(net.parameters(),lr = LEARNING_RATE)
    #optimzer = torch.optim.SGD(net.parameters(),lr = LEARNING_RATE)

# training
    for epoch in range(EPOCH):
        X_pred = net(Y.T)
        loss = criterion(X_pred,X.T)
        optimzer.zero_grad()
        loss.backward()
        optimzer.step()
        MSE.append(loss.data.numpy())
        # test
        if epoch == EPOCH-1:
            torch.manual_seed(2)
            n = 2000
            x = torch.randn(2,n)
            e = torch.randn(2,n)
            y = torch.mm(A_epsilon,x) + 0.01 * e
            x_pred = net(y.T)
            mse = criterion(x_pred,x.T)
            print('epsilon: ',epsilon, 'MSE: ',mse)
    plt.plot(range(EPOCH),MSE[EPOCH*k:EPOCH*(k+1)],'-',label='$\epsilon$ = '+ str(epsilon))
    #plt.ylim(0,3)
    #plt.xlim(0)
    plt.legend()
    plt.ylabel(r'$MSE=\frac{1}{n}\Vert W_{y}Y-X\Vert ^{2} $')
    plt.xlabel('Number of Epoch')
    plt.title('Inverse Problem (Basic Matrix)')
plt.savefig('nn(IPBMnoOutliers).pdf', format='pdf')

In [None]:
# numpy version (with weight restriction)
torch.manual_seed(1)
EPOCH = 3000
LEARNING_RATE = 0.1
N , D_in , H , D_out = 10000 , 2, 4 , 2 
# N is the number of the samples，D_in is the dim of input，
# H is the numeber of the notes in hidden layer,D_out is the dim of the output

X = torch.randn(N,D_in) # input(N,D_in)
MSE = []
for k in range(5):
    epsilon = 10**(-k)
    A_epsilon = torch.Tensor([[1,1],[1,1+epsilon]]) # (2,2)
    eta = torch.randn(N,D_in) # (N,D_in)
    Y = X.mm(A_epsilon) + 0.01*eta # (N, D_out)
    
    # Randomly initializes the parameter matrix W
    W = torch.rand(2,2)
    w0 = torch.Tensor([[1,0],[-1,0],[0,1],[0,-1]]) # (H,D_out)
    w1 = w0.mm(W) # (H,D_in)
    w2 = torch.Tensor([[1,0],[-1,0],[0,1],[0,-1]]) # (H,D_out)

    for step in range(EPOCH): 
        train_n = int(torch.rand(1) * N) # the number of samples for training
        idx = random.sample(list(range(N)),train_n) 
        X_training = Y[idx] #(train_n,D_in)
        Y_training = X[idx] #(train_n,D_in)
        
        # forward propagation
        z1 = w1.mm(X_training.t()) #(H,train_n)
        z2 = z1.clamp(min=0) #(H,train_n)
        Y_pred_training = (w2.t().mm(z2)).t() #(train_n,D_out)

        # the loss function
        loss = torch.norm((Y_pred_training - Y_training)).pow(2)/Y_pred_training.shape[0]
        MSE.append(loss)

        # back propagation
        dl_dy_pred = (2/Y_training.shape[0]) * (Y_pred_training - Y_training) # (train_n,D_out)
        dy_pred_dz2 = w2
        dz2_dW = w0.mm(X_training.t())
        dz2_dW[z1.le(0)] = 0

        dl_dW = dy_pred_dz2.t().mm(dz2_dW.mm(dl_dy_pred))
        
        # update the weight matrix
        W -= LEARNING_RATE * dl_dW.t()
        w1 = w0.mm(W) # (H,D_in)
        if step % 500 == 0 or step == EPOCH-1:
            print('epsilon: 1e{}\nepoch: {}\n{}'.format(-k,step, W))
        # test
        if step == EPOCH-1:
            n = 2000
            np.random.seed(2)
            x = np.random.randn(D_in,n)
            e = np.random.randn(D_in,n)
            y = np.dot(A_epsilon,x) + 0.01 * e
            
            # forward propagation
            z1 = np.dot(w1,y) #(H,train_n)
            z2 = z1 #(H,train_n)
            z2[z2<0] = 0
            x_pred = np.dot(w2.T,z2) #(D_out,train_n)

            # the loss function
            loss = np.linalg.norm((x_pred - x))**2/x_pred.shape[1]
            print('====='*5,'MSE: ')
            print(loss)
        
    # plot the figure
    plt.plot(range(EPOCH),MSE[EPOCH*k:EPOCH*(k+1)],'-',label='$\epsilon$ = '+ str(epsilon))
    plt.ylim(0,3)
    plt.xlim(0)
    plt.legend()
    plt.ylabel(r'$MSE=\frac{1}{n}\Vert W_{y}Y-X\Vert ^{2} $')
    plt.xlabel('Number of Epoch')
    plt.title('Inverse Problem')
plt.savefig('nn(IP).svg', format='svg')

# Hilbert Linear Problem

## Condition number

In [None]:
# Hilbert matrices and condition numbers
Hilbert = {}
cond = {}
for dim in [2,4,9,16,25,32]:
    H = 1. / (np.arange(1,dim+1) + np.arange(0,dim)[:, np.newaxis])
    Hilbert['H'+str(dim)] = H
    cond['H'+str(dim)] = np.linalg.cond(H)
cond

## Tikhonov Regularization

### Optimal parameter

In [None]:
# training
np.random.seed(1)
N = 10000 # number of data size
LAMBDA = []
X = np.random.randn(2,N)
for key, value in Hilbert.items():
    MSE = []
    # create data sample
    dim = value.shape[0]
    X = np.random.randn(dim,N)
    eta = np.random.randn(dim,N)
    Y = np.dot(value,X) + 0.01 * eta
    
    for a in range(1,100001):
        scale = 1e6
        a = a/scale
        # Tikhonov solution
        X_tik = np.linalg.inv(np.dot(value.T,value) + a * np.eye(dim)).dot(value.T).dot(Y)
        
        # error
        error = np.linalg.norm(X_tik-X)**2/X.shape[1]
        MSE.append(error)
        
    Lambda = (MSE.index(min(MSE))+1)/scale    
    LAMBDA.append(Lambda)
    plt.plot((np.array(range(len(MSE)))+1)/scale,MSE,'-',label='dim = '+ str(dim))
    plt.legend()
    plt.ylabel(r'$MSE=\frac{1}{n}\Vert X_{tik}-X \Vert$')
    plt.xlabel(r'$\lambda$')
    plt.title('Optimal')
    #plt.ylim(0)
    plt.show()
    print('dim = ', str(dim),' MSE:',min(MSE),' lambda:',Lambda ,'\n')

In [None]:
# test
N = 2000 # number of data size
MSE = []
for dim,a in zip([2,4,9,16,25,32],LAMBDA):
    H = 1. / (np.arange(1,dim+1) + np.arange(0,dim)[:, np.newaxis])
    # create data sample
    dim = H.shape[0]
    torch.manual_seed(2)
    X = np.random.randn(dim,N)
    eta = np.random.randn(dim,N)
    Y = np.dot(H,X) + 0.01 * eta

    # Tikhonov solution
    X_tik = np.linalg.inv(np.dot(H.T,H) + a * np.eye(dim)).dot(H.T).dot(Y)
        
    # error
    error = np.linalg.norm(X_tik-X)**2/X.shape[1]
    MSE.append(error)
print(MSE)

### L-curve

In [None]:
# training
np.random.seed(1)
N = 10000 # number of data size
LAMBDA = []
for key, value in Hilbert.items():
    solution = []
    residual = []
    # create data sample
    dim = value.shape[0]
    X = np.random.randn(dim,N)
    eta = np.random.randn(dim,N)
    Y = np.dot(value,X) + 0.01 * eta
    
    for a in range(1,1001):
        scale = 1e2
        a = a/scale
        # Tikhonov solution
        X_tik = np.dot(np.dot(np.linalg.inv(np.dot(value.T,value) + a * np.eye(dim)),value.T),Y)
        
        # norm
        X_tik_norm = np.linalg.norm(X_tik)
        residual_norm = np.linalg.norm(np.dot(value,X_tik) - Y)
        solution.append(X_tik_norm)
        residual.append(residual_norm)
        
    # L-curve figure
    plt.plot(residual,solution,'.',label='dim = '+ str(dim))
    plt.legend()
    plt.ylabel(r'$\log\Vert X_{Tik} \Vert$')
    plt.xlabel(r'$\log\Vert AX_{Tik}-Y^{\eta} \Vert$')
    plt.title('Discrete L-curve for Tikhonov regularization')
    plt.yscale('log')
    plt.xscale('log')
    plt.show()

    Lambda = (curvature(residual,solution)+1)/scale
    LAMBDA.append(Lambda)
    X_tik_Lambda = np.dot(np.dot(np.linalg.inv(np.dot(value.T,value) + Lambda * np.eye(dim)),value.T),Y)
    MSE = np.linalg.norm(X_tik_Lambda-X)**2 / Y.shape[1]
    
    print('dim: ',dim,'lambda: ',Lambda, ' MSE: ',MSE)

In [None]:
# test
N = 2000 # number of data size
MSE = []
for dim,a in zip([2,4,9,16,25,32],LAMBDA):
    H = 1. / (np.arange(1,dim+1) + np.arange(0,dim)[:, np.newaxis])
    # create data sample
    dim = H.shape[0]
    torch.manual_seed(2)
    X = np.random.randn(dim,N)
    eta = np.random.randn(dim,N)
    Y = np.dot(H,X) + 0.01 * eta

    # Tikhonov solution
    X_tik = np.linalg.inv(np.dot(H.T,H) + a * np.eye(dim)).dot(H.T).dot(Y)
        
    # error
    error = np.linalg.norm(X_tik-X)**2/X.shape[1]
    MSE.append(error)
print(MSE)

### GCV

In [None]:
# training
np.random.seed(1)
N = 10000 # number of data size
LAMBDA = []
for key, value in Hilbert.items():
    GCV = []
    # create data sample
    dim = value.shape[0]
    X = np.random.randn(dim,N)
    eta = np.random.randn(dim,N)
    Y = np.dot(value,X) + 0.01 * eta
    
    for a in range(1000):
        scale = 1e5
        a = a / scale
        # Tikhonov solution
        X_tik = np.dot(np.dot(np.linalg.inv(np.dot(value.T,value) + a * np.eye(dim)),value.T),Y)
        
        # GCV
        upper = np.linalg.norm(np.dot(value,X_tik)-Y)**2
        lower = np.trace(np.eye(dim) - np.dot(np.dot(value,np.linalg.inv(np.dot(value.T,value)+ a * np.eye(dim))),value.T))**2
        gcv = upper/lower
        GCV.append(gcv)
    
    plt.plot(np.array(range(len(GCV)))/scale,GCV,'-',label='dim = '+ str(dim))
    plt.legend()
    plt.ylabel(r'$GCV(\lambda)$')
    plt.xlabel(r'$\lambda$')
    plt.title('GCV')
    plt.ylim(0,10)
    
    Lambda = (GCV.index(min(GCV))+1)/scale
    LAMBDA.append(Lambda)
    X_tik_Lambda = np.dot(np.dot(np.linalg.inv(np.dot(value.T,value) + Lambda * np.eye(dim)),value.T),Y)
    MSE = np.linalg.norm(X_tik_Lambda-X)**2 / Y.shape[1]
    print('dim:',dim,' GCV:',min(GCV),' lambda:',GCV.index(min(GCV))/scale ,' MSE: ',MSE,'\n')

In [None]:
# test
N = 2000 # number of data size
MSE = []
for dim,a in zip([2,4,9,16,25,32],LAMBDA):
    H = 1. / (np.arange(1,dim+1) + np.arange(0,dim)[:, np.newaxis])
    # create data sample
    dim = H.shape[0]
    torch.manual_seed(2)
    X = np.random.randn(dim,N)
    eta = np.random.randn(dim,N)
    Y = np.dot(H,X) + 0.01 * eta

    # Tikhonov solution
    X_tik = np.linalg.inv(np.dot(H.T,H) + a * np.eye(dim)).dot(H.T).dot(Y)
        
    # error
    error = np.linalg.norm(X_tik-X)**2/X.shape[1]
    MSE.append(error)
print(MSE)

## Neural Network

### Forward problem

In [1]:
# torch version (without weight restriction)

EPOCH = 500
LEARNING_RATE = 0.005
N = 10000
    
MSE = []
for k,dim in enumerate([2,4,9,16,25,32]):
    H = 1. / (np.arange(1,dim+1) + np.arange(0,dim)[:, np.newaxis])
    H = torch.tensor(H).float()
    # create data sample
    dim = H.shape[0]
    torch.manual_seed(1)
    X = torch.randn(dim,N).float()
    eta = torch.randn(dim,N).float()
    Y = torch.mm(H,X) + 0.01 * eta

    # batch_n = 100
    hidden_layer = dim * 2
    input_data = dim
    output_data = dim

    net = torch.nn.Sequential(
        torch.nn.Linear(input_data,hidden_layer),
        torch.nn.ReLU(),
        torch.nn.Linear(hidden_layer,output_data)
    )

    criterion = torch.nn.MSELoss()
    optimzer = torch.optim.Adam(net.parameters(),lr = LEARNING_RATE)

# training
    for epoch in range(EPOCH):
        Y_pred = net(X.T)
        loss = criterion(Y_pred,Y.T)
        optimzer.zero_grad()
        loss.backward()
        optimzer.step()
        MSE.append(loss.data.numpy())
        
        # test
        if epoch == EPOCH-1:
            torch.manual_seed(2)
            n = 2000
            x = torch.randn(dim,n)
            e = torch.randn(dim,n)
            y = torch.mm(H,x) + 0.01 * e
            y_pred = net(x.T)
            mse = criterion(y_pred,y.T)
            print('dim: ',dim, 'MSE: ',mse)         
    plt.plot(range(EPOCH),MSE[EPOCH*k:EPOCH*(k+1)],'-',label= 'dim: '+ str(dim))
    #plt.ylim(0,3)
    #plt.xlim(0)
    plt.legend()
    plt.ylabel(r'$MSE=\frac{1}{n}\Vert W_{x}X-Y\Vert ^{2} $')
    plt.xlabel('Number of Epoch')
    plt.title('Forward Problem (Hilbert Matrix)')
plt.savefig('nn(DPHMnoOutlier).pdf', format='pdf')

NameError: name 'np' is not defined

### Inverse problem

In [None]:
# torch version (without weight restriction)
import torch
from torch.autograd import Variable
from torch.distributions import normal

EPOCH = 1000
LEARNING_RATE = 0.05
N = 10000
std = 10
MSE = []
for k,dim in enumerate([2,4,9,16,25,32]):
    H = 1. / (np.arange(1,dim+1) + np.arange(0,dim)[:, np.newaxis])
    H = torch.tensor(H).float()
    # create data sample
    dim = H.shape[0]
    torch.manual_seed(1)
    X = torch.randn(dim,N).float()
    eta = torch.empty(dim,N).normal_(mean=0,std=std)
    Y = torch.mm(H,X) + 0.01 * eta

    # batch_n = 100
    hidden_layer = dim * 2
    input_data = dim
    output_data = dim

    net = torch.nn.Sequential(
        torch.nn.Linear(input_data,hidden_layer),
        torch.nn.ReLU(),
        torch.nn.Linear(hidden_layer,output_data)
    )

    criterion = torch.nn.MSELoss()
    optimzer = torch.optim.Adam(net.parameters(),lr = LEARNING_RATE)

# training
    for epoch in range(EPOCH):
        X_pred = net(Y.T)
        loss = criterion(X_pred,X.T)
        optimzer.zero_grad()
        loss.backward()
        optimzer.step()
        MSE.append(loss.data.numpy())
        if epoch == EPOCH-1:
            print(loss)  
        # test
        if epoch == EPOCH-1:
            torch.manual_seed(2)
            n = 2000
            x = torch.randn(dim,n)
            e = torch.empty(dim,n).normal_(mean=0,std=std)
            y = torch.mm(H,x) + 0.01 * e
            x_pred = net(y.T)
            mse = criterion(x_pred,x.T)
            print('dim: ',dim, 'MSE: ',mse)  
    plt.plot(range(EPOCH),MSE[EPOCH*k:EPOCH*(k+1)],'-',label= 'dim: '+ str(dim))
    #plt.ylim(0,3)
    #plt.xlim(0)
    plt.legend()
    plt.ylabel(r'$MSE=\frac{1}{n}\Vert W_{y}Y-X\Vert ^{2} $')
    plt.xlabel('Number of Epoch')
    plt.title('Inverse Problem (Hilbert Matrix)')
plt.savefig('nn(IPHMnoOutlier).pdf', format='pdf')

### ISTA

In [None]:
def shrinkage(x, theta):
    max = np.abs(x) - theta
    max[max<0] = 0
    return np.sign(x) * max

def ista(Y, X, A, max_iter, eps,a=None, L=None):

    eig, eig_vector = np.linalg.eig(A.T.dot(A))
    if L is not None:
        assert L > np.max(eig)
    if L == None:
        L = np.max(eig)+0.1
    if a == None:
        a = 0.0001 * L
    del eig, eig_vector
    
    B = (1/L)*A.T

    ista_err = []
    X_old = np.zeros((A.shape[1], Y.shape[1]))
    for i in range(max_iter):

        X_new = shrinkage(X_old-B.dot(A.dot(X_old)-Y),a/L)

        if np.sum(np.abs(X_new - X_old))/Y.shape[1] < eps:
            break
        X_old = X_new
        error = np.linalg.norm(X_old - X)**2/Y.shape[1]
        ista_err.append(error)
        
    return X_new, ista_err

In [None]:
np.random.seed(1)
N = 10000 # number of data size
max_iter = 5000

for key, value in Hilbert.items():
    solution = []
    residual = []

    # create data sample
    dim = value.shape[0]
    X = np.random.randn(dim,N)
    eta = np.random.randn(dim,N)
    Y = np.dot(value,X) + 0.01 * eta

    X_ISTA, ISTAerror = ista(Y, X, value, max_iter=max_iter, eps=1e-6)
    #plt.subplot(2, 1, 2)
    plt.plot(range(len(ISTAerror)),ISTAerror, '-', label='dim = '+ str(dim))
    plt.legend()
    #plt.ylim(0,0.0001)
    plt.ylabel('MSE')
    #plt.yscale('log')
    plt.show()

    print('dim:',dim,' iter:',len(ISTAerror), ' MSE:',ISTAerror[-1])

### LISTA

In [None]:
class LISTA(nn.Module):
    def __init__(self, n, m, W_e, max_iter, theta, L=None):
        
        super(LISTA, self).__init__()
        self._W = nn.Linear(in_features=n, out_features=m, bias=False)
        self._S = nn.Linear(in_features=m, out_features=n, bias=False)
        self.shrinkage = nn.Softshrink(theta)
        self.theta = theta
        self.max_iter = max_iter
        self.A = W_e
        self.L = L
        
    # weights initialization based on the dictionary
    def weights_init(self):
        L = self.L
        A = self.A.cpu().numpy()
        S = torch.from_numpy(np.eye(A.shape[1]) - (1/L)*np.dot(A.T, A))
        S = S.float().to(device)
        W = torch.from_numpy((1/L)*A.T)
        W = W.float().to(device)
        
        self._S.weight = nn.Parameter(S)
        self._W.weight = nn.Parameter(W)


    def forward(self, y):
        x = self.shrinkage(self._W(y))

        if self.max_iter == 1 :
            return x

        for iter in range(self.max_iter):
            x = self.shrinkage(self._W(y) + self._S(x))

        return x

def train_lista(Y, X, dictionary, lr, max_iter=15,epoch=100,L=None, batch_size=None,eps=1e-6):
    
    eig, eig_vector = np.linalg.eig(dictionary.T.dot(dictionary))
    if L is not None:
        assert L > torch.max(eig)
    if L == None:
        L = np.max(eig)+0.1
    del eig, eig_vector

    n, m = dictionary.shape
    n_samples = Y.shape[1]
    #steps_per_epoch = n_samples // batch_size
    
    # convert the data into tensors
    Y = torch.from_numpy(Y)
    Y = Y.float().to(device)
    X = torch.from_numpy(X)
    X = X.float().to(device)
    
    W_d = torch.from_numpy(dictionary)
    W_d = W_d.float().to(device)

    net = LISTA(n, m, W_d, max_iter=max_iter, L=L, theta=1/L)
    net = net.float().to(device)
    net.weights_init()

    # build the optimizer and criterion
    criterion1 = nn.MSELoss()
    #criterion2 = nn.L1Loss()
    #all_zeros = torch.zeros(batch_size, m).to(device)
    optimizer = torch.optim.Adam(net.parameters(), lr=lr)

    loss_list = []
    X_h_old = torch.zeros((n, n_samples)).to(device)
    for e in range(epoch):
        optimizer.zero_grad()
    
        # get the outputs
        X_h = net(Y.T).T
        #Y_h = torch.mm(X_h, W_d.T)
        if torch.sum(torch.abs(X_h - X_h_old))/n_samples < eps:
            break
        
        # compute the losss
        loss1 = criterion1(X.float(), X_h.float())
        #loss2 = a * criterion2(X_h.float(), all_zeros.float())
        #loss = loss1 + loss2
        loss = loss1
            
        loss.backward()
        optimizer.step()
        X_h_old = X_h  
  
        with torch.no_grad():
          loss_list.append(loss.cpu().data.numpy()) 
    #for epoch in range(50):
     # index_samples = np.random.choice(a=n_samples, size=n_samples, replace=False, p=None)
      #Y_shuffle = Y[:,index_samples]
      #X_shuffle = X[:,index_samples]
      #for step in range(steps_per_epoch):
        #Y_batch = Y_shuffle[:,step*batch_size:(step+1)*batch_size]
        #X_batch = X_shuffle[:,step*batch_size:(step+1)*batch_size]
        #optimizer.zero_grad()
    
        # get the outputs
        #X_h = net(Y_batch.T).T
        #Y_h = torch.mm(X_h, W_d.T)
    
        # compute the losss
        #loss1 = criterion1(X_batch.float(), X_h.float())
        #loss2 = a * criterion2(X_h.float(), all_zeros.float())
        #loss = loss1 + loss2
        #loss = loss1
            
        #loss.backward()
        #optimizer.step()  
  
        #with torch.no_grad():
          #loss_list.append(loss.cpu().data.numpy()) 

    plt.plot(range(len(loss_list)),loss_list,'-',label='$\dim$ = '+ str(m)) 
    plt.legend()
    plt.ylabel(r'$MSE=\frac{1}{n}\Vert W_{y}Y-X\Vert ^{2} $')
    plt.xlabel('Number of Epoch')
    plt.show()
    print('Epoch: ',len(loss_list),'MSE: ',loss_list[-1])       
            
    return net, loss_list 

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import orth

N = 10000 # number of data size
for key, value in Hilbert.items():
    solution = []
    residual = []
    # create data sample
    dim = value.shape[0]
    np.random.seed(1)
    X = np.random.randn(dim,N)
    eta = np.random.randn(dim,N)
    Y = np.dot(value,X) + 0.01 * eta
    lista, lista_err = train_lista(Y,X, value,max_iter=1, lr=0.05,epoch=5000)

# Test stage
# generate sparse signal Z and measurement X
    n = 2000
    np.random.seed(2)
    X = np.random.randn(dim,n)
    eta = np.random.randn(dim,n)
    Y = np.dot(value,X) + 0.01 * eta
    X = torch.from_numpy(X).float().to(device)

    X_LISTA = lista(torch.from_numpy(Y.T).float().to(device)).T
    Loss = nn.MSELoss()
    loss = Loss(X_LISTA, X)
    print(loss.cpu().data.numpy())

# DnCNN

In [None]:
# add noise
def addnoise(img,types,n,mean=0,sd=1,lam=1,multi=False):
    image = img.clone()
    #image = img
    c,h,w = image.shape
    
    if types == 'original':
        pass
    
    elif types == 'saltpepper':
        n = n // 2
        for k in range(n):
            i = int(np.random.random() * image.shape[1])
            j = int(np.random.random() * image.shape[2])
            if image.ndim == 2:
                image[j,i] = 1
            elif img.ndim == 3:
                image[0,j,i]= 1
                image[1,j,i]= 1
                image[2,j,i]= 1
        for k in range(n):
            i = int(np.random.random() * image.shape[1])
            j = int(np.random.random() * image.shape[2])
            if image.ndim == 2:
                image[j,i] = 0
            elif img.ndim == 3:
                image[0,j,i]= 0
                image[1,j,i]= 0
                image[2,j,i]= 0
                
    else:
        mask = np.random.choice((0,1),size=(c,h,w),p=[1-n/h/w,n/h/w])
        
        if types == 'gaussian':
            noise = np.random.normal(loc=mean, scale=sd, size=(c,h,w))

        if types == 'poisson':
            noise = np.random.poisson(lam=lam, size=(c,h,w))

        if types == 'uniform':
            noise = np.random.random(size=(c,h,w))
        
        if types == 'exponential':
            noise = np.random.exponential(scale=lam,size=(c,h,w))
            
        if types == 'lognormal':
            noise = np.random.lognormal(mean=mean,sigma=sd,size=(c,h,w))
            
        if types == 'rayleigh':
            noise = np.random.rayleigh(scale=lam,size=(c,h,w))
      
        noise = noise * mask
        
        if multi == False:
            image = image + noise
        if multi == True:
            image = image + image * noise
  
    return image

In [None]:
# create dataset    
def dataset(data,size,seed=0,noise_type='gaussian',noise_num=250,multi=False):
    img = []
    img_gt = []
    img_noise = []
    if size == 'all':
        for k,(j,_) in enumerate(data):
            img_gt.append(j)
            imgnoise = addnoise(j,types=noise_type,n=noise_num,multi=multi)
            img_noise.append(imgnoise)
        else:
            np.random.seed(seed)
            index = np.random.randint(0,len(data),size=size)
    for i in index:
        img.append(data[i])
        
    for k,(j,_) in enumerate(img):
        img_gt.append(j)
        imgnoise = addnoise(j,types=noise_type,n=noise_num,multi=multi)
        img_noise.append(imgnoise)
    return img_gt,img_noise    
  
# show image
def imgshow(data,fig_num,fig_size):
    plt.figure(figsize=fig_size)
    for i,j in enumerate(data):
        ax = plt.subplot(math.ceil(len(data)/50), 5, i+1 ) # 行，列，索引（需从1开始）
        to_img = transforms.ToPILImage()
        n = to_img(j)
        ax.imshow(n) 
        #ax.set_title(name[i])
        ax.set_xticks([])  
        ax.set_yticks([])
        if i+1 == fig_num:
            break
    return plt.show()  
  

# create DnCNN
class DnCNN(nn.Module):

    def __init__(self, depth=12, n_channels=64, image_channels=3, use_bnorm=True, kernel_size=3):
        super(DnCNN, self).__init__()
        
        kernel_size = 3  
        padding = 1  
        layers = [] 

        layers.append(nn.Conv2d(in_channels=image_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
        layers.append(nn.ReLU())
        for _ in range(depth-2):
            layers.append(nn.Conv2d(in_channels=n_channels, out_channels=n_channels, kernel_size=kernel_size, padding=padding, bias=True))
            layers.append(nn.BatchNorm2d(n_channels, eps=0.0001, momentum = 0.9))
            layers.append(nn.ReLU())
        layers.append(nn.Conv2d(in_channels=n_channels, out_channels=image_channels, kernel_size=kernel_size, padding=padding, bias=True))
        self.dncnn = nn.Sequential(*layers).to(device)
        self._initialize_weights() 
        
    def forward(self, x):
        x = x.to(device)
        y = x.to(device)
        out = self.dncnn(x).to(device)
        return y-out

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.orthogonal_(m.weight)
                print('init weight')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)

dncnn = DnCNN()
print(dncnn)

def gaussian(window_size, sigma):
    gauss = torch.Tensor([exp(-(x - window_size//2)**2/float(2*sigma**2)) for x in range(window_size)])
    return gauss/gauss.sum()
 
def create_window(window_size, channel=3):
    _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
    _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
    window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
    return window
 
# SSIM
def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
    # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
    if val_range is None:
        if torch.max(img1) > 128:
            max_val = 255
        else:
            max_val = 1
 
        if torch.min(img1) < -0.5:
            min_val = -1
        else:
            min_val = 0
        L = max_val - min_val
    else:
        L = val_range
 
    padd = 0
    (_, channel, height, width) = img1.size()
    if window is None:
        real_size = min(window_size, height, width)
        window = create_window(real_size, channel=channel).to(img1.device)
 
    mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
    mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
 
    mu1_sq = mu1.pow(2)
    mu2_sq = mu2.pow(2)
    mu1_mu2 = mu1 * mu2
 
    sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
    sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
    sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
 
    C1 = (0.01 * L) ** 2
    C2 = (0.03 * L) ** 2
 
    v1 = 2.0 * sigma12 + C2
    v2 = sigma1_sq + sigma2_sq + C2
    cs = torch.mean(v1 / v2)  # contrast sensitivity
 
    ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
 
    if size_average:
        ret = ssim_map.mean()
    else:
        ret = ssim_map.mean(1).mean(1).mean(1)
 
    if full:
        return ret, cs
    return ret
 
class SSIM(torch.nn.Module):
    def __init__(self, window_size=11, size_average=True, val_range=None):
        super(SSIM, self).__init__()
        self.window_size = window_size
        self.size_average = size_average
        self.val_range = val_range
 
        # Assume 1 channel for SSIM
        self.channel = 1
        self.window = create_window(window_size)
 
    def forward(self, img1, img2):
        (_, channel, _, _) = img1.size()
 
        if channel == self.channel and self.window.dtype == img1.dtype:
            window = self.window
        else:
            window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
            self.window = window
            self.channel = channel
        ssimloss = ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
 
        return ssimloss

In [None]:
def net_training(data_noise,data_gt,test_noise,test_gt,epoch,lr,criterion,name,net=dncnn,optimizer='adam',fig_num=range(0,100,9),loss_min=0):

    Epoch = epoch
    LOSS = []
    if criterion == 'MSE':
        Criterion = nn.MSELoss()
    if criterion == 'SSIM':
        Criterion = SSIM()
        criterion = '1-SSIM'
    if criterion == 'PSNR':
        Criterion = PSRN()
        criterion = '1-PSNR'
    Criterion = Criterion.to(device)

    if optimizer == 'adam':
        optimizer = torch.optim.Adam(net.parameters(),lr=lr)
    if optimizer == 'sgd':
        optimizer = optim.SGD(net.parameters(), lr=lr)

    print("Training loop:")
    for idx in range(0,Epoch):
        for input, target in zip(data_noise,data_gt):
            optimizer.zero_grad()   # zero the gradient buffers
            input = input.float().to(device)
            target = target.float().to(device)
            #output = input - net(input)
            output = net(input)
            noise = (input - target)
            if criterion == 'MSE':
                loss = Criterion(output,noise)
            if criterion == '1-SSIM' or criterion == '1-PSRN':
                loss = 1 - Criterion(output,noise)
            loss.backward()
            optimizer.step()   
            #if idx % 500 == 0:
                #print("Epoch {: >8} Loss: {}".format(idx, loss.data.numpy()))
            #print("Epoch {: >8} Loss: {}".format(idx, loss.cpu()))
            LOSS.append(loss.cpu().data.numpy())
    
            # show the training image: ground-truth image, noisy image, denoising image, predicted noise, ground-truth noise
            #if idx in fig_num:
            #  for j,k in zip(data_noise,data_gt):
            #    to_img = transforms.ToPILImage()
            #    j = j.float().to(device)
            #    k = k.float().to(device)
            #    #pred = j - net(j)   # the denoising image
            #    pred = net(j)  # output is the noise
            #    plt.figure(figsize=(20,100))
            #    for n,m,p in zip(pred,j,k):
            #      n = n.to(device)
            #      ax = plt.subplot(1, 5, 1) 
            #      img = to_img(p)
            #      plt.imshow(img)
            #      ax.set_xticks([]) 
            #      ax.set_yticks([])
            #      ax = plt.subplot(1, 5, 2)
            #      img = to_img(m)
            #      plt.imshow(img)
            #      ax.set_xticks([]) 
            #      ax.set_yticks([])
            #      ax = plt.subplot(1, 5, 3)
            #      img = to_img(m-n)
            #      plt.imshow(img)
            #      ax.set_xticks([]) 
            #      ax.set_yticks([])
            #      ax = plt.subplot(1, 5, 4)
            #      img = to_img(n)
            #      plt.imshow(img)
            #      ax.set_xticks([]) 
            #      ax.set_yticks([])
            #      ax = plt.subplot(1, 5, 5)
            #      img = to_img(m-p)
            #      plt.imshow(img)
            #      ax.set_xticks([]) 
            #      ax.set_yticks([])
            #      break
            #    plt.show()
            #    break
            # the stopping strategy
            if loss < loss_min:
                break
    path = name + '.pkl'
    torch.save(net,path) 
    plt.plot(LOSS,'-')
    plt.ylabel(criterion)
    plt.xlabel('Epoch')
    #plt.yscale('log')
    plt.show()
    mse = []
  
    # show the last epoch result of training set
    for input,target in zip(data_noise,data_gt):
        input = input.float().to(device)
        target = target.float().to(device)
        output = net(input)
        denoising = input - output
        mse.append(Criterion(denoising,target).cpu().data.numpy())
    MSE = np.mean(mse)
    PRSN = round(20 * math.log10(1 / math.sqrt(MSE)),3)
    print('training PRSN: ',PRSN)
  
    # show the performance of test set
    for input,target in zip(test_noise,test_gt):
        input = input.float().to(device)
        target = target.float().to(device)
        output = net(input)
        denoising = input - output
        mse.append(Criterion(denoising,target).cpu().data.numpy())
    MSE = np.mean(mse)
    PRSN = round(20 * math.log10(1 / math.sqrt(MSE)),3)
    print('test PRSN: ',PRSN)
  
    # show the last epoch training image of training set
    for j,k in zip(data_noise,data_gt):
        to_img = transforms.ToPILImage()
        j = j.float().to(device)
        k = k.float().to(device)
        #pred = j - net(j)
        pred = net(j)
        plt.figure(figsize=(20,100))
        count = 25
        i = 0
            for n,m,p in zip(pred,j,k):
                i += 1
                if i == count:
                    n = n.to(device)
                    ax = plt.subplot(1, 5, 1) 
                    img = to_img(p)
                    plt.imshow(img)
                    ax.set_xticks([]) 
                    ax.set_yticks([])
                    ax = plt.subplot(1, 5, 2)
                    img = to_img(m)
                    plt.imshow(img)
                    ax.set_xticks([])
                    ax.set_yticks([])
                    ax = plt.subplot(1, 5, 3)
                    img = to_img(m-n)
                    plt.imshow(img)
                    ax.set_xticks([]) 
                    ax.set_yticks([])
                    ax = plt.subplot(1, 5, 4)
                    img = to_img(n)
                    plt.imshow(img)
                    ax.set_xticks([]) 
                    ax.set_yticks([])
                    ax = plt.subplot(1, 5, 5)
                    img = to_img(m-p)
                    plt.imshow(img)
                    ax.set_xticks([]) 
                    ax.set_yticks([])
                    plt.show()
                    break        
    print('========'*5)
  
    # show the predict performance of test set
    for j,k in zip(test_noise,test_gt):
    to_img = transforms.ToPILImage()
    j = j.float().to(device)
    k = k.float().to(device)
    #pred = j - net(j)
    pred = net(j)
    plt.figure(figsize=(20,100))
    count = 10
    i = 0
    for n,m,p in zip(pred,j,k):
        i += 1
        if i == count:
            n = n.to(device)
            ax = plt.subplot(1, 5, 1) 
            img = to_img(p)
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 5, 2)
            img = to_img(m)
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 5, 3)
            img = to_img(m-n)
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 5, 4)
            img = to_img(n)
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 5, 5)
            img = to_img(m-p)
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            break
    return LOSS

In [None]:
# load image and add addictive Gaussian noise
train_gt_gau,train_noise_gau = dataset(train_set,size=1024,seed=123,noise_type='gaussian',noise_num=1000,multi=False)
test_gt_gau,test_noise_gau = dataset(test_set,size=256,seed=123,noise_type='gaussian',noise_num=1000,multi=False)

batch_size = 64
train_noise_gau_batch = DataLoader(train_noise_gau, batch_size=batch_size, shuffle=False)
train_gt_gau_batch = DataLoader(train_gt_gau, batch_size=batch_size, shuffle=False)

test_noise_gau_batch = DataLoader(test_noise_gau, batch_size=batch_size, shuffle=False,num_workers=0)
test_gt_gau_batch = DataLoader(test_gt_gau, batch_size=batch_size, shuffle=False,num_workers=0)

# DnCNN training
a=net_training(train_noise_gau_batch,train_gt_gau_batch,test_noise_gau_batch,test_gt_gau_batch,net=dncnn,
        epoch=200,optimizer='adam',criterion='MSE',lr=0.0001,fig_num=range(0,200,9),loss_min=1e-6,
        name='rayleighAdd1000')

# ISTA-Net

In [None]:
# Define ISTA-Net Block
import torch.nn.init as init
import torch.nn.functional as F
class BasicBlock(torch.nn.Module):
    def __init__(self,n_channels=128, image_channels=3,kernel_size=3):
        super(BasicBlock, self).__init__()

        kernel_size = 3
        padding = 1
        fw_layers = []
        bw_layers = []
 
        self.lambda_step = nn.Parameter(torch.Tensor([1]))
        self.soft_thr = nn.Parameter(torch.Tensor([0.5]))

        fw_layers.append(nn.Conv2d(in_channels=image_channels,out_channels=n_channels,kernel_size=kernel_size,padding=padding,bias=True))
        fw_layers.append(nn.ReLU())
        fw_layers.append(nn.Conv2d(in_channels=n_channels,out_channels=n_channels,kernel_size=kernel_size,padding=padding,bias=True))
        
        self.conv_forward = nn.Sequential(*fw_layers)

        bw_layers.append(nn.Conv2d(in_channels=n_channels,out_channels=n_channels,kernel_size=kernel_size,padding=padding,bias=True))
        bw_layers.append(nn.ReLU())
        bw_layers.append(nn.Conv2d(in_channels=n_channels,out_channels=image_channels,kernel_size=kernel_size,padding=padding,bias=True))
        
        self.conv_backward = nn.Sequential(*bw_layers)

    def forward(self, x, PhiTPhi, PhiTb):

        x = x - self.lambda_step * torch.matmul(PhiTPhi,x)
        x_input = x + self.lambda_step * PhiTb
        x_input = x_input.view(-1, 3, 64, 64)

        x_forward = self.conv_forward(x_input)

        x_st = torch.matmul(torch.sign(x_forward), F.relu(torch.abs(x_forward) - self.soft_thr))
       
        x_backward = self.conv_backward(x_st)
        
        x_pred = x_backward
        
        x_est = self.conv_backward(x_forward)

        symloss = x_est - x_input

        return [x_pred, symloss, x_st]

In [None]:
# Define ISTA-Net
class ISTANet(torch.nn.Module):
    def __init__(self, LayerNo):
        super(ISTANet, self).__init__()
        layer = []
        self.LayerNo = LayerNo

        for i in range(LayerNo):
            layer.append(BasicBlock())

        self.fcs = nn.ModuleList(layer)

    def forward(self, Phix, Phi, Qinit):

        PhiTPhi = torch.matmul(Phi.T, Phi)
        PhiTb = torch.matmul(Phi.T, Phix)

        x = torch.matmul(Qinit,Phix)

        layers_sym = []   # for computing symmetric loss
        layer_st = []

        for i in range(self.LayerNo):
            [x, layer_sym, layer_st] = self.fcs[i](x, PhiTPhi, PhiTb)
            layers_sym.append(layer_sym)

        x_final = x

        return [x_final, layers_sym, layer_st]

In [None]:
def CS_istanet(img_batch,test_batch,layer_num,learning_rate,epoch,gamma,cs_ratio):

    cs = int(64*cs_ratio)
    model = ISTANet(layer_num)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    to_img = transforms.ToPILImage()

    np.random.seed(123)
    Phi = np.random.normal(size=(64,64))
    Phi = orth(Phi.T).T
    Phi = torch.tensor(Phi)[:cs].float().to(device)

    loss_list = []
    loss_all = torch.tensor(0).to(device)
    stop = False
    # Training loop  
    for epoch_i in range(1, epoch+1):
        for i,img1 in enumerate(img_batch):
            loss_old = loss_all.clone()
            target = torch.tensor(img1)
            target = target.float().to(device)
            #Phix = torch.matmul(Phi,target).float()
            Phix = torch.matmul(Phi,target).float()
            Phix = Phix + 0.01*torch.randn_like(Phix)
            Qinit = torch.linalg.inv(Phix.matmul(Phix.permute(0,1,3,2)))
            Qinit = (target.matmul(Phix.permute(0,1,3,2))).matmul(Qinit)
            Qinit = Qinit.to(device)
          
            [x_output, loss_layers_sym, loss_st] = model(Phix, Phi, Qinit)

            # Compute and print loss
            loss_discrepancy = torch.mean(torch.pow(x_output - target, 2))
            loss_constraint = torch.mean(torch.pow(loss_layers_sym[0],2))
            for k in range(layer_num-1):
                loss_constraint += torch.mean(torch.pow(loss_layers_sym[k+1],2))

            sparsity_constraint = 0
            for k,_ in enumerate(loss_st, 0):
                sparsity_constraint += torch.mean(torch.abs(loss_st[k]))

            #loss_all = loss_discrepancy
            loss_all = loss_discrepancy + 0.01 * loss_constraint + 0.001 * sparsity_constraint

            # Zero gradients, perform a backward pass, and update the weights.
            optimizer.zero_grad()
            loss_all.backward()
            optimizer.step()
            
            #if i == 0:
                #f epoch_i in range(1,epoch+1,5):
                    #for i,(j,k,_) in enumerate(x_output):
            #print('Epoch: ',epoch_i,'MSE: ',loss_all.cpu().data.numpy())
            #    plt.figure(figsize=(20,120))
            #    ax = plt.subplot(1, 6, 1) 
            #    img = to_img(target[0])
            #    plt.imshow(img)
              #    ax.set_xticks([])  
              #    ax.set_yticks([])
              #    ax = plt.subplot(1, 6, 2)
              #    img = to_img(Phix[0])
              #    plt.imshow(img)
              #    ax.set_xticks([])  
              #    ax.set_yticks([])
              #    ax = plt.subplot(1, 6, 3)
              #    img = to_img(x_output[0])
              #    plt.imshow(img)
              #    ax.set_xticks([])  
              #    ax.set_yticks([])
              #    ax = plt.subplot(1, 6, 4) 
              #    img = to_img(x_output[0][0])
              #    plt.imshow(img)
              #    ax.set_xticks([])  
              #    ax.set_yticks([])
              #    ax = plt.subplot(1, 6, 5) 
              #    img = to_img(x_output[0][1])
              #    plt.imshow(img)
              #    ax.set_xticks([])  
              #    ax.set_yticks([])
              #    ax = plt.subplot(1, 6, 6) 
              #    img = to_img(x_output[0][2])
              #    plt.imshow(img)
              #    ax.set_xticks([])  
              #    ax.set_yticks([]) 
              #    plt.show()
 
            loss_list.append(loss_all.cpu().data.numpy())
            res_loss = np.abs(loss_old.cpu().data.numpy() - loss_all.cpu().data.numpy())
            if res_loss < 1e-8:
                stop = True
                break

        if stop:
            break
    path = 'ISTA-Net'+ str(layer_num) + str(cs_ratio) + '.pkl'
    torch.save(model,path)

    # if epoch_i == epoch:
    print('layer_num: ',layer_num,' lr: ',learning_rate)
    plt.plot(range(len(loss_list)),loss_list, '-')
    plt.ylabel('MSE')
    plt.show()

    print(' iter:',len(loss_list), ' MSE:',loss_list[-1])
  
    np.random.seed(123)
    Phi = np.random.normal(size=(64,64))
    Phi = orth(Phi.T).T
    Phi = torch.tensor(Phi[:cs]).float().to(device)
    mse = []
    for i,img1 in enumerate(img_batch):
        target = torch.tensor(img1)
        target = target.float().to(device)
        Phix = torch.matmul(Phi,target).float() + 0.01*torch.randn_like(Phi)
        Qinit = torch.linalg.inv(Phix.matmul(Phix.permute(0,1,3,2)))
        Qinit = (target.matmul(Phix.permute(0,1,3,2))).matmul(Qinit)
        Qinit = Qinit.to(device)
        [x_output, loss_layers_sym, loss_st] = model(Phix, Phi, Qinit)
        mse.append(torch.mean(torch.pow(x_output - target, 2)).cpu().data.numpy())
        for i in range(1,2):
            plt.figure(figsize=(20,120))
            ax = plt.subplot(1, 6, 1) 
            img = to_img(target[i])
            plt.imshow(img)
            ax.set_xticks([])
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 2)
            img = to_img(Phix[i])
            plt.imshow(img)
            ax.set_xticks([]) 
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 3)
            img = to_img(x_output[i])
            plt.imshow(img)
            ax.set_xticks([]) 
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 4) 
            img = to_img(x_output[i][0])
            plt.imshow(img)
            ax.set_xticks([]) 
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 5) 
            img = to_img(x_output[i][1])
            plt.imshow(img)
            ax.set_xticks([]) 
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 6) 
            img = to_img(x_output[i][2])
            plt.imshow(img)
            ax.set_xticks([]) 
            ax.set_yticks([]) 
            plt.show()
    MSE = np.mean(mse)
    PRSN = round(20 * math.log10(1 / math.sqrt(MSE)),3)
    print('training MSE: ',MSE)
    print('training PRSN: ',PRSN)
    print('========'*5)
    np.random.seed(123)
    Phi = np.random.normal(size=(64,64))
    Phi = orth(Phi.T).T
    Phi = torch.tensor(Phi[:cs]).float().to(device)
    mse = []
    for i,img1 in enumerate(test_batch):
        target = torch.tensor(img1)
        target = target.float().to(device)
        Phix = torch.matmul(Phi,target).float() + 0.01*torch.randn_like(Phi)
        Qinit = torch.linalg.inv(Phix.matmul(Phix.permute(0,1,3,2)))
        Qinit = (target.matmul(Phix.permute(0,1,3,2))).matmul(Qinit)
        Qinit = Qinit.to(device)
        [x_output, loss_layers_sym, loss_st] = model(Phix, Phi, Qinit)
        mse.append(torch.mean(torch.pow(x_output - target, 2)).cpu().data.numpy())
        for i in range(1,2):
            plt.figure(figsize=(20,120))
            ax = plt.subplot(1, 6, 1) 
            img = to_img(target[i])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 2)
            img = to_img(Phix[i])
            plt.imshow(img)
            ax.set_xticks([]) 
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 3)
            img = to_img(x_output[i])
            plt.imshow(img)
            ax.set_xticks([]) 
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 4) 
            img = to_img(x_output[i][0])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 5) 
            img = to_img(x_output[i][1])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 6) 
            img = to_img(x_output[i][2])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([]) 
            plt.show()
    MSE = np.mean(mse)
    PRSN = round(20 * math.log10(1 / math.sqrt(MSE)),3)
    print('training MSE: ',MSE)
    print('training PRSN: ',PRSN)
    print('======'*5)
    return loss_list

In [None]:
%%time
istanet = {}
for i in [0.25,0.5,0.75,1]:
    for j in [3]:
        istanet[str(i)] = CS_istanet(train_gt_gau_batch,test_gt_gau_batch,layer_num=j,learning_rate=0.0001,epoch=200,gamma=0.01,cs_ratio=i)
        print('CS ratio: ',i,'\n Layer number: ',j)

# ISTA-Net+

In [None]:
# Define ISTA-Net+ Block
import torch.nn.init as init
import torch.nn.functional as F
class BasicBlock(torch.nn.Module):
    def __init__(self,n_channels=128, image_channels=3,kernel_size=3):
        super(BasicBlock, self).__init__()

        kernel_size = 3
        padding = 1
        fw_layers = []
        bw_layers = []
 
        self.lambda_step = nn.Parameter(torch.Tensor([1]))
        self.soft_thr = nn.Parameter(torch.Tensor([0.5]))

        fw_layers.append(nn.Conv2d(in_channels=image_channels,out_channels=n_channels,kernel_size=kernel_size,padding=padding,bias=True))
        fw_layers.append(nn.ReLU())
        fw_layers.append(nn.Conv2d(in_channels=n_channels,out_channels=n_channels,kernel_size=kernel_size,padding=padding,bias=True))
        
        self.conv_forward = nn.Sequential(*fw_layers)

        bw_layers.append(nn.Conv2d(in_channels=n_channels,out_channels=n_channels,kernel_size=kernel_size,padding=padding,bias=True))
        bw_layers.append(nn.ReLU())
        bw_layers.append(nn.Conv2d(in_channels=n_channels,out_channels=image_channels,kernel_size=kernel_size,padding=padding,bias=True))
        
        self.conv_backward = nn.Sequential(*bw_layers)

    def forward(self, x, PhiTPhi, PhiTb):

        x = x - self.lambda_step * torch.matmul(PhiTPhi,x)
        x_input = x + self.lambda_step * PhiTb
        x_input = x_input.view(-1, 3, 64, 64)

        x_forward = self.conv_forward(x_input)

        x_st = torch.matmul(torch.sign(x_forward), F.relu(torch.abs(x_forward) - self.soft_thr))
       
        x_backward = self.conv_backward(x_st)
        
        #x_pred = x_backward
        x_pred = F.relu(x_input + x_backward)
        
        x_est = self.conv_backward(x_forward)

        symloss = x_est - x_input

        return [x_pred, symloss, x_st]

# Define ISTA-Net+
class ISTANetp(torch.nn.Module):
    def __init__(self, LayerNo):
        super(ISTANetp, self).__init__()
        layer = []
        self.LayerNo = LayerNo

        for i in range(LayerNo):
            layer.append(BasicBlock())

        self.fcs = nn.ModuleList(layer)

    def forward(self, Phix, Phi, Qinit):

        PhiTPhi = torch.matmul(Phi.T, Phi)
        PhiTb = torch.matmul(Phi.T, Phix)

        x = torch.matmul(Qinit,Phix)

        layers_sym = []   # for computing symmetric loss
        layer_st = []

        for i in range(self.LayerNo):
            [x, layer_sym, layer_st] = self.fcs[i](x, PhiTPhi, PhiTb)
            layers_sym.append(layer_sym)

        x_final = x

        return [x_final, layers_sym, layer_st]

def CS_istanetp(img_batch,test_batch,layer_num,learning_rate,epoch,gamma,cs_ratio):
  
    cs = int(64*cs_ratio)
    model = ISTANetp(layer_num)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    to_img = transforms.ToPILImage()

    np.random.seed(123)
    Phi = np.random.normal(size=(64,64))
    Phi = orth(Phi.T).T
    Phi = torch.tensor(Phi[:cs]).float().to(device)

    loss_list = []
    loss_all = torch.tensor(0).to(device)
    stop = False
    # Training loop  
    for epoch_i in range(1, epoch+1):
        for i,img1 in enumerate(img_batch):
            loss_old = loss_all.clone()
            target = torch.tensor(img1)
            target = target.float().to(device)
            #Phix = torch.matmul(Phi,target).float()
            Phix = torch.matmul(Phi,target).float() + 0.01*torch.randn_like(Phi)
            Qinit = torch.linalg.inv(Phix.matmul(Phix.permute(0,1,3,2)))
            Qinit = (target.matmul(Phix.permute(0,1,3,2))).matmul(Qinit)
            Qinit = Qinit.to(device)
          
            [x_output, loss_layers_sym, loss_st] = model(Phix, Phi, Qinit)

            # Compute and print loss
            loss_discrepancy = torch.mean(torch.pow(x_output - target, 2))
            loss_constraint = torch.mean(torch.pow(loss_layers_sym[0],2))
            for k in range(layer_num-1):
                loss_constraint += torch.mean(torch.pow(loss_layers_sym[k+1],2))

            sparsity_constraint = 0
            for k,_ in enumerate(loss_st, 0):
                sparsity_constraint += torch.mean(torch.abs(loss_st[k]))

            #loss_all = loss_discrepancy
            loss_all = loss_discrepancy + 0.01 * loss_constraint + 0.001 * sparsity_constraint

            # Zero gradients, perform a backward pass, and update the weights.
            optimizer.zero_grad()
            loss_all.backward()
            optimizer.step()

            loss_list.append(loss_all.cpu().data.numpy())
            res_loss = np.abs(loss_old.cpu().data.numpy() - loss_all.cpu().data.numpy())
            if res_loss < 1e-8:
                stop = True
                break
        if stop:
            break
    path = 'ISTA-Netp'+ str(layer_num) + str(cs_ratio) + '.pkl'
    torch.save(model,path) 

    # if epoch_i == epoch:
    print('layer_num: ',layer_num,' lr: ',learning_rate)
    plt.plot(range(len(loss_list)),loss_list, '-')
    plt.ylabel('MSE')
    plt.show()
    print(' iter:',len(loss_list), ' MSE:',loss_list[-1])

    np.random.seed(123)
    Phi = np.random.normal(size=(64,64))
    Phi = orth(Phi.T).T
    Phi = torch.tensor(Phi[:cs]).float().to(device)
    mse = []
    for i,img1 in enumerate(img_batch):
        target = torch.tensor(img1)
        target = target.float().to(device)
        Phix = torch.matmul(Phi,target).float() + 0.01*torch.randn_like(Phi)
        Qinit = torch.linalg.inv(Phix.matmul(Phix.permute(0,1,3,2)))
        Qinit = (target.matmul(Phix.permute(0,1,3,2))).matmul(Qinit)
        Qinit = Qinit.to(device)
        [x_output, loss_layers_sym, loss_st] = model(Phix, Phi, Qinit)
        mse.append(torch.mean(torch.pow(x_output - target, 2)).cpu().data.numpy())
        for i in range(1,2):
            plt.figure(figsize=(20,120))
            ax = plt.subplot(1, 6, 1) 
            img = to_img(target[i])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 2)
            img = to_img(Phix[i])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 3)
            img = to_img(x_output[i])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 4) 
            img = to_img(x_output[i][0])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 5) 
            img = to_img(x_output[i][1])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 6) 
            img = to_img(x_output[i][2])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([]) 
            plt.show()
    MSE = np.mean(mse)
    PRSN = round(20 * math.log10(1 / math.sqrt(MSE)),3)
    print('training MSE: ',MSE)
    print('training PRSN: ',PRSN)
    print('========'*5)
    np.random.seed(123)
    Phi = np.random.normal(size=(64,64))
    Phi = orth(Phi.T).T
    Phi = torch.tensor(Phi[:cs]).float().to(device)
    mse = []
    for i,img1 in enumerate(test_batch):
        target = torch.tensor(img1)
        target = target.float().to(device)
        Phix = torch.matmul(Phi,target).float() + 0.01*torch.randn_like(Phi)
        Qinit = torch.linalg.inv(Phix.matmul(Phix.permute(0,1,3,2)))
        Qinit = (target.matmul(Phix.permute(0,1,3,2))).matmul(Qinit)
        Qinit = Qinit.to(device)
        [x_output, loss_layers_sym, loss_st] = model(Phix, Phi, Qinit)
        mse.append(torch.mean(torch.pow(x_output - target, 2)).cpu().data.numpy())
        for i in range(1,2):
            plt.figure(figsize=(20,120))
            ax = plt.subplot(1, 6, 1) 
            img = to_img(target[i])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 2)
            img = to_img(Phix[i])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 3)
            img = to_img(x_output[i])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 4) 
            img = to_img(x_output[i][0])
            plt.imshow(img)
            ax.set_xticks([]) 
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 5) 
            img = to_img(x_output[i][1])
            plt.imshow(img)
            ax.set_xticks([]) 
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 6) 
            img = to_img(x_output[i][2])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([]) 
            plt.show()
    MSE = np.mean(mse)
    PRSN = round(20 * math.log10(1 / math.sqrt(MSE)),3)
    print('training MSE: ',MSE)
    print('training PRSN: ',PRSN)
    print('======'*5)
    return loss_list

In [None]:
%%time
istanetp = {}
for i in [0.25,0.5,0.75,1]:
    for j in [3]:
        istanetp[str(i)] = CS_istanetp(train_gt_gau_batch,test_gt_gau_batch,layer_num=j,learning_rate=0.0001,epoch=200,gamma=0.01,cs_ratio=i)
        print('CS ratio: ',i,'\n layer number: ',j)

# FISTA-Net

In [None]:
# define basic block of FISTA-Net
class  BasicBlock(nn.Module):

    def __init__(self, n_channels=128, image_channels=3, kernel_size=3,padding=1):
        super(BasicBlock, self).__init__()
        self.Sp = nn.Softplus()

        fw_layers = []
        bw_layers = []
 
        self.lambda_step = nn.Parameter(torch.Tensor([1]))
        self.soft_thr = nn.Parameter(torch.Tensor([0.5]))

        fw_layers.append(nn.Conv2d(in_channels=image_channels,out_channels=n_channels,kernel_size=kernel_size,padding=padding,bias=True))
        fw_layers.append(nn.ReLU())
        fw_layers.append(nn.Conv2d(in_channels=n_channels,out_channels=n_channels,kernel_size=kernel_size,padding=padding,bias=True))
        fw_layers.append(nn.ReLU())
        fw_layers.append(nn.Conv2d(in_channels=n_channels,out_channels=n_channels,kernel_size=kernel_size,padding=padding,bias=True))
        fw_layers.append(nn.ReLU())
        fw_layers.append(nn.Conv2d(in_channels=n_channels,out_channels=n_channels,kernel_size=kernel_size,padding=padding,bias=True))
        
        self.conv_forward = nn.Sequential(*fw_layers)
        
        bw_layers.append(nn.Conv2d(in_channels=n_channels,out_channels=n_channels,kernel_size=kernel_size,padding=padding,bias=True))
        bw_layers.append(nn.ReLU())        
        bw_layers.append(nn.Conv2d(in_channels=n_channels,out_channels=n_channels,kernel_size=kernel_size,padding=padding,bias=True))
        bw_layers.append(nn.ReLU())
        bw_layers.append(nn.Conv2d(in_channels=n_channels,out_channels=n_channels,kernel_size=kernel_size,padding=padding,bias=True))
        bw_layers.append(nn.ReLU())
        bw_layers.append(nn.Conv2d(in_channels=n_channels,out_channels=image_channels,kernel_size=kernel_size,padding=padding,bias=True))
        
        self.conv_backward = nn.Sequential(*bw_layers)


    def forward(self, x, PhiTPhi, PhiTb, lambda_step, soft_thr):
        
        # naive gradient descent update
        x = x - self.Sp(lambda_step)  * PhiTPhi.matmul(x) + self.Sp(lambda_step) * PhiTb

        x_input = x.view(-1,3,64,64)

        x_forward = self.conv_forward(x_input)

        # soft-thresholding block
        x_st = torch.mul(torch.sign(x_forward), F.relu(torch.abs(x_forward) - self.Sp(soft_thr)))

        x_backward = self.conv_backward(x_st)

        # prediction output (skip connection); non-negative output
        x_pred = F.relu(x_input + x_backward)

        # compute symmetry loss
        x_input = self.conv_backward(x_forward)

        symloss = x_input - x_input

        return [x_pred, symloss, x_st]

class FISTANet(nn.Module):
    def __init__(self, LayerNo):
        super(FISTANet, self).__init__()
        self.LayerNo = LayerNo
        layer = []

        for i in range(LayerNo):
            layer.append(BasicBlock())

        self.fcs = nn.ModuleList(layer)
        
        # thresholding value
        self.w_theta = nn.Parameter(torch.Tensor([-0.5]))
        self.b_theta = nn.Parameter(torch.Tensor([-2]))
        # gradient step
        self.w_mu = nn.Parameter(torch.Tensor([-0.2]))
        self.b_mu = nn.Parameter(torch.Tensor([0.1]))
        # two-step update weight
        self.w_rho = nn.Parameter(torch.Tensor([0.5]))
        self.b_rho = nn.Parameter(torch.Tensor([0]))

        self.Sp = nn.Softplus()

    def forward(self, Phix, Phi, Qinit):

        PhiTPhi = torch.matmul(Phi.T, Phi)
        PhiTb = torch.matmul(Phi.T, Phix)

        x = torch.matmul(Qinit,Phix)

        # initialize the result
        xold = x
        y = xold 
        layers_sym = []     # for computing symmetric loss
        layers_st = []      # for computing sparsity constraint
        xnews = []       # iteration result
        xnews.append(xold)

        for i in range(self.LayerNo):
            theta_ = self.w_theta * i + self.b_theta
            mu_ = self.w_mu * i + self.b_mu
            [xnew, layer_sym, layer_st] = self.fcs[i](y, PhiTPhi, PhiTb, mu_, theta_)
            rho_ = 1 - self.Sp(self.w_rho + self.b_rho) / self.Sp(self.w_rho * i + self.b_rho)
            y = xnew + rho_ * (xnew - xold) # two-step update
            xold = xnew
            xnews.append(xnew)   # iteration result
            layers_st.append(layer_st)
            layers_sym.append(layer_sym)

        return [xnew, layers_sym, layers_st]

def CS_fistanet(img_batch,test_batch,layer_num,learning_rate,epoch,gamma,cs_ratio):

    model = FISTANet(layer_num)
    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(),lr=learning_rate)
    to_img = transforms.ToPILImage()
    cs = int(64*cs_ratio)

    np.random.seed(123)
    Phi = np.random.normal(size=(64,64))
    Phi = torch.tensor(orth(Phi.T).T)
    Phi = Phi[:cs].float().to(device)

    loss_list = []
    loss_all = torch.tensor(0).to(device)
    stop = False
    # Training loop  
    for epoch_i in range(1, epoch+1):
        for i,img1 in enumerate(img_batch):
            loss_old = loss_all
            target = torch.tensor(img1)
            target = target.float().to(device)
            Phix = torch.matmul(Phi,target).float() + 0.01 * torch.randn_like(Phi)
            Qinit = torch.linalg.inv(Phix.matmul(Phix.permute(0,1,3,2)))
            Qinit = (target.matmul(Phix.permute(0,1,3,2))).matmul(Qinit)
            Qinit = Qinit.to(device)
            [x_output, loss_layers_sym, loss_st] = model(Phix, Phi, Qinit)          
                    
            # Compute loss, data consistency and regularizer constraints
            loss_discrepancy = torch.mean(torch.pow(x_output - target, 2)) # + l1_loss(x_output, target, 0.1)

            loss_constraint = torch.mean(torch.pow(loss_layers_sym[0],2))
            for k in range(layer_num-1):
                loss_constraint += torch.mean(torch.pow(loss_layers_sym[k+1],2))
                    
            sparsity_constraint = 0
            for k, _ in enumerate(loss_st, 0):
                sparsity_constraint += torch.mean(torch.abs(loss_st[k]))
                    
            #loss_all = loss_discrepancy + gamma * loss_constraint
            loss_all = loss_discrepancy +  0.01 * loss_constraint + 0.001 * sparsity_constraint
       
            #model.zero_grad()
            # Zero gradients, perform a backward pass, and update the weights.
            optimizer.zero_grad()
            loss_all.backward()
            optimizer.step()                
         
            loss_list.append(loss_all.cpu().data.numpy())
            res_loss = np.abs(loss_old.cpu().data.numpy()-loss_all.cpu().data.numpy())
            if res_loss < 1e-8:
                stop = True
                break

        if stop:
            break
    path = 'FISTA-Net_parameter'+ str(layer_num) + str(cs_ratio) + '.pkl'
    torch.save(model.state_dict(),path) 
     
    print('layer_num: ',layer_num,' lr: ',learning_rate)
    plt.plot(range(len(loss_list)),loss_list, '-')
    plt.ylabel('MSE')
    plt.show()
    print(' iter:',len(loss_list), ' MSE:',loss_list[-1])

    np.random.seed(123)
    Phi = np.random.normal(size=(64,64))
    Phi = orth(Phi.T).T
    Phi = torch.tensor(Phi[:cs]).float().to(device)
    mse = []
    for i,img1 in enumerate(img_batch):
        target = torch.tensor(img1)
        target = target.float().to(device)
        Phix = torch.matmul(Phi,target).float() + 0.01*torch.randn_like(Phi)
        Qinit = torch.linalg.inv(Phix.matmul(Phix.permute(0,1,3,2)))
        Qinit = (target.matmul(Phix.permute(0,1,3,2))).matmul(Qinit)
        Qinit = Qinit.to(device)
        [x_output, loss_layers_sym, loss_st] = model(Phix, Phi, Qinit)
        mse.append(torch.mean(torch.pow(x_output - target, 2)).cpu().data.numpy())
        for i in range(1,2):
            plt.figure(figsize=(20,120))
            ax = plt.subplot(1, 6, 1) 
            img = to_img(target[i])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 2)
            img = to_img(Phix[i])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 3)
            img = to_img(x_output[i])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 4) 
            img = to_img(x_output[i][0])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 5) 
            img = to_img(x_output[i][1])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 6) 
            img = to_img(x_output[i][2])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([]) 
            plt.show()
    MSE = np.mean(mse)
    PRSN = round(20 * math.log10(1 / math.sqrt(MSE)),3)
    print('training MSE: ',MSE)
    print('training PRSN: ',PRSN)
    print('========'*5)
    np.random.seed(123)
    Phi = np.random.normal(size=(64,64))
    Phi = orth(Phi.T).T
    Phi = torch.tensor(Phi[:cs]).float().to(device)
    mse = []
    for i,img1 in enumerate(test_batch):
        target = torch.tensor(img1)
        target = target.float().to(device)
        Phix = torch.matmul(Phi,target).float() + 0.01*torch.randn_like(Phi)
        Qinit = torch.linalg.inv(Phix.matmul(Phix.permute(0,1,3,2)))
        Qinit = (target.matmul(Phix.permute(0,1,3,2))).matmul(Qinit)
        Qinit = Qinit.to(device)
        [x_output, loss_layers_sym, loss_st] = model(Phix, Phi, Qinit)
        mse.append(torch.mean(torch.pow(x_output - target, 2)).cpu().data.numpy())
        for i in range(1,2):
            plt.figure(figsize=(20,120))
            ax = plt.subplot(1, 6, 1) 
            img = to_img(target[i])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 2)
            img = to_img(Phix[i])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 3)
            img = to_img(x_output[i])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 4) 
            img = to_img(x_output[i][0])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 5) 
            img = to_img(x_output[i][1])
            plt.imshow(img)
            ax.set_xticks([]) 
            ax.set_yticks([])
            ax = plt.subplot(1, 6, 6) 
            img = to_img(x_output[i][2])
            plt.imshow(img)
            ax.set_xticks([])  
            ax.set_yticks([]) 
            plt.show()    

    MSE = np.mean(mse)
    PRSN = round(20 * math.log10(1 / math.sqrt(MSE)),3)
    print('training MSE: ',MSE)
    print('training PRSN: ',PRSN)
    print('======'*5)
    return loss_list

In [None]:
# Addictive Gaussian
train_gt_gau,train_noise_gau = dataset(train_set,size=1024,seed=123,noise_type='gaussian',noise_num=500,multi=True)
test_gt_gau,test_noise_gau = dataset(test_set,size=256,seed=123,noise_type='gaussian',noise_num=500,multi=True)
batch_size = 64
train_noise_gau_batch = DataLoader(train_noise_gau, batch_size=batch_size, shuffle=False)
train_gt_gau_batch = DataLoader(train_gt_gau, batch_size=batch_size, shuffle=False)

test_noise_gau_batch = DataLoader(test_noise_gau, batch_size=batch_size, shuffle=False,num_workers=0)
test_gt_gau_batch = DataLoader(test_gt_gau, batch_size=batch_size, shuffle=False,num_workers=0)

%%time
fista = {}
for i in [0.25,0.5,0.75,1]:
    for j in [3]:
        fista[str(i)]=CS_fistanet(train_gt_gau_batch,test_gt_gau_batch,layer_num=j,learning_rate=0.0001,epoch=200,gamma=0.01,cs_ratio=i)
        print('CS ratio: ',i,'\n layer number: ',j)