# Import Libraries

In [None]:

from google.colab import drive
#drive.mount('/content/gdrive/')
drive.mount("/content/gdrive/", force_remount=True)

In [None]:
%cd "/content/gdrive/My Drive/Colab Notebooks/LADMM_Net_Pytorch"
%ls
%pip install sewar

In [None]:
import os
import numpy as np
import numpy.matlib
import matplotlib.pyplot as plt
import scipy.io as sio
from time import time

# our libraries
from utils import featurefusionpkg as ff

from sewar.full_ref import sam
from skimage.metrics import structural_similarity as ssim


from models.LadmmNet import LADMMcsifusionfastNet
# Pytorch libraries
import torch
import torch.nn as nn
import torch.nn.functional as F

gpu_list = '0'
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_list
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# Measurement Matrices

In [None]:
# Loading hyperspectral coded aperture
fname1       = 'cca_hs.mat'
data_path    = os.path.join(os.getcwd(),'data/Harvard/csi_measurements/50')
#data_path    = os.path.join(os.getcwd(),'data/Harvard/csi_measurements/375')
cca_hs       = sio.loadmat(os.path.join(data_path, fname1))['cca_hs']
shots_hs, M_hs, N_hs, L = cca_hs.shape
ccahs_np     = np.zeros((shots_hs*M_hs*N_hs*L))
for i in range(0,shots_hs):
  ccahs_np[i*M_hs*N_hs*L:(i+1)*M_hs*N_hs*L] = cca_hs[i,:,:,:].reshape((M_hs*N_hs*L),order='F')
ccahs        = torch.from_numpy(np.double(ccahs_np)).type(torch.FloatTensor)
del cca_hs, ccahs_np
ccahs = ccahs.view(-1,L,M_hs,N_hs).to(device)

# Loading multispectral coded aperture
fname1       = 'cca_ms.mat'
cca_ms       = sio.loadmat(os.path.join(data_path, fname1))['cca_ms']
shots_ms, M, N, L_ms = cca_ms.shape
ccams_np     = np.zeros((shots_ms*M*N*L_ms))
for i in range(0,shots_ms):
  ccams_np[i*M*N*L_ms:(i+1)*M*N*L_ms] = cca_ms[i,:,:,:].reshape((M*N*L_ms),order='F')
ccams        = torch.from_numpy(np.double(ccams_np)).type(torch.FloatTensor)
del cca_ms, ccams_np
ccams = ccams.view(-1,L_ms,M,N).to(device)

p = 4
q = 2

# LADMM Parameters

In [None]:
layer_num               = 10
learning_rate           = 0.0005
epochs                  = 256
epochs                  = epochs + 1
num_training_samples    = 48
num_samples             = 48
batch_size              = 1
compression_ratio       = 50

model     = LADMMcsifusionfastNet(layer_num)
model     = nn.DataParallel(model)
model     = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

data_path       = os.path.join(os.getcwd(),'data/Harvard/test_images')
model_dir       = "./train_parameters/Harvard/LADMM_Net_layer_%d_ratio_%d" % (layer_num, compression_ratio)
model.load_state_dict(torch.load('./%s/net_params_%d.pkl' % (model_dir, epochs-1),map_location='cpu'))
model= model.to(device)

