In [None]:
import numpy as np
import matplotlib.pyplot as plt
import datetime
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import keras

from data import read_data
from utils import add_noise_est, normalize, add_noise, read_div2k_data

from model_mwcnn import MWCNN
from model_baseline import Unet
from model_kpn import KPN, LossFunc, LossBasic

gpu_ok = tf.test.is_gpu_available()
print("tf version:", tf.__version__)
print("use GPU:", gpu_ok)

In [None]:
seed = 42
np.random.seed(seed)

# Model

In [None]:
# model
#model = Unet(color = False, filter_size=(5,5), channel_att=False, spatial_att=True, if_wavelet=False)
#model = MWCNN(color = False, filter_size=(5,5), channel_att=False, spatial_att=False)
model = KPN(color=False, burst_length=1, blind_est=True, sep_conv=False, kernel_size=[3],
            channel_att=False, spatial_att=True, core_bias=True, use_bias=True)

filename = 'kpn_ks3_satt_bias_combinedloss'

load_model = True
if load_model:
    model.load_weights(filepath = "model_weights/transfer_to_div2k/" + filename + ".ckpt")

# Prepare the data

In [None]:
(train_X_p, train_Y_p), (test_X_p, test_Y_p) = read_data('div2k')
N_ims= len(train_X_p)

In [None]:
# read test data
num = 5
fname = "./DIV2K_database/DIV2K_valid_HR"
test_X, test_Y = read_div2k_data(fname, num, if_normalized=True)

labels = [test_X[i][0] for i in range(num)]
test_X = [tf.expand_dims(tf.expand_dims(test_X[i][1], axis=-1), axis=0) for i in range(num) if test_X[i][1].shape[0]%8==0 and test_X[i][1].shape[1]%8==0]
test_Y = [tf.expand_dims(tf.expand_dims(test_Y[i][1], axis=-1), axis=0) for i in range(num) if test_Y[i][1].shape[0]%8==0 and test_Y[i][1].shape[1]%8==0]

num_tested = len(test_X)
print('Totally', num_tested, 'images to be tested')

In [None]:
# use generator to handle inputs of different sizes
def generator():
    for (test_x, test_y) in zip(test_X, test_Y):
        yield (test_x, test_y)

#test_dataset = tf.data.Dataset.from_tensor_slices((test_X,test_Y))
test_dataset = tf.data.Dataset.from_generator(generator = generator, output_types = (tf.float32, tf.float32))
test_dataset = test_dataset.batch(1).prefetch(1)

# Test on the full images

In [None]:
pred_Y = []
for step, (test_x, test_y) in enumerate(test_dataset.take(num_tested)):
    print(test_x.shape)
    pred_y, _, _ = model(test_x, tf.expand_dims(test_x[...,0], axis=-1))
    pred_Y.append(pred_y)

In [None]:
plt.figure(figsize = (50,10*num_tested))
i = 1
    
for n in range(num_tested):
    plt.subplot(num_tested,3,i)
    plt.imshow(test_X[n].numpy().squeeze(), cmap='gray')
    #plt.title('noise var {:.3f}'.format(test_x[n][...,1].mean()))
    plt.axis('off')
    i += 1

    plt.subplot(num_tested,3,i)
    plt.imshow(test_Y[n].numpy().squeeze(), cmap='gray')
    plt.axis('off')
    i += 1
    
    plt.subplot(num_tested,3,i)
    plt.imshow(pred_Y[n].numpy().squeeze(), cmap='gray')
    plt.axis('off')
    i += 1

plt.savefig('./results/images/transfer_to_div2k/full_images/outputs_ks3.png')
plt.show()

In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

def error(x1, x2, mode='mse'):
    if mode == 'mse':
        return np.mean(np.square(x1-x2))
    elif mode == 'mae':
        return np.mean(np.abs(x1-x2))
    return

psnr_gt_n = ssim_gt_n = error_gt_n = 0
psnr_r_n = ssim_r_n = error_r_n = 0
for i in range(num_tested):
    psnr_gt_n += psnr(test_X[i].numpy().squeeze(), test_Y[i].numpy().squeeze(), data_range=1)
    ssim_gt_n += ssim(test_X[i].numpy().squeeze(), test_Y[i].numpy().squeeze(), data_range=1)
    error_gt_n += error(test_X[i].numpy().squeeze(), test_Y[i].numpy().squeeze())
    
    psnr_r_n += psnr(pred_Y[i].numpy().squeeze(), test_Y[i].numpy().squeeze(), data_range=1)
    ssim_r_n += ssim(pred_Y[i].numpy().squeeze(), test_Y[i].numpy().squeeze(), data_range=1)
    error_r_n += error(pred_Y[i].numpy().squeeze(), test_Y[i].numpy().squeeze())

