# DISCUS Implementation for LGE data: Studies (III, IV)

In [None]:
########################################################## Import libraries
from __future__ import print_function
import matplotlib.pyplot as plt
%matplotlib inline
from scipy.io import savemat
from skimage.metrics import structural_similarity as ssim
from skimage.metrics import peak_signal_noise_ratio as psnr
import numpy as np
import pickle
#from models.resnet import ResNet
# from models.unet import UNet
from models.skip_dropout_LatestOrder import skip

import torch
from torch import optim
import scipy.io
import time

from utils.common_utils import *
from utils.my_utils import *
from utils.fftc import * # ra: added the pytorch fft routine from fastmri

import os# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark =True
dtype = torch.cuda.FloatTensor
PLOT = True

In [None]:
# check available hardware:
def list_cuda_devices():
    if torch.cuda.is_available():
        num_devices = torch.cuda.device_count()
        print(f"Number of available CUDA devices: {num_devices}")
        for i in range(num_devices):
            print(f"Device {i}: {torch.cuda.get_device_name(i)}")
            print(f"  Memory Allocated: {torch.cuda.memory_allocated(i) / 1024 ** 2:.2f} MB")
            print(f"  Memory Cached: {torch.cuda.memory_reserved(i) / 1024 ** 2:.2f} MB")
            print(f"  Total Memory: {torch.cuda.get_device_properties(i).total_memory / 1024 ** 2:.2f} MB")
    else:
        print("CUDA is not available.")

list_cuda_devices()

In [None]:
# print(torch.cuda.device_count()) # restart shell if it doesnt show 2

# torch.cuda.set_device(0) # 0/1


# print(torch.cuda.current_device())
# # Change directory
# %pwd # check current directory
# # Move the directory where the data is
# import os
# os.chdir("/home/ahmad.sultan/.cache/gvfs/smb-share:server=ak-isi01-sh2.prdnas1.osumc.edu,share=dhlri$/labs/CMRCT Lab Team/_ahmad_sultan/_shared_GPU_station/LGE Patient Datasets Preprocessing")
# %ls


## Set study parameters and select dataset:

In [None]:
# parameters:
FS = 0 # retrospective or prospective 
R=2 # if FS=0, 
nt= 32 # choose num of frames to recon. Max = 32.

# a list of all patient file names: ["SJK_SAX_BASE", "SJK_SAX_MOCO", "SJK_3CH", "JBH_3CH", "JGR_3CH", "GRH_3CH", "AT_3CH", "JBH_3CH"]
id="JBH_3CH"  # choose one of the 8 preprocessed patients 

In [None]:
## based on ID above, data file names are selected automatically. You can specify yours file too.
if not FS:
    R=5.069307
# # arbitrary id
# id = "JB_3CH" #'SJK_SAX_MOCO' #SJK_SAX_BASE # SAX # 2CH # 3CH
crop_set = 0.9
if id[0:3]=="SJK":
   subject = "20230125_CS_LGE_SKJ/" #"20230125_CS_LGE_SKJ/"
   if not FS:
      if id=="SJK_SAX_BASE":
         file = "/meas_MID00257_FID117434_SS_TRUFI_CS_PSIR_SAX_BASE_SAX/" #meas_MID00381_FID115208_SS_TRUFI_CS_PSIR_3CH
      elif id=="SJK_SAX_MOCO":
         file = "/meas_MID00258_FID117435_SS_TRUFI_CS_PSIR_SAX_SAX_MOCO_54_6/" #meas_MID00381_FID115208_SS_TRUFI_CS_PSIR_3CH
      elif id=="SJK_3CH":
         file = "/meas_MID00259_FID117436_SS_TRUFI_CS_PSIR_SAX_3CH/" #meas_MID00381_FID115208_SS_TRUFI_CS_PSIR_3CH

   else:
      if id=="SJK_SAX_BASE":
         file = "/meas_MID00260_FID117437_SS_TRUFI_PSIR_SAX_FULL_SAMP_BASE_SAX/" #meas_MID00381_FID115208_SS_TRUFI_CS_PSIR_3CH
      elif id=="SJK_SAX_MOCO":
         file = "/meas_MID00261_FID117438_SS_TRUFI_PSIR_SAX_FULL_SAMP_SAX_MOCO_54_6/" #meas_MID00381_FID115208_SS_TRUFI_CS_PSIR_3CH
      elif id=="SJK_3CH":
         file = "/meas_MID00262_FID117439_SS_TRUFI_PSIR_SAX_FULL_SAMP_3CH/" #meas_MID00381_FID115208_SS_TRUFI_CS_PSIR_3CH

elif id[0:3]=="JBH":
   subject = "20230118_CS_LGE_JB/" #"20230125_CS_LGE_SKJ/"
   if not FS:
      if id=="JBH_3CH":
         crop_set=0.8
         file = "/meas_MID00381_FID115208_SS_TRUFI_CS_PSIR_3CH/" #meas_MID00381_FID115208_SS_TRUFI_CS_PSIR_3CH
   else:
      if id=="JBH_3CH":
         file = "/meas_MID00383_FID115210_SS_TRUFI_PSIR_3CH_FULL_SAMP/" #meas_MID00381_FID115208_SS_TRUFI_CS_PSIR_3CH

elif id[0:3]=="JGR":
   subject = "20230308_CS_LGE_JGR/" #"20230125_CS_LGE_SKJ/"
   if not FS:
      crop_set = 0.6
      R=6.336634
      if id=="JGR_3CH":
         file = "/meas_MID00097_FID134147_SS_TRUFI_CS_PSIR_3CH/" #meas_MID00381_FID115208_SS_TRUFI_CS_PSIR_3CH
      elif id=="JGR_SAX_BASE":
         file = "/meas_MID00096_FID134146_SS_TRUFI_CS_PSIR_BASE_SAX/" #meas_MID00381_FID115208_SS_TRUFI_CS_PSIR_3CH

elif id[0:3]=="GRH":
   subject = "20230118_CS_LGE_GR/" #"20230125_CS_LGE_SKJ/"
   if not FS:
      crop_set=0.8
      R=7.603960
      if id=="GRH_3CH":
         file = "/meas_MID00488_FID115315_SS_TRUFI_CS_PSIR_3CH/" #meas_MID00381_FID115208_SS_TRUFI_CS_PSIR_3CH

elif id[0:2]=="AT":
   subject = "20230816_CS_LGE_AT/" #"20230125_CS_LGE_SKJ/"
   if not FS:
      if id=="AT_3CH":
         file = "/meas_MID00400_FID199583_SS_TRUFI_CS_PSIR_3_CH_SLICE_54_5/" #meas_MID00381_FID115208_SS_TRUFI_CS_PSIR_3CH

folder = ""#"new sen cc maps/" # for new sen cc maps
data_path = "./data/preprocessed/"+subject +folder+file # for new sen cc maps: folder
# data_path0 = data_path+"set1/"
# set=0 # set=0 for T1 (with contrast)
if FS:
    data_path_r = data_path + "R%d/"%R
else:
    data_path_r = data_path
# else: # US case
# crop_set= 0.9 #[0.2,0.3]
set=0
crop=crop_set #[set]
## display
gm = [1, 0.3, 0.7] # gamma correction for display
g=1
mps = [plt.cm.Greys_r, 'gray']
epsilon = 1e-10


## Load preprocessed data

