In [1]:
# In this notebook, we use various preprocessing treatments to create different images instances, for Vessels,
# red lesions and bright lesions.
# Used preprocessing methods are:
# - Red channel only
# - Green channel only
# - Blue channel only
# - CLAHE
# - Median filter
# - Concatenation of raw and Green
# - Morphopreprocess
# We then train different models for each of these
# Metrics and tests are done in a different notebook

In [2]:
%load_ext autoreload
%autoreload 2
from sys import path
path.append('/home/clement/Documents/Code/JuNNo/lib')
path.append('../code/')
import junno.datasets as J
import cv2
import numpy as np
import torch
from utils.eval.tester import Tester
from manager import Trainer
import yaml
from easydict import EasyDict
import pprint
from sklearn.utils import class_weight
from networks.unet import UNet
import datetime
pp = pprint.PrettyPrinter(indent=4)


In [3]:
# RAW

size = 512
PATH = '/home/clement/Documents/database/arnaud/messidor/database/'
d_raw = J.images(PATH+'images/', reshape=size, keep_proportion='pad')
d_raw

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [4]:
# BRIGHT LESIONS

images = J.images(PATH+'images', reshape=size)
red_hemorrhages = J.images(PATH+'Bright/Cotton Wool Spots', recursive=True, normalize=False, reshape=size)
red_pre_re_hem = J.images(PATH+'Bright/Drusen', recursive=True, normalize=False, reshape=size)
red_sub_re_hem = J.images(PATH+'Bright/Exudates', recursive=True, normalize=False, reshape=size)
def join_dataset(*args):
    join_names = []
    arguments = {}
    col_names = []
    for arg in args:
        arguments['name'] = arg.col.name
        join_names.append(arg.col.name)
        col_name = arg._name.replace('-', '_').replace(' ', '_')
        col_names.append(col_name)
        arguments[col_name] = arg.col.data
    return J.join(join_names, **arguments)
def merge_columns(Cotton_Wool_Spots, Drusen, Exudates):
    joined = np.any((Cotton_Wool_Spots, Drusen, Exudates) , axis=0)
    return joined
j_brt_raw = join_dataset(images, red_hemorrhages, 
                      red_pre_re_hem, red_sub_re_hem).concat(x='images').apply('gt', merge_columns, keep_parent=False)
j_brt_raw

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [5]:
# RED LESIONS

images = J.images(PATH+'images', reshape=size)
red_hemorrhages = J.images(PATH+'Red/Hemorrhages', recursive=True, normalize=False, reshape=size)
red_pre_re_hem = J.images(PATH+'Red/Pre-retinal hemorrhage', recursive=True, normalize=False, reshape=size)
red_sub_re_hem = J.images(PATH+'Red/Sub-retinal hemorrhage', recursive=True, normalize=False, reshape=size)
red_ma = J.images(PATH+'Red/Microaneurysms', recursive=True, normalize=False, reshape=size)
def join_dataset(*args):
    join_names = []
    arguments = {}
    col_names = []
    for arg in args:
        arguments['name'] = arg.col.name
        join_names.append(arg.col.name)
        col_name = arg._name.replace('-', '_').replace(' ', '_')
        col_names.append(col_name)
        arguments[col_name] = arg.col.data
    return J.join(join_names, **arguments)
def merge_columns(Pre_retinal_hemorrhage, Hemorrhages, Sub_retinal_hemorrhage, Microaneurysms):
    joined = np.any((Pre_retinal_hemorrhage,Hemorrhages, Sub_retinal_hemorrhage,Microaneurysms) , axis=0)
    return joined
j_red_raw = join_dataset(images, red_hemorrhages, 
                      red_pre_re_hem, red_sub_re_hem, red_ma).concat(x='images').apply('gt', merge_columns, keep_parent=False)
j_red_raw

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [6]:
# GREEN CHANNEL

def extract_green(data):
    return (data[1,:,:])[np.newaxis, :, :]

d_green = d_raw.apply('data', extract_green)
d_green

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [7]:
# BRIGHT LESIONS - GREEN CHANNEL

j_brt_g = j_brt_raw.apply('x', extract_green)
j_brt_g

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [8]:
# RED LESIONS - GREEN CHANNEL

