Preprocess images to prepare for CNN landmark detection

Requires:
* original images
* GT landmark positions
* mask images (optional)

Does:
* Resizes the image to the analysis size 
* Normalizes per image
* Masks (if available/desirable)

In [None]:
import os
import numpy as np
import glob
from skimage import io
from skimage.filters import gaussian
import math
from skimage.transform import resize
import tifffile as tif

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_context('poster')
sns.set_style('white')
get_ipython().run_line_magic('matplotlib', 'inline')
plt.set_cmap('gray')
sns.set_context("poster")

In [None]:
## Show landmarks on image
def show_landmarks_on_image(im, landmarks):
    plt.figure(figsize=(10,10))
    implot = plt.imshow(im)
    plt.xticks([])
    plt.yticks([])
    plt.scatter(landmarks[:,0], landmarks[:,1], c='r', s=40)
    plt.show()
    sns.despine(bottom=True, left=True)

In [None]:
# Resize function with anti-alias

import skimage
skimage.__version__

# If <0.15 we need to copy the skimage.transform.resize version from version >= 15.  

### Set parameters

In [None]:
parent_dir = '/YOURPATH/'

# Size to run the analysis on:
resize_size = 1024

In [None]:
# Define number of landmarks:
n_lm = 40

# Define gaussian blur sizes to run the detection on (to benchmark)
# gaussian filter w sigma 3 worked best (not significant)
#gauss_range = [x for x in range(3,4)]
sig = 3

In [None]:
## input dirs:
imgs_dir = os.path.join(parent_dir, 'data/images/')
ldmks_dir = os.path.join(parent_dir, 'ldmks/raw/')

In [None]:
output_dir = os.path.join(parent_dir, 'data_CNN')
os.makedirs(output_dir, exist_ok=True)

### Load the images and the landmarks

In [None]:
all_images_files = [f for f in glob.glob(os.path.join(imgs_dir,'*'))]
all_landmarks_files = [f for f in glob.glob(os.path.join(ldmks_dir,'*'))] 

In [None]:
all_images_names = [os.path.basename(f)[:-4] for f in all_images_files]
all_landmarks_names = [os.path.basename(f)[:-4] for f in all_landmarks_files]

In [None]:
# Check that all images files have landmarks file:
print(f'number of images: {len(all_images_names)}')
set(all_images_names).symmetric_difference(all_landmarks_names)

# another way to test this:
#set(all_images_names) ^ set(all_landmarks_names)

In [None]:
# sort all lists to be on the same order
images_lists = zip(all_images_names, all_images_files)
landmarks_lists = zip(all_landmarks_names, all_landmarks_files)

images_lists_s = list(sorted(images_lists))
landmarks_lists_s = list(sorted(landmarks_lists))

### Delete images that can't be analysed:

In [None]:
# Delete if not all landmarks are found for image:

# Find number of landmarks for each image:
num_lm_list = [sum(1 for line in open(l[1])) for l in landmarks_lists_s]
# bool that - false if number of landmarks is not correct:
bool_lm = [1 if n==n_lm else 0 for n in num_lm_list]

# Delete faulty images and landmarks list:
images_lists_s = [im for i,im in enumerate(images_lists_s) if bool_lm[i]]
landmarks_lists_s = [lms for i,lms in enumerate(landmarks_lists_s) if bool_lm[i]]

In [None]:
# All images landmarks read - to list of np arrays:
ims_landmarks = []
for l in landmarks_lists_s:
    lines = [line.rstrip('\n') for line in open(l[1])]
    ims_landmarks.append(np.asarray([f.split(" ") for f in lines]).astype(np.float16))

In [None]:
# Check the axes in the landmarks:
max_x = 0
max_y = 0
for l in ims_landmarks:
    mx = np.max(l[:,0])
    my = np.max(l[:,1])
    if mx>max_x:
        max_x=mx
    if my>max_y:
        max_y=my
        
print(f'max landmarks: {max_x} {max_y}')

# Check the axes in the images:
from PIL import Image
from PIL.TiffTags import TAGS

with Image.open(images_lists_s[0][1]) as img:
    im_length = [img.tag[key][0] for key in img.tag if TAGS[key]=='ImageLength'][0]
    im_width = [img.tag[key][0] for key in img.tag if TAGS[key]=='ImageWidth'][0]
    #meta_dict = {TAGS[key] : img.tag[key] for key in img.tag}
    
