In [None]:
from __future__ import print_function
import keras
from keras.datasets import mnist
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Sequential
from keras.layers import Dense, Dropout, Activation, Flatten
from keras.layers import Conv2D, MaxPooling2D
from keras.layers.convolutional import ZeroPadding2D
from keras.models import load_model

from keras.layers import Input, Conv2D, Conv2DTranspose, Add, MaxPooling2D, Dropout
from keras.models import Model
from keras.optimizers import Adam

from sklearn.utils import shuffle

import h5py
from skimage import io, exposure

# convenient imports
import tensorflow as tf
from keras import backend as K

import cv2

from skimage.transform import resize

#import sys
#np.set_printoptions(threshold=sys.maxsize)

In [None]:
from skimage import io, exposure

In [None]:
import os, random
import numpy as np

import seaborn as sns

import matplotlib
matplotlib.use('Agg');
import matplotlib.pyplot as plt
plt.set_cmap('Greys');

%matplotlib inline

In [None]:
print("keras", keras.__version__)
print("tensorflow", tf.__version__)

In [None]:
# check the backend the ordering of the channels
print(keras.backend.backend())
print(keras.backend.image_dim_ordering())
print(K.image_data_format())

### Load parameters:

In [None]:
parent_dir = 'YOURPATH'

weights_file = '0600_0.0001'
y_file = 'Y'

im_size = 1024

num_classes = 40

### Load the model:

In [None]:
def cnn(input_shape, n_output_channels):

    x_in = Input(input_shape)

    x1 = Conv2D(n_conv, kernel_size=3, padding="same", activation="relu")(x_in)
    x1 = Conv2D(n_conv, kernel_size=3, padding="same", activation="relu")(x1)
    x1 = BatchNormalization()(x1)
    
    x1_pool = MaxPooling2D(pool_size=2, strides=2, padding="same")(x1)
    x1_pool = Dropout(0.25)(x1_pool)
    
    x2 = Conv2D(n_conv*2, kernel_size=3, padding="same", activation="relu")(x1_pool)
    x2 = Conv2D(n_conv*2, kernel_size=3, padding="same", activation="relu")(x2)
    x2 = BatchNormalization()(x2)
    
    x2_pool = MaxPooling2D(pool_size=2, strides=2, padding="same")(x2)
    x2_pool = Dropout(0.25)(x2_pool)

    x3 = Conv2D(n_conv*4, kernel_size=3, padding="same", activation="relu")(x2_pool)
    x3 = Conv2D(n_conv*4, kernel_size=3, padding="same", activation="relu")(x3)
    x3 = BatchNormalization()(x3)

    x4 = Conv2DTranspose(n_conv*2, kernel_size=3, strides=2, padding="same", activation="relu")(x3)
    x4 = Conv2D(n_conv*2, kernel_size=3, padding="same", activation="relu")(x4)
    x4 = BatchNormalization()(x4)
    x4 = Dropout(0.25)(x4)

    x_out = Conv2DTranspose(n_output_channels, kernel_size=3, strides=2, padding="same", activation="linear")(x4)

    # Compile
    CNN = Model(inputs=x_in, outputs=x_out, name="CNN")
    CNN.compile(optimizer=Adam(), loss="mean_squared_error")

    return CNN

In [None]:
if 'model' in globals(): # check that the model is defined
    del model 

## Load model + weights seperately
model = cnn((im_size,im_size), num_classes)
model.load_weights(os.path.join(parent_dir, 'models', f'{y_file}', f'{weights_file}.hdf5'))
model.summary()

#model = load_model('wing_models/nn-wing-gauss5.h5py')

### Load the Data

In [None]:
# Load images:
x = np.load(os.path.join(parent_dir, 'npy_data', f'X_masked.npy'))
y = np.load(os.path.join(parent_dir, 'npy_data', f'Y_sig3.npy'))
print(f'x size: {x.shape}')
print(f'y size: {y.shape}')