j_red_g = j_red_raw.apply('x', extract_green)
j_red_g

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [9]:
# BLUE CHANNEL

def extract_blue(data):
    return (data[0,:,:])[np.newaxis, :, :]

d_blue = d_raw.apply('data', extract_blue)
d_blue

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [10]:
# BRIGHT LESIONS - BLUE CHANNEL

j_brt_b = j_brt_raw.apply('x', extract_blue)
j_brt_b

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [11]:
# RED LESIONS - BLUE CHANNEL

j_red_b = j_red_raw.apply('x', extract_blue)
j_red_b

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [12]:
# RED CHANNEL

def extract_red(data):
    return (data[2,:,:])[np.newaxis, :, :]

d_red = d_raw.apply('data', extract_red)
d_red

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [13]:
# BRIGHT LESIONS - RED CHANNEL

j_brt_r = j_brt_raw.apply('x', extract_red)
j_brt_r

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [14]:
# RED LESIONS - RED CHANNEL

j_red_r = j_red_raw.apply('x', extract_red)
j_red_r

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [15]:
# RED LESIONS - MEDIAN FILTER

def median_filter(im):
    k = 3
    bg = cv2.medianBlur(im, k)
    return bg

j_red_median = j_red_raw.apply_cv('x', median_filter)

In [16]:
# BRIGHT LESIONS - MEDIAN FILTER

j_brt_median = j_brt_raw.apply_cv('x', median_filter)

In [17]:
# VESSELS - GREEN AND BLUE CHANNELS

def extract_green_and_blue(data):
    return (data[0:2,:,:])

d_gb = d_raw.apply('data', extract_green_and_blue)
d_gb

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [18]:
# CLAHE

def median_filter_clahe(im):
    k = np.max(im.shape)//20*2+1
    bg = cv2.medianBlur(im, k)
    return bg
def subtract_median_bg_image(im):
    return cv2.addWeighted(im, 4, median_filter_clahe(im), -4, 128)

def subtract_gaussian_bg_image(im):
    k = np.max(im.shape)/10
    bg = cv2.GaussianBlur(im ,(0,0) ,k)
    return cv2.addWeighted(im, 4, bg, -4, 128)


clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
def ROI(img, threshold=5):
    """
    Assume img to be a uint8, h,w c format image
    :param img: Array
    :return: A mask containing the region of interest of the image, ie True where the circular Fundus is and False
    elsewhere
    """
    mask = cv2.medianBlur(img[:, :, 2], 21) > threshold

    return mask


normalize = False 
# If true, the function will standardize the LAB components according to a given mean and std, channel wise.
L_MEAN = 31.319101
A_MEAN = 17.877468
B_MEAN = 29.181826
L_STD = 12.684237
A_STD = 7.002096
B_STD = 10.272004

def LAB_clahe(img):
    mask = ROI(img)
    mean_b = np.median(img[:, :, 0][mask])
    mean_g = np.median(img[:, :, 1][mask])
    mean_r = np.median(img[:, :, 2][mask])
    mean_channels = [mean_b, mean_g, mean_r]
    img = np.clip(
        img.astype(np.float32) - median_filter_clahe(img) * np.expand_dims(mask, 2) + np.asarray(mean_channels).astype(
            np.uint8), 0, 255)
    img = cv2.GaussianBlur(img, (3, 3), 0)
    lab = cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2LAB)
    lab_planes = cv2.split(lab)
    lab_planes[0] = clahe.apply(lab_planes[0])
    lab = cv2.merge(lab_planes)
    
    if normalize:
        lab = (lab * np.expand_dims(mask, 2)).astype(np.float32)
        lab[:, :, 0] -= lab[:, :, 0].mean()
        lab[:, :, 0] *= L_STD / (lab[:, :, 0].std() + 1e-7)
        lab[:, :, 0] += L_MEAN

        lab[:, :, 1] -= lab[:, :, 1].mean()
        lab[:, :, 1] *= A_STD / (lab[:, :, 1].std() + 1e-7)
        lab[:, :, 1] += A_MEAN

        lab[:, :, 2] -= lab[:, :, 2].mean()
        lab[:, :, 2] *= B_STD / (lab[:, :, 2].std() + 1e-7)
        lab[:, :, 2] += B_MEAN
        
    rgb = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
    return rgb * np.expand_dims(mask, 2)