print(f'image dims: {im_width} {im_length}')

In [None]:
# Delete if not all landmarks are inside the image 
# (sometimes a wing is a bit outside of image, and the landmark will be outside):
idx2rm = []
for i,lms in enumerate(ims_landmarks):
    #if (np.max(lms[:,0])>=(im_width-im_width/resize_size)) or (np.max(lms[:,1])>=(im_length-im_length/resize_size)):
    if (np.max(lms[:,0])>=im_width) or (np.max(lms[:,1])>=im_length):
        idx2rm.append(i)

images_lists_s = [im for i,im in enumerate(images_lists_s) if i not in idx2rm]
ims_landmarks = [lms for i,lms in enumerate(ims_landmarks) if i not in idx2rm]

### Write imgs names to file 
(ordered)

In [None]:
n_ims = len(images_lists_s)
n_ims

In [None]:
l = list(map(list, zip(*images_lists_s)))[0]
original_names_outfile = os.path.join(output_dir, 'original_files_list.txt')

with open(original_names_outfile, mode='wt', encoding='utf-8') as myfile:
    myfile.write('\n'.join(l))

### Get Landmarks Info
e.g distance between pairs of landmarks

In [None]:
##### Analyse spot relationships - 
### whats the distance between the spots:
    
max_dis_matrix = np.zeros((40,40))
min_dis_matrix = np.full((40,40), 4000)

for wing in ims_landmarks:
    for i in range(40):
        for j in range(40):
            
            ii_dist = wing[i][0]-wing[j][0]
            jj_dist = wing[i][1]-wing[j][1]
        
            dist = (math.sqrt(ii_dist**2 + jj_dist**2))
            
            max_dis_matrix[i,j] = dist if dist > max_dis_matrix[i,j] else max_dis_matrix[i,j]
            min_dis_matrix[i,j] = dist if dist < min_dis_matrix[i,j] else min_dis_matrix[i,j]

In [None]:
max_dis_matrix_int = max_dis_matrix.astype(int)
min_dis_matrix_int = min_dis_matrix.astype(int)

In [None]:
np.save(os.path.join(output_dir,'max_ldmk_dis_matrix_int'),max_dis_matrix_int)
np.save(os.path.join(output_dir,'min_ldmk_dis_matrix_int'),min_dis_matrix_int)

In [None]:
np.max(min_dis_matrix_int)

In [None]:
fig, ax = plt.subplots(2,1,figsize=(40,80))

min_val, max_val = 0, 2143

a = ax[0].matshow(min_dis_matrix_int, cmap=plt.cm.Blues)
fig.colorbar(a)
b = ax[1].matshow(max_dis_matrix_int, cmap=plt.cm.Blues)
fig.colorbar(b)

for i in range(40):
    for j in range(40):
        c = min_dis_matrix_int[j,i]
        ax[0].text(i, j, str(c), va='center', ha='center', fontsize=15)
        
for i in range(40):
    for j in range(40):
        c = max_dis_matrix_int[j,i]
        ax[1].text(i, j, str(c), va='center', ha='center', fontsize=15)
        
        
fig = plt.gcf()
fig.savefig(f'mat_max_min.png', bbox_inches="tight")

### Load all images and landmarks:
Landmarks for now just as a dot image per landmark

In [None]:
# Save x and y as numpy arrays:

## I'm only taking a 100 images as yann wants training on a small dataset 
# to see if the method can be applied on such.
n=100

X = np.zeros((n, im_length, im_width), dtype=np.float16)
Y = np.zeros((n, im_length, im_width, n_lm), dtype='uint8')

for i in range(n):
    
    # Load image:
    im = skimage.img_as_float(io.imread(images_lists_s[i][1])).astype(np.float16)    
    X[i,:,:] = im
    
    # Load Landmarks:
    landmarks = ims_landmarks[i].copy().astype(int)
    
    for j,lm in enumerate(landmarks):
        landmark_im = np.zeros(im.shape, dtype='int8')
        landmark_im[lm[1],lm[0]] = 1
        #landmark_im = gaussian(landmark_im,gaussian_size)

        Y[i,:,:,j] = landmark_im