In [None]:
# read US data:
# print(data_path_r+"yu" +"_R_%f"%R+"_set_%d"%set+".npy")
ksp = np.load(data_path_r+"yu" +"_R_%f"%R+"_set_%d"%set+".npy")
print("Data size: ", ksp.shape)
(N,Nc,Nx,Ny) = np.shape(ksp)
n = (Nx, Ny)
RO_offset = int(0.2*Nx) # skip 25 RO rows from each top and bottom of image when calculating metrics to discard inhomogenity
# R=2
# print(data_path_r+"mask_R_%f"%R+".npy")
msk = np.load(data_path_r+"mask_R_%f"%R+".npy")
# print(msk.shape, msk.dtype)
# print("MAPS path: " +data_path+"sen_cc_map"+".npy")
cc_maps_sens = np.load(data_path+"sen_cc_map"+".npy")
# print(cc_maps_sens.shape, cc_maps_sens.dtype)
# print("MAPS path: " +data_path_r+"ESPIRiT_crop_%f"%crop+ '_R_%f'%R+"_set_%d"%set +".npy")
maps = np.load(data_path_r+"ESPIRiT_crop_%f"%crop+ '_R_%f'%R+"_set_%d"%set +".npy")
# print(maps.shape, maps.dtype)
sen_msk = np.load(data_path_r+"th_sen_map_crop_%f"%crop+'_R_%f'%R+"_set_%d"%set+".npy")
# print(sen_msk.shape, sen_msk.dtype)
if FS:
    img = np.load(data_path+"xRef_N_%d"%N+"_set_%d"%set+".npy")
    # print(img.shape, img.dtype)
    # print(np.max(np.abs(img)), np.min(np.abs(img)))
ku = complex_to_real_plus_imag_4d(ksp)
samp = msk
sen = complex_to_real_plus_imag_4d(np.expand_dims(maps,axis=0)) * cc_maps_sens
print("\ndata sizes going to DISCUS")
print(ku.shape) # (2*N, CH, RO, PE)
print(samp.shape) # (N, CH, RO, PE)
print(sen.shape) # (2, CH, RO, PE)
if FS:
    m = complex_to_real_plus_imag_3d(img)
    # print(m.shape) # (2*N, RO, PE)

## CS  

In [None]:
PLOT_CS = 1
show_every_CS = 100


mu = 1e-1 # lagrangian parameter
nIter = np.array([100, 5]) # outer and inner iterations
ss = 0.3 # step size

# tau=0.005


In [None]:
########################################################## CS Recon
# get the start time
# st = time.time()


csRe=1

ar = 0.75 # aspect ratio for plotting xAbs

if csRe==1:
  # Running ADMM L2-L1

  xCS = np.zeros((2*N,n[0],n[1]))
  Sc = ScN
  # xCSAbs = np.zeros((N,n[0],n[1]))
  # if FS:
  #   nmse = np.zeros([N,1])
  #   ssm = np.zeros([N,1])
  #   psn = np.zeros([N,1])
  #   xAbs = np.zeros((N,n[0],n[1]))
  #   errMap = np.zeros((2*N,n[0],n[1]))

  # xCSF = np.zeros((2*N,n[0],n[1]))
  # xCSFAbs = np.zeros((N,n[0],n[1]))
  # # yAbs = np.zeros((N,n[0],n[1]))
  # # errFMap = np.zeros((2*N,n[0],n[1]))


# CS-L1: individual recon (frame by frame); no common support bw frames
  for i in range(N): 
    loss = np.zeros(nIter[0]) # track training loss to see stability of learning

    print('\nSlice: %2d' %(i+1), 'out of %2d' %N)
    x0 = (np.zeros(n)).astype(complex)
