In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
import random

In [3]:
import cv2
from fastai import *
from fastai.vision import *
from fastai.callbacks import *

In [4]:
from dataset_spliter import SplitByPatient
from metrics import *#F1Weighted, MCC
from losses import *

In [5]:
path = Path('/data/Datasets/WhiteBloodCancer/train/')

In [6]:
np.random.seed(42)

In [7]:
fnames = get_image_files(path, recurse=True)
fnames[:5]

[PosixPath('/data/Datasets/WhiteBloodCancer/train/fold_1/hem/UID_H10_43_1_hem.bmp'),
 PosixPath('/data/Datasets/WhiteBloodCancer/train/fold_1/hem/UID_H22_31_15_hem.bmp'),
 PosixPath('/data/Datasets/WhiteBloodCancer/train/fold_1/hem/UID_H14_9_11_hem.bmp'),
 PosixPath('/data/Datasets/WhiteBloodCancer/train/fold_1/hem/UID_H14_28_6_hem.bmp'),
 PosixPath('/data/Datasets/WhiteBloodCancer/train/fold_1/hem/UID_H10_189_1_hem.bmp')]

In [8]:
hem_regex = re.compile(r'UID_(H[0-9]+)_', re.IGNORECASE)
all_regex = re.compile(r'UID_([0-9]+)_', re.IGNORECASE)

In [9]:
hem_patient_ids = list(set([hem_regex.search(str(fn)).group(1)
                            for fn in fnames if hem_regex.search(str(fn)) is not None]))
all_patint_ids = list(set([all_regex.search(str(fn)).group(1)
                           for fn in fnames if all_regex.search(str(fn)) is not None]))

hem_patients = dict((k,[]) for k in hem_patient_ids)
all_patints = dict((k,[]) for k in all_patint_ids)

[all_patints[key].append(fn) for key in all_patints.keys() for fn in fnames if 'UID_{0}_'.format(key) in str(fn)]
[hem_patients[key].append(fn) for key in hem_patients.keys() for fn in fnames if 'UID_{0}_'.format(key) in str(fn)]
print()




## Split data into train val 

In [10]:
split_handler = SplitByPatient(hem_patients, all_patints)

### Split by regex

In [11]:
train_regex = re.compile(r'(fold_0|fold_1|fold_2)')
val_regex = re.compile(r'(fold_3)')

hem_train, all_train, hem_val, all_val = split_handler.split_by_regex(train_regex, val_regex)

In [12]:
print('Train Total: {0}'.format(len(hem_train)+len(all_train)))
print('Val Total: {0}'.format(len(hem_val)+len(all_val)))
print("")
print('Hem train: {}'.format(len(hem_train)))
print('All train: {}'.format(len(all_train)))
print('Hem val: {}'.format(len(hem_val)))
print('All val: {}'.format(len(all_val)))

Train Total: 10661
Val Total: 1867

Hem train: 3389
All train: 7272
Hem val: 648
All val: 1219


In [13]:
pat = re.compile(r'^.*(hem|all).bmp$')

def get_label(fn):
    return pat.search(str(fn)).group(1)

### Use complete image

In [14]:
train = ImageList(hem_train + all_train) #optinal scale up classes 
valid = ImageList(hem_val + all_val)

In [15]:
item_list = ItemLists(path, train, valid)
lls = item_list.label_from_func(get_label)

#### Data augmentation

In [16]:
xtra_tfms=[cutout(n_holes=5, length=0.2)]#squish(scale=0.66), 
tfms = get_transforms(do_flip=True, 
                      flip_vert=True, 
                      #max_rotate=90,  
                      #max_lighting=0.15, 
                      #max_zoom=1.5, 
                      #max_warp=0.2,
                      #p_affine=0.75,
                      #p_lighting=0.75,  
                      #xtra_tfms=xtra_tfms,
                     )

#### Create dataset 

In [17]:
def get_data(bs, size):
    data  = ImageDataBunch.create_from_ll(lls, size=size, bs=bs, 
                                      ds_tfms=tfms, padding_mode='zeros',
                                      resize_method=ResizeMethod.PAD)
    data = data.normalize()
    #data = data.normalize((channel_mean, channel_std))
    return data
    

In [19]:
size = 256
bs = 96
data = get_data(bs, size)

experiment_name = "baseline_rn50"
learn = create_cnn(data, models.resnet50, 
                       metrics=[error_rate, F1Weighted(), MCC()], #  
                       #loss_func=FocalLoss(num_classes=1),
                       #ps=0.75,
                       #wd=0.1,
                       loss_func = LabelSmoothingCrossEntropy(),
                       callback_fns=[partial(SaveModelCallback, name='stage1-{}-{}'.format(experiment_name, size), monitor='f1_weighted', mode="max")],

                  )#

  warn("`create_cnn` is deprecated and is now named `cnn_learner`.")


In [20]:
learn.freeze()
lr = 1e-2
learn.fit_one_cycle(5, lr)

learn.unfreeze()
learn.fit_one_cycle(10, slice(1e-5,lr/5))

epoch,train_loss,valid_loss,error_rate,f1_weighted,mcc,time
0,0.562665,0.590148,0.230316,0.746469,0.469540,01:26
1,0.485522,0.748089,0.341189,0.530178,0.101319,01:23
2,0.446896,0.645597,0.259775,0.717448,0.389577,01:24
3,0.425882,0.763776,0.312266,0.602321,0.238474,01:24
4,0.409865,0.626442,0.250670,0.716957,0.417783,01:24


Better model found at epoch 0 with f1_weighted value: 0.7464690166185467.


