# DISCUS Implementation for Shepp-Logan Phantom: Study-I




### Load utilitiues and libraries 

In [None]:
########################################################## Import libraries
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline
from skimage.metrics import structural_similarity as ssim
import numpy as np
import pickle
# from models.resnet import ResNet
# from models.unet import UNet
from models.skip import skip
import torch
from torch import optim

from scipy.io import savemat, loadmat



########################################################## from utils.inpainting_utils import *
from utils_SL.common_utils import *
from utils_SL.fftc import * # ra: added the pytorch fft routine from fastmri

import os# os.environ['CUDA_VISIBLE_DEVICES'] = '1'

torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor

PLOT = True

### Check available hardware:

In [None]:
# check available hardware:
def list_cuda_devices():
    if torch.cuda.is_available():
        num_devices = torch.cuda.device_count()
        print(f"Number of available CUDA devices: {num_devices}")
        for i in range(num_devices):
            print(f"Device {i}: {torch.cuda.get_device_name(i)}")
            print(f"  Memory Allocated: {torch.cuda.memory_allocated(i) / 1024 ** 2:.2f} MB")
            print(f"  Memory Cached: {torch.cuda.memory_reserved(i) / 1024 ** 2:.2f} MB")
            print(f"  Total Memory: {torch.cuda.get_device_properties(i).total_memory / 1024 ** 2:.2f} MB")
    else:
        print("CUDA is not available.")

list_cuda_devices()

In [None]:
# print(torch.cuda.device_count()) # restart shell if it doesnt show 2
# device=0
# torch.cuda.set_device(device) # 0/1
# print(torch.cuda.current_device())


## Set study parameters and select dataset:

In [None]:
########################################################## Adjust parameters
data_path = '../data/SL-ph/'
R = 2  # net acceleration rate
gm = 1  # gamma correction for display
N = 64  # number of repetitions = frames

# select series:
# A list of three series: ["rotation", "translation", "both"]
series = 'both' 


sv = 1
data_path_r = data_path + series + "/"




### Load simulated data:


In [None]:
yN = np.load(data_path_r + 'y_N_%d_R_%d' % (N, R) + '.npy')
ynN = np.load(data_path_r + 'yn_N_%d_R_%d' % (N, R) + '.npy')
yuN = np.load(data_path_r + 'yu_N_%d_R_%d' % (N, R) + '.npy')
mskN = np.load(data_path_r + 'mask_R_%fN_%d_phantom' % (R, N) + '.npy')
xN = np.load(data_path_r + 'xRef_N_%d' % N + '.npy')
n = xN.shape[1:]
print("Image size: ", n)

### CS-Recon:

In [None]:
csRe = 1
tau = 1e-2 # regularization strength

########################################################## CS Recon