In [None]:
np.save(os.path.join(output_dir,f'X.npy'), X)  
np.save(os.path.join(output_dir,f'Y.npy'), Y)

In [None]:
### show landmarks on example images:

nn=2

for i in range(nn):

    # plt doesnt support float16
    im = X[i].astype(np.float32)
    
    # Load Landmarks:
    landmarks = ims_landmarks[i].copy()
    
    fig, ax = plt.subplots(figsize=(10,10))
    ax.imshow(im)
    for j,lm in enumerate(landmarks):
        
        ax.text(lm[0],lm[1],f'{j}',color='r',fontsize=15, fontweight="bold")
        ax.axis('off')

    fig = plt.gcf()
    fig.savefig(f'im{i}_wing_w_numbers_ldmks.png', bbox_inches="tight")

### Normalize X
zero mean and unit variance

In [None]:
X_mean = np.mean(X)
X_std = np.std(X, dtype=np.float64)
print("X mean value is", X_mean)
print("X std value is", X_std)
np.max(X), np.min(X)

In [None]:
X -= X_mean
X /= X_std
# check again to double check
print("After normalization the data has mean value", np.mean(X))
print("After normalization the data has standard deviation", np.std(X, dtype=np.float64))

In [None]:
np.save(os.path.join(output_dir,'X_normed.npy'), X)

### Check the images:

In [None]:
X = np.load(os.path.join(output_dir,'X_normed.npy'))
Y = np.load(os.path.join(output_dir,'Y.npy'))

In [None]:
im = X[0]
im.shape

In [None]:
y = np.sum(Y[2], axis=2)
np.where(y==1)

## Make Y images for CNN:

In [None]:
# If needed load Y:
n_ims = Y.shape[0]
n_ims

In [None]:
np.unique(Y[2], return_counts=True)

### Make all Gaussians:
(normalized between 0-1)

In [None]:
#for sig in gauss_range:

Y_sig = np.zeros((n_ims, im_length, im_width, n_lm), dtype='float16')
for i in range(n_ims):
    for j in range(n_lm):
        im = gaussian(Y[i,:,:,j], sig)
        # Normalize and assign:
        Y_sig[i,:,:,j] = im/np.max(im)

np.save(os.path.join(output_dir,f'Y_sig{sig}.npy'), Y_sig)

In [None]:
Y_sig.shape

In [None]:
y.dtype

In [None]:
#Test it:

i = 2

#Y_sig = np.load(os.path.join(output_path,'Y_sig4.npy'))

y = np.sum(Y_sig[i], axis=2)

im = y#X[i,:,:] + y
#im = y[200:500,2000:2500]
plt.figure(figsize=(10,10))
implot = plt.imshow(im.astype(np.float32))

##### Make Euclidean Distance (Vector) Image:
that will be 2 images per landmark - x and y distances:  
didnt work as well as gaussian

In [None]:
# Y_IJdist = np.zeros((n_ims, resize_size, resize_size, n_lm, 2), dtype=np.float32)
# for i in range(n_ims):
#     landmarks = ims_landmarks[i].copy()
    
#     landmarks[:,1] += 606
#     landmarks = landmarks/20
    
#     for j,lm in enumerate(landmarks):
        
#         Y_IJdist[i,:,:,j,0] = np.fromfunction(lambda ii,jj: lm[1]-ii, (resize_size, resize_size), dtype=np.float32)
#         Y_IJdist[i,:,:,j,1] = np.fromfunction(lambda ii,jj: lm[0]-jj, (resize_size, resize_size), dtype=np.float32)
        
# np.save(os.path.join(output_dir,f'Y_IJdist.npy'), Y_IJdist)

In [None]:
# #Test it:

# i = 2

# #Y_sig = np.load(os.path.join(output_dir,'Y_sig4.npy'))

# y = np.zeros((resize_size,resize_size))
# for j in range(1):
#     y[:,:] = y + Y_IJdist[i,:,:,j,0] #+ Y_IJdist[i,:,:,j,1]

# im = X[i,:,:] + y
# plt.figure(figsize=(10,10))
# implot = plt.imshow(im)

In [None]:
# # Make between -1 - 1
# Y_IJdist_normed = np.zeros((n_ims, resize_size, resize_size, n_lm, 2), dtype=np.float16)