In [None]:
# Load images:
x_copy = np.load(os.path.join(parent_dir, 'npy_data', f'X_masked.npy'))
y_copy = np.load(os.path.join(parent_dir, 'npy_data', f'Y_sig3.npy'))
print(f'x size: {x.shape}')
print(f'y size: {y.shape}')

### Split train test

In [None]:
# Shuffle 
x, y = shuffle(x, y, random_state=0)

n_train = 550

(x_train, y_train) = x[:n_train], y[:n_train]
(x_test, y_test) = x[n_train:], y[n_train:]

print(f'x train size: {x_train.shape}')
print(f'x test size: {x_test.shape}')

print(f'y train size: {y_train.shape}')
print(f'y test size: {y_test.shape}')

x_train = x_train.astype('float32')[:,:,:, None]
x_test = x_test.astype('float32')[:,:,:, None]

print(f'x train size: {x_train.shape}')
print(f'x test size: {x_test.shape}')

### Save the test images - X & Y:

In [None]:
# # Save x_test:
# xs_images_dir = os.path.join(parent_dir, 'images_xs_notnorm')
# os.makedirs(xs_images_dir, exist_ok=True)

# for i in range(x_test.shape[0]):
    
#     temp = x_test[i,:,:,0]
#     max_temp = max(np.absolute([np.min(x_test), np.max(x_test)]))
#     temp = exposure.equalize_hist(temp)
#     io.imsave(os.path.join(xs_dir, f'{i}.png'), temp)
    

In [None]:
# Save y_test - true values, not predictions

ys_im_dir = os.path.join(parent_dir, 'images_ys', 'sig3')
os.makedirs(ys_im_dir, exist_ok=True)

for i in range(x_test.shape[0]):
    for j in range(num_classes):
        io.imsave(os.path.join(ys_im_dir, f'{i}_{j}.png'), (y_test[i,:,:,j]).astype(np.float32))

In [None]:
# Save predictions 

pred_dir = os.path.join(parent_dir, 'images_predictions', f'sig3_xMasked')
os.makedirs(pred_dir, exist_ok=True)

for i in range(x_test.shape[0]):
    
    sample = x_test[[i]]
    predicted = model.predict(sample)
    predicted[predicted<0] = 0
    predicted = predicted/np.max(predicted)
    
    for j in range(num_classes):
        
        io.imsave(os.path.join(pred_dir, f'{i}_{j}.png'), predicted[0,:,:,j])

In [None]:
# Save predictions not normed:

pred_notnormed_dir = os.path.join(parent_dir, 'images_predictions_notnormed', f'sig3_xMasked')
os.makedirs(pred_notnormed_dir, exist_ok=True)

for i in range(x_test.shape[0]):
    
    sample = x_test[[i]]
    predicted = model.predict(sample)
    
    predicted[predicted<0] = 0
    predicted = (predicted*1000).astype(np.int16)
    
    for j in range(num_classes):
        
        io.imsave(os.path.join(pred_notnormed_dir, f'{i}_{j}.png'), predicted[0,:,:,j])

In [None]:
plt.imshow(predicted[0,:,:,20])

### Display the REAL landmarks on the images:

In [None]:
# How to display Y images - each landmark has own color:

# sns.color_palette('colorblind', 3) - has 3 colors
# [1] takes the second color in the scheme, 256 displays 256 shades of that color
sns.palplot(sns.light_palette(sns.color_palette('colorblind', 3)[1], 256))

# Array of all those colors
#my_color_scale = sns.light_palette('green',256)

### Big Images:
Real landmarks display, with some radius around it

In [None]:
# import glob

# mama_dir = '/home/ella/Desktop/wings_for_RF/'

# img_folder = os.path.join(mama_dir, 'data/images/')
# landmarks_folder = os.path.join(mama_dir, 'ldmks/raw/')

# all_images_files = [f for f in glob.glob(os.path.join(img_folder,'*'))]
# all_landmarks_files = [f for f in glob.glob(os.path.join(landmarks_folder,'*'))] 

# all_images_names = [os.path.basename(f)[:-4] for f in all_images_files]
# all_landmarks_names = [os.path.basename(f)[:-4] for f in all_landmarks_files]

