In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from DatasetHandler import DatasetHandler
from Model import CNNSpeckleFilter
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
from utils import *

In [None]:
def enl(img):
    k=16
    
    enls = []
    for i in range(0,img.shape[0], k):
        for j in range(0, img.shape[1], k):
            p = img[i:i+k, j:j+k,...]
            enls.append((np.nanmean(p)**2)/(np.nanstd(p)**2))
    return np.max(enls)

# Load the dataset

In [None]:
handler = DatasetHandler('dataset_v2')
print("Training dataset size: ", len(handler.train_paths))
print("Validation dataset size: ", len(handler.val_paths))
print("Testing dataset size: ", len(handler.test_paths))

In [None]:
IMG_SHAPE = (256,256,1)

In [None]:
batch_speckle, batch_clean, batch_noise = next(iter(handler.data_loader(
    paths      = handler.train_paths, 
    batch_size = 1, 
    img_shape  = IMG_SHAPE,
    out_noise = True)
))

In [None]:
fig, axes = plt.subplots(nrows = 2, ncols = 3, figsize = (15,10))
axes[0,0].imshow(batch_speckle[0,:,:,0])
axes[0,0].axis(False)
axes[0,0].set_title('Input with speckle')
axes[1,0].hist(batch_speckle[0,:,:,0].flatten(), 100)

axes[0,1].imshow(batch_clean[0,:,:,0])
axes[0,1].axis(False)
axes[0,1].set_title('Ground Truth')
axes[0,0].set_title('Input with speckle')
axes[1,1].hist(batch_clean[0,:,:,0].flatten(), 100)

axes[0,2].imshow(batch_noise[0,:,:,0])
axes[0,2].axis(False)
axes[0,2].set_title('Speckle')
axes[1,2].hist(batch_noise[0,:,:,0].flatten(), 100)

plt.show()

# CNN speckle filter training

## Intialize the model


In [None]:
TRAIN = False

N_LAYER = 30
speckle_filter = CNNSpeckleFilter(input_shape=IMG_SHAPE, n_layers=N_LAYER)


if TRAIN:
    epochs = 100
    batch_size = 12

    train_gen = handler.data_loader(
        paths      = handler.train_paths, 
        batch_size = batch_size, 
        img_shape  = IMG_SHAPE,
        out_noise = False)

    val_gen = handler.data_loader(
        paths      = handler.val_paths, 
        batch_size = batch_size, 
        img_shape  = IMG_SHAPE,
        out_noise = False)

    train_step = len(handler.train_paths)//batch_size
    val_step = len(handler.val_paths)//batch_size

    history = speckle_filter.train_model(epochs, train_gen, val_gen, train_step, val_step)
    speckle_filter.model.save_weights('weights/new_model_'+str(N_LAYER)+'.h5')

    fig, axes = plt.subplots(nrows = 1, ncols = 1, figsize = (15,5))
    axes.plot(history.history['loss'], label = 'Training')
    axes.plot(history.history['val_loss'], label = 'Validation')
    axes.legend()

else:
    from tensorflow.keras.models import load_model
    speckle_filter.model.load_weights('weights/new_model_30.h5')

In [None]:
batch_speckle, batch_clean = next(iter(handler.data_loader(
    paths      = handler.test_paths, 
    batch_size = 16,
    img_shape  = IMG_SHAPE,
    out_noise  = False)))

batch_pred = speckle_filter.model.predict(batch_speckle)

In [None]:
plot_model_results(batch_speckle, batch_clean, batch_pred, n = True)

# Compare results

In [None]:
from test_models import *

In [None]:
from findpeaks import findpeaks
import findpeaks

def apply_filters(clean, speckle, proposed):
    img = findpeaks.stats.scale(speckle[...,0])
    # CLASSICAL
    image_lee, image_lee_enhanced, image_kuan, image_frost, image_mean, image_median, img_fastnl, img_bilateral = test_classic(img)
    # BM3D
    img_bm3d = test_BM3D(img)
    
    size = 256
    
    imgs = []
    imgs.append(clean[0:size,0:size])
    imgs.append(speckle[0:size,0:size,0])
    imgs.append(proposed[0:size,0:size])
    imgs.append(image_lee[0:size,0:size]/255.0)
    imgs.append(image_lee_enhanced[0:size,0:size]/255.0)
    imgs.append(image_kuan[0:size,0:size]/255.0)
    imgs.append(image_frost[0:size,0:size]/255.0)
    imgs.append(image_mean[0:size,0:size]/255.0)
    imgs.append(image_median[0:size,0:size]/255.0)
    imgs.append(img_fastnl[0:size,0:size]/255.0)
    imgs.append(img_bilateral[0:size,0:size]/255.0)
    imgs.append(img_bm3d[0:size,0:size]/255.0)

    return imgs

In [None]:
P1, S1 = [],[]

for IMG_N in range(batch_speckle.shape[0]):
    imgs = apply_filters(batch_clean[IMG_N,...,0], batch_speckle[IMG_N,...], batch_pred[IMG_N,...,0])
    labels = ['Ground Truth','Input With Speckle','Proposed', 'Lee', 'Lee Enhanced', 'Kuan', 'Frost', 'Mean', 'Median', 'Fastnl', 'Bilateral', 'BM3D']
    
    print('================================================ TEST %d ================================================' % (IMG_N))
    print('|----------------------------------------------------------------|')
    fig, axes = plt.subplots(nrows = 3, ncols = 4, figsize = (24,18))
    
    pp1, ss1, = [],[]

    counter = 0
    for i in range(3):
        for j in range(4):
            axes[i,j].imshow(imgs[counter], cmap = 'gray', vmin = 0, vmax = 1)
            axes[i,j].set_title(labels[counter], fontsize = 18)
            axes[i,j].axis(False)
            
            psnr = peak_signal_noise_ratio(batch_clean[IMG_N,...,0], imgs[counter])
            ssim = structural_similarity(batch_clean[IMG_N,...,0], imgs[counter])
            
            print('{} --- PSNR  {} --- SSIM {}'.format(labels[counter], psnr, ssim))
            print('{} --- ENL G {} --- ENL N {} --- ENL P {}'.format(labels[counter], enl(batch_clean[IMG_N,...,0]), enl(batch_speckle[IMG_N,...]), enl(imgs[counter])))
            
            counter += 1
            
            pp1.append(psnr)
            ss1.append(ssim)
            
    
    P1.append(pp1)
    S1.append(ss1)

    plt.show()

In [None]:
P1 = np.array(P1)
S1 = np.array(S1)

In [None]:
print(np.nanmean(P1, axis = 0))
print(np.nanmean(S1, axis = 0))