# for i in range(n_ims):
#     for j,lm in enumerate(landmarks):
#         im_ii = Y_IJdist[i,:,:,j,0]
#         Y_IJdist_normed[i,:,:,j,0] = im_ii/np.max(np.absolute(im_ii))
#         im_jj = Y_IJdist[i,:,:,j,1]
#         Y_IJdist_normed[i,:,:,j,1] = im_jj/np.max(np.absolute(im_jj))
        
# np.save(os.path.join(output_dir,f'Y_IJdist_normed.npy'), Y_IJdist_normed)

In [None]:
# # Merge Dims:
# Y_IJdist_normed = np.load(os.path.join(output_dir,'Y_IJdist_normed.npy'))
# Y_IJdist_normed_80 = np.zeros((n_ims, resize_size, resize_size, n_lm*2), dtype=np.float32)

# for i in range(n_ims):
#     for j in range(n_lm):
#         Y_IJdist_normed_80[i,:,:,j*2] = Y_IJdist_normed[i,:,:,j,0]
#         Y_IJdist_normed_80[i,:,:,j*2+1] = Y_IJdist_normed[i,:,:,j,1]
        
# np.save(os.path.join(output_dir,f'Y_IJdist_normed_80.npy'), Y_IJdist_normed_80)

In [None]:
# # ~zero to 1 or -1, 1 and -1 to zero
# Y_IJdist_normed_flipped = np.where(Y_IJdist_normed<=0, Y_IJdist_normed+1, 10)
# Y_IJdist_normed_flipped = np.where(Y_IJdist_normed>0, Y_IJdist_normed-1, Y_IJdist_normed_flipped)

# np.save(os.path.join(output_dir,f'Y_IJdist_normed_flipped.npy'), Y_IJdist_normed_flipped)

In [None]:
# # Merge Dims:
# Y_IJdist_normed_flipped = np.load(os.path.join(output_dir,'Y_IJdist_normed_flipped.npy'))
# Y_IJdist_normed_flipped_80 = np.zeros((n_ims, resize_size, resize_size, n_lm*2), dtype=np.float32)

# for i in range(n_ims):
#     for j in range(n_lm):
#         Y_IJdist_normed_flipped_80[i,:,:,j*2] = Y_IJdist_normed_flipped[i,:,:,j,0]
#         Y_IJdist_normed_flipped_80[i,:,:,j*2+1] = Y_IJdist_normed_flipped[i,:,:,j,1]
        
# np.save(os.path.join(output_dir,f'Y_IJdist_normed_flipped_80.npy'), Y_IJdist_normed_flipped_80)

In [None]:
# # ABS
# Y_IJdist_normed_flipped_abs = np.absolute(Y_IJdist_normed_flipped)
# np.save(os.path.join(output_dir,f'Y_IJdist_normed_flipped_abs.npy'), Y_IJdist_normed_flipped_abs)

In [None]:
# # ABS merge dims:

# Y_IJdist_normed_flipped_abs_80 = np.zeros((n_ims, resize_size, resize_size, n_lm*2), dtype=np.float32)

# for i in range(n_ims):
#     for j,lm in enumerate(landmarks):
#         Y_IJdist_normed_flipped_abs_80[i,:,:,j*2] = Y_IJdist_normed_flipped_abs[i,:,:,j,0]
#         Y_IJdist_normed_flipped_abs_80[i,:,:,j*2+1] = Y_IJdist_normed_flipped_abs[i,:,:,j,1]
        
        
# np.save(os.path.join(output_dir,f'Y_IJdist_normed_flipped_abs_80.npy'), Y_IJdist_normed_flipped_abs_80)

#### Make euclidean distance images:

In [None]:
# Y_IJdist = np.load(os.path.join(output_dir,'Y_IJdist.npy'))
# Y_eucli = np.zeros((n_ims, resize_size, resize_size, n_lm), dtype=np.float32)

# for i in range(n_ims):
#     for j in range(n_lm):
#         Y_eucli[i,:,:,j] = np.sqrt(np.square(Y_IJdist[i,:,:,j,0]) + np.square(Y_IJdist[i,:,:,j,1]))