#     print(yuN.shape)
#     print(mskN.shape)
    y0 = yuN[2*i,:,:,:] + 1j*yuN[2*i+1,:,:,:]
    msk0 = mskN[i,:,:,:]
    # Sc = ScN[i,:,:,:]
    

    # ADMM
    #xTmp = admm_pmri_l1(x0, y0, msk0, ScN, nIter, ss, mu, tau) ## returns reconstructed x only (no weights i.e., d, b, )
    # x = x0
    # y = y0

    B = 4 # wavelet bands
    u = pAt(y0, msk0, Sc) # starts with ZF-image
    # print(u.shape)
    d = np.zeros((B,x0.shape[0],x0.shape[1])).astype(complex)
    b = np.zeros((B,x0.shape[0],x0.shape[1])).astype(complex)

    # ii=0 # iterations

    for l in range(nIter[0]): # outer iter


      for j in range(nIter[1]): # inner iter

        lossA = pA(u, msk0, Sc) - y0 # Data consistency
        gradA = pAt(lossA, msk0, Sc)

        lossW = swt2_haar(u) - d + b # wavelet reg.
        gradW = mu * iswt2_haar(lossW) 

        # loss[l] = np.sum(np.abs(lossA)**2) + mu*np.sum(np.abs(lossW))

        u = u - ss * (gradA + gradW) # grad. update

        # ii = ii + 1



        # xAbs[i:i+1,:,:] = np.sqrt((xN[i*2,:,:])**2 + (xN[i*2+1,:,:])**2)
        # errMap[i*2:(i+1)*2,:,:] = xN[i*2:(i+1)*2,:,:]-xCS[i*2:(i+1)*2,:,:]

        # nmse[i] =  np.mean((xN[i*2:(i+1)*2,:,:]-xCS[i*2:(i+1)*2,:,:])**2) / np.mean((xN[i*2:(i+1)*2,:,:])**2)
        # ssm[i] = ssim(xCSAbs[i,:,:], xAbs[i,:,:], data_range = xCSAbs[i,:,:].max() - xCSAbs[i,:,:].min()) # xCSL1Abs[i:i+1,:,:].max() - xCSL1Abs[i:i+1,:,:].min()
        
        # print('\nNMSE : %1.2f' %(10*np.log10(nmse[i])))
        # print('SSIM : %1.2f' %(ssm[i]))


      # recon weights learnt: d, b
      d = st(swt2_haar(u) + b, tau/mu) 
      b = b + (swt2_haar(u) - d)
      
      loss[l] = np.sum(np.abs(lossA)**2) + mu*np.sum(np.abs(lossW))
    
    
      if PLOT_CS:


        if  (l+1) % show_every_CS ==0:
          print('Iteration: %2d ' %(l+1))

          xHat_it = np.zeros((2,n[0],n[1]))
          if FS:
            xHat_itAbs = np.zeros((1, n[0],n[1]))
            x_itAbs = np.zeros((1, n[0],n[1]))

                # recon image: u
          xHat_it[0,:,:]  = np.real(u)
          xHat_it[1,:,:] = np.imag(u)
          xHat_itAbs = np.sqrt((xHat_it[0,:,:])**2 + (xHat_it[1,:,:])**2)
          if FS:
            x_itAbs = np.sqrt((xN[i*2,:,:])**2 + (xN[i*2+1,:,:])**2)
            nmse_it =  np.mean((xN[i*2:(i+1)*2,:,:]-xHat_it)**2) / np.mean((xN[i*2:(i+1)*2,:,:])**2)
            ssm_it = ssim(xHat_itAbs, x_itAbs, data_range = xHat_itAbs.max() - xHat_itAbs.min()) # xCSL1Abs[i:i+1,:,:].max() - xCSL1Abs[i:i+1,:,:].min()
            psn_it = psnr(x_itAbs, xHat_itAbs, data_range = xHat_itAbs.max() - xHat_itAbs.min()) # xCSL1Abs[i:i+1,:,:].max() - xCSL1Abs[i:i+1,:,:].min()
            print('\nNMSE : %1.2f' %(10*np.log10(nmse_it)))
            print('SSIM : %1.2f' %(ssm_it))
            print('PSNR : %1.2f' %(psn_it))
            fig = plt.figure(figsize=(6,5),facecolor='white', edgecolor=None)
            plt.imshow(np.reshape(np.transpose(np.concatenate((np.expand_dims(xHat_itAbs**g, axis=0), np.expand_dims(x_itAbs**g, axis=0)), axis=2),[1,0,2]), [n[0],n[1]*2]), cmap=plt.cm.Greys_r, aspect=(n[1]/n[0])/ar) # use a specific color map
            plt.title("xHat_itAbs, x_itAbs")
            plt.show()
          else:
            fig = plt.figure(figsize=(6,5),facecolor='white', edgecolor=None)
            plt.imshow(xHat_itAbs**g, cmap=plt.cm.Greys_r, aspect=(n[1]/n[0])/ar) # use a specific color map
            plt.title("xHat_itAbs")
            plt.show()    

        if i==N-1 and l==nIter[0]-1: # last outer iter.
          fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
          plt.plot(np.log10(loss))
          plt.xlabel("No. of iterations")
          plt.ylabel("Total loss")
          plt.show()


    #return u # recon image
    # xTmp = u

  

    xCS[2*i,:,:]   = np.real(u)
    xCS[2*i+1,:,:] = np.imag(u)
    # xCSAbs[i:i+1,:,:] = np.sqrt((xCS[i*2,:,:])**2 + (xCS[i*2+1,:,:])**2)


    
    
    # fTmp = np.fft.fftshift(np.fft.fft2(np.fft.ifftshift(xTmp, axes=(-2, -1)), axes=(-2, -1), norm='ortho'), axes=(-2, -1))
    # xCSF[2*i,:,:]   = np.real(fTmp)
    # xCSF[2*i+1,:,:] = np.imag(fTmp)
    # xCSFAbs[i:i+1,:,:] = np.sqrt((xCSF[i*2,:,:])**2 + (xCSF[i*2+1,:,:])**2)
    # yAbs[i:i+1,:,:] = np.sqrt((yN[i*2,:,:])**2 + (yN[i*2+1,:,:])**2)
    # errFMap[i*2:(i+1)*2,:,:] = yN[i*2:(i+1)*2,:,:]-xCSF[i*2:(i+1)*2,:,:]

  xCSAbs = takeMag(xCS)

  if FS:
    xAbs = takeMag(xN)
    errMap = xN-xCS
    nmse =  calc_nmse(real_plus_imag_to_complex_3d(xCS), real_plus_imag_to_complex_3d(xN))
    # ssm[i] = ssim(xCSAbs[i,:,:], xAbs[i,:,:], data_range = xCSAbs[i,:,:].max() - xCSAbs[i,:,:].min()) # xCSL1Abs[i:i+1,:,:].max() - xCSL1Abs[i:i+1,:,:].min()
    # psn[i] = psnr( xAbs[i,:,:], xCSAbs[i,:,:], data_range = xCSAbs[i,:,:].max() - xCSAbs[i,:,:].min()) # xCSL1Abs[i:i+1,:,:].max() - xCSL1Abs[i:i+1,:,:].min()
    # print('\nNMSE : %1.2f' %(10*np.log10(nmse[i])))
    # print('SSIM : %1.2f' %(ssm[i]))
    # print('PSNR : %1.2f' %(psn[i]))

    
  fig = plt.figure(figsize=(16,8),facecolor='white', edgecolor=None)
  #plt.imshow(np.reshape(np.transpose(xCSAbs**0.9,[1,0,2]), [n[0],n[1]*N]), vmin=0, vmax=0.5*np.max(xCSAbs**0.9), cmap=plt.cm.Greys_r, aspect=(n[1]/n[0])/ar) # use a specific color map
  plt.imshow(np.reshape(np.transpose(xCSAbs**g,[1,0,2]), [n[0],n[1]*N]), cmap=plt.cm.Greys_r) # use a specific color map
  plt.title("xCSAbs")
  plt.show()
    
  if FS:
    print('\nMean nmse: %1.2f,' %(10*np.log10(np.mean(nmse))), 'nmse: ',', '.join('%1.2f' % (10*np.log10(nmse[j])) for j in range(len(nmse)))) 
    # print('Mean ssim: %1.3f,' %(np.mean(ssm)), 'ssim: ',', '.join('%1.3f' % (ssm[j]) for j in range(len(ssm)))) 
    # print('Mean psnr: %1.3f,' %(np.mean(psn)), 'psnr: ',', '.join('%1.3f' % (psn[j]) for j in range(len(psn)))) 
   
    fig = plt.figure(figsize=(16,8),facecolor='white', edgecolor=None)
    plt.imshow(np.reshape(np.transpose(xAbs**g,[1,0,2]), [n[0],n[1]*N]), cmap=mps[0]) # use a specific color map
    plt.title("xAbs")
    plt.show()
  #   fig = plt.figure(figsize=(16,8),facecolor='white', edgecolor=None)
  #   plt.imshow(np.reshape(np.transpose(np.concatenate((xAbs[0:1,:,:]**0.99, xCSAbs[0:1,:,:]**0.99), axis=2),[1,0,2]), [n[0],n[1]*2]), cmap=plt.cm.Greys_r) # use a specific color map
  #   plt.title("xAbs, xCSAbs")  
  #   plt.show()

      
    fig = plt.figure(figsize=(16,8),facecolor='white', edgecolor=None)
    plt.imshow(np.reshape(np.transpose(takeMag(errMap),[1,0,2]), [n[0],n[1]*N]),vmin=0, vmax=(np.max(errMap)/5), cmap=plt.cm.Greys_r) # use a specific color map
    plt.title("errxMap")
    plt.show()

  if FS:
    ### NMSE on ROI:
      print("NMSE on ROI: ")
      RO_offset = int(0.2*Nx)
      print(RO_offset)
      nmse =  calc_nmse(real_plus_imag_to_complex_3d(xCS[:,RO_offset:-RO_offset,:]), real_plus_imag_to_complex_3d(xN[:,RO_offset:-RO_offset,:]))
      print('\nMean nmse: %1.2f,' %(10*np.log10(np.mean(nmse))), 'nmse: ',', '.join('%1.2f' % (10*np.log10(nmse[j])) for j in range(len(nmse)))) 


In [None]:

sv_CS=1
CSresults_path=data_path_r
string = "sen_cc_"
if sv_CS==1: np.save(CSresults_path + string+'xHatCSL1_'+'N_%d' % N + '_R_%f'%R +"_set_%d"%set+"_tau_%f"%tau+'.npy', xCS)
if FS:
    if sv_CS==1: np.save(CSresults_path + string+ 'xRefCSL1_'+'N_%d' % N + '_R_%f'%R+"_set_%d"%set+'.npy', xN)
    # if sv_CS==1: np.save(CSresults_path + string+ 'nmseCSL1_'+'N_%d' % N + '_R_%f'%R+"_set_%d"%set +'.npy', nmse)
    # if sv_CS==1: np.save(CSresults_path + 'ssimCSL1_'+'N_%d' % N + '_R_%f'%R +'.npy', ssm)
    # if sv_CS==1: np.save(CSresults_path + 'psnrCSL1_'+'N_%d' % N + '_R_%f'%R +'.npy', psn)



## DISCUS training

In [None]:
with torch.no_grad():
    torch.cuda.empty_cache()

In [None]:
## DISCUS Archit. Parameters:

opt = 2 # select the flavor of the algorithm
LSz = 128 # Number of channels in hidden layers
NLy = 6 # number of layers
Nm = 2**NLy # minimum matrix size for UNet # depends on giveb data matrix size
# code accepts both even and odd sized matrix sizes

num_iter = 12000#120#12000 # number of iterations
show_every = 1000

## tuning parameters...
# code vectors init
reg_sig0 = 0.01 #0.03 # noise regularization 
#code vector sparsity
z_scale = 1

# adam optimization
WtD = 1*1e-6 # weight decay 
p=0 #0.01 # keep drop_out

sv=0
cf=30


In [None]:
# for DISCUS
## Final pre-processed data:
N=nt

if FS:
    xN  = m[0:2*N, :,:]*sen_msk # skip 1st frame # ref
    print(xN.shape) # (2*N, RO, PE)

In [None]:
# get the start time
st = time.time()

########################################################## Image selection
# for loop size
if opt==0:
  L = N
else:
  L = 1
  NInd = 0 # If '0' don't select an individual image

for el in range(L):
  if opt==0:
    NInd = el+1 # pick one image from N images
    Ns   = 4 # Number of channels common to all images
    Nz   = 0 # Number of image specific channels
    Nout = 1 # Number of outputs, 1 or N (2*Nout = output channels) 
    Nin = Nz+Ns
  elif opt==1: # Fix noise stack in with image specific output channels
    Ns   = 4 # Number of channels common to all images
    Nz   = 0 # Number of image specific channels
    Nout = N # Number of outputs, 1 or N (2*Nout = output channels)
    Nin = Nz+Ns
  elif opt==2: # [Ns; Nz] in with a single output channel
    Ns   = 3 # Number of channels common to all images #<--3
    Nz   = 1 # Number of image specific channels #<--1
    Nout = 1 # Number of outputs, 1 or N (2*Nout = output channels)
    Nin = Nz+Ns
  elif opt==3: # [Ns; Nz] in with image specific output channels
    Ns   = 3 # Number of channels common to all images
    Nz   = 1 # Number of image specific channels
    Nout = N # Number of outputs, 1 or N (2*Nout = output channels)
    Nin = Nz+Ns
  elif opt==4: # [Ns + Nz] in with a single output channel
    Ns   = 1 # Number of channels common to all images 
    Nz   = Ns # Number of image specific channels
    Nout = 1 # Number of outputs, 1 or N (2*Nout = output channels)
    Nin = Nz
  elif opt==5: # H[Nz] in with image-specific first layer and single output channel 
    Ns   = 0 # Number of channels common to all images
    Nz   = 2 # Number of image specific channels
    Nout = 1 # Number of outputs, 1 or N (2*Nout = output channels)
    Nin = Nz*N
  elif opt==6: # H[Nz] in with image-specific first layer and image specific output channels
    Ns   = 0 # Number of channels common to all images
    Nz   = 2 # Number of image specific channels
    Nout = N # Number of outputs, 1 or N (2*Nout = output channels)
    Nin = Nz*N


  if NInd==0: # process all images
    if FS:
      x  = xN
    # y  = yN
    #     yn = ynN
    yu = yuN
    msk= mskN
    S = SN
  elif NInd > 0: # Process only one image
    N  = 1
    x  = xN[(NInd-1)*2:(NInd-1)*2+2]
    y  = yN[(NInd-1)*2:(NInd-1)*2+2]
    yn = ynN[(NInd-1)*2:(NInd-1)*2+2]
    yu = yuN[(NInd-1)*2:(NInd-1)*2+2]
    msk= mskN[(NInd-1):NInd]



  ########################################################## Network setup
  pad = 'reflection' # 'zero'
  mse = torch.nn.MSELoss().type(dtype)
  mae = torch.nn.L1Loss().type(dtype)

  INPUT = 'noise' # 'noise', 'meshgrid', or 'hybrid'

  # ra: note, setting num_channels_skip[0] = 0 may improve performance
  if opt==5 or opt==6: 
    net = skip(Nin, 2*Nout, # skip|skip_depth6|skip_depth4|skip_depth2|UNET|ResNet
    num_channels_down = [2]+[LSz]*(NLy-1), #[2, 128, 128, 128, 128, 128]
    # num_channels_down[0] = 2,
    num_channels_up =   [LSz]*NLy, #[128, 128, 128, 128, 128, 128]
    num_channels_skip = [LSz]*NLy, #[128, 128, 128, 128, 128, 128] 
    filter_size_up   = [3]*NLy, #[3, 3, 3, 3, 3, 3], 
    filter_size_down = [3]*NLy, #[3, 3, 3, 3, 3, 3],
    filter_skip_size = 1, # kernel size for the filters along the skip connections
    upsample_mode='nearest', 
    output_act = 0, # 0 for none, 1 for sigmoid, 2 for tanh
    need_bias=True, 
    pad=pad, 
    act_fun='LeakyReLU').type(dtype) # need_sigmoid forces the out to between 0 and 1
  else:
    # RA: The number of layers = 2^layers < image size
    # Five layers for sl64 and six for other applications
    net = skip(Nin, 2*Nout, # skip|skip_depth6|skip_depth4|skip_depth2|UNET|ResNet
    num_channels_down = [LSz]*NLy, #[128, 128, 128, 128, 128]
    num_channels_up =   [LSz]*NLy, #[128, 128, 128, 128, 128]
    num_channels_skip =  [LSz]*NLy, # [LSz]*NLy, #[128, 128, 128, 128, 128]  
    filter_size_up   = [3]*NLy, #[3, 3, 3, 3, 3], 
    filter_size_down = [3]*NLy, #[3, 3, 3, 3, 3],
    filter_skip_size = 1, # kernel size for the filters along the skip connections
    upsample_mode='nearest', 
    output_act = 0, # 0 for none, 1 for sigmoid, 2 for tanh
    need_bias=True, 
    pad=pad, 
    act_fun='LeakyReLU', dropout=p).type(dtype) # need_sigmoid forces the out to between 0 and 1


  s  = sum(np.prod(list(p.size())) for p in net.parameters())
  print ('Number of params: %d' % s)


  ########################################################## Setup network inputs
  # generate network input
  if FS:
    x_tor   = np_to_torch(x).type(dtype)
  yu_tor  = np_to_torch(yu).type(dtype)
#   yn_tor  = np_to_torch(yn).type(dtype)
  msk_tor = np_to_torch(msk).type(dtype)
  # x_avg_tor = np_to_torch(x_avg).type(dtype)
  S_tor = np_to_torch(S).type(dtype)

