## Generate superpixel-based pseudolabels


### Overview

This is the third step for data preparation

Input: normalized images

Output: pseulabel label candidates for all the images

In [244]:
%reset
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import copy
import skimage

from skimage.segmentation import slic
from skimage.segmentation import mark_boundaries
from skimage.util import img_as_float
from skimage.measure import label
import scipy.ndimage.morphology as snm
from skimage import io
import argparse
import numpy as np
import glob

import SimpleITK as sitk
import os

to01 = lambda x: (x - x.min()) / (x.max() - x.min())



Once deleted, variables cannot be recovered. Proceed (y/[n])? y
The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


**Summary**

a. Generate a mask of the patient to avoid pseudolabels of empty regions in the background

b. Generate superpixels as pseudolabels

**Configurations of pseudlabels**

```python
# default setting of minimum superpixel sizes
segs = seg_func(img[ii, ...], min_size = 400, sigma = 1)
# you can also try other configs
segs = seg_func(img[ii, ...], min_size = 100, sigma = 0.8)
```


In [245]:
DATASET_CONFIG = {'SABS':{
                    'img_bname': f'./SABS/sabs_CT_normalized/image_*.nii.gz',
                    'out_dir': './SABS/sabs_CT_normalized',
                    'fg_thresh': 1e-4
                    },
                  'CHAOST2':{
#                       'img_bname': f'/Users/kian/Desktop/superpixel/superpixel-based-registration/data/CHAOST2/chaos_MR_T2_normalizedimage_*.nii.gz',
                      'img_bname': f'./CHAOST2/chaos_MR_T2_normalized/image_*.nii.gz',
                      'out_dir': './CHAOST2/new_psudo',
                      'fg_thresh': 1e-4 + 50
                    }
                 }
            

DOMAIN = 'CHAOST2'
img_bname = DATASET_CONFIG[DOMAIN]['img_bname']
imgs = glob.glob(img_bname, recursive=True)
# out_dir = DATASET_CONFIG[DOMAIN]['out_dir']

In [246]:
imgs