# Y_eucli_normed = np.zeros((n_ims, resize_size, resize_size, n_lm), dtype=np.float32)
# for i in range(n_ims):
#     for j in range(n_lm):
#         im = Y_eucli[i,:,:,j]
#         Y_eucli_normed[i,:,:,j] = im/np.max(im)
        
# np.save(os.path.join(output_dir,f'Y_eucli_normed.npy'), Y_eucli_normed)

In [None]:
# # ~zero to 1, 1 to zero
# Y_eucli_normed_flipped = np.absolute(Y_eucli_normed-1)

# np.save(os.path.join(output_dir,f'Y_eucli_normed_flipped.npy'), Y_eucli_normed_flipped)

In [None]:
# plt.imshow(Y_eucli_normed_flipped[3,:,:,1])

In [None]:
# # Only neighborhood (biggest values):
# Y_eucli_normed_flipped_hood = np.copy(Y_eucli_normed_flipped)
# Y_eucli_normed_flipped_hood[Y_eucli_normed_flipped_hood<0.9] = 0

# np.save(os.path.join(output_dir,f'Y_eucli_normed_flipped_hood.npy'), Y_eucli_normed_flipped_hood)

In [None]:
# # Make one euclidean distance image - the nearest point:
# Y_eucli_normed_flipped_closest_1 = np.zeros((n_ims, resize_size, resize_size), dtype=np.float32)

# for i in range(n_ims):
#     Y_eucli_normed_flipped_closest_1[i,:,:] = np.max(Y_eucli_normed_flipped[i,:,:,:], axis=2)
    
# np.save(os.path.join(output_dir,f'Y_eucli_normed_flipped_closest_1.npy'), Y_eucli_normed_flipped_closest_1)

In [None]:
# plt.imshow(Y_eucli_normed_flipped_closest_1[3,:,:])

In [None]:
# # Take only the neighborhood (biggest values):
# Y_eucli_normed_flipped_closest_hood_1 = np.copy(Y_eucli_normed_flipped_closest_1)
# Y_eucli_normed_flipped_closest_hood_1[Y_eucli_normed_flipped_closest_hood_1<0.9] = 0

# np.save(os.path.join(output_dir,f'Y_eucli_normed_flipped_closest_hood_1.npy'), Y_eucli_normed_flipped_closest_hood_1)
# plt.imshow(Y_eucli_normed_flipped_closest_hood_1[3,:,:])

### create masked images

In [None]:
x = np.load(os.path.join(output_dir, 'X_normed.npy'))

In [None]:
n = x.shape[0]
n

In [None]:
x.shape

In [None]:
with open(original_names_outfile,'r') as f:
    filenames = (f.read()).split("\n")
    
filenames = filenames[:n]

In [None]:
masks_path = os.path.join(output_dir,'..','data','labels')

In [None]:
masks = np.asarray([io.imread(os.path.join(masks_path, f'{f}.tif')) for f in filenames], dtype=bool)

In [None]:
x[masks] = np.max(x)

In [None]:
plt.imshow(x[11].astype(np.float32))
plt.axis('off')

In [None]:
np.save(os.path.join(output_dir,f'X_masked.npy'), x)

### Resize images

In [None]:
X_name = 'X_masked.npy'
Y_name = 'Y_sig3.npy'

In [None]:
x = np.load(os.path.join(input_dir, X_name))
y = np.load(os.path.join(input_dir, Y_name))

In [None]:
n = 30
x = x[:n]
y = y[:n]

In [None]:
siz = (1724,2048)
x_resized = np.zeros((n,siz[0],siz[1]), dtype=np.float16)
for i,im in enumerate(x):
    x_resized[i] = resize(im.astype(np.float32), siz, anti_aliasing=True, preserve_range=True)

In [None]:
siz = (1724,2048)
y_resized = np.zeros((n,siz[0],siz[1],y.shape[-1]), dtype=np.float16)

for i in range(y.shape[0]):
    for j in range(y.shape[-1]):
        y_resized[i,:,:,j] = resize(y[i,:,:,j].astype(np.float32), siz, anti_aliasing=True, preserve_range=True)

In [None]:
y_resized.shape

In [None]:
np.save(os.path.join(input_dir, f'X_masked_resized_{siz[0]}_{siz[1]}'), x_resized)
np.save(os.path.join(input_dir, f'Y_sig3_resized_{siz[1]}_{siz[1]}'), y_resized)