#   print(yu_tor.size())
#   print(msk_tor.size())
#   print(S_tor.size())

  net = net.type(dtype)
  if opt==5 or opt==6:
    # net = net.type(dtype)
    z0 = torch.zeros(1,Nz,n[0],n[1]).type(dtype) # all zeros
    z = get_noise(Nz, INPUT, x.shape[1:], var=1/10).type(dtype) - 1/20
    z_saved = torch.clone(z)   # ra: use clone so that z and z_saved don't point to the same location
    z0_saved = torch.clone(z0)
    # zs = get_noise(Nin-Nz, INPUT, x.shape[1:], var=1./10).type(dtype) - 1/20
    # zs_saved = torch.clone(zs)  # ra: use clone so that zs and zs_saved don't point to the same location

  elif opt==4:
    z = torch.zeros(1,Nz*N,n[0],n[1]).type(dtype) # all zeros
    zs = get_noise(Ns, INPUT, x.shape[1:], var=1./10).type(dtype) - 1/20
    z_saved  = torch.clone(z)
    zs_saved = torch.clone(zs)
    z0 = torch.zeros(1,Nz*N,n[0],n[1]).type(dtype) # all zeros
    z0_saved = torch.clone(z0)

  else: 
    zs = get_noise(Nin-Nz, INPUT, yu.shape[2:], var=1./10).type(dtype) - 1/20
    zs_saved = torch.clone(zs)  # ra: use clone so that zs and zs_saved don't point to the same location
    
    z0 = get_noise(Nz, INPUT, yu.shape[2:], var=1./10).type(dtype) - 1/20
    z = torch.zeros(1,Nz*N,n[0],n[1]).type(dtype) # all zeros
    for i in range(N):
      z[:,i*Nz:(i+1)*Nz,:,:] = z_scale * (0.8*z0 + 0.2*(get_noise(Nz, INPUT, yu.shape[2:], var=1./10).type(dtype) - 1/20))
    z0 = torch.tile(z0,[1,N,1,1])
    z_saved  = torch.clone(z)   # ra: use clone so that z and z_saved don't point to the same location
    z0_saved = torch.clone(z0)

  # print(z.shape)

  ### tuned DISCUS hyperparameters:
  if R==2:
    LR = 1.7e-4 #1.5e-4 #1.5 4, 1.5 1 # 1 4, 1 1 # 7 5, 7 0 # 5 5, 7 0
    z_lamb0 = 1.5e1
  elif R==3:
    LR = 1e-4 #1.5 4, 1.5 1 # 1 4, 1 1 # 7 5, 7 0 # 5 5, 7 0
    z_lamb0 = 1e1
  elif R==4:
    LR = 7e-5 #1.5 4, 1.5 1 # 1 4, 1 1 # 7 5, 7 0 # 5 5, 7 0
    z_lamb0 = 7e0
    if nt==16:
      LR = 5e-5 #1.5 4, 1.5 1 # 1 4, 1 1 # 7 5, 7 0 # 5 5, 7 0
      z_lamb0 = 5e0
  else: # both retrospective R=5 and prospective R=5.069307
    if nt==32:
        LR = 1e-3#1e-4 #1.5 4, 1.5 1 # 1 4, 1 1 # 7 5, 7 0 # 5 5, 7 0
        z_lamb0 = 1e1
    elif nt==16:
        LR = 5e-4#1e-4 #1.5 4, 1.5 1 # 1 4, 1 1 # 7 5, 7 0 # 5 5, 7 0
        z_lamb0 = 7e0
    else:
        LR = 1e-4#1e-4 #1.5 4, 1.5 1 # 1 4, 1 1 # 7 5, 7 0 # 5 5, 7 0
        z_lamb0 = 4e0



  ########################################################## Main iterations
  # from torch.nn.modules.loss import L1Loss
  ii = 0
  def closure():
      global ii
      # global running_loss
      # global z_lamb
      z_lamb = z_lamb0 #*(1 + 99 * ii/num_iter)
      losses = torch.empty([N,1]).type(dtype)
      xHat_tor = torch.empty(N,1,2*Nout, 1, n[0], n[1]).type(dtype)
      reg_sig = reg_sig0*(1 - 0.9 * ii/num_iter)
#       if ii<=num_iter/2:
#         reg_sig = reg_sig0*0.9
#       else:
#         reg_sig = 0  
     # reg_sig = reg_sig0 * (1 - ii/num_iter)
      for i in range(N):
        if opt==5 or opt==6:
          xHat_tor[i,:,:,:,:] = net(torch.cat((torch.randn(1,Nz*i,n[0],n[1]).type(dtype) * reg_sig/(N**0.5), z + (torch.randn(1,Nz,n[0],n[1]).type(dtype) * reg_sig), torch.randn(1,N*Nz-Nz*(i+1),n[0],n[1]).type(dtype)* reg_sig/(N**0.5)), 1))
        elif opt==4:
          xHat_tor[i,:,:,:,:] = net(zs + z[:,i*Nz:(i+1)*Nz,:,:] + (torch.randn(1,Nz,n[0],n[1]).type(dtype) * reg_sig))
        else: 
            ## opt = 2
          xHat_tor[i,:,:,:,:,:] = torch.unsqueeze(net(torch.cat((zs + (torch.randn(1,Nin-Nz,n[0],n[1]).type(dtype) * reg_sig), z[:,i*Nz:(i+1)*Nz,:,:] + (torch.randn(1,Nz,n[0],n[1]).type(dtype) * reg_sig)), 1)), -3) #<-- change *1 to *4 AND <--zs to zs_saved
          #xHat_tor[i,:,:,:,:,:] = torch.unsqueeze(net(torch.cat((zs, z[:,i*Nz:(i+1)*Nz,:,:] + (torch.randn(1,Nz,n[0],n[1]).type(dtype) * reg_sig)), 1)), -3) #<-- change *1 to *4 AND <--zs to zs_saved
            # print(xHat_tor[i,:,:,:,:,:].size())
          # xH = np.squeeze(torch_to_np(xHat_tor[i,:,:,:,:,:])) # (2, 160, 96)
          # xHAbs = np.sqrt((xH[0,:,:])**2 + (xH[1,:,:])**2)
          # plt.imshow(xHAbs, cmap=plt.cm.Greys_r)
          # plt.show()
          #print(xHat_tor.size())
        if Nout==N:
          xHatF_tor = fft2c_ra(xHat_tor[i,:,i*2:(i+1)*2,:,:], 'ortho')
        elif Nout<N:
            ## this case
          xHatF_tor = fft2c_pra(multc(xHat_tor[i,:,:,:,:,:], S_tor), 'ortho')
          # print(S_tor.size())
          # print(xHatF_tor.size())
          # xHF = torch_to_np(xHatF_tor) # (2, 8, 160, 96)
          # xHFAbs = np.sqrt((xHF[0,:,:,:])**2 + (xHF[1,:,:,:])**2)
          # fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
          # plt.imshow(np.reshape(np.transpose(xHFAbs,[1,0,2]), [n[0],n[1]*Nc]), cmap=plt.cm.Greys_r)
          # plt.show()
