# Preprocessing for Weak Supervision

In this notebook, we perform segmentation on the cardiac MRI images. 

As input, we expect gamma-corrected MRI images. As output, we save a list of `failed_indexes.npy` that could not be properly segmented and `vocab_matrix.npy`, which contains the region properties for the segmented aortic valve.

In [None]:
import numpy as np

import sys
import os

# viz
from matplotlib import pyplot as plt
import matplotlib.patches as mpatches

# segmentation 
from skimage.filters import threshold_otsu
from skimage.measure import label, regionprops
from skimage.morphology import closing, square

### 1. Specify the path to (a) input images and (b) output directory

In [None]:
INPUT_IMAGES_PATH = "6_brightest_gamma_images.npy"
OUTPUT_DIR = "./"

In [None]:
images = np.load(INPUT_IMAGES_PATH, mmap_mode='r')
failed_indexes = []

### 2. Perform segmentation using SKlearn functions

In [None]:
def extract_region_label(image):
    try: 
        thresh = threshold_otsu(image)    
    except: 
        raise ValueError('invalid image for otsu thresholding', np.max(image))
    
    bw = closing(image > thresh, square(2))        
    region_label = label(bw, connectivity=2)
    return region_label 

def extract_region_labels_from_images(images, indexes=None):
    region_labels = np.empty(images.shape)
    
    for i in range(len(images)):
        sys.stdout.write('\r' + str(i+1) + " / " + str(len(images)) + " region labels extracted")
        sys.stdout.flush()
        try: 
            region_labels[i] = extract_region_label(images[i])
        except ValueError as err: 
            print ('\r image', i, err.args)
            failed_indexes.append((i, 'extract_region_label'))
        
    return region_labels

region_labels = extract_region_labels_from_images(images) 

### 3. Extract the vocab matrix from each segmented region label
This section uses hand-tuned thresholds (defined in `find_target_region`) to select the correct segmented region from a number of candidates.

In [None]:
def find_target_region(region_label, image):
    ''' 
    Given region labels, extracts region of interest using the following steps: 
        1) filter out regions that are too small 
        2) filter out regions with absurdly large eccentricity (image artifcats)
        3) narrow down to 2 brightest regions 
        4)  a) if regions are close to one another, pick one to right
            b) else, pick region closest to bottom left corner
            
    TUNEABLE PARAMS:
        - WIDTH_DELTA_THRESHOLD: num horizontal pixels between regions 
                to be considered 'close'
        - HEIGHT_DELTA_THRESHOLD: num vertical pixels between regions 
                to be considered 'close'
        - 
    '''
    
    # tune these params 
    WIDTH_DELTA_THRESHOLD = 25
    HEIGHT_DELTA_THRESHOLD = 15
    AREA_THRESHOLD = 30 
    ECCENTRICITY_THRESHOLD = 0.98
    
    def two_regions_close(two_regions):
        ''' Helper fn to determine if regions are considered 'close' '''
        
        if len(two_regions) != 2: return False
        
        width_delta = abs(two_regions[0].centroid[1] - two_regions[1].centroid[1])
        height_delta = abs(two_regions[0].centroid[0] - two_regions[1].centroid[0])
        return width_delta < WIDTH_DELTA_THRESHOLD \
                and height_delta < HEIGHT_DELTA_THRESHOLD
        
    regions = regionprops(region_label.astype(int), image)

    # remove all blobs < 30 area and > 0.95 eccentricity
    filtered = list(filter(lambda x: x.area >= AREA_THRESHOLD 
            and x.eccentricity < ECCENTRICITY_THRESHOLD, regions))
    
    if len(filtered) == 0: 
        raise ValueError('bad threshold', 
                [(r.area, r.eccentricity) for r in regions])
    
    # pick top 2 mean_intensity 
    sorted_mean_intensity = sorted(filtered, key=lambda x: -x.mean_intensity)[:2]
    
    if two_regions_close(sorted_mean_intensity):
        # return right most
        return sorted(sorted_mean_intensity, key=lambda x: -x.centroid[1])[0]
    else: 
        # return botom-left most
        bottom_left = np.array([image.shape[0], 0])
        bottom_leftmost = sorted(sorted_mean_intensity, key=lambda x: np.linalg.norm(bottom_left - x.centroid))[0]

        return bottom_leftmost

*Note*: you can specify whether you want to visualize the target region you've chosen with the `show_images` flag.

In [None]:
def visualize_target_region(image, region):
        # plot original and target region
        fig, axes = plt.subplots(1)
        axes.imshow(image)
        
        # draw bounding box 
        minr, minc, maxr, maxc = region.bbox
        rect = mpatches.Rectangle((minc, minr), maxc - minc, maxr - minr,
                                  fill=False, edgecolor='red', linewidth=2)
        axes.add_patch(rect)
        plt.show()

def extract_vocab_from_region_labels(region_labels, images, indexes=None, show_images=False):
    def extract_vocab(region):
        vocab = np.zeros(10)
        fields = ['area', 'eccentricity', 'equivalent_diameter', 'extent', 'major_axis_length', 'minor_axis_length', 'perimeter', 'max_intensity', 'mean_intensity', 'min_intensity']
        for i in range(10):
            vocab[i] = region[fields[i]]
        return vocab
    
    vocab_matrix = np.zeros((10, len(images)))
    for i in range(len(images)):
        sys.stdout.write('\r' + str(i+1) + " / " + str(len(images)) + " vocab extracted\r")
        sys.stdout.flush()
        
        try: 
            region = find_target_region(region_labels[i], images[i])
            
            if show_images:
                print (images[i].shape)
                visualize_target_region(images[i], region)

            vocab_matrix[:, i] = extract_vocab(region)
        except ValueError as err: 
            print ('\r image', i, err.args)
            failed_indexes.append((i, 'find_target_region'))
        
    return vocab_matrix

vocab_matrix = extract_vocab_from_region_labels(region_labels, images, show_images=True)

### 4. Save `vocab_matrix` and `failed_indexes`

In [None]:
out_file = os.path.join(OUTPUT_DIR, 'vocab_matrix.npy')
print ('saving vocab matrix to: %s' % out_file)
np.save(out_file, vocab_matrix)

if len(failed_indexes) > 0:
    failed_indexes_file = os.path.join(OUTPUT_DIR, 'failed_indexes.npy')
    print ('failed', len(failed_indexes), 'indexes:', failed_indexes)
    print ('saving vocab matrix to: %s' % failed_indexes_file)
    np.save(failed_indexes_file, np.array(failed_indexes)[:,0].astype(int)) 