# # sort all lists to be on the same order
# images_lists = zip(all_images_names, all_images_files)
# landmarks_lists = zip(all_landmarks_names, all_landmarks_files)

# images_lists_s = list(sorted(images_lists))
# landmarks_lists_s = list(sorted(landmarks_lists))

# n_lm = 40

# # All images landmarks read - to list of np arrays:
# ims_landmarks = []
# for l in landmarks_lists_s:
#     lines = [line.rstrip('\n') for line in open(l[1])]
#     ims_landmarks.append(np.asarray([f.split(" ") for f in lines]).astype(np.float16))

In [None]:
# im = 1

# # Load image:
# x_im = io.imread(images_lists_s[im][1])   

# #thr = 0.95
# my_img = np.ones((3234, 3840,3))

# lms = ims_landmarks[im]

# r_lm = [0,1,5,10,2,7,8,9,3,11,12,13]

# r_lm = [i for i in range(40)]

# for l in range(len(r_lm)):
    
#     temp = np.ones((3234, 3840))
#     lm = lms[l]
    
#     temp[int(lm[1]),int(lm[0])] = 0
    
#     my_color_scale = sns.light_palette(sns.color_palette("Paired", 40)[l], 256)
#     my_color_scale[0] = [1,1,1]
    
#     color = my_color_scale[-1][0]
    
#     radi = 40
#     for i in range(radi):
#         for j in range(radi):
#             for ch in range(3):
#                 my_img[int(lm[1])-j,int(lm[0])-i,ch] = my_color_scale[-1][ch]
#                 my_img[int(lm[1])-j,int(lm[0])+i,ch] = my_color_scale[-1][ch]
#                 my_img[int(lm[1])+j,int(lm[0])-i,ch] = my_color_scale[-1][ch]
#                 my_img[int(lm[1])+j,int(lm[0])+i,ch] = my_color_scale[-1][ch]

# # Normalize X:
# x_show = ((x_im-np.min(x_im))/(np.max(x_im)-np.min(x_im))).astype(np.float64)
# x_show = np.repeat(x_show[:,:,np.newaxis],3,axis=2)


In [None]:
# new_img = cv2.addWeighted(x_show, 0.25, my_img, 0.8, 0)

# plt.figure(figsize=(20,10))
# plt.imshow(new_img, interpolation='none')

In [None]:
# new_img.shape

### Small Image:
Real Values - Y from yann

In [None]:
im = 6 
thr = 0.95

lm_img_clr = np.ones((im_size,im_size,3))
lm_img_bw = np.ones((im_size,im_size))

r_lm = [0,1,5,10,2,7,8,9,3,11,12,13]
lm = [i for i in range(40)]

for l in range(len(lm)):
    
    # Take current landmark image:
    temp = y_test[im,:,:,lm[l]].copy()
    temp[temp<thr] = thr
    temp = ((temp-np.min(temp))/(np.max(temp)-np.min(temp)) * 255).astype(np.uint8)
    
    my_color_scale = sns.light_palette(sns.color_palette("Paired", 40)[l], 256) 
    #my_color_scale[0] = [1,1,1]
    
    for i in range(temp.shape[0]):
        for j in range(temp.shape[1]):
            if temp[i,j] != 0:
                
                # Black and white image:
                lm_img_bw[i,j] = np.mean([lm_img_bw[i,j], temp[i,j]])
                
                # 
                lm_img_clr[i,j,0] = np.mean([lm_img_clr[i,j,0], my_color_scale[temp[i,j]][0]])
                lm_img_clr[i,j,1] = np.mean([lm_img_clr[i,j,1], my_color_scale[temp[i,j]][1]])
                lm_img_clr[i,j,2] = np.mean([lm_img_clr[i,j,2], my_color_scale[temp[i,j]][2]])

# Normalize X:
x_show = x_test[im,:,:,0] + np.min(x_test[im,:,:,0])
x_show = ((x_show-np.min(x_show))/(np.max(x_show)-np.min(x_show))).astype(np.float64)

