In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [3]:
import numpy as np
from tqdm import tqdm
from pathlib import Path
import openslide
import pandas as pd
import pickle

In [4]:
from fastai.callbacks.csv_logger import CSVLogger

In [5]:
from object_detection_fastai.helper.object_detection_helper import *
from object_detection_fastai.helper.wsi_loader import *
from object_detection_fastai.loss.RetinaNetFocalLoss import RetinaNetFocalLoss
from object_detection_fastai.models.RetinaNet import RetinaNet
from object_detection_fastai.callbacks.callbacks import BBLossMetrics, BBMetrics, PascalVOCMetric, PascalVOCMetricByDistance

In [6]:
slides_train = list(set(['3106_20 BB BAL Human neu.svs', 
                              '10080_19 humane BAL Berliner Blau-001.svs', 
                              '11480_19 humane BAL Berliner Blau.svs',
                              '11480_19 humane BAL Turnbull Blau.svs', 
                              '10854_19 humane BAL Turnbull Blau.svs', 
                        ]))

slides_val = list(set(['10120_19 humane BAL Turnbull Blau.svs',
                              '10277_19 humane BAL Turnbull Blau.svs', 
                              '10672_19 humane BAL Turnbull Blau.svs']))

In [8]:
annotations_path = Path("../../Statistics/EIPH_Annotations.pkl")
annotations = pd.read_pickle(annotations_path)
annotations_train = annotations[annotations["image_name"].isin(slides_train)]
annotations_val = annotations[annotations["image_name"].isin(slides_val)]
annotations_train.head()

Unnamed: 0,id,image_id,image_set,species,image_name,image_type,grade,vector,unique_identifier,user_id,deleted,last_editor,data_set_name,version
119329,2200930,3623,251,human,3106_20 BB BAL Human neu.svs,TurnbullBlue,0,"{'x1': 21647, 'x2': 21751, 'y1': 18133, 'y2': ...",22ffc5a6-56da-41ef-bef4-69e54413bddd,1,False,1,SDATA,Inference
119330,2205439,3623,251,human,3106_20 BB BAL Human neu.svs,TurnbullBlue,0,"{'x1': 4512, 'x2': 4667, 'y1': 8242, 'y2': 8397}",3d0d194e-0fef-4b91-b2d4-957e8007c3b9,1,False,1,SDATA,Inference
119331,2209572,3623,251,human,3106_20 BB BAL Human neu.svs,TurnbullBlue,0,"{'x1': 1640, 'x2': 1778, 'y1': 17751, 'y2': 17...",0a3b7d83-2f8d-4ed1-a194-e3eaaaa4e7a9,1,False,1,SDATA,Inference
119332,2209835,3623,251,human,3106_20 BB BAL Human neu.svs,TurnbullBlue,0,"{'x1': 26641, 'x2': 26808, 'y1': 10599, 'y2': ...",54541283-3220-42a6-b9b0-9bbf5af7eda6,1,False,1,SDATA,Inference
119333,2210512,3623,251,human,3106_20 BB BAL Human neu.svs,TurnbullBlue,0,"{'x1': 9363, 'x2': 9514, 'y1': 21971, 'y2': 22...",47aaea20-97d3-42f2-abd3-88b8d7853ac4,1,False,1,SDATA,Inference


In [9]:
slides_path = Path("../../../Slides")
files = {slide.name: slide for slide in slides_path.rglob("*.svs")  if slide.name in slides_train + slides_val}
files

{'10080_19 humane BAL Berliner Blau-001.svs': PosixPath('../../../Slides/Human/10080_19 humane BAL Berliner Blau-001.svs'),
 '10120_19 humane BAL Turnbull Blau.svs': PosixPath('../../../Slides/Human/10120_19 humane BAL Turnbull Blau.svs'),
 '10277_19 humane BAL Turnbull Blau.svs': PosixPath('../../../Slides/Human/10277_19 humane BAL Turnbull Blau.svs'),
 '10672_19 humane BAL Turnbull Blau.svs': PosixPath('../../../Slides/Human/10672_19 humane BAL Turnbull Blau.svs'),
 '10854_19 humane BAL Turnbull Blau.svs': PosixPath('../../../Slides/Human/10854_19 humane BAL Turnbull Blau.svs'),
 '11480_19 humane BAL Berliner Blau.svs': PosixPath('../../../Slides/Human/11480_19 humane BAL Berliner Blau.svs'),
 '11480_19 humane BAL Turnbull Blau.svs': PosixPath('../../../Slides/Human/11480_19 humane BAL Turnbull Blau.svs'),
 '3106_20 BB BAL Human neu.svs': PosixPath('../../../Slides/Human/3106_20 BB BAL Human neu.svs')}

In [10]:
tfms = get_transforms(do_flip=True,
                      flip_vert=True,
                      #max_rotate=90,
                      max_lighting=0.0,
                      max_zoom=1.,
                      max_warp=0.0,
                      p_affine=0.5,
                      p_lighting=0.0,
                      #xtra_tfms=xtra_tfms,
                     )
tfms