def SpectralDegradationFilter(window_size, L, q):
  kernel = torch.zeros((L//q,L,window_size,window_size))
  for i in range(0,L//q):
    kernel[i,i*q:(i+1)*(q),window_size//2,window_size//2] = 1/q
  return kernel

def ProjectionFilter(window_size, L):
  kernel = torch.zeros((1,L,window_size,window_size))
  kernel[0,1:L,window_size//2,window_size//2] = 1
  return kernel

def SpectralUpsamplingFilter(window_size, q, L):
  kernel = torch.zeros((L,L//q,window_size,window_size))
  for i in range(0,L//q):
    for j in range(0,q):
      kernel[i*q+j,i,window_size//2,window_size//2] = 1 
  return kernel

fnp = np.zeros((M*N*L*batch_size))

psnr_vector = np.zeros(24)
ssim_vector = np.zeros(24)
samp_vector = np.zeros(24)
cmpt_vector = np.zeros(24)

kk = 0
for ii in range(49,73):
  print('Iteration: %d'%(kk+1))
  fpointer        = ii
  fsamples_ms     = 'hri_%03d.mat' % (fpointer)
  hri             = sio.loadmat(os.path.join(data_path, fsamples_ms))['foo']
  fnp[0:M*N*L] = hri.reshape((M*N*L),order='F')

  f = torch.from_numpy(np.double(fnp)).type(torch.FloatTensor)*(1/255.0)
  f = f.view(-1,L,M,N).to(device)

  # Acquisition process of the HS compressive measurements
  hs_deg        = nn.AvgPool2d(p)
  shot_data_hs  = torch.mean(torch.mul(ccahs,hs_deg(f).repeat(shots_hs, 1, 1, 1)),(1))
  shot_data_hs  = shot_data_hs * (1/torch.max(shot_data_hs)) 

  # HS measurement matrix transpose
  HTyhs         = F.interpolate(torch.mean(torch.mul(shot_data_hs.view(shots_hs,1,M_hs,N_hs).repeat(1,L,1,1), ccahs),(0)).view(1,L,M_hs,N_hs),scale_factor=(p,p))

  # Acquisition process of the MS compressive measurements
  kernel = SpectralDegradationFilter(3,L,q).to(device)
  shot_data_ms  = torch.mean(torch.mul(ccams,F.conv2d(f, kernel, padding=1).repeat(shots_ms, 1, 1, 1)),(1))
  shot_data_ms  = shot_data_ms * (1/torch.max(shot_data_ms))

  # MS measurement matrix transpose
  upsamp = SpectralUpsamplingFilter(3,q,L_ms*q).to(device)
  HTyms  = F.conv2d(torch.mean(torch.mul(shot_data_ms.view(shots_ms,1,M,N).repeat(1,L_ms,1,1), ccams),(0)).view(1,L_ms,M,N),upsamp, padding=1)

  start = time()
  [x_output, loss_layers_sym] = model(ccahs, ccams, HTyhs, HTyms, M, N, L, p, q, shots_hs, shots_ms)
  end   = time()
  del shot_data_hs, shot_data_ms, HTyhs, HTyms, loss_layers_sym

  cmpt_vector[kk] = end - start

  hri = hri / np.max(hri)

  Io        = np.zeros((M,N,L))
  psnr_rec  = np.zeros(L)
  ssim_rec  = np.zeros(L)
  for l in range(0,L):
    It1         = x_output.view(-1,32,512,512)[0,l,:,:]
    It1         = It1.cpu().detach().numpy()
    Io[:,:,l]   = np.transpose(It1)
    It2         = np.transpose(hri[:,:,l])
    res         = np.mean(np.power((It2-It1),2))
    psnr_rec[l]   = 10*np.log10(1/res)
    ssim_rec[l]   = ssim(It2, It1, data_range=1)

  samp_vector[kk] = sam(Io,hri)
  psnr_vector[kk] = np.mean(psnr_rec)
  ssim_vector[kk] = np.mean(ssim_rec)
  del x_output, Io, hri, It1, It2
  kk = kk + 1

print('----------')
print('PSNR: %.4f + %.4f dB'%(np.mean(psnr_vector),np.std(psnr_vector)))
print('SSIM: %.4f + %.4f'%(np.mean(ssim_vector),np.std(ssim_vector)))
print('SAM: %.4f + %.4f'%(np.mean(samp_vector),np.std(samp_vector))) 
print('Time: %.4f + %.4f s'%(np.mean(cmpt_vector),np.std(cmpt_vector)))
print('----------')