In [3]:
import gc
import torch
import torch.nn as nn
from torch.nn import init
import torch.nn.functional as F
import scipy.io as sio
import numpy as np
import os
from torch.utils.data import Dataset, DataLoader
import platform
from argparse import ArgumentParser
import random
import csdata_fast
import cv2
import glob
import math
from DGUNet import DGUNet
import tensorflow as tf

In [None]:
# multi-gpu
num_gpu = 3
gpu_list = '0,1,2' # gpu list
start_epoch = 0 # epoch number start training
end_epoch = 1 # epoch number end training
learning_rate = 1e-4 # learning rate
layer_num_ICNN = 15 # phase number of ISTA-Net-plus
layer_num_IFC = 5 # phase number of ISTA-Net-plus
group_num = 1 # group number for training
cs_ratio = 25 # compression rate from {10, 25, 30, 40, 50}
rb_type = 1 # from {1, 2}
rb_num = 2 # from {3-10}
batch_size = 32
patch_size = 32
matrix_dir = 'sampling_matrix_new'


In [9]:
# single gpu
num_gpu = 1
gpu_list = '0' # gpu list
start_epoch = 0 # epoch number start training
end_epoch = 1 # epoch number end training
learning_rate = 1e-4 # learning rate
layer_num_ICNN = 15 # phase number of ISTA-Net-plus
layer_num_IFC = 5 # phase number of ISTA-Net-plus
group_num = 1 # group number for training
cs_ratio = 25 # compression rate from {10, 25, 30, 40, 50}
rb_type = 1 # from {1, 2}
rb_num = 2 # from {3-10}
batch_size = 32
patch_size = 32
matrix_dir = 'sampling_matrix_new'


In [None]:
def col2im_CS_py(X_col, row, col, row_new, col_new):
    block_size = patch_size
    X0_rec = np.zeros([row_new, col_new])
    count = 0
    for x in range(0, row_new - block_size + 1, block_size):
        for y in range(0, col_new - block_size + 1, block_size):
            X0_rec[x:x + block_size, y:y + block_size] = X_col[:, count].reshape([block_size, block_size])
            count = count + 1
    X_rec = X0_rec[:row, :col]
    return X_rec

In [None]:
def psnr(img1, img2):
    img1.astype(np.float32)
    img2.astype(np.float32)
    mse = np.mean((img1 - img2) ** 2)
    if mse == 0:
        return 100
    PIXEL_MAX = 255.0
    return 20 * math.log10(PIXEL_MAX / math.sqrt(mse))

In [None]:
def imread_CS_py(Iorg):
    block_size = patch_size
    [row, col] = Iorg.shape
    row_pad = block_size - np.mod(row, block_size)
    col_pad = block_size - np.mod(col, block_size)
    Ipad = np.concatenate((Iorg, np.zeros([row, col_pad])), axis=1)
    Ipad = np.concatenate((Ipad, np.zeros([row_pad, col + col_pad])), axis=0)
    [row_new, col_new] = Ipad.shape

    return [Iorg, row, col, Ipad, row_new, col_new]

In [None]:

def img2col_py(Ipad, block_size):
    [row, col] = Ipad.shape
    row_block = row / block_size
    col_block = col / block_size
    block_num = int(row_block * col_block)
    img_col = np.zeros([block_size ** 2, block_num])
    count = 0
    for x in range(0, row - block_size + 1, block_size):
        for y in range(0, col - block_size + 1, block_size):
            img_col[:, count] = Ipad[x:x + block_size, y:y + block_size].reshape([-1])
            count = count + 1
    return img_col

In [None]:
class RandomDataset(Dataset):
    def __init__(self, data, length):
        self.data = data
        self.len = length

    def __getitem__(self, index):
        return torch.Tensor(self.data[index, :]).float()

    def __len__(self):
        return self.len

In [10]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = gpu_list
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
gpus = tf.config.experimental.list_physical_devices(device_type='GPU')
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [11]:
ratio_dict = {10: 0, 25: 1, 30: 2, 40: 3, 50: 4}
# n_input = ratio_dict[cs_ratio]
n_output = 1089
nrtrain = 88912  # number of training blocks
psnr_best = 0
best_epoch = 0

In [12]:
# test data
test_name = 'myTest'  # test images repositary
test_dir = os.path.join('Datasets', test_name)
filepaths = glob.glob(test_dir + '/*.jpg')
ImgNum = len(filepaths)

In [13]:
# Load CS Sampling Matrix: phi
Phi_data_Name = os.path.join(matrix_dir,
                             'phi_sampling_%d_%dx%d.npy' % (cs_ratio, patch_size,patch_size))
Phi_input = np.load(Phi_data_Name)
Phi = torch.from_numpy(Phi_input).type(torch.FloatTensor)

In [None]:
length, in_channels = Phi.shape
print('Length:', length, ' In_channels:', in_channels)
model = DGUNet(in_c=1, out_c=1,cs_ratio=cs_ratio)
print('GPU: ',list(range(num_gpu)))
model = nn.DataParallel(model,device_ids=list(range(num_gpu)))
model = model.to(device)

