In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import os
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt
import random
from sklearn.metrics import matthews_corrcoef

In [3]:
import cv2
from fastai import *
from fastai.vision import *
from fastai.callbacks import *

In [4]:
from dataset_spliter import SplitByPatient
from metrics import *#F1Weighted, MCC
from losses import *
from data_loader import ImageItemListCell
from augmentation import cutout

In [5]:
from fastai.callbacks.hooks import  params_size
def extract_train_information(learner: Learner):
    #_, params, trainables, _ = params_size(learner.model)

    #total_params = sum(params)
    #trainable_ratio = sum([param*train for param, train in zip(params, trainables)]) / total_params * 100

    bs = learner.data.batch_size
    image_size = learner.data.valid_ds[0][0].size
    wd = learner.wd
    lr = max(learner.recorder.lrs)

    summary = ""
    #summary = "\nTotal params: {:,}".format(total_params)
    #summary += f"\nTrainable: {round(trainable_ratio,2)}%"
    summary += f"\nBs: {bs}"
    summary += f"\nwd: {wd}"
    summary += f"\nlr: {lr}"
    summary += f"\nImage: {image_size}\n"

    for tf in learner.data.train_dl.dl.dataset.tfms:
        summary += f"\n {tf}"
    
    return summary

In [6]:
path = Path('/data/Datasets/WhiteBloodCancer/train/')

In [7]:
fnames = get_image_files(path, recurse=True)
fnames[:5]

[PosixPath('/data/Datasets/WhiteBloodCancer/train/fold_1/hem/UID_H10_43_1_hem.bmp'),
 PosixPath('/data/Datasets/WhiteBloodCancer/train/fold_1/hem/UID_H22_31_15_hem.bmp'),
 PosixPath('/data/Datasets/WhiteBloodCancer/train/fold_1/hem/UID_H14_9_11_hem.bmp'),
 PosixPath('/data/Datasets/WhiteBloodCancer/train/fold_1/hem/UID_H14_28_6_hem.bmp'),
 PosixPath('/data/Datasets/WhiteBloodCancer/train/fold_1/hem/UID_H10_189_1_hem.bmp')]

In [8]:
len(fnames)

10661

In [9]:
hem_regex = re.compile(r'UID_(H[0-9]+)_', re.IGNORECASE)
all_regex = re.compile(r'UID_([0-9]+)_')

In [10]:
hem_patient_ids = list(set([hem_regex.search(str(fn)).group(1)
                            for fn in fnames if hem_regex.search(str(fn)) is not None]))
all_patint_ids = list(set([all_regex.search(str(fn)).group(1)
                           for fn in fnames if all_regex.search(str(fn)) is not None]))

hem_patients = dict((k,[]) for k in hem_patient_ids)
all_patints = dict((k,[]) for k in all_patint_ids)

[all_patints[key].append(fn) for key in all_patints.keys() for fn in fnames if 'UID_{0}_'.format(key) in str(fn)]
[hem_patients[key].append(fn) for key in hem_patients.keys() for fn in fnames if 'UID_{0}_'.format(key) in str(fn)]
print()




#### Data augmentation

In [11]:
cutout_fn = TfmLighting(cutout)
xtra_tfms=[squish(scale=0.66), cutout_fn(n_holes=5, length=0.2)]
tfms = get_transforms(do_flip=True, 
                      flip_vert=True, 
                      max_rotate=90,  
                      max_lighting=0.15, 
                      max_zoom=1.5, 
                      max_warp=0.2,
                      p_affine=0.75,
                      p_lighting=0.75,  
                      xtra_tfms=xtra_tfms,
                     )

#### Create dataset 

In [12]:
test_path = Path('/data/Datasets/WhiteBloodCancer/test/')

def get_data(bs, size):
    data  = ImageDataBunch.create_from_ll(lls, size=size, bs=bs, 
                                      ds_tfms=tfms, padding_mode='zeros',
                                      resize_method=ResizeMethod.PAD, test=test_path)
    data = data.normalize()
    #data = data.normalize((channel_mean, channel_std))
    return data
    

In [13]:
pat = re.compile(r'^.*(hem|all).bmp$')
def get_label(fn):
    return pat.search(str(fn)).group(1)

## Split data into train val 

In [14]:
split_handler = SplitByPatient(hem_patients, all_patints)

### Split by Fold

In [15]:
folds = split_handler.split_by_folds(5)

In [16]:
flatten = lambda l: [item for sublist in l for item in sublist]
for fold in range(5):
    gc.collect()
    experiment_name = 'rn34-Image-{}'.format(fold)
    size = 450
    bs = 128 
    
    train_files = []
    valid_files = []    
    train_files += [folds[f] for f in range(5) if f != fold]
    valid_files = folds[fold]
    
    train_files = flatten(train_files)
    
    if len(train_files+valid_files) != len(fnames):
        raise ArithmeticError
    
    valid = ImageItemList(valid_files)
    train = ImageItemList(train_files)
    
    item_list = ItemLists(path, train, valid)
    lls = item_list.label_from_func(get_label)
    
    learn = create_cnn(get_data(bs, size), models.resnet34, 
                   metrics=[error_rate, F1Weighted(), MCC()],  
                   #loss_func=FocalLoss(num_classes=1, alpha=0.4, gamma=0.5),
                   #ps=0.75,
                   wd=0.001,
                   callback_fns=[ShowGraph, partial(SaveModelCallback, monitor="mcc", mode='max', name='stage1-{}-{}'.format(experiment_name, size))],
                  ).to_fp16().mixup()      
    
    lr = 1e-2
    learn.fit_one_cycle(10, lr)
    
    learn.unfreeze()
    learn.callback_fns[2] = partial(SaveModelCallback, 
                              monitor="mcc", 
                              mode='max', 
                              name='stage2-{}-{}'.format(experiment_name, size))
    learn.fit_one_cycle(10, slice(1e-5,lr/5))
    
    preds_test, y_test=learn.get_preds(ds_type=DatasetType.Valid)# 
    preds_test = np.argmax(torch.sigmoid(preds_test), axis=1)
    score = int(matthews_corrcoef(y_test, preds_test) * 100)
    
    learn.export('{}-{}-{}.pkl'.format(experiment_name, size, score))