#         print(xHat_tor.size())
#         print(xHatF_tor.size()) 
        losses[i] = mse(xHatF_tor*msk_tor[:,i:i+1,:,:,:],  yu_tor[:,i*2:(i+1)*2,:,:,:])
        # print(msk_tor[:,i:i+1,:,:,:].size())
        # print(yu_tor[:,i*2:(i+1)*2,:,:,:].size())
        #losses[i] = mse(xHatF_tor,  yu_tor[:,i*2:(i+1)*2,:,:,:])
        # yund = torch_to_np(yu_tor[:,i*2:(i+1)*2,:,:,:]) # (2, 8, 160, 96)
        # yundAbs = np.sqrt((yund[0,:,:,:])**2 + (yund[1,:,:,:])**2)
        # fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
        # plt.imshow(np.reshape(np.transpose(yundAbs,[1,0,2]), [n[0],n[1]*Nc]), cmap=plt.cm.Greys_r)
        # plt.show()
      if opt==2 or opt==3 or opt==4:
        z_loss = z_lamb * torch.mean(torch.sqrt(torch.mean(torch.abs(z)**2, axis=1)+1e-6)) # img spec.
        # z_loss = z_lamb*mae(z, z0_saved) # z_loss = mse(z,z0_saved) #<--default
        # z_loss = z_lamb*mae(z, torch.tile(z[:,0:Nz,:,:],[1,N,1,1]))
        # z_loss = z_lamb*mae(z, torch.tile(z[:,np.remainder(ii,N)*Nz:(np.remainder(ii,N)+1)*Nz,:,:],[1,N,1,1]))
      else:
        z_loss = 0
      total_loss = sum(losses) + z_loss
      total_loss.backward()
      running_loss[0,ii] = sum(losses) + z_loss/z_lamb
  
      if FS:
    ###
        psn = np.zeros([N,1])

        xHat = np.zeros((2*N,n[0],n[1]))
        xHatS = np.zeros((2*N,n[0],n[1]))
        xHatAbs = np.zeros((N,n[0],n[1]))
        xAbs = np.zeros((N,n[0],n[1]))
          
        for i in range(N):
          xHat[i*2:(i+1)*2,:,:] = np.squeeze(torch_to_np(xHat_tor[i,:,:,:,:,:]))
          xHatS[i*2:(i+1)*2,:,:] = xHat[i*2:(i+1)*2,:,:] * sen_msk
          xHatAbs[i:i+1,:,:] = np.sqrt((xHatS[i*2,:,:])**2 + (xHatS[i*2+1,:,:])**2)
          xAbs[i:i+1,:,:] = np.sqrt((x[i*2,:,:])**2 + (x[i*2+1,:,:])**2)
          psn[i] = psnr(xAbs[i,RO_offset:-RO_offset,:], xHatAbs[i,RO_offset:-RO_offset,:], data_range = xHatAbs[i,RO_offset:-RO_offset,:].max() - xHatAbs[i,RO_offset:-RO_offset,:].min())
          
        running_psn[0,ii] = np.mean(psn)
    ###
      # print ('Iteration %05d    Loss %f' % (i, total_loss.item()), '\r', end='')
      if  PLOT and ( ii % show_every == 0):
        if FS:
          nmse = np.zeros([N,1])
          ssm = np.zeros([N,1])
          psn = np.zeros([N,1])
          xAbs = np.zeros((N,n[0],n[1]))
          errMap = np.zeros((2*N,n[0],n[1]))
        
        xHat = np.zeros((2*N,n[0],n[1]))
        xHatS = np.zeros((2*N,n[0],n[1]))
        xHatAbs = np.zeros((N,n[0],n[1]))

        # xHatF = np.zeros((2*N,n[0],n[1]))
        # xHatFAbs = np.zeros((N,n[0],n[1]))
        # yAbs = np.zeros((N,n[0],n[1]))
        # errFMap = np.zeros((N,n[0],n[1]))
        for i in range(N):
          if Nout==N:
            xHat[i*2:(i+1)*2,:,:] = torch_to_np(xHat_tor[i, :, i*2:(i+1)*2,:,:,:])
            xHatAbs[i:i+1,:,:] = np.sqrt((xHat[i*2,:,:])**2 + (xHat[i*2+1,:,:])**2)