print('Evaluation of ground truth and noised images:')
print('psnr:{:.3f}\tssmi:{:.3f}\tmse:{:.3f}'.format(psnr_gt_n/num_tested, ssim_gt_n/num_tested, error_gt_n/num_tested))

print('\nEvaluation of recovered images and noised images:')
print('psnr:{:.3f}\tssmi:{:.3f}\tmse:{:.3f}'.format(psnr_r_n/num_tested, ssim_r_n/num_tested, error_r_n/num_tested))

# Using the clusters to simulate the kernels

In [None]:
#kernel_size = [3,5,7]

def apply_filtering(frames, core, bias, kernel_size):
    img_stack = []
    pred_img = []
    kernel = kernel_size[::-1]
    for index, K in enumerate(kernel):
        if not len(img_stack):
            frame_pad = tf.pad(frames, paddings=[[0,0], [0,0], [K//2,K//2], [K//2,K//2], [0,0]], mode='constant')
            for i in range(K):
                for j in range(K):
                    img_stack.append(frame_pad[:, :, i:i+height, j:j+width,:])
            img_stack = tf.stack(img_stack, axis=-1)                 # (bs, N, h, w，color, K*K) 
        else:
            # k_diff = (kernel[index - 1]**2 - kernel[index]**2) // 2
            k_diff = (kernel[index-1] - kernel[index]) // 2
            k_chosen = []
            for i in range(k_diff, kernel[index-1]-k_diff):
                k_chosen += [i*kernel[index-1]+j for j in range(k_diff, kernel[index-1]-k_diff)]
            # img_stack = img_stack[..., k_diff:-k_diff]
            img_stack = tf.convert_to_tensor(img_stack.numpy()[..., k_chosen])
        pred_img.append(tf.reduce_sum(tf.math.multiply(core[K], img_stack), axis=-1, keepdims=False))
    pred_img = tf.stack(pred_img, axis=0)                           # (nb_kernels, bs, N, h, w, color)
    pred_img_i = tf.reduce_mean(pred_img, axis=0, keepdims=False)   # (bs, N, h, w, color)

    pred_img_i += bias

    pred_img = tf.reduce_mean(pred_img_i, axis=1, keepdims=False)          # (bs, h, w, color)
    return pred_img, pred_img_i

In [None]:
from sklearn.cluster import KMeans

n_clusters = 1
kmeans = KMeans(n_clusters=n_clusters, random_state=0)

pred_Y_clustered = []
for step, (test_x, test_y) in enumerate(test_dataset.take(num_tested)):
    print(test_x.shape)
    pred_y, _, core = model(test_x, tf.expand_dims(test_x[...,0], axis=-1))
    batch_size, N, height, width, color = tf.expand_dims(test_x[...,0], axis=-1).shape 
    core, bias = model.kernel_pred._convert_dict(core, batch_size, N, height, width, color)

    core3_all = tf.reshape(core[3], [-1, 9]).numpy()
    print(core3_all.shape)
    
    y_preds = kmeans.fit_predict(core3_all)
    core3_all_clustered = kmeans.cluster_centers_[kmeans.labels_]  # use kmeans to cluster the kernels
    core3_all_clustered = core3_all_clustered.reshape(batch_size, N, height, width, color, -1)
    core3_all_clustered = dict({3: core3_all_clustered}) # use dict
    print(core3_all_clustered[3].shape)
    
    pred_test_y3, _ = apply_filtering(test_x, core, bias, kernel_size = [3])
    pred_test_y3_clustered, _ = apply_filtering(test_x, core3_all_clustered, bias, kernel_size = [3])
    
    pred_Y_clustered.append(pred_test_y3_clustered)

In [None]:
plt.figure(figsize = (50,10*num_tested))
i = 1
    
for n in range(num_tested):
    plt.subplot(num_tested,3,i)
    plt.imshow(test_X[n].numpy().squeeze(), cmap='gray')
    #plt.title('noise var {:.3f}'.format(test_x[n][...,1].mean()))
    plt.axis('off')
    i += 1

    plt.subplot(num_tested,3,i)
    plt.imshow(test_Y[n].numpy().squeeze(), cmap='gray')
    plt.axis('off')
    i += 1
    
    plt.subplot(num_tested,3,i)
    plt.imshow(pred_Y_clustered[n].numpy().squeeze(), cmap='gray')
    plt.axis('off')
    i += 1

plt.savefig('./results/images/transfer_to_div2k/full_images/outputs_ks3_by_clustered_kernels_n1.png')
plt.show()

In [None]:
from skimage.metrics import peak_signal_noise_ratio as psnr
from skimage.metrics import structural_similarity as ssim

def error(x1, x2, mode='mse'):
    if mode == 'mse':
        return np.mean(np.square(x1-x2))
    elif mode == 'mae':
        return np.mean(np.abs(x1-x2))
    return

psnr_gt_n = ssim_gt_n = error_gt_n = 0
psnr_r_n = ssim_r_n = error_r_n = 0
for i in range(num_tested):
    psnr_gt_n += psnr(test_X[i].numpy().squeeze(), test_Y[i].numpy().squeeze(), data_range=1)
    ssim_gt_n += ssim(test_X[i].numpy().squeeze(), test_Y[i].numpy().squeeze(), data_range=1)
    error_gt_n += error(test_X[i].numpy().squeeze(), test_Y[i].numpy().squeeze())
    
    psnr_r_n += psnr(pred_Y_clustered[i].numpy().squeeze(), test_Y[i].numpy().squeeze(), data_range=1)
    ssim_r_n += ssim(pred_Y_clustered[i].numpy().squeeze(), test_Y[i].numpy().squeeze(), data_range=1)
    error_r_n += error(pred_Y_clustered[i].numpy().squeeze(), test_Y[i].numpy().squeeze())

print('Evaluation of ground truth and noised images:')
print('psnr:{:.3f}\tssmi:{:.3f}\tmse:{:.3f}'.format(psnr_gt_n/num_tested, ssim_gt_n/num_tested, error_gt_n/num_tested))

print('\nEvaluation of recovered images and noised images:')
print('psnr:{:.3f}\tssmi:{:.3f}\tmse:{:.3f}'.format(psnr_r_n/num_tested, ssim_r_n/num_tested, error_r_n/num_tested))

# Use the patches information to decide which kernels to apply

In [None]:
n_clusters = 5
kmeans = KMeans(n_clusters=n_clusters, random_state=0)

core3_all = []
for step, (batch_test_X, batch_test_Y) in enumerate(test_dataset.take(1)):
    pred_test_Y, _, core = model(batch_test_X, tf.expand_dims(batch_test_X[...,0], axis=-1))
    
    batch_size, N, height, width, color = tf.expand_dims(batch_test_X[...,0], axis=-1).shape 
    core, bias = model.kernel_pred._convert_dict(core, batch_size, N, height, width, color)
    
    core3 = tf.reshape(core[3], [-1,9]).numpy()
    core3_all.append(core3)

core3_all = np.concatenate(core3_all, axis=0)
kmeans.fit(core3_all)

In [None]:
batch_test_X_flatten = []
K = 3
frame_pad = tf.pad(batch_test_X, paddings=[[0,0], [0,0], [K//2,K//2], [K//2,K//2], [0,0]], mode='constant')
for i in range(K):
    for j in range(K):
        batch_test_X_flatten.append(frame_pad[:, :, i:i+height, j:j+width,:])
batch_test_X_flatten = tf.stack(batch_test_X_flatten, axis=-1)       
print(batch_test_X_flatten.shape)

batch_test_X_flatten = batch_test_X_flatten.numpy().reshape(-1, 9)
print(batch_test_X_flatten.shape)

test_all = dict()
for i in range(n_clusters):
    test_all[i] = batch_test_X_flatten[np.where(kmeans.labels_==i)[0]].mean(axis=0).reshape(3,3)

In [None]:
from skimage.metrics import structural_similarity as ssim

def simulate_patch(patch, test_all):
    best_ssim = 0
    for k,v in test_all.items():
        cur_ssim = ssim(patch, v.flatten())
        if cur_ssim > best_ssim:
            best_ssim = cur_ssim
            simulated_patch = patch
            label = k
    return label, simulated_patch

simulated_patches = []
labels = []
for patch in batch_test_X_flatten:
    label, simulated_patch = simulate_patch(patch, test_all)
    simulated_patches.append(simulated_patch)
    labels.append(label)
simulated_patches = np.array(simulated_patches)
labels = np.array(labels)

print(labels.shape)
print(simulated_patches.shape)

In [None]:
core3_all_simulated = kmeans.cluster_centers_[labels]
core3_all_simulated = core3_all_simulated.reshape(batch_size, N, height, width, color, -1)
core3_all_simulated = dict({3: core3_all_simulated}) # use dict

pred_test_Y3_simulated, _ = apply_filtering(batch_test_X, core3_all_simulated, bias, kernel_size = [3])
pred_test_Y3_simulated.shape

In [None]:
plt.figure(figsize = (50,10))
i = 1
    
for n in range(1):
    plt.subplot(1,3,i)
    plt.imshow(test_X[n].numpy().squeeze(), cmap='gray')
    #plt.title('noise var {:.3f}'.format(test_x[n][...,1].mean()))
    plt.axis('off')
    i += 1

    plt.subplot(1,3,i)
    plt.imshow(test_Y[n].numpy().squeeze(), cmap='gray')
    plt.axis('off')
    i += 1
    
    plt.subplot(1,3,i)
    plt.imshow(pred_test_Y3_simulated[n].numpy().squeeze(), cmap='gray')
    plt.axis('off')
    i += 1

plt.savefig('./results/images/transfer_to_div2k/full_images/outputs_ks3_by_simulated_kernels.png')
plt.show()