['./CHAOST2/chaos_MR_T2_normalized/image_36.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_8.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_38.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_34.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_10.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_2.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_22.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_20.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_32.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_5.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_13.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_37.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_39.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_31.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_15.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_3.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_19.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_21.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_33.nii.gz

In [247]:
imgs = sorted(imgs, key = lambda x: int(x.split('_')[-1].split('.nii.gz')[0]) )

In [248]:
imgs

['./CHAOST2/chaos_MR_T2_normalized/image_1.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_2.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_3.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_5.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_8.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_10.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_13.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_15.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_19.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_20.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_21.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_22.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_31.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_32.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_33.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_34.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_36.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_37.nii.gz',
 './CHAOST2/chaos_MR_T2_normalized/image_38.nii.gz'

In [249]:
MODE = 'MIDDLE' # minimum size of pesudolabels. 'MIDDLE' is the default setting

# wrapper for process 3d image in 2d
def superpix_vol(img, min_size, sigma, method = 'fezlen', **kwargs):
    """
    loop through the entire volume
    assuming image with axis z, x, y
    """
    if method =='fezlen':
        seg_func = skimage.segmentation.felzenszwalb
    else:
        raise NotImplementedError
        
    out_vol = np.zeros(img.shape)
    for ii in range(img.shape[0]):
        if MODE == 'MIDDLE':
            segs = seg_func(img[ii, ...], min_size=min_size, sigma=sigma)
        else:
            raise NotImplementedError
        out_vol[ii, ...] = segs
        
    return out_vol

# thresholding the intensity values to get a binary mask of the patient
def fg_mask2d(img_2d, thresh): # change this by your need
    mask_map = np.float32(img_2d > thresh)
    
    def getLargestCC(segmentation): # largest connected components
        labels = label(segmentation)
        assert( labels.max() != 0 ) # assume at least 1 CC
        largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
        return largestCC
    if mask_map.max() < 0.999:
        return mask_map
    else:
        post_mask = getLargestCC(mask_map)
        fill_mask = snm.binary_fill_holes(post_mask)
    return fill_mask

# remove superpixels within the empty regions
def superpix_masking(raw_seg2d, mask2d):
    raw_seg2d = np.int32(raw_seg2d)
    lbvs = np.unique(raw_seg2d)
    max_lb = lbvs.max()
    raw_seg2d[raw_seg2d == 0] = max_lb + 1
    lbvs = list(lbvs)
    lbvs.append( max_lb )
    raw_seg2d = raw_seg2d * mask2d
    lb_new = 1
    out_seg2d = np.zeros(raw_seg2d.shape)
    for lbv in lbvs:
        if lbv == 0:
            continue
        else:
            out_seg2d[raw_seg2d == lbv] = lb_new
            lb_new += 1
    
    return out_seg2d
            
def superpix_wrapper(img, verbose = False, fg_thresh = 1e-4, min_size=400, sigma=1):
    raw_seg = superpix_vol(img, min_size=min_size, sigma=sigma)
    fg_mask_vol = np.zeros(raw_seg.shape)
    processed_seg_vol = np.zeros(raw_seg.shape)
    for ii in range(raw_seg.shape[0]):
        if verbose:
            print("doing {} slice".format(ii))
        _fgm = fg_mask2d(img[ii, ...], fg_thresh )
        _out_seg = superpix_masking(raw_seg[ii, ...], _fgm)
        fg_mask_vol[ii] = _fgm
        processed_seg_vol[ii] = _out_seg
    return fg_mask_vol, processed_seg_vol
        
# copy spacing and orientation info between sitk objects
def copy_info(src, dst):
    dst.SetSpacing(src.GetSpacing())
    dst.SetOrigin(src.GetOrigin())
    dst.SetDirection(src.GetDirection())
    # dst.CopyInfomation(src)
    return dst


def strip_(img, lb):
    img = np.int32(img)
    if isinstance(lb, float):
        lb = int(lb)
        return np.float32(img == lb) * float(lb)
    elif isinstance(lb, list):
        out = np.zeros(img.shape)
        for _lb in lb:
            out += np.float32(img == int(_lb)) * float(_lb)
            
        return out
    else:
        raise Exception
        
        
def read_nii_bysitk(input_fid, peel_info=False):
    """ read nii to numpy through simpleitk
        peelinfo: taking direction, origin, spacing and metadata out
    """
    img_obj = sitk.ReadImage(input_fid)
    img_np = sitk.GetArrayFromImage(img_obj)
    if peel_info:
        info_obj = {
            "spacing": img_obj.GetSpacing(),
            "origin": img_obj.GetOrigin(),
            "direction": img_obj.GetDirection(),
            "array_size": img_np.shape
        }
        return img_np, info_obj
    else:
        return img_np

In [250]:
# Generate pseudolabels for every image and save them

def get_psudo_labels(imgs, max_idx, min_size, sigma):
    all_fgs, all_segs, all_imgs, p_ids = list(), list(), list(), list()
    for i, img_fid in enumerate(imgs):
        if i > max_idx:
            break

        idx = os.path.basename(img_fid).split("_")[-1].split(".nii.gz")[0]
        img_np = read_nii_bysitk(img_fid)
        im_obj = sitk.ReadImage(img_fid)
        out_fg, out_seg = superpix_wrapper(sitk.GetArrayFromImage(im_obj),\
                                           fg_thresh = DATASET_CONFIG[DOMAIN]['fg_thresh'], min_size=min_size, sigma=sigma)
        p_ids.append(idx)
        all_fgs.append(out_fg)
        all_segs.append(out_seg)
        all_imgs.append(img_np)
    #     out_fg_o = sitk.GetImageFromArray(out_fg ) 
    #     out_seg_o = sitk.GetImageFromArray(out_seg )

    #     out_fg_o = copy_info(im_obj, out_fg_o)
    #     out_seg_o = copy_info(im_obj, out_seg_o)
    #     seg_fid = os.path.join(out_dir, f'superpix-{MODE}_{idx}.nii.gz')
    #     msk_fid = os.path.join(out_dir, f'fgmask_{idx}.nii.gz')
    #     sitk.WriteImage(out_fg_o, msk_fid)
    #     sitk.WriteImage(out_seg_o, seg_fid)
        print(f'image with id {idx} has finished')
    return p_ids, all_imgs, all_fgs, all_segs

def get_seg(all_segs, patient_idx, seg_depth=None):
    if seg_depth:
        return all_segs[patient_idx][seg_depth]
    return all_segs[patient_idx]

p_ids, all_imgs, all_fgs, all_segs = get_psudo_labels(imgs, max_idx=10, min_size=600, sigma=1)

image with id 1 has finished
image with id 2 has finished
image with id 3 has finished
image with id 5 has finished
image with id 8 has finished
image with id 10 has finished
image with id 13 has finished
image with id 15 has finished
image with id 19 has finished
image with id 20 has finished
image with id 21 has finished


In [267]:
'''Psudo label registeration research'''

import pandas as pd
import seaborn as sns
import io
from skimage import color
from PIL import Image
from matplotlib import cm
from sklearn.preprocessing import MinMaxScaler
from PIL import Image
from PIL import ImageFont
from PIL import ImageDraw 


def plot_all(img_list):
    try:
        plt.figure(figsize=(20, 12))
        for i in range(len(img_list)):
            img = img_list[i]
            plt.subplot(max(1, len(img_list)//4), 4, i+1)
            plt.imshow(img)
            plt.xticks([])
            plt.yticks([]) 
    except Exception:
        print('Some exception')
        
def plot(pic):
    plt.figure(figsize=(5, 5))
    plt.imshow(pic)
    
def pl_hist(pl):
    # pl stands for psudo label
    l = pl.ravel()
    l = pd.Series(l)
    l = l[l != 0]
    plt.figure(figsize=(10, 5))
    sns.countplot(l)
    
def plot_msk(pl, num):
    msk = pl==num
    plot(np.multiply(pl, msk))
    
def pair_plot(p_ids, all_segs, all_imgs, p_idx, depth):
    plt.figure(figsize=(10, 10))
    ly1, ly2 = all_imgs[p_idx][depth], all_imgs[p_idx][depth + 1]
    pl1, pl2 = all_segs[p_idx][depth], all_segs[p_idx][depth + 1]
    for i, pic in enumerate([ly1, ly2, pl1, pl2]):
        plt.subplot(2, 2, i+1)
        plt.imshow(pic)
        if i < 2:
            plt.title(f'slice {depth + i} patient {p_ids[p_idx]} image')
        else:
            plt.xlabel(f'slice {depth + i - 2} patient {p_ids[p_idx]} psudo label\n'\
                       + 'super pixels min size: 800 - sigma: 1', fontsize=12)
            
def pair_plot2(seq, depth, pid, score, j, step):
    fig = plt.figure(figsize=(10, 10))
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 10))
    for i, pic in enumerate(seq):
        if i > 1: break;
        if i == 0:
            ax1.imshow(pic + (100 * seq[i + 2]));
            ax1.set(title=f'Patient {pid}   slice {depth}   score {round(score, 2)}');
        else:
            ax2.imshow(pic + (100 * seq[i + 2]));
            ax2.set(title=f'Patient {pid}   slice {depth + step}   score {round(score, 2)}');
    fig.savefig(f'P{pid}D{depth}N{j}.png', transparent=True, bbox_inches='tight');

        
        
def plot_all_pairs(p_ids, all_segs, all_imgs, p_idx, from_depth, to_depth):
    for d in range(from_depth, to_depth):
        pair_plot(p_ids, all_segs, all_imgs, p_idx, d)
         
            
def play_images(p_ids, all_segs, all_imgs, p_idx, from_depth, to_depth):
    import imageio
    scaler = MinMaxScaler()
    images = []
    for d in range(from_depth, to_depth):
        pic = scaler.fit_transform(all_segs[p_idx][d])
        pic = Image.fromarray(np.uint8(cm.viridis(pic)*255))
        draw = ImageDraw.Draw(pic)
#         font = ImageFont.truetype("sans-serif.ttf", 16)
        draw.text((0, 10), f"Patient {p_ids[p_idx]} slice {d}", (255,255,255))
        images.append(pic)
    imageio.mimwrite(os.getcwd() + '/movie2s.mp4', images, fps=6)
    

In [268]:
def get_masks(pl):
    uniq = pd.Series(pl.ravel()).unique()
    masks = list()
    for pix in uniq:
        msk = np.multiply(pl, pl==pix)
        masks.append(msk)
    return masks


def get_good_pairs(p_ids, all_segs, p_idx, depth, step):
    flatten = np.vectorize(lambda x: 1 if x > 0 else x)
    pid = p_ids[p_idx]
    pl1, pl2 = all_segs[p_idx][depth], all_segs[p_idx][depth + step]
    masks1, masks2 = get_masks(pl1), get_masks(pl2)
    good_pairs = list()
    for i, m1 in enumerate(masks1):
        for j, m2 in enumerate(masks2):
            m1, m2 = flatten(m1), flatten(m2)
            # IOU score
            intersection = np.multiply(m1, m2).sum()
            union = flatten(m1 + m2).sum()
            score = intersection / union
            if score >= 0.7:
                good_pairs.append((masks1[i], masks2[j], depth, score))
    return good_pairs

In [269]:
def get_gps_by_depth(p_ids, all_imgs, all_segs, p_idx, step):
    good_pairs_each_depth = []
    for d in range(all_imgs[p_idx].shape[0] - step):
        print(f'starting layer {d}')
        gp = get_good_pairs(p_ids, all_segs, p_idx=p_idx, depth=d, step=step)
        good_pairs_each_depth.append(gp)
    return good_pairs_each_depth

def plot_good_pairs(p_ids, all_imgs, all_segs, p_idx, step):
    pid, pic_3d, segs_3d = p_ids[p_idx], all_imgs[p_idx], all_segs[p_idx]
    pgs_by_depth = get_gps_by_depth(p_ids, all_imgs, all_segs, p_idx, step=step)
    for i, masks in enumerate(pgs_by_depth):
        if len(masks) > 0:
            for j, vals in enumerate(masks):
                m1, m2, d, s = vals
                pair_plot2([pic_3d[d], pic_3d[d + step], m1, m2], depth=d, pid=pid, score=s, j=j, step=step)

In [271]:
# plot_good_pairs(p_ids, all_imgs, all_segs, p_idx=1, step=1);