([RandTransform(tfm=TfmCrop (crop_pad), kwargs={'row_pct': (0, 1), 'col_pct': (0, 1), 'padding_mode': 'reflection'}, p=1.0, resolved={}, do_run=True, is_random=True),
  RandTransform(tfm=TfmAffine (dihedral_affine), kwargs={}, p=1.0, resolved={}, do_run=True, is_random=True),
  RandTransform(tfm=TfmAffine (rotate), kwargs={'degrees': (-10.0, 10.0)}, p=0.5, resolved={}, do_run=True, is_random=True)],
 [RandTransform(tfm=TfmCrop (crop_pad), kwargs={}, p=1.0, resolved={}, do_run=True, is_random=True)])

In [11]:
size = 1024 
level = 0
bs = 16
train_images = 2500
val_images = 1500

In [12]:
def get_y_func(x):
    return x.y

In [13]:
anchors = create_anchors(sizes=[(32,32)], ratios=[1], scales=[0.6, 0.7, 0.9, 1.25, 1.5])

In [None]:
for i in range(len(slides_train)):
    
    torch.cuda.empty_cache()

    train_files = []
    val_files = []
    
    experiment_name = "HumanVsHuman-Ablation_{}".format(i)

    for image_name in slides_train[:i+1]:

        annotations = annotations_train[annotations_train["image_name"] == image_name]
        annotations = annotations[annotations["deleted"] == False]

        slide_path = files[image_name]
        labels =  list(annotations["grade"])
        bboxes = [[vector["x1"], vector["y1"], vector["x2"], vector["y2"]] for vector in annotations["vector"]]

        for grade in [0, 1, 2, 3, 4]:
            if grade not in set(labels):
                bboxes.append([0,0,0,0])
                labels.append(grade)
        
        train_files.append(SlideContainer(slide_path, y=[bboxes, labels],  level=level, width=size, height=size))

    for image_name in annotations_val["image_name"].unique():

        annotations = annotations_val[annotations_val["image_name"] == image_name]
        annotations = annotations[annotations["deleted"] == False]

        slide_path = files[image_name]
        labels =  list(annotations["grade"])
        bboxes = [[vector["x1"], vector["y1"], vector["x2"], vector["y2"]] for vector in annotations["vector"]]

        for grade in [0, 1, 2, 3, 4]:
            if grade not in set(labels):
                bboxes.append([0,0,0,0])
                labels.append(grade)
        
        val_files.append(SlideContainer(slide_path, y=[bboxes, labels],  level=level, width=size, height=size))
    
    train_files = list(np.random.choice(train_files, train_images))
    valid_files = list(np.random.choice(val_files, val_images))
    
    train =  ObjectItemListSlide(train_files, path=slides_path)
    valid = ObjectItemListSlide(valid_files, path=slides_path)
    item_list = ItemLists(slides_path, train, valid)
    lls = item_list.label_from_func(get_y_func, label_cls=SlideObjectCategoryList) #
    lls = lls.transform(tfms, tfm_y=True, size=size)
    data = lls.databunch(bs=bs, collate_fn=bb_pad_collate, num_workers=0).normalize()
    
    crit = RetinaNetFocalLoss(anchors)
    encoder = create_body(models.resnet18, True, -2)
    model = RetinaNet(encoder, n_classes=data.train_ds.c, n_anchors=5, sizes=[32], chs=128, final_bias=-4., n_conv=3)
    
    voc = PascalVOCMetricByDistance(anchors, size, [str(i) for i in data.train_ds.y.classes[1:]], radius=25)
    learn = Learner(data, model, loss_func=crit, callback_fns=[BBMetrics, partial(CSVLogger, append=True, filename=experiment_name)], #BBMetrics, ShowGraph
                    metrics=[voc])

    learn.split([model.encoder[6], model.c5top5])
    learn.freeze_to(-2)
    
    learn.fit_one_cycle(3, 1e-3)
    
    learn.unfreeze()
    learn.fit_one_cycle(10, 1e-3)
    learn.fit_one_cycle(10, 1e-3)
    
    learn.destroy() 

epoch,train_loss,valid_loss,pascal_voc_metric_by_distance,BBloss,focal_loss,AP-0,AP-1,AP-2,AP-3,AP-4,time
0,0.48955,0.540574,0.389531,0.087883,0.452691,0.224461,0.111506,0.673068,0.73117,0.207447,24:13
1,0.260245,0.279059,0.558233,0.062843,0.216215,0.264098,0.688167,0.699919,0.844784,0.294197,35:14
2,0.208488,0.236469,0.613664,0.055796,0.180674,0.497984,0.645417,0.836221,0.829168,0.259531,29:39


epoch,train_loss,valid_loss,pascal_voc_metric_by_distance,BBloss,focal_loss,AP-0,AP-1,AP-2,AP-3,AP-4,time
0,0.206753,0.224231,0.643149,0.055981,0.16825,0.469223,0.741796,0.841932,0.773232,0.389561,35:18
1,0.247534,0.43416,0.487602,0.07161,0.362551,0.246748,0.442353,0.794152,0.805994,0.148764,23:46
2,0.250504,0.486991,0.509512,0.106095,0.380896,0.23948,0.591627,0.843349,0.600735,0.272368,25:10
3,0.233274,0.325934,0.588025,0.069423,0.25651,0.414976,0.701856,0.737864,0.820136,0.265294,29:22