if csRe==1:
  # N=4
  
  # Running ADMM L2-L1
  nmse = np.zeros([N,1])
  ssm = np.zeros([N,1])
  xHat = np.zeros((2*N,n[0],n[1]))
  xHatAbs = np.zeros((N,n[0],n[1]))
  xAbs = np.zeros((N,n[0],n[1]))
  errMap = np.zeros((2*N,n[0],n[1]))

  for i in range(N):
    print('Slice: %2d' %(i+1), 'out of %2d' %N)
    x0 = (np.zeros(n)).astype(complex)
    y = yuN[2*i,:,:] + 1j*yuN[2*i+1,:,:]
    msk = mskN[i, :,:]
    mu = 1e-1 # lagrangian parameter
    nIter = [50, 5] # outer and inner iterations
    ss = 0.9 # step size

    [xTmp,loss] = admm_l1(x0, y, msk, nIter, ss, mu, tau)
    xHat[2*i,:,:]   = np.real(xTmp)
    xHat[2*i+1,:,:] = np.imag(xTmp)
    xHatAbs[i:i+1,:,:] = np.sqrt((xHat[i*2,:,:])**2 + (xHat[i*2+1,:,:])**2)

    xAbs[i:i+1,:,:] = np.sqrt((xN[i*2,:,:])**2 + (xN[i*2+1,:,:])**2)
    nmse[i] =  np.mean((xN[i*2:(i+1)*2,:,:]-xHat[i*2:(i+1)*2,:,:])**2) / np.mean((xN[i*2:(i+1)*2,:,:])**2)
    ssm[i] = ssim(xHatAbs[i,:,:], xAbs[i,:,:], data_range = xHatAbs[i,:,:].max() - xHatAbs[i,:,:].min()) # xHatL1Abs[i:i+1,:,:].max() - xHatL1Abs[i:i+1,:,:].min()
    errMap[i*2:(i+1)*2,:,:] = xN[i*2:(i+1)*2,:,:]-xHat[i*2:(i+1)*2,:,:]

  fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
  plt.plot(np.log10(loss))
  plt.xlabel("No. of iterations")
  plt.ylabel("Total loss")
  plt.show()
          
  print('Mean nmse: %1.2f,' %(10*np.log10(np.mean(nmse))), 'nmse: ',', '.join('%1.2f' % (10*np.log10(nmse[j])) for j in range(len(nmse)))) 
  print('Mean ssim: %1.3f,' %(np.mean(ssm)), 'ssim: ',', '.join('%1.3f' % (ssm[j]) for j in range(len(ssm)))) 
  fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
  plt.imshow(np.reshape(np.transpose(xHat,[1,0,2]), [n[0],n[1]*2*N]), vmin=-1, vmax=1, cmap=plt.cm.Greys_r) # use a specific color map
  plt.show()

  fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
  plt.imshow(np.reshape(np.transpose(errMap,[1,0,2]), [n[0],n[1]*2*N]), vmin=-0.1, vmax=0.1, cmap=plt.cm.Greys_r) # use a specific color map
  plt.show()

  fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
  plt.imshow(np.reshape(np.transpose(np.concatenate((xAbs[0:1,:,:]**gm, xHatAbs[0:1,:,:]**gm), axis=2),[1,0,2]), [n[0],n[1]*2]), vmin=0, vmax=0.7, cmap=plt.cm.Greys_r) # use a specific color map
  plt.show()
  # sv=1
  if sv==1: 
    np.save(data_path_r + 'xRefL1_'+'N_%d' % N + '.npy', xN)
    np.save(data_path_r + 'xHatL1_'+'N_%d' % N + '.npy', xHat)
    # np.save(data_path_r + 'nmseL1_'+'N_%d' % N + '.npy', nmse)
    # np.save(data_path_r + 'ssimL1_'+'N_%d' % N + '.npy', ssm)

In [None]:
# print(np.max(takeMag(yuN)))
# print(np.max(takeMag(xN)))



In [None]:
# N=64

num_iter = 12000 # 15000 # number of iterations
z_lamb0 = 4.3e2
reg_sig0 = 0.005
show_every = 1000
z_sc = 0.1
LR = 1e-4

# sv=0
# reg_sig0 = 0.03 # noise regularization

WtD = 2*1e-6 # weight decay
opt = 2 # select the flavor of the algorithm
# LR  = 1e-3 # learning rate, original: 1e-2
LSz = 128 # Number of channels in hidden layers
NLy=7

In [None]:
# print(n)

### DIP and DISCUS

In [None]:
# get the start time
st = time.time()

########################################################## Image selection
# for loop size
if opt==0:
  L = N
else:
  L = 1
  NInd = 0 # If '0' don't select an individual image