In [None]:
print_flag = 1  # print parameter number
if print_flag:
    num_count = 0
    num_params = 0
    for para in model.parameters():
        num_count += 1
        num_params += para.numel()
        print('Layer %d' % num_count)
        print(para.size())
    print("total para num: %d" % num_params)

In [None]:
training_data = csdata_fast.SlowDataset(args)

In [None]:
if (platform.system() == "Windows"):
    rand_loader = DataLoader(dataset=training_data, batch_size=batch_size,
                             shuffle=True)
else:
    print('linux')
    rand_loader = DataLoader(dataset=training_data, batch_size=batch_size, num_workers=24,
                             shuffle=True)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

model_dir = "./%s/CS_%s_layerICNN_%d_layerIFC_%s_group_%d_ratio_%d" % (
model_dir, algo_name, layer_num_ICNN, layer_num_IFC, group_num, cs_ratio)

log_file_name = "./%s/Log_CS_%s_layerICNN_%d_layerIFC_%d_group_%d_ratio_%d.txt" % (
log_dir, algo_name, layer_num_ICNN, layer_num_IFC, group_num, cs_ratio)

In [None]:
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

if start_epoch > 0:
    pre_model_dir = model_dir
    model.load_state_dict(torch.load('./%s/net_params_%d.pkl' % (pre_model_dir, start_epoch)))

step_all = len(rand_loader)

In [None]:
# Training loop
for epoch_i in range(start_epoch + 1, end_epoch + 1):
    model.train()
    for step, data in enumerate(rand_loader):
        batch_x = data
        batch_x = batch_x.to(device)
        batch_x = batch_x.view(batch_x.shape[0], patch_size * patch_size)

        batch_x = batch_x.view(batch_x.shape[0], 1, patch_size, patch_size)
        x_output_f = model(batch_x)
        if loss_mod==1:
            loss_list = [torch.mean(torch.pow(torch.clamp(x_output_f[j], 0, 1) - batch_x, 2)) for j in range(len(x_output_f))]
            loss_all = torch.sum(torch.stack(loss_list))
        else:
            loss_all = torch.mean(torch.pow(torch.clamp(x_output_f[0], 0, 1) - batch_x,2))

        batch_x.cpu()
        del batch_x
        gc.collect()

        optimizer.zero_grad()
        loss_all.backward()
        optimizer.step()
        if step % 40 == 0:
            output_data = "[Epoch: %02d/%02d Step: %d/%d] Total Loss: %.4f" % (
                epoch_i, end_epoch, step, step_all, loss_all.item())
            print(output_data)
    if not os.path.exists(log_file_name):
        file = open(log_file_name, 'w')
        file.write(output_data)
    else:
        output_file = open(log_file_name, 'a')
        output_file.write(output_data)
        output_file.close()

    del rand_loader
    gc.collect()

    model.eval()
    with torch.no_grad():
        psnr_ave = 0
        for img_no in range(ImgNum):
            imgName = filepaths[img_no]
            Img = cv2.imread(imgName, 1)
            Img_yuv = cv2.cvtColor(Img, cv2.COLOR_BGR2YCrCb)
            Iorg_y = Img_yuv[:, :, 0]
            [Iorg, row, col, Ipad, row_new, col_new] = imread_CS_py(Iorg_y)
            Icol = img2col_py(Ipad, patch_size).transpose() / 255.0
            Img_output = Icol
            batch_x = torch.from_numpy(Img_output)
            batch_x = batch_x.type(torch.FloatTensor)
            batch_x = batch_x.to(device)

#             Phix = torch.mm(batch_x, torch.transpose(Phi, 0, 1))  # compression result
#             PhixPhiT = torch.mm(Phix, Phi)
            batch_x = batch_x.view(batch_x.shape[0], 1, patch_size, patch_size)
            x_output = model(batch_x)[0]  # torch.mm(batch_x,

            batch_x.cpu()
            del batch_x
            gc.collect()

            x_output = x_output.view(x_output.shape[0], -1)
            Prediction_value = x_output.cpu().data.numpy()

            X_rec = np.clip(col2im_CS_py(Prediction_value.transpose(), row, col, row_new, col_new), 0, 1)

            rec_PSNR = psnr(X_rec * 255, Iorg.astype(np.float64))
            psnr_ave += rec_PSNR
            del x_output
            gc.collect()
    psnr_ave /= ImgNum
    if psnr_ave > psnr_best:
        best_epoch = epoch_i
        psnr_best = psnr_ave
        torch.save(model.state_dict(), "./%s/net_best.pkl" % (model_dir))  # save only the parameters
    torch.save(model.state_dict(), "./%s/net_last.pkl" % (model_dir))  # save only the parameters
    print('best psnr is %.4f in epoch %d psnr_rec: %.4f' % (psnr_best, best_epoch, psnr_ave))