epoch,train_loss,valid_loss,error_rate,f1_weighted,mcc,time
0,0.451728,0.632009,0.250134,0.730422,0.415462,01:48
1,0.444263,0.622380,0.236743,0.738913,0.452317,01:47
2,0.434279,0.624569,0.258704,0.711550,0.392741,01:46
3,0.427578,0.690088,0.318157,0.624861,0.210507,01:47
4,0.416923,0.661400,0.243171,0.733767,0.433938,01:46
5,0.409840,0.643303,0.249063,0.731138,0.418145,01:46
6,0.396985,0.613425,0.237279,0.741694,0.449222,01:47
7,0.387417,0.595043,0.236208,0.741734,0.452513,01:47
8,0.381273,0.549554,0.199250,0.785174,0.547405,01:47
9,0.378129,0.557154,0.200857,0.786052,0.541020,01:47


Better model found at epoch 0 with f1_weighted value: 0.7304217573352734.
Better model found at epoch 1 with f1_weighted value: 0.7389126166273247.
Better model found at epoch 6 with f1_weighted value: 0.7416938496648277.
Better model found at epoch 7 with f1_weighted value: 0.7417344675435839.
Better model found at epoch 8 with f1_weighted value: 0.7851736984887221.
Better model found at epoch 9 with f1_weighted value: 0.7860519410423052.


In [21]:
learn.export('baseline_rn50-1.pkl') 

In [22]:
size = 384
bs = 32
learn.data = get_data(bs, size)

learn.freeze()
lr = 1e-2
learn.fit_one_cycle(5, lr)

learn.unfreeze()
learn.fit_one_cycle(10, slice(1e-5,lr/5))

epoch,train_loss,valid_loss,error_rate,f1_weighted,mcc,time
0,0.463027,0.806196,0.302625,0.620575,0.274230,03:14
1,0.448465,0.850198,0.322978,0.570864,0.209141,03:08
2,0.432478,0.723320,0.353508,0.619779,0.149026,03:08
3,0.400484,0.606664,0.239957,0.745069,0.442686,03:08
4,0.387887,0.582767,0.217997,0.762480,0.500576,03:08


Better model found at epoch 0 with f1_weighted value: 0.6205746149445779.
Better model found at epoch 3 with f1_weighted value: 0.7450685675285298.
Better model found at epoch 4 with f1_weighted value: 0.762479537209054.


epoch,train_loss,valid_loss,error_rate,f1_weighted,mcc,time
0,0.391698,0.609168,0.226567,0.755622,0.476590,04:18
1,0.413995,0.574032,0.226567,0.752056,0.478538,04:17
2,0.409144,0.513617,0.162292,0.831535,0.632697,04:17
3,0.406943,0.545734,0.188538,0.803747,0.570032,04:17
4,0.401000,0.573326,0.198715,0.803617,0.575693,04:17
5,0.388690,0.581789,0.223889,0.755527,0.485302,04:17
6,0.382284,0.557856,0.192287,0.802028,0.562332,04:17
7,0.376960,0.531124,0.171934,0.820287,0.610153,04:17
8,0.373936,0.525750,0.167113,0.826689,0.621176,04:18
9,0.370027,0.541172,0.175683,0.815453,0.601629,04:17


Better model found at epoch 0 with f1_weighted value: 0.7556221252185121.
Better model found at epoch 2 with f1_weighted value: 0.8315348053483981.


In [23]:
learn.export('baseline_rn50-384-2.pkl') 

In [24]:
size = 450
bs = 16
learn.data = get_data(bs, size)

learn.freeze()
lr = 1e-2
learn.fit_one_cycle(5, lr)

learn.unfreeze()
learn.fit_one_cycle(10, slice(1e-5,lr/5))

epoch,train_loss,valid_loss,error_rate,f1_weighted,mcc,time
0,0.442555,0.549940,0.217997,0.760793,0.502261,04:51
1,0.482216,0.647299,0.218532,0.757863,0.503717,04:46
2,0.444281,0.842783,0.321907,0.576069,0.207767,04:46
3,0.401358,0.597151,0.215854,0.758130,0.515644,04:46
4,0.401263,0.542020,0.182646,0.802672,0.591196,04:46


Better model found at epoch 0 with f1_weighted value: 0.7607929597603703.
Better model found at epoch 4 with f1_weighted value: 0.8026721292111105.


epoch,train_loss,valid_loss,error_rate,f1_weighted,mcc,time
0,0.392612,0.568485,0.206749,0.771992,0.534516,06:40
1,0.420840,0.633482,0.251741,0.700536,0.436701,06:39
2,0.421374,0.607907,0.261382,0.702347,0.388007,06:39
3,0.418462,0.602157,0.244778,0.727344,0.431814,06:38
4,0.421682,0.591437,0.246385,0.711268,0.445442,06:39
5,0.407994,0.549831,0.210498,0.777576,0.516383,06:39
6,0.397126,0.533037,0.194430,0.789768,0.560730,06:39
7,0.385295,0.577722,0.214783,0.763459,0.512054,06:39
8,0.376932,0.543563,0.189073,0.799786,0.569730,06:38
9,0.382509,0.551124,0.185324,0.804245,0.578681,06:39


Better model found at epoch 0 with f1_weighted value: 0.77199156112703.
Better model found at epoch 5 with f1_weighted value: 0.7775755772401021.
Better model found at epoch 6 with f1_weighted value: 0.7897675927309015.
Better model found at epoch 8 with f1_weighted value: 0.7997855105729891.
Better model found at epoch 9 with f1_weighted value: 0.804245241511471.


In [25]:
learn.export('baseline_rn50-450-2.pkl') 