#             xAbs[i:i+1,:,:] = np.sqrt((x[i*2,:,:])**2 + (x[i*2+1,:,:])**2)
          elif Nout<N:
            # this case
            xHat[i*2:(i+1)*2,:,:] = np.squeeze(torch_to_np(xHat_tor[i,:,:,:,:,:]))
            xHatS[i*2:(i+1)*2,:,:] = xHat[i*2:(i+1)*2,:,:] * sen_msk
            xHatAbs[i:i+1,:,:] = np.sqrt((xHatS[i*2,:,:])**2 + (xHatS[i*2+1,:,:])**2)
            if FS:
              xAbs[i:i+1,:,:] = np.sqrt((x[i*2,:,:])**2 + (x[i*2+1,:,:])**2)
            # xHatF[i*2:(i+1)*2,:,:] = np.squeeze(torch_to_np(fft2c_pra(np_to_torch(np.expand_dims(xHatS[i*2:(i+1)*2,:,:], axis=1)),'ortho')))
            # xHatFAbs[i:i+1,:,:] = np.sqrt((xHatF[i*2,:,:])**2 + (xHatF[i*2+1,:,:])**2)
            # yAbs[i:i+1,:,:] = np.sqrt((y[i*2,:,:])**2 + (y[i*2+1,:,:])**2)
          if FS:
            nmse[i] =  np.mean((x[i*2:(i+1)*2,RO_offset:-RO_offset,:]-xHatS[i*2:(i+1)*2,RO_offset:-RO_offset,:])**2) / np.mean((x[i*2:(i+1)*2,RO_offset:-RO_offset,:])**2)
            ssm[i] = ssim(xHatAbs[i,RO_offset:-RO_offset,:], xAbs[i,RO_offset:-RO_offset,:], data_range = xHatAbs[i,RO_offset:-RO_offset,:].max() - xHatAbs[i,RO_offset:-RO_offset,:].min()) # xHatL1Abs[i:i+1,:,:].max() - xHatL1Abs[i:i+1,:,:].min() 
            psn[i] = psnr(xAbs[i,RO_offset:-RO_offset,:], xHatAbs[i,RO_offset:-RO_offset,:], data_range = xHatAbs[i,RO_offset:-RO_offset,:].max() - xHatAbs[i,RO_offset:-RO_offset,:].min())
            errMap[i*2:(i+1)*2,:,:] = x[i*2:(i+1)*2,:,:]-xHatS[i*2:(i+1)*2,:,:]
            # errFMap[i,:,:] = np.sqrt(np.sum(np.abs(y[i*2:(i+1)*2,:,:]-xHatF[i*2:(i+1)*2,:,:])**2, axis=0))
        ## Printing and Plotting begins here...
        print('\nIteration: %1.3d,' %(ii+1), 'Loss x 1e4: %1.2f,' %(running_loss[0,ii]*1e4))
        print('Individual losses x 1e4: ',', '.join('%1.3f' % (losses[j]*1e4) for j in range(len(losses))))

        fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
        plt.imshow(np.reshape(np.transpose(xHat**g,[1,0,2]), [n[0],n[1]*2*N]), cmap=mps[0]) # use a specific color map
        plt.title("xHat")
        plt.show()

        if FS:
          fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
          plt.imshow(np.reshape(np.transpose(errMap,[1,0,2]), [n[0],n[1]*2*N]), vmin=-0.1, vmax=0.1, cmap=plt.cm.Greys_r) # use a specific color map
          plt.title("errxMap")
          plt.show()
          print('Mean nmse: %1.2f,' %(10*np.log10(np.mean(nmse))), 'nmse: ',', '.join('%1.2f' % (10*np.log10(nmse[j])) for j in range(len(nmse)))) 
          print('Mean ssim: %1.3f,' %(np.mean(ssm)), 'ssim: ',', '.join('%1.3f' % (ssm[j]) for j in range(len(ssm)))) 
          print('Mean psnr: %1.3f,' %(np.mean(psn)), 'psnr: ',', '.join('%1.3f' % (psn[j]) for j in range(len(psn)))) 

        # fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
        # plt.imshow(np.reshape(np.transpose(xHatFAbs**gm[1],[1,0,2]), [n[0],n[1]*N]), cmap=mps[1]) # use a specific color map
        # plt.title("xHatFAbs")
        # plt.show()
          
        # fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
        # plt.imshow(np.reshape(np.transpose(yAbs**gm[1],[1,0,2]), [n[0],n[1]*N]), cmap=mps[1]) # use a specific color map
        # plt.title("yAbs")
        # plt.show()
          
        # fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
        # plt.imshow(np.reshape(np.transpose(errFMap**gm[1],[1,0,2]), [n[0],n[1]*N]), cmap=plt.cm.Greys_r) # use a specific color map
        # plt.title("errFMap")
        # plt.show() 

        fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
        if Nz == 0:
          plt.imshow(np.reshape(np.transpose(np.concatenate((xAbs[0:1,:,:]**gm[0], xHatAbs[0:1,:,:]**gm[0]), axis=2),[1,0,2]), [n[0],n[1]*2]), vmin=0, vmax=0.7, cmap=plt.cm.Greys_r) # use a specific color map
          plt.show() 
            
        else: # opt=2
          zNp = torch_to_np(z)
          if FS:
            plt.imshow(np.reshape(np.transpose(np.concatenate((xAbs[0:1,:,:]**gm[0], xHatAbs[0:1,:,:]**gm[0], 100*zNp[0:1,:,:]), axis=2),[1,0,2]), [n[0],n[1]*3]), cmap=plt.cm.Greys_r) # use a specific color map
            plt.title("xAbs, xHatAbs, z (i=1)")
            plt.show()
          else:
            plt.imshow(np.reshape(np.transpose(np.concatenate((xHatAbs[0:1,:,:]**g, 10*zNp[0:1,:,:]), axis=2),[1,0,2]), [n[0],n[1]*2]), cmap=plt.cm.Greys_r) # use a specific color map
            plt.title("xHatAbs, z (i=1)")
            plt.show()
      ii += 1
      return total_loss

  running_loss = torch.empty([1,num_iter]).type(dtype)
  if FS:
    running_psn = torch.empty([1,num_iter]).type(dtype)

  print('=====optimizing over network and input=====')
  OPT_OVER = 'net,input'
  # z_lamb0 = 8*2e0 #<-- default: 5e1 for mse, 2e0 for mae, and 2e0 for group sparsity
  # WtD = 0*1e-6 # weight decay
  if Nz != 0 and Ns != 0:
    p = get_params(OPT_OVER, net, ([z,zs])) #<-- ra: remove zs
  if Nz != 0 and Ns == 0:
    p = get_params(OPT_OVER, net, ([z])) #<-- ra: remove zs
  if Ns != 0 and Nz == 0:
    p = get_params(OPT_OVER, net, ([zs])) #<-- ra: remove zs
  optimize('adam', p, closure, LR, num_iter, WtD)

  fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
  plt.plot(np.log10(torch_to_np(running_loss)))
  plt.xlabel("No. of iterations")
  plt.ylabel("Total loss")
  plt.show()
    
  if FS:
    fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
    plt.plot(torch_to_np(running_psn))
    plt.xlabel("No. of iterations")
    plt.ylabel("PSNR")
    plt.show()
    print("Max PSNR: ", np.max(torch_to_np(running_psn)), "at Iteration: ", np.argmax(torch_to_np(running_psn)), "/", num_iter)


  ########################################################## Display and saving
  xHat_tor = torch.empty(N,1,2*Nout, 1, n[0], n[1]).type(dtype)
  for i in range(N):
    if opt==5 or opt==6:
      xHat_tor[i,:,:,:,:] = net(torch.cat((torch.zeros(1,Nz*i,n[0],n[1]).type(dtype), z , torch.zeros(1,N*Nz-Nz*(i+1),n[0],n[1]).type(dtype)), 1))
    elif opt==4:
      xHat_tor[i,:,:,:,:] = net(1*zs + 1*z[:,i*Nz:(i+1)*Nz,:,:])
    else: 
        # this case
      xHat_tor[i,:,:,:,:,:] = torch.unsqueeze(net(torch.cat((1*zs, 1*z[:,i*Nz:(i+1)*Nz,:,:]), 1)), -3)

  if FS:
    nmse = np.zeros([N,1])
    ssm = np.zeros([N,1])
    psn = np.zeros([N,1])
    xAbs = np.zeros((N,n[0],n[1]))
    errMap = np.zeros((N,n[0],n[1]))

  xHat = np.zeros((2*N,n[0],n[1]))
  xHatS = np.zeros((2*N,n[0],n[1]))
  xHatAbs = np.zeros((N,n[0],n[1]))
  xHatF = np.zeros((2*N,n[0],n[1]))
  xHatFAbs = np.zeros((N,n[0],n[1]))
  # yAbs = np.zeros((N,n[0],n[1]))
  # errFMap = np.zeros((N,n[0],n[1]))

  for i in range(N):
    if Nout==N:
      xHat[i*2:(i+1)*2,:,:] = torch_to_np(xHat_tor[i, :, i*2:(i+1)*2,:,:,:])
      xHatAbs[i:i+1,:,:] = np.sqrt((xHat[i*2,:,:])**2 + (xHat[i*2+1,:,:])**2)
      # xHatF[i*2:(i+1)*2,:,:] = torch_to_np(fft2c_ra(xHat_tor[i, :, i*2:(i+1)*2,:,:,:],'ortho'))
      # xHatFAbs[i:i+1,:,:] = np.sqrt((xHatF[i*2,:,:])**2 + (xHatF[i*2+1,:,:])**2)
      xAbs[i:i+1,:,:] = np.sqrt((x[i*2,:,:])**2 + (x[i*2+1,:,:])**2)
    elif Nout<N:
        ## this case
      xHat[i*2:(i+1)*2,:,:] = np.squeeze(torch_to_np(xHat_tor[i,:,:,:,:,:]))
      xHatS[i*2:(i+1)*2,:,:] = xHat[i*2:(i+1)*2,:,:] * sen_msk
      xHatAbs[i:i+1,:,:] = np.sqrt((xHatS[i*2,:,:])**2 + (xHatS[i*2+1,:,:])**2)
    
      # xHatF[i*2:(i+1)*2,:,:] = np.squeeze(torch_to_np(fft2c_pra(np_to_torch(np.expand_dims(xHatS[i*2:(i+1)*2,:,:], axis=1)),'ortho')))
      # xHatFAbs[i:i+1,:,:] = np.sqrt((xHatF[i*2,:,:])**2 + (xHatF[i*2+1,:,:])**2)
      if FS:
        xAbs[i:i+1,:,:] = np.sqrt((x[i*2,:,:])**2 + (x[i*2+1,:,:])**2)
      # yAbs[i:i+1,:,:] = np.sqrt((y[i*2,:,:])**2 + (y[i*2+1,:,:])**2)

    if FS:
      nmse[i] =  np.mean((x[i*2:(i+1)*2,RO_offset:-RO_offset,:]-xHatS[i*2:(i+1)*2,RO_offset:-RO_offset,:])**2) / np.mean((x[i*2:(i+1)*2,RO_offset:-RO_offset,:])**2)
      ssm[i] = ssim(xHatAbs[i,RO_offset:-RO_offset,:], xAbs[i,RO_offset:-RO_offset,:], data_range = xHatAbs[i,RO_offset:-RO_offset,:].max() - xHatAbs[i,RO_offset:-RO_offset,:].min()) # xHatL1Abs[i:i+1,:,:].max() - xHatL1Abs[i:i+1,:,:].min()      
      psn[i] = psnr(xAbs[i,RO_offset:-RO_offset,:], xHatAbs[i,RO_offset:-RO_offset,:], data_range = xHatAbs[i,RO_offset:-RO_offset,:].max() - xHatAbs[i,RO_offset:-RO_offset,:].min())
      errMap[i,:,:] = np.sqrt(np.sum(np.abs(x[i*2:(i+1)*2,:,:]-xHatS[i*2:(i+1)*2,:,:])**2, axis=0))
      # errFMap[i,:,:] = np.sqrt(np.sum(np.abs(y[i*2:(i+1)*2,:,:]-xHatF[i*2:(i+1)*2,:,:])**2, axis=0))
  