x_show = np.repeat(x_show[:,:,np.newaxis],3,axis=2)

In [None]:
# Show bw landmarks image:
plt.figure(figsize=(20,10))
plt.imshow(lm_img_bw, interpolation='none')

In [None]:
# Show color image:
new_img = cv2.addWeighted(x_show, 0.2, lm_img_clr, 0.8, 0)
new_img = cv2.addWeighted(x_show, 0.2, lm_img_clr, 0.83, 0)

plt.figure(figsize=(20,10))
plt.imshow(new_img, interpolation='none')
plt.axis('off');

### Display the predictions:

In [None]:
#if 'model' in globals(): # check that the model is defined
#    del model 

#model = load_model('/home/ella/Desktop/wings_for_RF/forDL/wing_models/nn-wing-gauss5.h5py')

In [None]:
# Display 9 consecutive images:

fig, ax = plt.subplots(nrows=2, ncols=3, sharex=True, sharey=True, figsize=(20,13))

# How many highest values to display:
n_pix_per_lm = 5

im_num = 23
for row in ax:
    for col in row:

        sample = x_test[[im_num]]
        predicted = model.predict(sample)

        # Normalize X:
        x_show = sample[0,:,:,0] + np.min(sample[0,:,:,0])
        x_show = ((x_show-np.min(x_show))/(np.max(x_show)-np.min(x_show))).astype(np.float64)
        x_show = np.repeat(x_show[:,:,np.newaxis],3,axis=2)

        pred_show = np.ones((im_size,im_size,3))

        #r_lm = [0,1,5,10,2,7,8,9,3,11,12,13]
        lm = [i for i in range(40)]
        for l in range(len(lm)):
            temp = predicted[0,:,:,lm[l]].copy()

            # Take x highest values:
            val = (np.sort(temp.reshape(-1)))[-n_pix_per_lm]
            temp = np.where(temp>=val, 1, 0)

            temp = ((temp-np.min(temp))/(np.max(temp)-np.min(temp)) * 255).astype(np.uint8)
            my_color_scale = sns.light_palette(sns.color_palette("Paired", 40)[l], 256)
            my_color_scale = sns.light_palette(sns.color_palette("bright", 40)[l], 256)
            my_color_scale[0] = [1,1,1]
            
            for i in range(temp.shape[0]):
                for j in range(temp.shape[1]):
                    if temp[i,j] != 0:
                        pred_show[i,j,0] = np.mean([pred_show[i,j,0], my_color_scale[temp[i,j]][0]])
                        pred_show[i,j,1] = np.mean([pred_show[i,j,1], my_color_scale[temp[i,j]][1]])
                        pred_show[i,j,2] = np.mean([pred_show[i,j,2], my_color_scale[temp[i,j]][2]])

        new_img = cv2.addWeighted(x_show, 0.15, pred_show, 0.8, 0)
        col.imshow(new_img, vmin=0, vmax=1)
        col.axis('off')
        
        im_num+=1
        
fig.tight_layout()

In [None]:
# Display 9 consecutive images:

fig, ax = plt.subplots(nrows=2, ncols=3, sharex=True, sharey=True, figsize=(20,13))

# How many highest values to display:
n_pix_per_lm = 5

