In [None]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

In [None]:
import numpy as np
from tqdm import tqdm
from pathlib import Path
import openslide
import pandas as pd

In [None]:
from fastai.callbacks.csv_logger import CSVLogger

In [None]:
from object_detection_fastai.helper.object_detection_helper import *
from object_detection_fastai.helper.wsi_loader import *
from object_detection_fastai.loss.RetinaNetFocalLoss import RetinaNetFocalLoss
from object_detection_fastai.models.RetinaNet import RetinaNet
from object_detection_fastai.callbacks.callbacks import BBLossMetrics, BBMetrics, PascalVOCMetric, PascalVOCMetricByDistance

In [None]:
slides_train = list(set(['BAL_Cat_Turnbull blue_1.svs', 
                              'BAL_Cat_Turnbull blue_2.svs', 
                              'BAL_Cat_Turnbull blue_3.svs',
                              'BAL_Cat_Turnbull blue_4.svs', 
                              'BAL_Cat_Turnbull blue_5.svs', 
                              #'BAL_Cat_Turnbull blue_6.svs',
                              #'BAL_Cat_Turnbull blue_7.svs', 
                              #'BAL_Cat_Turnbull blue_8.svs', 
                              #'BAL_Cat_Turnbull blue_9.svs',
                              #'BAL_Cat_Turnbull blue_10.svs', 
                              #'BAL_Cat_Turnbull blue_11.svs',
                              #'BAL_Cat_Turnbull blue_12.svs',
                              #'BAL_Cat_Turnbull blue_13.svs', 
                              #'BAL_Cat_Turnbull blue_14.svs'
                        ]))

slides_val = list(set(['10120_19 humane BAL Turnbull Blau.svs',
                              '10277_19 humane BAL Turnbull Blau.svs', 
                              '10672_19 humane BAL Turnbull Blau.svs']))

In [None]:
annotations_path = Path("../../Statistics/EIPH_Annotations.pkl")
annotations = pd.read_pickle(annotations_path)
annotations_train = annotations[annotations["image_name"].isin(slides_train)]
annotations_val = annotations[annotations["image_name"].isin(slides_val)]
annotations_train.head()

In [None]:
slides_path = Path("../../Slides")
files = {slide.name: slide for slide in slides_path.rglob("*.svs")  if slide.name in slides_train + slides_val}
files

In [None]:
size = 1024 
level = 0
bs = 16
train_images = 2500
val_images = 1500
experiment_name = "CatVsHuman"

In [None]:
train_files = []
val_files = []

for image_name in annotations_train["image_name"].unique():
    
    annotations = annotations_train[annotations_train["image_name"] == image_name]
    annotations = annotations[annotations["deleted"] == False]
    
    slide_path = files[image_name]
    labels =  list(annotations["grade"])
    bboxes = [[vector["x1"], vector["y1"], vector["x2"], vector["y2"]] for vector in annotations["vector"]]
    
    train_files.append(SlideContainer(slide_path, y=[bboxes, labels],  level=level, width=size, height=size))
    
for image_name in annotations_val["image_name"].unique():
    
    annotations = annotations_val[annotations_val["image_name"] == image_name]
    annotations = annotations[annotations["deleted"] == False]
    
    slide_path = files[image_name]
    labels =  list(annotations["grade"])
    bboxes = [[vector["x1"], vector["y1"], vector["x2"], vector["y2"]] for vector in annotations["vector"]]
    
    val_files.append(SlideContainer(slide_path, y=[bboxes, labels],  level=level, width=size, height=size))
    
train_files = list(np.random.choice(train_files, train_images))
valid_files = list(np.random.choice(val_files, val_images))

In [None]:
tfms = get_transforms(do_flip=True,
                      flip_vert=True,
                      #max_rotate=90,
                      max_lighting=0.0,
                      max_zoom=1.,
                      max_warp=0.0,
                      p_affine=0.5,
                      p_lighting=0.0,
                      #xtra_tfms=xtra_tfms,
                     )
tfms

In [None]:
def get_y_func(x):
    return x.y

In [None]:
train =  ObjectItemListSlide(train_files, path=slides_path)
valid = ObjectItemListSlide(valid_files, path=slides_path)
item_list = ItemLists(slides_path, train, valid)
lls = item_list.label_from_func(get_y_func, label_cls=SlideObjectCategoryList) #
lls = lls.transform(tfms, tfm_y=True, size=size)
data = lls.databunch(bs=bs, collate_fn=bb_pad_collate, num_workers=0).normalize()

In [None]:
data.show_batch(rows=3, ds_type=DatasetType.Train, figsize=(15,15))

In [None]:
anchors = create_anchors(sizes=[(32,32)], ratios=[1], scales=[0.6, 0.7, 0.9, 1.25, 1.5])

In [None]:
crit = RetinaNetFocalLoss(anchors)
encoder = create_body(models.resnet18, True, -2)
model = RetinaNet(encoder, n_classes=data.train_ds.c, n_anchors=5, sizes=[32], chs=128, final_bias=-4., n_conv=3)

In [None]:
voc = PascalVOCMetricByDistance(anchors, size, [str(i) for i in data.train_ds.y.classes[1:]], radius=40)
learn = Learner(data, model, loss_func=crit, callback_fns=[BBMetrics, partial(CSVLogger, append=True, filename=experiment_name)], #BBMetrics, ShowGraph
                metrics=[voc]
               )

In [None]:
learn.split([model.encoder[6], model.c5top5])
learn.freeze_to(-2)

In [None]:
learn.lr_find()
learn.recorder.plot()

In [None]:
learn.fit_one_cycle(3, 1e-3)

In [None]:
learn.unfreeze()
learn.fit_one_cycle(10, 1e-3)

In [None]:
learn.fit_one_cycle(10, 1e-3)

In [None]:
import pickle

stats = {"anchors": anchors,
         "mean": to_np(data.stats[0]),
         "std": to_np(data.stats[1]),
         "size": size,
         "n_classes": 6,
         "n_anchors": 5,
         "sizes": [32],
         "chs": 128,
         "encoder": "RN-18",
         "n_conv": 3,
         "level": 0,
         "model": get_model(learn.model).state_dict()
        }

torch.save(stats, "{}.p".format(experiment_name))

In [None]:
slide_object_result(learn, anchors, detect_thresh=0.3, nms_thresh=0.2, image_count=20)