def preprocess(data):
    return LAB_clahe(data)


d_clahe = d_raw.apply_cv('data', preprocess)
d_clahe

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [19]:
# MORPHOPREPROCESS

def preprocess(x):
  return cv2.GaussianBlur(subtract_median_bg_image(x), (7,7), 0)
def morphopreprocess(x):
  k = np.max(x.shape)//20*2+1
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(k//2+1,k//2+1))
  mask_org = (x[:,:,2]>20).astype(np.uint8)
  mask = cv2.erode(mask_org, np.ones((15,15), np.uint8))
  mask = np.expand_dims(mask, 2).astype(np.uint8)
  mask_org = np.expand_dims(mask_org, 2)
  dilation = cv2.dilate(x,kernel,iterations = 1)
  bgr = preprocess(dilation*(1-mask)+mask*x)*mask
  return bgr
d_morph = d_raw.apply_cv('data', morphopreprocess)
d_morph 

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [20]:
# RED LESIONS - MORPH

j_red_morph = j_red_raw.apply_cv('x', morphopreprocess)
j_red_morph

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [21]:
# BRIGHT LESIONS - MORPH

j_brt_morph = j_brt_raw.apply_cv('x', morphopreprocess)
j_brt_morph

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [22]:
# MORPHGREEN

d_morph_g = d_morph.apply('data', extract_green)
d_morph_g

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [23]:
# MEDIAN FILTER

d_median = d_raw.apply_cv('data', median_filter)
d_median = d_raw.apply_cv('data', subtract_median_bg_image)
d_median

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [24]:
# CONCAT RAW MORPH

def concat(raw, morph, name):
    return np.concatenate((raw, morph), axis=0)

tmp = J.join([d_raw.col.name, d_morph.col.name], raw=d_raw.col.data, morph=d_morph.col.data, name=d_raw.col.name)
d_concat_rm = tmp.apply('data', concat)
d_concat_rm

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [25]:
# CONCAT RAW GREEN

def concat(raw, green, name):
    return np.concatenate((raw, green), axis=0)

tmp = J.join([d_raw.col.name, d_green.col.name], raw=d_raw.col.data, green=d_green.col.data, name=d_raw.col.name)
d_concat_rg = tmp.apply('data', concat)
d_concat_rg

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [26]:
# GT - VESSELS - TRAINING

def binarize(data):
    return (data>0).astype(np.int)

d_gt = J.images(PATH + 'Vessels/Vessels - Uncertain/', reshape = size).apply('data', binarize)
d_gt

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [27]:
# GT - VESSELS - TESTING

test_d_gt = J.images(PATH + 'Vessels/MESSIDOR', reshape = size).apply('data', binarize)
test_d_gt

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

DatabaseView(children=(HBox(children=(ToolBar(children=(ToolButton(html='<i class="fa fa-eye-slash" style="fon…

In [28]:
# JOINED SETS

def binarize(gt):
    return (gt>0).astype(np.int)

# Vessels
j_raw = J.join([d_raw.col.name, d_gt.col.name], x=d_raw.col.data, gt=d_gt.col.data).apply('gt', binarize)
j_r = J.join([d_red.col.name, d_gt.col.name], x=d_red.col.data, gt=d_gt.col.data).apply('gt', binarize)
j_g = J.join([d_green.col.name, d_gt.col.name], x=d_green.col.data, gt=d_gt.col.data).apply('gt', binarize)
j_b = J.join([d_blue.col.name, d_gt.col.name], x=d_blue.col.data, gt=d_gt.col.data).apply('gt', binarize)
j_gb = J.join([d_gb.col.name, d_gt.col.name], x=d_gb.col.data, gt=d_gt.col.data).apply('gt', binarize)
j_clahe = J.join([d_clahe.col.name, d_gt.col.name], x=d_clahe.col.data, gt=d_gt.col.data).apply('gt', binarize)
j_median = J.join([d_median.col.name, d_gt.col.name], x=d_median.col.data, gt=d_gt.col.data).apply('gt', binarize)
j_morph = J.join([d_morph.col.name, d_gt.col.name], x=d_morph.col.data, gt=d_gt.col.data).apply('gt', binarize)
j_morph_g = J.join([d_morph_g.col.name, d_gt.col.name], x=d_morph_g.col.data, gt=d_gt.col.data).apply('gt', binarize)
j_concat_g = J.join([d_concat_rg.col.name, d_gt.col.name], x=d_concat_rg.col.data, gt=d_gt.col.data).apply('gt', binarize)
j_concat_morph = J.join([d_concat_rm.col.name, d_gt.col.name], x=d_concat_rm.col.data, gt=d_gt.col.data).apply('gt', binarize)
# Red lesions
j_rd_raw = j_red_raw.apply('gt', binarize)
j_rd_r = j_red_r.apply('gt', binarize)
j_rd_g = j_red_g.apply('gt', binarize)
j_rd_b = j_red_b.apply('gt', binarize)
j_rd_median = j_red_median.apply('gt', binarize)
j_rd_morph = j_red_morph.apply('gt', binarize)
# Bright lesions
j_brt_raw = j_brt_raw.apply('gt', binarize)
j_brt_r = j_brt_r.apply('gt', binarize)
j_brt_g = j_brt_g.apply('gt', binarize)
j_brt_b = j_brt_b.apply('gt', binarize)
j_brt_median = j_brt_median.apply('gt', binarize)
j_brt_morph = j_brt_morph.apply('gt', binarize)

# sets = [j_raw, j_r, j_g, j_b, j_gb, j_clahe, j_median, j_morph, j_morph_g, j_concat_g, j_concat_morph, 
#         j_rd_raw, j_rd_r, j_rd_g, j_rd_b, j_rd_median, j_rd_morph,
#         j_brt_raw, j_brt_r, j_brt_g, j_brt_b, j_brt_median, j_brt_morph]
# test_sets = [d_raw, d_red, d_green, d_blue, d_gb, d_clahe, d_median, d_morph, d_morph_g, d_concat_rg, d_concat_rm,
#              j_rd_raw, j_rd_r, j_rd_g, j_rd_b, j_rd_median, j_rd_morph, 
#              j_brt_raw, j_brt_r, j_brt_g, j_brt_b, j_brt_median, j_brt_morph]
# variant_name = ['ves-raw', 'ves-r', 'ves-g', 'ves-b', 'ves-gb', 'ves-clahe', 'ves-median', 'ves-morph', 'ves-morph-g', 'ves-concat-g',
#                 'ves-concat-morph', 
#                 'rd-raw', 'rd-r', 'rd-g', 'rg-b', 'rd-median', 'rd-morph', 
#                 'brt-raw', 'brt-r', 'brt-g', 'brt-b', 'brt-median','brt-morph' ]

# sets = [j_raw, j_r, j_g, j_b, j_clahe, j_median, j_morph, j_morph_g]
# test_sets = [d_raw, d_red, d_green, d_blue, d_clahe, d_median, d_morph, d_morph_g]
# variant_name = ['ves-raw', 'ves-r', 'ves-g', 'ves-b', 'ves-clahe', 'ves-median', 'ves-morph', 'ves-morph-g']

sets = [j_g]
test_sets = [d_green]
variant_name = ['ves-g']

In [29]:
# Saving folder management

code = datetime.datetime.now().isoformat()
variant_nb = 0

In [None]:
# Main training loop

for dt in sets:
    datasets = dt.split_sets(valid=0.10, train=-1)
    datasets['train'] = datasets['train'].augment(J.DataAugment().rotate(angle=(-90,90)))
    
    baseConfigPath = '../code/config.yaml'
    with open(baseConfigPath, 'r') as f:
        yaml_file = yaml.load(f, Loader=yaml.FullLoader)
        config = EasyDict(yaml_file)
    
    config.variant.type = variant_name[variant_nb]
    config.variant.code = code
    
    if ((config.variant.type).endswith('-g') or (config.variant.type).endswith('-b') or (config.variant.type).endswith('-r')):
        config.model.input_chan = 1
    
    if ((config.variant.type).endswith('-gb')):
        config.model.input_chan = 2
    
    if config.variant.type == 'ves-concat-g':
        config.model.input_chan = 4
        
    if config.variant.type == 'ves-concat-morph':
        config.model.input_chan = 6
    
#     if (config.variant.type).startswith('brt'):
#         config.hp.initial_lr = 0.000001
    
    trainer = Trainer(config)
    
    """
    Counting number of pixels occurence per class (see doc in the code for more details)
    """
    gt_count_gen = datasets['train'].generator(n=20, columns=['gt'])
    class_count = None
    for b in gt_count_gen:
        _, count = np.unique(b['gt'], return_counts=True)
        
        if class_count is None:
            class_count = count
        else:
            class_count += count

    y = (datasets['train'][:, 'gt']).flatten()
    
    if ((config.variant.type).startswith('brt') or (config.variant.type).startswith('rd') ):
        class_weights = class_weight.compute_class_weight('balanced', np.unique(y),y)
    else:
        class_weights = np.log(1+class_weight.compute_class_weight('balanced', np.unique(y),y))
    class_count / class_count.sum().astype(np.float32)

    trainer.set_datasets(train=datasets['train'], valid=datasets['valid'])
    trainer.setup_loss(class_counts=class_count, class_weights=class_weights)
    trainer.set_model()
    trainer.model
    trainer.train()
    
    # Clean memory
    del trainer
    torch.cuda.empty_cache()
    
    variant_nb+=1

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

VBox(children=(HBox(children=(HSpace(value='', layout=Layout(width='400px')), LogToolBar(children=(ToolButton(…



Selected optimization level O1:  Insert automatic casts around Pytorch functions and Tensor methods.

Defaults for this optimization level are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
Processing user overrides (additional kwargs that are not None)...
After processing overrides, optimization options are:
enabled                : True
opt_level              : O1
cast_model_type        : None
patch_torch_functions  : True
keep_batchnorm_fp32    : None
master_weights         : None
loss_scale             : dynamic
**************************************************
Epoch 1, iteration 6
Training training_loss 0.650321


  images = (255*images/np.max(images)).astype(np.uint)


**************************************************
Epoch 11, iteration 6
Training training_loss 0.458766
**************************************************
Epoch 21, iteration 6
Training training_loss 0.311421


In [None]:
# Testing loop

for i in range(len(test_sets)):
    # Loading model
    
    baseConfigPath = '../code/config.yaml'
    with open(baseConfigPath, 'r') as f:
        yaml_file = yaml.load(f, Loader=yaml.FullLoader)
        config = EasyDict(yaml_file)
        
    config.variant.type = variant_name[i]
    config.variant.code = code
    
    if ((config.variant.type).endswith('-g') or (config.variant.type).endswith('-b') or (config.variant.type).endswith('-r')):
        config.model.input_chan = 1
    
    if ((config.variant.type).endswith('-gb')):
        config.model.input_chan = 2
    
    if config.variant.type == 'concat-g':
        config.model.input_chan = 4
        
    if config.variant.type == 'concat-mor':
        config.model.input_chan = 6
        
    PATH = '/home/clement/Documents/Arnaud/models'
    model = UNet(config = config.model)
    model.load_state_dict(torch.load(PATH + ('/uNet_%s_%s.pth' % (code, variant_name[i])))) 
    model.eval()
    model = model.to('cuda')
    tester = Tester()
    
    if (variant_name[i]).startswith('ves'):
        d_pred = test_sets[i].apply_torch('data', model, device='cuda')
        joined_pred = J.join([d_pred.col.name, test_d_gt.col.name], x=d_pred.col.data, gt=test_d_gt.col.data).apply('gt', binarize)
    
    # TESTS POUR LESIONS PAS ENCORE DEFINIS, CODE INCORRECT
    else:
        d_pred = test_sets[i].apply_torch('x', model, device='cuda')
        joined_pred = J.join([d_pred.col.name, test_d_gt.col.name], x=d_pred.col.x, gt=test_d_gt.col.data).apply('gt', binarize)
    
    tester.evaluate(joined_pred, 'x', 'gt', pred_index=1)
    tester.metrics(path="/home/clement/Documents/Arnaud/models/", name="%s_metrics_%s" % (code, variant_name[i]))

In [None]:
joined_pred

In [None]:
sets[i]

In [None]:
d_pred