for el in range(L):
  if opt==0:
    NInd = el+1 # pick one image from N images
    Ns   = 4 # Number of channels common to all images
    Nz   = 0 # Number of image specific channels
    Nout = 1 # Number of outputs, 1 or N (2*Nout = output channels) 
    Nin = Nz+Ns
  elif opt==1: # Fix noise stack in with image specific output channels
    Ns   = 4 # Number of channels common to all images
    Nz   = 0 # Number of image specific channels
    Nout = N # Number of outputs, 1 or N (2*Nout = output channels)
    Nin = Nz+Ns
  elif opt==2: # [Ns; Nz] in with a single output channel
    Ns   = 3 # Number of channels common to all images #<--3
    Nz   = 1 # Number of image specific channels #<--1
    Nout = 1 # Number of outputs, 1 or N (2*Nout = output channels)
    Nin = Nz+Ns
  elif opt==3: # [Ns; Nz] in with image specific output channels
    Ns   = 3 # Number of channels common to all images
    Nz   = 1 # Number of image specific channels
    Nout = N # Number of outputs, 1 or N (2*Nout = output channels)
    Nin = Nz+Ns
  elif opt==4: # [Ns + Nz] in with a single output channel
    Ns   = 1 # Number of channels common to all images 
    Nz   = Ns # Number of image specific channels
    Nout = 1 # Number of outputs, 1 or N (2*Nout = output channels)
    Nin = Nz
  elif opt==5: # H[Nz] in with image-specific first layer and single output channel 
    Ns   = 0 # Number of channels common to all images
    Nz   = 2 # Number of image specific channels
    Nout = 1 # Number of outputs, 1 or N (2*Nout = output channels)
    Nin = Nz*N
  elif opt==6: # H[Nz] in with image-specific first layer and image specific output channels
    Ns   = 0 # Number of channels common to all images
    Nz   = 2 # Number of image specific channels
    Nout = N # Number of outputs, 1 or N (2*Nout = output channels)
    Nin = Nz*N


  if NInd==0: # process all images
    x  = xN
    y  = yN
    yn = ynN
    yu = yuN
    msk= mskN
  elif NInd > 0: # Process only one image
    N  = 1
    x  = xN[(NInd-1)*2:(NInd-1)*2+2]
    y  = yN[(NInd-1)*2:(NInd-1)*2+2]
    yn = ynN[(NInd-1)*2:(NInd-1)*2+2]
    yu = yuN[(NInd-1)*2:(NInd-1)*2+2]
    msk= mskN[(NInd-1):NInd]



  ########################################################## Network setup
  pad = 'reflection' # 'zero'
  mse = torch.nn.MSELoss().type(dtype)
  mae = torch.nn.L1Loss().type(dtype)

  INPUT = 'noise' # 'noise', 'meshgrid', or 'hybrid'

  # ra: note, setting num_channels_skip[0] = 0 may improve performance
  if opt==5 or opt==6: 
    net = skip(Nin, 2*Nout, # skip|skip_depth6|skip_depth4|skip_depth2|UNET|ResNet
    num_channels_down = [2]+[LSz]*(NLy-1), #[2, 128, 128, 128, 128, 128]
    # num_channels_down[0] = 2,
    num_channels_up =   [LSz]*NLy, #[128, 128, 128, 128, 128, 128]
    num_channels_skip = [LSz]*NLy, #[128, 128, 128, 128, 128, 128] 
    filter_size_up   = [3]*NLy, #[3, 3, 3, 3, 3, 3], 
    filter_size_down = [3]*NLy, #[3, 3, 3, 3, 3, 3],
    filter_skip_size = 1, # kernel size for the filters along the skip connections
    upsample_mode='nearest', 
    output_act = 0, # 0 for none, 1 for sigmoid, 2 for tanh
    need_bias=True, 
    pad=pad, 
    act_fun='LeakyReLU').type(dtype) # need_sigmoid forces the out to between 0 and 1
  else:
    # RA: The number of layers = 2^layers < image size
    # Five layers for sl64 and six for other applications
    net = skip(Nin, 2*Nout, # skip|skip_depth6|skip_depth4|skip_depth2|UNET|ResNet
    num_channels_down = [LSz]*NLy, #[128, 128, 128, 128, 128]
    num_channels_up =   [LSz]*NLy, #[128, 128, 128, 128, 128]
    num_channels_skip = [LSz]*NLy, #[128, 128, 128, 128, 128]  
    filter_size_up   = [3]*NLy, #[3, 3, 3, 3, 3], 
    filter_size_down = [3]*NLy, #[3, 3, 3, 3, 3],
    filter_skip_size = 1, # kernel size for the filters along the skip connections
    upsample_mode='nearest', 
    output_act = 0, # 0 for none, 1 for sigmoid, 2 for tanh
    need_bias=True, 
    pad=pad, 
    act_fun='LeakyReLU').type(dtype) # need_sigmoid forces the out to between 0 and 1


  s  = sum(np.prod(list(p.size())) for p in net.parameters())
  print ('Number of params: %d' % s)


  ########################################################## Setup network inputs
  # generate network input
  x_tor   = np_to_torch(x).type(dtype)
  yu_tor  = np_to_torch(yu).type(dtype)
  yn_tor  = np_to_torch(yn).type(dtype)
  msk_tor = np_to_torch(msk).type(dtype)
  # x_avg_tor = np_to_torch(x_avg).type(dtype)

  net = net.type(dtype)
  if opt==5 or opt==6:
    # net = net.type(dtype)
    z0 = torch.zeros(1,Nz,n[0],n[1]).type(dtype) # all zeros
    z = get_noise(Nz, INPUT, x.shape[1:], var=1/10).type(dtype) - 1/20
    z_saved = torch.clone(z)   # ra: use clone so that z and z_saved don't point to the same location
    z0_saved = torch.clone(z0)
    # zs = get_noise(Nin-Nz, INPUT, x.shape[1:], var=1./10).type(dtype) - 1/20
    # zs_saved = torch.clone(zs)  # ra: use clone so that zs and zs_saved don't point to the same location

  elif opt==4:
    z = torch.zeros(1,Nz*N,n[0],n[1]).type(dtype) # all zeros
    zs = get_noise(Ns, INPUT, x.shape[1:], var=1./10).type(dtype) - 1/20
    z_saved  = torch.clone(z)
    zs_saved = torch.clone(zs)
    z0 = torch.zeros(1,Nz*N,n[0],n[1]).type(dtype) # all zeros
    z0_saved = torch.clone(z0)

  else: 
    zs = get_noise(Nin-Nz, INPUT, x.shape[1:], var=1./10).type(dtype) - 1/20
    zs_saved = torch.clone(zs)  # ra: use clone so that zs and zs_saved don't point to the same location
    
    z0 = get_noise(Nz, INPUT, x.shape[1:], var=1./10).type(dtype) - 1/20
    z = torch.zeros(1,Nz*N,n[0],n[1]).type(dtype) # all zeros
    for i in range(N):
      z[:,i*Nz:(i+1)*Nz,:,:] = z_sc*(0.8*z0 + 0.2*(get_noise(Nz, INPUT, x.shape[1:], var=1./10).type(dtype) - 1/20))
    z0 = torch.tile(z0,[1,N,1,1])
    z_saved  = torch.clone(z)   # ra: use clone so that z and z_saved don't point to the same location
    z0_saved = torch.clone(z0)

  # print(z.shape)


  ########################################################## Main iterations
  # from torch.nn.modules.loss import L1Loss
  ii = 0
  def closure():
      global ii
      # global running_loss
      # global z_lamb
      z_lamb = z_lamb0 #*(1 + 99 * ii/num_iter)
      losses = torch.empty([N,1]).type(dtype)
      xHat_tor = torch.empty(N,1,2*Nout, n[0], n[1]).type(dtype)
      reg_sig = reg_sig0*(1 - 0.9 * ii/num_iter)
      for i in range(N):
        if opt==5 or opt==6:
          xHat_tor[i,:,:,:,:] = net(torch.cat((torch.randn(1,Nz*i,n[0],n[1]).type(dtype) * reg_sig/(N**0.5), z + (torch.randn(1,Nz,n[0],n[1]).type(dtype) * reg_sig), torch.randn(1,N*Nz-Nz*(i+1),n[0],n[1]).type(dtype)* reg_sig/(N**0.5)), 1))
        elif opt==4:
          xHat_tor[i,:,:,:,:] = net(zs + z[:,i*Nz:(i+1)*Nz,:,:] + (torch.randn(1,Nz,n[0],n[1]).type(dtype) * reg_sig))
        else: 
          xHat_tor[i,:,:,:,:] = net(torch.cat((zs + (torch.randn(1,Nin-Nz,n[0],n[1]).type(dtype) * reg_sig), z[:,i*Nz:(i+1)*Nz,:,:] + (torch.randn(1,Nz,n[0],n[1]).type(dtype) * reg_sig)), 1)) #<-- change *1 to *4 AND <--zs to zs_saved
        if Nout==N:
          xHatF_tor = fft2c_ra(xHat_tor[i,:,i*2:(i+1)*2,:,:], 'ortho')
        elif Nout<N:
          xHatF_tor = fft2c_ra(xHat_tor[i,:,:,:,:], 'ortho')

        losses[i] = mse(xHatF_tor*msk_tor[:,i:i+1,:,:],  yu_tor[:,i*2:(i+1)*2,:,:])
        
      if opt==2 or opt==3 or opt==4:
        z_loss = z_lamb * torch.mean(torch.sqrt(torch.mean(torch.abs(z)**2, axis=1)+1e-6))
        # z_loss = z_lamb*mae(z, z0_saved) # z_loss = mse(z,z0_saved) #<--default
        # z_loss = z_lamb*mae(z, torch.tile(z[:,0:Nz,:,:],[1,N,1,1]))
        # z_loss = z_lamb*mae(z, torch.tile(z[:,np.remainder(ii,N)*Nz:(np.remainder(ii,N)+1)*Nz,:,:],[1,N,1,1]))
      else:
        z_loss = 0

      total_loss = sum(losses) + z_loss

      total_loss.backward()
      running_loss[0,ii] = total_loss.item()

      # print ('Iteration %05d    Loss %f' % (i, total_loss.item()), '\r', end='')
      if  PLOT and (ii ==0 or (ii+1) % show_every == 0):
        nmse = np.zeros([N,1])
        ssm = np.zeros([N,1])
        xHat = np.zeros((2*N,n[0],n[1]))
        xHatAbs = np.zeros((N,n[0],n[1]))
        xAbs = np.zeros((N,n[0],n[1]))
        errMap = np.zeros((2*N,n[0],n[1]))
        for i in range(N):
          if Nout==N:
            xHat[i*2:(i+1)*2,:,:] = torch_to_np(xHat_tor[i, :, i*2:(i+1)*2,:,:])
            xHatAbs[i:i+1,:,:] = np.sqrt((xHat[i*2,:,:])**2 + (xHat[i*2+1,:,:])**2)
            xAbs[i:i+1,:,:] = np.sqrt((x[i*2,:,:])**2 + (x[i*2+1,:,:])**2)
          elif Nout<N:
            xHat[i*2:(i+1)*2,:,:] = torch_to_np(xHat_tor[i,:,:,:,:])
            xHatAbs[i:i+1,:,:] = np.sqrt((xHat[i*2,:,:])**2 + (xHat[i*2+1,:,:])**2)
            xAbs[i:i+1,:,:] = np.sqrt((x[i*2,:,:])**2 + (x[i*2+1,:,:])**2)
          
          nmse[i] =  np.mean((x[i*2:(i+1)*2,:,:]-xHat[i*2:(i+1)*2,:,:])**2) / np.mean((x[i*2:(i+1)*2,:,:])**2)
          ssm[i] = ssim(xHatAbs[i,:,:], xAbs[i,:,:], data_range = xHatAbs[i,:,:].max() - xHatAbs[i,:,:].min()) # xHatL1Abs[i:i+1,:,:].max() - xHatL1Abs[i:i+1,:,:].min() 
          errMap[i*2:(i+1)*2,:,:] = x[i*2:(i+1)*2,:,:]-xHat[i*2:(i+1)*2,:,:]
        
        print('Individual losses x 1e4: ',', '.join('%1.3f' % (losses[j]*1e4) for j in range(len(losses))))
        print('Iteration: %1.3d,' %(ii+1), 'Loss x 1e4: %1.2f,' %(running_loss[0,ii]*1e4), 'Mean nmse: %1.2f,' %(10*np.log10(np.mean(nmse))), 'nmse: ',', '.join('%1.2f' % (10*np.log10(nmse[j])) for j in range(len(nmse)))) 
        fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
        plt.imshow(np.reshape(np.transpose(xHat,[1,0,2]), [n[0],n[1]*2*N]), vmin=-1, vmax=1, cmap=plt.cm.Greys_r) # use a specific color map
        plt.show()

        fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
        plt.imshow(np.reshape(np.transpose(errMap,[1,0,2]), [n[0],n[1]*2*N]), vmin=-0.1, vmax=0.1, cmap=plt.cm.Greys_r) # use a specific color map
        plt.show()

        fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
        if Nz == 0:
          plt.imshow(np.reshape(np.transpose(np.concatenate((xAbs[0:1,:,:]**gm, xHatAbs[0:1,:,:]**gm), axis=2),[1,0,2]), [n[0],n[1]*2]), vmin=0, vmax=0.7, cmap=plt.cm.Greys_r) # use a specific color map
          plt.show() 
        else:
          zNp = torch_to_np(z)
          plt.imshow(np.reshape(np.transpose(np.concatenate((xAbs[0:1,:,:]**gm, xHatAbs[0:1,:,:]**gm, 50*zNp[0:1,:,:]+0.7/2), axis=2),[1,0,2]), [n[0],n[1]*3]), vmin=0, vmax=0.7, cmap=plt.cm.Greys_r) # use a specific color map
          plt.show()
      ii += 1
      return total_loss


  # num_iter = 15000 #12000
  running_loss = torch.empty([1,num_iter]).type(dtype)
  # reg_sig0 = 0.02 * (1 + (Np*1e3)**0.5) 


  print('=====optimizing over network and input=====')
  OPT_OVER = 'net,input'
  # z_lamb0 = 8*2e0 #<-- default: 5e1 for mse, 2e0 for mae, and 2e0 for group sparsity
  # WtD = 0*1e-6 # weight decay
  if Nz != 0 and Ns != 0:
    p = get_params(OPT_OVER, net, ([z,zs])) #<-- ra: remove zs
  if Nz != 0 and Ns == 0:
    p = get_params(OPT_OVER, net, ([z])) #<-- ra: remove zs
  if Ns != 0 and Nz == 0:
    p = get_params(OPT_OVER, net, ([zs])) #<-- ra: remove zs
  optimize('adam', p, closure, LR, num_iter, WtD)


  fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
  plt.plot(np.log10(torch_to_np(running_loss)))
  plt.show()


  ########################################################## Display and saving
  xHat_tor = torch.empty(N,1,2*Nout, n[0], n[1]).type(dtype)
  for i in range(N):
    if opt==5 or opt==6:
      xHat_tor[i,:,:,:,:] = net(torch.cat((torch.zeros(1,Nz*i,n[0],n[1]).type(dtype), z , torch.zeros(1,N*Nz-Nz*(i+1),n[0],n[1]).type(dtype)), 1))
    elif opt==4:
      xHat_tor[i,:,:,:,:] = net(1*zs + 1*z[:,i*Nz:(i+1)*Nz,:,:])
    else: 
      xHat_tor[i,:,:,:,:] = net(torch.cat((1*zs, 1*z[:,i*Nz:(i+1)*Nz,:,:]), 1))

  nmse = np.zeros([N,1])
  ssm = np.zeros([N,1])
  xHat = np.zeros((2*N,n[0],n[1]))
  xHatAbs = np.zeros((N,n[0],n[1]))
  xHatF = np.zeros((2*N,n[0],n[1]))
  xHatFAbs = np.zeros((N,n[0],n[1]))
  xAbs = np.zeros((N,n[0],n[1]))
  errMap = np.zeros((N,n[0],n[1]))
  errFMap = np.zeros((N,n[0],n[1]))

  for i in range(N):
    if Nout==N:
      xHat[i*2:(i+1)*2,:,:] = torch_to_np(xHat_tor[i, :, i*2:(i+1)*2,:,:])
      xHatAbs[i:i+1,:,:] = np.sqrt((xHat[i*2,:,:])**2 + (xHat[i*2+1,:,:])**2)
      xHatF[i*2:(i+1)*2,:,:] = torch_to_np(fft2c_ra(xHat_tor[i, :, i*2:(i+1)*2,:,:],'ortho'))
      xHatFAbs[i:i+1,:,:] = np.sqrt((xHatF[i*2,:,:])**2 + (xHatF[i*2+1,:,:])**2)
      xAbs[i:i+1,:,:] = np.sqrt((x[i*2,:,:])**2 + (x[i*2+1,:,:])**2)
    elif Nout<N:
      xHat[i*2:(i+1)*2,:,:] = torch_to_np(xHat_tor[i,:,:,:,:])
      xHatAbs[i:i+1,:,:] = np.sqrt((xHat[i*2,:,:])**2 + (xHat[i*2+1,:,:])**2)
      xHatF[i*2:(i+1)*2,:,:] = torch_to_np(fft2c_ra(xHat_tor[i,:,:,:,:],'ortho'))
      xHatFAbs[i:i+1,:,:] = np.sqrt((xHatF[i*2,:,:])**2 + (xHatF[i*2+1,:,:])**2)
      xAbs[i:i+1,:,:] = np.sqrt((x[i*2,:,:])**2 + (x[i*2+1,:,:])**2)
    
    nmse[i] =  np.mean((x[i*2:(i+1)*2,:,:]-xHat[i*2:(i+1)*2,:,:])**2) / np.mean((x[i*2:(i+1)*2,:,:])**2)
    ssm[i] = ssim(xHatAbs[i,:,:], xAbs[i,:,:], data_range = xHatAbs[i,:,:].max() - xHatAbs[i,:,:].min()) # xHatL1Abs[i:i+1,:,:].max() - xHatL1Abs[i:i+1,:,:].min() 
    errMap[i,:,:] = np.sqrt(np.sum(np.abs(x[i*2:(i+1)*2,:,:]-xHat[i*2:(i+1)*2,:,:])**2, axis=0))
    errFMap[i,:,:] = np.sqrt(np.sum(np.abs(y[i*2:(i+1)*2,:,:]-xHatF[i*2:(i+1)*2,:,:])**2, axis=0))
    
  print('Mean nmse: %1.2f,' %(10*np.log10(np.mean(nmse))), 'nmse: ',', '.join('%1.2f' % (10*np.log10(nmse[j])) for j in range(len(nmse)))) 
  print('Mean ssim: %1.3f,' %(np.mean(ssm)), 'ssim: ',', '.join('%1.3f' % (ssm[j]) for j in range(len(ssm)))) 

  fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
  plt.imshow(np.reshape(np.transpose(xHatFAbs**0.25,[1,0,2]), [n[0],n[1]*N]), cmap=plt.cm.Greys_r) # use a specific color map
  plt.show()
  fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
  plt.imshow(np.reshape(np.transpose(errFMap**gm,[1,0,2]), [n[0],n[1]*N]), cmap=plt.cm.Greys_r) # use a specific color map
  plt.show() 

  fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
  plt.imshow(np.reshape(np.transpose(xHatAbs**gm,[1,0,2]), [n[0],n[1]*N]), vmin=0, vmax=0.7, cmap=plt.cm.Greys_r) # use a specific color map
  plt.show()
  fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
  plt.imshow(np.reshape(np.transpose(errMap**gm,[1,0,2]), [n[0],n[1]*N]), vmin=0, vmax=0.2, cmap=plt.cm.Greys_r) # use a specific color map
  plt.show()

  if sv==1: np.save(data_path + 'xRefDIP_'+'opt_%d_N_%d_Ind_%d' % (opt, N, NInd) + '.npy', x)
  if sv==1: np.save(data_path + 'xHatDIP_'+'opt_%d_N_%d_Ind_%d' % (opt, N, NInd) + '.npy', xHat)
  if sv==1: np.save(data_path + 'nmseDIP_'+'opt_%d_N_%d_Ind_%d' % (opt, N, NInd) + '.npy', nmse)
  if sv==1: np.save(data_path + 'ssimDIP_'+'opt_%d_N_%d_Ind_%d' % (opt, N, NInd) + '.npy', ssm)
  if Nz!=0:
    if sv==1: np.save(data_path + 'zDIP_'+'opt_%d_N_%d_Ind_%d' % (opt, N, NInd) + '.npy', torch_to_np(z))
  if Ns!=0: 
    if sv==1: np.save(data_path + 'zsDIP_'+'opt_%d_N_%d_Ind_%d' % (opt, N, NInd) + '.npy', torch_to_np(zs))

  if sv==1: torch.save(net, data_path + 'model_'+'opt_%d_N_%d_Ind_%d' % (opt, N, NInd))
  # model = torch.load(data_path + 'model_'+'opt_%d_N_%d_Ind_%d' % (opt, N, NInd))
  # model.eval()


  ########################################################## Display/read the weights
  # # Print model's state_dict
  # print("Model's state_dict:")
  # for param_tensor in net.state_dict():
  #     print(param_tensor, "\t", net.state_dict()[param_tensor].size())

  # # Print optimizer's state_dict
  # print("Optimizer's state_dict:")
  # for var_name in optimizer.state_dict():
  #     print(var_name, "\t", optimizer.state_dict()[var_name])

  # print(net) # this gives the structure of the network
  L = 4 # read thie weights of this layer
  i=0
  for param in net.parameters():
    wt = param.data
    if i == L:
      break
    i=i+1

  print(wt.shape)
  print(wt.view(-1)) # this view(-1) reshpaes into a list
  # print(zs.shape)


  ########################################################## Display code vectors
  if Ns !=0: # plotting 1 but there are Ns=3 static channels
    fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
    plt.imshow(np.concatenate((torch_to_np(zs_saved[:,0,:,:]), torch_to_np(zs[:,0,:,:]), torch_to_np(zs[:,0,:,:]-zs_saved[:,0,:,:])), axis=1), cmap=plt.cm.Greys_r) # use a specific color map
    plt.show()
    print(np.max(torch_to_np(zs[:,0,:,:]-zs_saved[:,0,:,:])))

  if Nz !=0:
    for i in range(N*Nz):
      fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
      plt.imshow(np.concatenate((torch_to_np(z_saved[:,i,:,:]), torch_to_np(z[:,i,:,:]), torch_to_np(z[:,i,:,:]-z_saved[:,i,:,:])) , axis=1), cmap=plt.cm.Greys_r) # use a specific color map
      plt.show()
    


  # sv=1
  # if sv:
  if sv==1: np.save(data_path_r + 'xRefDISCUS_'+'N_%d_LR_%f_zlamb_%f_zSc_%f' % (N, LR, z_lamb0, z_sc) + '_R_%f'%R + '.npy', x)
  if sv==1: np.save(data_path_r + 'xHatDISCUS_'+'N_%d_LR_%f_zlamb_%f_zSc_%f' % (N, LR, z_lamb0, z_sc)+ '_R_%f'%R + '.npy', xHat)
  if Nz!=0:
    if sv==1: np.save(data_path_r + 'zDISCUS_'+'N_%d_LR_%f_zlamb_%f_zSc_%f' % (N, LR, z_lamb0, z_sc)+ '_R_%f'%R + '.npy', torch_to_np(z))
  if Ns!=0: 
    if sv==1: np.save(data_path_r + 'zsDISCUS_'+'N_%d_LR_%f_zlamb_%f_zSc_%f' % (N, LR, z_lamb0, z_sc) + '_R_%f'%R + '.npy', torch_to_np(zs))
  if sv==1: torch.save(net, data_path_r + 'model_DISCUS_'+'N_%d_LR_%f_zlamb_%f_zSc_%f' % (N, LR, z_lamb0, z_sc)+ '_R_%f'%R )
    
  # get the end time
  et = time.time()

  # get the execution time
  elapsed_time = et - st
  print('Execution time:', elapsed_time/60, 'minutes')