# Creating segmentation masks using SIFT keypoints

This notebook contains algorithms to create segmentation masks by using the feature extraction tool SIFT. We get the SIFT keypoints using OpenCV. Those keypoints get clustered with KMeans and a gaussian filter is applied, afterwards. Pixels are classified as non-cell type below a defined threshold and classified as cell type when their values exceed this threshold.


In [None]:
import pathlib
import pandas as pd
import os
import cv2
import matplotlib.pyplot as plt
import numpy as np
from sklearn.cluster import KMeans
from skimage.io import imread
%matplotlib inline

In [None]:
# Get the directory of the repository
curr_dir = os.getcwd()
parent_dir = pathlib.Path(curr_dir).parents[1]

In [None]:
# Load the train.csv to write a new file that contains the sliced image file names
df = pd.read_csv(str(parent_dir) + '/data/data_original/train.csv')
df = df[['id','cell_type']]
df = df.drop_duplicates(subset=['id'])

path_list_astro = df.query('cell_type == "astro"')

# put paths of astro cell images in a list
astros = []
for img in range(0,len(path_list_astro.id.to_list())):
    astros.append(str(parent_dir) + '/data/data_original/train/' + str(path_list_astro.id.to_list()[img]) + '.png')

In [None]:
def save_seg_img(img_list, path_list):
    
    "Function saves clustered images"
    
    home_path = str(parent_dir) + '/data/data_preprocessed/mask_cluster/before_preprocessing/segmented_img_sift/'
    
    save_seg_dic = {k: v for k, v in zip(path_list, img_list)}
    for path, img in save_seg_dic.items():
        id_img = path.split('/')[-1].split('.')[0]
        cv2.imwrite(home_path + id_img + '_seg.png', img*255.)

In [None]:
nOctaveLayers = 40

nfeatures=20000
contrastThreshold = 0.02
edgeThreshold = 0.01
sigma = 1

astro_img = []

#iterate trough astros list
for i, path in enumerate(astros):
    if i % 10 == 0:
        print(i, 'images processed')
        
    img = imread(path)
    img_i = cv2.imread(path)
    gray = cv2.cvtColor(img_i, cv2.COLOR_BGR2GRAY)
    #get the keypoints
    sift = cv2.xfeatures2d.SIFT_create(nfeatures, nOctaveLayers, contrastThreshold, edgeThreshold, sigma)
    keypoints_1, descriptors_1 = sift.detectAndCompute(img,None)
    
    kp_lst = []
    for j in range(0,len(keypoints_1)):
        kp_lst.append(keypoints_1[j].pt)
        
    #create keypoint_shell with entries of 0
    kp_map = np.zeros((img.shape[0], img.shape[1])) 

    #rounding coordinates
    kp_lst_round = np.round(kp_lst).astype(int) 

    #change pixel value to 1, when entry in shell is a keypoint
    for k in range(0,len(kp_lst_round)):
        kp_map[kp_lst_round[k][1]][kp_lst_round[k][0]] = 1 
    
    #change pixel value of original image(s) to 255 where keypoint exists
         #first: reshaping for kmeans later on
    kp_map_reshaped = kp_map.reshape(kp_map.shape[0] * kp_map.shape[1], 1)
    orig_img_reshaped = img.reshape(img.shape[0] * img.shape[1], 1)
         #second: replacing pixel values
    for l in range(0, len(orig_img_reshaped)):
        if kp_map_reshaped[l] == 1:
            orig_img_reshaped == 255
                
    kmeans = KMeans(n_clusters=80, 
                     n_init= 5, 
                     max_iter=50,
                     random_state = 42).fit(orig_img_reshaped)

    kp_map = np.reshape(np.array(kmeans.labels_, dtype=np.uint8),(kp_map.shape[0], kp_map.shape[1]))
    
    # run gauss filter over kp_map
    kerne_size = 5
    kernel = np.ones((kerne_size,kerne_size),np.float32)/(kerne_size**2)
    dst = cv2.filter2D(kp_map,-1,kernel)
    
    #set threshold to classify pixels to cell and non-cell type
    thr_gauss = int(dst.mean())+7

    dst_new = dst.copy()
    dst_new[dst_new <= thr_gauss] = 0
    dst_new[dst_new > thr_gauss] = 1        
    
    astro_img.append(dst_new)
    
    #optional: plot each mask
    fig = plt.figure(figsize=(24,16))
    plt.imshow(cv2.imread(path), cmap = 'gray')
    plt.imshow(dst_new, alpha = 0.2)
    plt.imshow(cv2.drawKeypoints(gray,keypoints_1,img), alpha = 0.4)
    
    #if i == 1: 
        #break


In [None]:
#saving the images
save_seg_img(astro_img, astros)