im_num = 23
for row in ax:
    for col in row:

        sample = x_test[[im_num]]
        predicted = model.predict(sample)

        # Normalize X:
        x_show = sample[0,:,:,0] + np.min(sample[0,:,:,0])
        x_show = ((x_show-np.min(x_show))/(np.max(x_show)-np.min(x_show))).astype(np.float64)
        x_show = np.repeat(x_show[:,:,np.newaxis],3,axis=2)

        pred_show = np.ones((im_size,im_size,3))

        #r_lm = [0,1,5,10,2,7,8,9,3,11,12,13]
        lm = [i for i in range(40)]
        for l in range(len(lm)):
            temp = predicted[0,:,:,lm[l]].copy()

            # Take x highest values:
            val = (np.sort(temp.reshape(-1)))[-n_pix_per_lm]
            temp = np.where(temp>=val, 1, 0)

            temp = ((temp-np.min(temp))/(np.max(temp)-np.min(temp)) * 255).astype(np.uint8)
            my_color_scale = sns.light_palette(sns.color_palette("Paired", 40)[l], 256)
            my_color_scale = sns.light_palette(sns.color_palette("bright", 40)[l], 256)
            my_color_scale[0] = [1,1,1]
            
            for i in range(temp.shape[0]):
                for j in range(temp.shape[1]):
                    if temp[i,j] != 0:
                        pred_show[i,j,0] = np.mean([pred_show[i,j,0], my_color_scale[temp[i,j]][0]])
                        pred_show[i,j,1] = np.mean([pred_show[i,j,1], my_color_scale[temp[i,j]][1]])
                        pred_show[i,j,2] = np.mean([pred_show[i,j,2], my_color_scale[temp[i,j]][2]])

        new_img = cv2.addWeighted(x_show, 0.15, pred_show, 0.8, 0)
        col.imshow(new_img, vmin=0, vmax=1)
        col.axis('off')
        
        im_num+=1
        
fig.tight_layout()

In [None]:
x_show[-1,-1,0]

### Predictions on full size images:

In [None]:
# original_names_file = '/PARENT_DIR/original_files_list.txt'

# original_files = [line.rstrip('\n') for line in open(original_names_file)]

In [None]:
# x_4display, original_files = shuffle(x_copy, original_files, random_state=0)

In [None]:
# print(original_files[586])
# print(original_files[589])
# print(original_files[591])
# print(original_files[594])

In [None]:
# import glob

# img_folder = os.path.join(parent_dir, 'data/images/')

In [None]:
# ##### SUPER LONG RUN!!!!

# for im in range(550,600):

#     # Load image:
#     #x_im = io.imread(images_lists_s[im][1])   
#     x_im = io.imread(os.path.join(img_folder,f'{original_files[im]}.tif'))

#     #sample = x_test[[im-550]]
#     sample = x_test[[im-550]]
#     predicted = model.predict(sample)

#     # Normalize X:
#     x_show = x_im / np.max(x_im)
#     x_show = np.repeat(x_show[:,:,np.newaxis],3,axis=2)

#     pred_show = np.ones((3234, 3840,3))

#     n_pix_per_lm = 600

#     lm = [i for i in range(40)]
#     for l in range(len(lm)):

#         temp = resize(predicted[0,30:,:,lm[l]], (3234, 3840))
#         #io.imsave(f'/home/ella/Desktop/for_stephan/{im}_{l}.tif', temp.astype(np.float32))

#         # Take x highest values:
#         val = (np.sort(temp.reshape(-1)))[-n_pix_per_lm]
#         temp = np.where(temp>=val, 1, 0)

#         temp = ((temp-np.min(temp))/(np.max(temp)-np.min(temp)) * 255).astype(np.uint8)

#         my_color_scale = sns.light_palette(sns.color_palette("Paired", 40)[l], 256)
#         my_color_scale[0] = [1,1,1]

#         for i in range(temp.shape[0]):
#             for j in range(temp.shape[1]):
#                 if temp[i,j] != 0:
#                     pred_show[i,j,0] = np.mean([pred_show[i,j,0], my_color_scale[temp[i,j]][0]])
#                     pred_show[i,j,1] = np.mean([pred_show[i,j,1], my_color_scale[temp[i,j]][1]])
#                     pred_show[i,j,2] = np.mean([pred_show[i,j,2], my_color_scale[temp[i,j]][2]])

#     new_img = cv2.addWeighted(x_show, 0.2, pred_show, 0.8, 0)

#     io.imsave(f'/home/ella/Desktop/{im}.png', new_img)

In [None]:
# new_img = cv2.addWeighted(x_show, 0.2, pred_show, 0.8, 0)

In [None]:
# #new_img = cv2.addWeighted(x_show, 0.25, my_img, 0.8, 0)

# plt.figure(figsize=(20,10))
# plt.imshow(new_img, interpolation='none')
# plt.axis('off')