In [1]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import sqlite3
import numpy as np
from SlideRunner.dataAccess.database import Database
from tqdm import tqdm
from pathlib import Path
import openslide
import time
import pickle
import cv2

In [3]:
import torchvision.transforms as transforms

In [4]:
from fastai import *
from fastai.vision import *
from fastai.callbacks import *

from data_loader import *

from helper.object_detection_helper import *
from loss.RetinaNetFocalLoss import RetinaNetFocalLoss
from models.RetinaNet import RetinaNet

In [5]:
path = Path('/data/Datasets/EIPH_WSI/')

database = Database()
database.open(str(path/'EIPH.sqlite'))

size = 1024
level = 0

files = []

In [10]:
getslides = """SELECT uid, filename FROM Slides"""
for currslide, filename in tqdm(database.execute(getslides).fetchall()):
    database.loadIntoMemory(currslide)

    check = True if 'erliner' in filename else False
    slidetype = 'Berliner Blau/' if check else 'Turnbull Blue/'

    slide_path = path / slidetype / filename

    slide = openslide.open_slide(str(slide_path))
    level = level#slide.level_count - 1
    level_dimension = slide.level_dimensions[level]
    down_factor = slide.level_downsamples[level]

    files.append(SlideContainer(slide_path,[[0], [1]], level, size, size))

  0%|          | 0/24 [00:00<?, ?it/s]

Loading DB into memory ...


  4%|▍         | 1/24 [00:00<00:08,  2.77it/s]

Loading DB into memory ...


  8%|▊         | 2/24 [00:01<00:15,  1.45it/s]

Loading DB into memory ...


 12%|█▎        | 3/24 [00:02<00:13,  1.52it/s]

Loading DB into memory ...


 21%|██        | 5/24 [00:03<00:08,  2.15it/s]

Loading DB into memory ...
Loading DB into memory ...
Loading DB into memory ...


 33%|███▎      | 8/24 [00:03<00:05,  2.84it/s]

Loading DB into memory ...


 38%|███▊      | 9/24 [00:03<00:04,  3.50it/s]

Loading DB into memory ...
Loading DB into memory ...


 46%|████▌     | 11/24 [00:04<00:02,  4.38it/s]

Loading DB into memory ...
Loading DB into memory ...


 58%|█████▊    | 14/24 [00:04<00:02,  4.83it/s]

Loading DB into memory ...
Loading DB into memory ...
Loading DB into memory ...
Loading DB into memory ...


 67%|██████▋   | 16/24 [00:05<00:02,  3.82it/s]

Loading DB into memory ...


 75%|███████▌  | 18/24 [00:05<00:01,  4.45it/s]

Loading DB into memory ...
Loading DB into memory ...


 79%|███████▉  | 19/24 [00:05<00:00,  5.18it/s]

Loading DB into memory ...
Loading DB into memory ...


 88%|████████▊ | 21/24 [00:06<00:00,  4.26it/s]

Loading DB into memory ...
Loading DB into memory ...


100%|██████████| 24/24 [00:07<00:00,  3.09it/s]

Loading DB into memory ...





In [11]:
fname = "pferd_0_1024_reg.pth"

state = torch.load(Path(path) / fname, map_location='cpu') \
    if defaults.device == torch.device('cpu') \
    else torch.load(Path(path) / fname)
model = state.pop('model')
mean = state['data']['normalize']['mean']
std = state['data']['normalize']['std']

In [12]:
anchors = create_anchors(sizes=[(32,32)], ratios=[1], scales=[0.6, 0.7, 0.9, 1.25, 1.5])
detect_thresh = 0.1 
nms_thresh = 0.3
result_boxes = {}
result_regression = {}

In [13]:
def rescale_box(bboxes, size: Tensor):
    bboxes[:, :2] = bboxes[:, :2] - bboxes[:, 2:] / 2
    bboxes[:, :2] = (bboxes[:, :2] + 1) * size / 2
    bboxes[:, 2:] = bboxes[:, 2:] * size / 2
    bboxes = bboxes.long()
    return bboxes

In [14]:
debug_level = 1
with torch.no_grad():
    for slide_container in tqdm(files):

        size = state['data']['tfmargs']['size']
        result_boxes[slide_container.file.name] = []
        result_regression[slide_container.file.name] = []

        basepic = np.array(slide_container.slide.read_region(location=(0, 0),
                                                             level=debug_level,
                                                             size=slide_container.slide.level_dimensions[debug_level]))
        basepic = basepic[:, :, :3].astype(np.uint8)

        for x in range(0, slide_container.slide.level_dimensions[level][1] - 2 * size, int(size / 2)):
            for y in range(0, slide_container.slide.level_dimensions[level][0] - 2 * size, int(size / 2)):
                x_real = x  # * slide_container.down_factor, \
                y_real = y  # * slide_container.down_factor

                patch_ori = slide_container.get_patch(x, y)
                patch = pil2tensor(patch_ori / 255., np.float32)

                patch = transforms.Normalize(mean, std)(patch)

                class_pred_batch, bbox_pred_batch, _, regression_pred, bbox_regression_pred = model(
                    patch[None, :, :, :])
                for clas_pred, bbox_pred, reg_pred, box_reg_pred in zip(class_pred_batch, bbox_pred_batch,
                                                                        regression_pred, bbox_regression_pred):

                    result_regression[slide_container.file.name].append(
                        np.array([x_real, y_real, x_real + size, y_real + size, reg_pred]))
                    bbox_pred, scores, preds = process_output(clas_pred, bbox_pred, anchors, detect_thresh)

                    if bbox_pred is not None:
                        to_keep = nms(bbox_pred, scores, nms_thresh)
                        bbox_pred, preds, scores = bbox_pred[to_keep].cpu(), preds[to_keep].cpu(), scores[to_keep].cpu()
                        box_reg_pred = box_reg_pred[to_keep].cpu()

                        t_sz = torch.Tensor([size, size])[None].float()

                        bbox_pred = rescale_box(bbox_pred, t_sz)

                        patch_ori = patch_ori.astype(np.uint8)
                        for box, pred, score, bb_reg in zip(bbox_pred, preds, scores, box_reg_pred):
                            y_box, x_box = box[:2]
                            h, w = box[2:4]

                            result_boxes[slide_container.file.name].append(np.array([x_box + x_real, y_box + y_real,
                                                                                     x_box + x_real + w, y_box + y_real + h,
                                                                                     pred, score, bb_reg]))

                            cv2.rectangle(patch_ori, (int(x_box), int(y_box)), (int(x_box + w), int(y_box + h)),
                                          (0, 0, 255), 1)

                            y_box, x_box = box[:2] / slide.level_downsamples[debug_level]
                            h, w = box[2:4] / slide.level_downsamples[debug_level]
                            temp_x_real = x_real / slide.level_downsamples[debug_level]
                            temp_y_real = y_real / slide.level_downsamples[debug_level]

                            cv2.rectangle(basepic, (int(x_box + temp_x_real), int(y_box + temp_y_real)),
                                          (int(x_box + temp_x_real + w), int(y_box + temp_y_real + h)), (255, 0, 0), 1)

        cv2.imwrite("/server/born_pix_cm/{}.png".format(slide_container.file.stem), basepic[:, :, [2, 1, 0]])
        pickle.dump(result_boxes, open("inference_results_boxes.p", "wb"))
        pickle.dump(result_regression, open("inference_result_regression.p", "wb"))


100%|██████████| 24/24 [18:30:57<00:00, 2636.34s/it]  