## printing and plotting after iterations:
  if FS:
    print('Mean nmse: %1.2f,' %(10*np.log10(np.mean(nmse))), 'nmse: ',', '.join('%1.2f' % (10*np.log10(nmse[j])) for j in range(len(nmse)))) 
    print('Mean ssim: %1.3f,' %(np.mean(ssm)), 'ssim: ',', '.join('%1.3f' % (ssm[j]) for j in range(len(ssm)))) 
    print('Mean psnr: %1.3f,' %(np.mean(psn)), 'psnr: ',', '.join('%1.3f' % (psn[j]) for j in range(len(psn)))) 

  # fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
  # plt.imshow(np.reshape(np.transpose(xHatFAbs**gm[1],[1,0,2]), [n[0],n[1]*N]), cmap=mps[1]) # use a specific color map
  # plt.title("xHatFAbs")
  # plt.show()
    
  # fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
  # plt.imshow(np.reshape(np.transpose(yAbs**gm[1],[1,0,2]), [n[0],n[1]*N]), cmap=mps[1]) # use a specific color map
  # plt.title("yAbs")
  # plt.show()
    
  # fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
  # plt.imshow(np.reshape(np.transpose(errFMap**gm[1],[1,0,2]), [n[0],n[1]*N]), cmap=plt.cm.Greys_r) # use a specific color map
  # plt.title("errFMap")
  # plt.show() 
    
    fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
    plt.imshow(np.reshape(np.transpose(xAbs**g,[1,0,2]), [n[0],n[1]*N]), cmap=mps[0]) # use a specific color map
    plt.title("xAbs")
    plt.show()
      
    fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
    plt.imshow(np.reshape(np.transpose(errMap,[1,0,2]), [n[0],n[1]*N]), cmap=plt.cm.Greys_r) # use a specific color map
    plt.title("errxMap")
    plt.show()

  fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
  plt.imshow(np.reshape(np.transpose(xHatAbs**g,[1,0,2]), [n[0],n[1]*N]), cmap=mps[0]) # use a specific color map
  plt.title("xHatAbs")
  plt.show()
  # model = torch.load(data_path + 'model_'+'opt_%d_N_%d_Ind_%d' % (opt, N, NInd))
  # model.eval()
  ########################################################## Display/read the weights
  # # Print model's state_dict
  # print("Model's state_dict:")
  # for param_tensor in net.state_dict():
  #     print(param_tensor, "\t", net.state_dict()[param_tensor].size())

  # # Print optimizer's state_dict
  # print("Optimizer's state_dict:")
  # for var_name in optimizer.state_dict():
  #     print(var_name, "\t", optimizer.state_dict()[var_name])

  # print(net) # this gives the structure of the network
  L = 4 # read thie weights of this layer
  i=0
  for param in net.parameters():
    wt = param.data
    if i == L:
      break
    i=i+1

  print(wt.shape)
  print(wt.view(-1)) # this view(-1) reshpaes into a list
  # print(zs.shape)
  ########################################################## Display code vectors
  if Ns !=0: # common z for all images (Ns = 3 for opt 2)
    for i in range(Ns):
        fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
        plt.imshow(np.concatenate((torch_to_np(zs_saved[:,i,:,:]), torch_to_np(zs[:,i,:,:]), torch_to_np(zs[:,i,:,:]-zs_saved[:,i,:,:])), axis=1), cmap=plt.cm.Greys_r) # use a specific color map
        plt.title("Common Code Vector: "+str(i+1)+" (zs_init, zs, diff)")
        plt.show()
        print("max diff. ",np.max(torch_to_np(zs[:,i,:,:]-zs_saved[:,i,:,:])))  

  if Nz !=0: # 1 z for each image (Nz = 1 for opt 2)
    for i in range(N*Nz):
      fig = plt.figure(figsize=(16,5),facecolor='white', edgecolor=None)
      plt.imshow(np.concatenate((torch_to_np(z_saved[:,i,:,:]), torch_to_np(z[:,i,:,:]), torch_to_np(z[:,i,:,:]-z_saved[:,i,:,:])) , axis=1), cmap=plt.cm.Greys_r) # use a specific color map
      plt.title("Image-Specific Code Vector: "+str(i+1)+" (z_init, z, diff)")
      plt.show()
      print("min: ",torch.min(z[:,i,:,:]))
      print("max: ",torch.max(z[:,i,:,:]))
  
    
# get the end time
et = time.time()

# get the execution time
elapsed_time = et - st
print('Execution time:', elapsed_time/60, 'minutes')

# sv=1
# if FS:
  # if sv==1: np.save(data_path_r + 'xRefDISCUS_'+'N_%d_LR_%f_zlamb_%f' % (N, LR, z_lamb0) + '_R_%f'%R + '_numIters_%d'%num_iter + '.npy', x)
  # if sv==1: np.save(data_path_r + 'nmseDISCUS_'+'N_%d_LR_%f_zlamb_%f' % (N, LR, z_lamb0) + '_R_%f'%R + '_numIters_%d'%num_iter + '.npy', nmse)
  # if sv==1: np.save(data_path_r + 'ssimDISCUS_'+'N_%d_LR_%f_zlamb_%f' % (N, LR, z_lamb0) + '_R_%f'%R + '_numIters_%d'%num_iter + '.npy', ssm)
  # if sv==1: np.save(data_path_r + 'psnrDISCUS_'+'N_%d_LR_%f_zlamb_%f' % (N, LR, z_lamb0) + '_R_%f'%R + '_numIters_%d'%num_iter + '.npy', psn)
  # if sv==1: np.save(data_path_r + 'PSNRtracking_DISCUS_'+'N_%d_LR_%f_zlamb_%f' % (N, LR, z_lamb0) + '_R_%f'%R + '_numIters_%d'%num_iter + '.npy', torch_to_np(running_psn))

if sv==1: np.save(data_path_r + 'xHatDISCUS_'+'N_%d_LR_%f_zlamb_%f' % (N, LR, z_lamb0) + '_R_%f'%R + '_numIters_%d'%num_iter + '.npy', xHatS)

# if Nz!=0:
#   if sv==1: np.save(data_path_r + 'zDISCUS_'+'N_%d_LR_%f_zlamb_%f' % (N, LR, z_lamb0) + '_R_%f'%R + '_numIters_%d'%num_iter + '.npy', torch_to_np(z))
# if Ns!=0: 
#   if sv==1: np.save(data_path_r + 'zsDISCUS_'+'N_%d_LR_%f_zlamb_%f' % (N, LR, z_lamb0) + '_R_%f'%R + '_numIters_%d'%num_iter + '.npy', torch_to_np(zs))

# if sv==1: np.save(data_path_r + 'LOSStracking_DISCUS_'+'N_%d_LR_%f_zlamb_%f' % (N, LR, z_lamb0) + '_R_%f'%R + '_numIters_%d'%num_iter + '.npy', torch_to_np(running_loss))

# if sv==1: torch.save(net, data_path_r + 'model_DISCUS_'+'N_%d_LR_%f_zlamb_%f' % (N, LR, z_lamb0) + '_R_%f'%R + '_numIters_%d'%num_iter)
# # if sv==1: torch.save(mlp, data_path + 'model_'+'opt_%d_N_%d_Ind_%d' % (opt, N, NInd))
