-
Notifications
You must be signed in to change notification settings - Fork 297
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
149 changed files
with
11,736 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
|
||
<p align="center"><img src="assets/mdt_logo_2.png" width=450></p><br> | ||
|
||
## Overview | ||
This is a fully automated framework for object detection featuring: | ||
- 2D + 3D implementations of prevalent object detectors: e.g. Mask R-CNN [1], Retina Net [2], Retina U-Net [3]. | ||
- Modular and light-weight structure ensuring sharing of all processing steps (incl. backbone architecture) for comparability of models. | ||
- training with bounding box and/or pixel-wise annotations. | ||
- dynamic patching and tiling of 2D + 3D images (for training and inference). | ||
- weighted consolidation of box predictions across patch-overlaps, ensembles, and dimensions [3]. | ||
- monitoring + evaluation simultaneously on object and patient level. | ||
- 2D + 3D output visualizations. | ||
- integration of COCO mean average precision metric [5]. | ||
- integration of MIC-DKFZ batch generators for extensive data augmentation [6]. | ||
- easy modification to evaluation of instance segmentation and/or semantic segmentation. | ||
<br/> | ||
[1] He, Kaiming, et al. <a href="https://arxiv.org/abs/1703.06870">"Mask R-CNN"</a> ICCV, 2017<br> | ||
[2] Lin, Tsung-Yi, et al. <a href="https://arxiv.org/abs/1708.02002">"Focal Loss for Dense Object Detection"</a> TPAMI, 2018.<br> | ||
[3] Jaeger, Paul et al. <a href="http://arxiv.org/abs/1811.08661"> "Retina U-Net: Embarrassingly Simple Exploitation | ||
of Segmentation Supervision for Medical Object Detection" </a>, 2018 | ||
|
||
[5] https://github.com/cocodataset/cocoapi/blob/master/PythonAPI/pycocotools/cocoeval.py<br/> | ||
[6] https://github.com/MIC-DKFZ/batchgenerators<br/><br> | ||
|
||
## Installation | ||
Setup package in virtual environment | ||
``` | ||
git clone https://github.com/pfjaeger/medicaldetectiontoolkit.git . | ||
cd medicaldetectiontoolkit | ||
virtualenv -p python3 venv | ||
source venv/bin/activate | ||
pip3 install -e . | ||
``` | ||
Install MIC-DKFZ batch-generators | ||
``` | ||
cd .. | ||
git clone https://github.com/MIC-DKFZ/batchgenerators | ||
cd batchgenerators | ||
pip3 install -e . | ||
cd mdt | ||
``` | ||
|
||
## Prepare the Data | ||
This framework is meant for you to be able to train models on your own data sets. | ||
An example data loader is provided in medicaldetectiontoolkit/experiments including thorough documentation to ensure a quick start for your own project. | ||
|
||
## Execute | ||
1. Set I/O paths, model and training specifics in the configs file: medicaldetectiontoolkit/experiments/your_experiment/configs.py | ||
2. Train the model: | ||
|
||
``` | ||
python exec.py --mode train --exp_source experiments/my_experiment --exp_dir path/to/experiment/directory | ||
``` | ||
This copies snapshots of configs and model to the specified exp_dir, where all outputs will be saved. By default, the data is split into 60% training and 20% validation and 20% testing data to perform a 5-fold cross validation (can be changed to hold-out test set in configs) and all folds will be trained iteratively. In order to train a single fold, specify it using the folds arg: | ||
``` | ||
python exec.py --folds 0 1 2 .... # specify any combination of folds [0-4] | ||
``` | ||
3. Run inference: | ||
``` | ||
python exec.py --mode test --exp_dir path/to/experiment/directory | ||
``` | ||
This runs the prediction pipeline and saves all results to exp_dir. | ||
|
||
|
||
## Models | ||
|
||
This framework features all models explored in [3] (implemented in 2D + 3D): The proposed Retina U-Net, a simple but effective Architecture fusing state-of-the-art semantic segmentation with object detection,<br><br> | ||
<p align="center"><img src="assets/retu_figure.png" width=50%></p><br> | ||
also implementations of prevalent object detectors, such as Mask R-CNN, Faster R-CNN+ (Faster R-CNN w\ RoIAlign), Retina Net, U-Faster R-CNN+ (the two stage counterpart of Retina U-Net: Faster R-CNN with auxiliary semantic segmentation), DetU-Net (a U-Net like segmentation architecture with heuristics for object detection.)<br><br><br> | ||
<p align="center"><img src="assets/baseline_figure.png" width=85%></p><br> | ||
|
||
## Training annotations | ||
This framework features training with pixelwise and/or bounding box annotations. To overcome the issue of box coordinates in | ||
data augmentation, we feed the annotation masks through data augmentation (create a pseudo mask, if only bounding box annotations provided) and draw the boxes afterwards.<br><br> | ||
<p align="center"><img src="assets/annotations.png" width=85%></p><br> | ||
|
||
|
||
## Prediction pipeline | ||
This framework provides an inference module, which automatically handles patching of inputs, and tiling, ensembling, and weighted consolidation of output predictions:<br><br><br> | ||
<img src="assets/prediction_pipeline.png" ><br><br> | ||
|
||
|
||
## Consolidation of predictions (Weighted Box Clustering) | ||
Multiple predictions of the same image (from test time augmentations, tested epochs and overlapping patches), result in a high amount of boxes (or cubes), which need to be consolidated. In semantic segmentation, the final output would typically be obtained by averaging every pixel over all predictions. As described in [3], **weighted box clustering** (WBC) does this for box predictions:<br> | ||
<p align="center"><img src="assets/wcs_text.png" width=650><br><br></p> | ||
<p align="center"><img src="assets/wcs_readme.png" width=800><br><br></p> | ||
|
||
|
||
|
||
## Visualization / Monitoring | ||
By default, loss functions and performance metrics are monitored:<br><br><br> | ||
<img src="assets/loss_monitoring.png" width=700><br> | ||
<hr> | ||
Histograms of matched output predictions for training/validation/testing are plotted per foreground class:<br><br><br> | ||
<img src="assets/hist_example.png" width=550> | ||
<hr> | ||
Input images + ground truth annotations + output predictions of a sampled validation abtch are plotted after each epoch (here 2D sampled slice with +-3 neighbouring context slices in channels):<br><br><br> | ||
<img src="assets/output_monitoring_1.png" width=750> | ||
<hr> | ||
Zoomed into the last two lines of the plot:<br><br><br> | ||
<img src="assets/output_monitoring_2.png" width=700> | ||
|
||
## How to cite this code | ||
Please cite the original publication [3]. | ||
|
||
## License | ||
The code is published under the [Apache License Version 2.0](LICENSE). | ||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
[Dolphin] | ||
Timestamp=2018,11,4,16,51,18 | ||
Version=3 | ||
ViewMode=1 |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Empty file.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
|
||
from torch.utils.ffi import _wrap_function | ||
from ._nms import lib as _lib, ffi as _ffi | ||
|
||
__all__ = [] | ||
def _import_symbols(locals): | ||
for symbol in dir(_lib): | ||
fn = getattr(_lib, symbol) | ||
if callable(fn): | ||
locals[symbol] = _wrap_function(fn, _ffi) | ||
else: | ||
locals[symbol] = fn | ||
__all__.append(symbol) | ||
|
||
_import_symbols(locals()) |
Binary file not shown.
Binary file not shown.
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import os | ||
import torch | ||
from torch.utils.ffi import create_extension | ||
|
||
|
||
sources = ['src/nms.c'] | ||
headers = ['src/nms.h'] | ||
defines = [] | ||
with_cuda = False | ||
|
||
if torch.cuda.is_available(): | ||
print('Including CUDA code.') | ||
sources += ['src/nms_cuda.c'] | ||
headers += ['src/nms_cuda.h'] | ||
defines += [('WITH_CUDA', None)] | ||
with_cuda = True | ||
|
||
this_file = os.path.dirname(os.path.realpath(__file__)) | ||
print(this_file) | ||
extra_objects = ['src/cuda/nms_kernel.cu.o'] | ||
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] | ||
|
||
ffi = create_extension( | ||
'_ext.nms', | ||
headers=headers, | ||
sources=sources, | ||
define_macros=defines, | ||
relative_to=__file__, | ||
with_cuda=with_cuda, | ||
extra_objects=extra_objects | ||
) | ||
|
||
if __name__ == '__main__': | ||
ffi.build() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
import torch | ||
from ._ext import nms | ||
|
||
|
||
def nms_gpu(dets, thresh): | ||
""" | ||
dets has to be a tensor | ||
""" | ||
|
||
scores = dets[:, 4] | ||
order = scores.sort(0, descending=True)[1] | ||
dets = dets[order].contiguous() | ||
|
||
keep = torch.LongTensor(dets.size(0)) | ||
num_out = torch.LongTensor(1) | ||
nms.gpu_nms(keep, num_out, dets, thresh) | ||
return order[keep[:num_out[0]].cuda()].contiguous() | ||
|
||
|
||
|
||
def nms_cpu(dets, thresh): | ||
|
||
dets = dets.cpu() | ||
x1 = dets[:, 0] | ||
y1 = dets[:, 1] | ||
x2 = dets[:, 2] | ||
y2 = dets[:, 3] | ||
scores = dets[:, 4] | ||
|
||
areas = (x2 - x1 + 1) * (y2 - y1 + 1) | ||
order = scores.sort(0, descending=True)[1] | ||
# order = torch.from_numpy(np.ascontiguousarray(scores.numpy().argsort()[::-1])).long() | ||
|
||
keep = torch.LongTensor(dets.size(0)) | ||
num_out = torch.LongTensor(1) | ||
nms.cpu_nms(keep, num_out, dets, order, areas, thresh) | ||
|
||
return keep[:num_out[0]] | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
// ------------------------------------------------------------------ | ||
// Faster R-CNN | ||
// Copyright (c) 2015 Microsoft | ||
// Licensed under The MIT License [see fast-rcnn/LICENSE for details] | ||
// Written by Shaoqing Ren | ||
// ------------------------------------------------------------------ | ||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
#include <math.h> | ||
#include <stdio.h> | ||
#include <float.h> | ||
#include "nms_kernel.h" | ||
|
||
__device__ inline float devIoU(float const * const a, float const * const b) { | ||
float left = fmaxf(a[0], b[0]), right = fminf(a[2], b[2]); | ||
float top = fmaxf(a[1], b[1]), bottom = fminf(a[3], b[3]); | ||
float width = fmaxf(right - left + 1, 0.f), height = fmaxf(bottom - top + 1, 0.f); | ||
float interS = width * height; | ||
float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); | ||
float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); | ||
return interS / (Sa + Sb - interS); | ||
} | ||
|
||
__global__ void nms_kernel(const int n_boxes, const float nms_overlap_thresh, | ||
const float *dev_boxes, unsigned long long *dev_mask) { | ||
const int row_start = blockIdx.y; | ||
const int col_start = blockIdx.x; | ||
|
||
// if (row_start > col_start) return; | ||
|
||
const int row_size = | ||
fminf(n_boxes - row_start * threadsPerBlock, threadsPerBlock); | ||
const int col_size = | ||
fminf(n_boxes - col_start * threadsPerBlock, threadsPerBlock); | ||
|
||
__shared__ float block_boxes[threadsPerBlock * 5]; | ||
if (threadIdx.x < col_size) { | ||
block_boxes[threadIdx.x * 5 + 0] = | ||
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 0]; | ||
block_boxes[threadIdx.x * 5 + 1] = | ||
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 1]; | ||
block_boxes[threadIdx.x * 5 + 2] = | ||
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 2]; | ||
block_boxes[threadIdx.x * 5 + 3] = | ||
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 3]; | ||
block_boxes[threadIdx.x * 5 + 4] = | ||
dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 5 + 4]; | ||
} | ||
__syncthreads(); | ||
|
||
if (threadIdx.x < row_size) { | ||
const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; | ||
const float *cur_box = dev_boxes + cur_box_idx * 5; | ||
int i = 0; | ||
unsigned long long t = 0; | ||
int start = 0; | ||
if (row_start == col_start) { | ||
start = threadIdx.x + 1; | ||
} | ||
for (i = start; i < col_size; i++) { | ||
if (devIoU(cur_box, block_boxes + i * 5) > nms_overlap_thresh) { | ||
t |= 1ULL << i; | ||
} | ||
} | ||
const int col_blocks = DIVUP(n_boxes, threadsPerBlock); | ||
dev_mask[cur_box_idx * col_blocks + col_start] = t; | ||
} | ||
} | ||
|
||
|
||
void _nms(int boxes_num, float * boxes_dev, | ||
unsigned long long * mask_dev, float nms_overlap_thresh) { | ||
|
||
dim3 blocks(DIVUP(boxes_num, threadsPerBlock), | ||
DIVUP(boxes_num, threadsPerBlock)); | ||
dim3 threads(threadsPerBlock); | ||
nms_kernel<<<blocks, threads>>>(boxes_num, | ||
nms_overlap_thresh, | ||
boxes_dev, | ||
mask_dev); | ||
} | ||
|
||
#ifdef __cplusplus | ||
} | ||
#endif |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
#ifndef _NMS_KERNEL | ||
#define _NMS_KERNEL | ||
|
||
#ifdef __cplusplus | ||
extern "C" { | ||
#endif | ||
|
||
#define DIVUP(m,n) ((m) / (n) + ((m) % (n) > 0)) | ||
int const threadsPerBlock = sizeof(unsigned long long) * 8; | ||
|
||
void _nms(int boxes_num, float * boxes_dev, | ||
unsigned long long * mask_dev, float nms_overlap_thresh); | ||
|
||
#ifdef __cplusplus | ||
} | ||
#endif | ||
|
||
#endif | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
#include <TH/TH.h> | ||
#include <math.h> | ||
|
||
int cpu_nms(THLongTensor * keep_out, THLongTensor * num_out, THFloatTensor * boxes, THLongTensor * order, THFloatTensor * areas, float nms_overlap_thresh) { | ||
// boxes has to be sorted | ||
THArgCheck(THLongTensor_isContiguous(keep_out), 0, "keep_out must be contiguous"); | ||
THArgCheck(THLongTensor_isContiguous(boxes), 2, "boxes must be contiguous"); | ||
THArgCheck(THLongTensor_isContiguous(order), 3, "order must be contiguous"); | ||
THArgCheck(THLongTensor_isContiguous(areas), 4, "areas must be contiguous"); | ||
// Number of ROIs | ||
long boxes_num = THFloatTensor_size(boxes, 0); | ||
long boxes_dim = THFloatTensor_size(boxes, 1); | ||
|
||
long * keep_out_flat = THLongTensor_data(keep_out); | ||
float * boxes_flat = THFloatTensor_data(boxes); | ||
long * order_flat = THLongTensor_data(order); | ||
float * areas_flat = THFloatTensor_data(areas); | ||
|
||
THByteTensor* suppressed = THByteTensor_newWithSize1d(boxes_num); | ||
THByteTensor_fill(suppressed, 0); | ||
unsigned char * suppressed_flat = THByteTensor_data(suppressed); | ||
|
||
// nominal indices | ||
int i, j; | ||
// sorted indices | ||
int _i, _j; | ||
// temp variables for box i's (the box currently under consideration) | ||
float ix1, iy1, ix2, iy2, iarea; | ||
// variables for computing overlap with box j (lower scoring box) | ||
float xx1, yy1, xx2, yy2; | ||
float w, h; | ||
float inter, ovr; | ||
|
||
long num_to_keep = 0; | ||
for (_i=0; _i < boxes_num; ++_i) { | ||
i = order_flat[_i]; | ||
if (suppressed_flat[i] == 1) { | ||
continue; | ||
} | ||
keep_out_flat[num_to_keep++] = i; | ||
ix1 = boxes_flat[i * boxes_dim]; | ||
iy1 = boxes_flat[i * boxes_dim + 1]; | ||
ix2 = boxes_flat[i * boxes_dim + 2]; | ||
iy2 = boxes_flat[i * boxes_dim + 3]; | ||
iarea = areas_flat[i]; | ||
for (_j = _i + 1; _j < boxes_num; ++_j) { | ||
j = order_flat[_j]; | ||
if (suppressed_flat[j] == 1) { | ||
continue; | ||
} | ||
xx1 = fmaxf(ix1, boxes_flat[j * boxes_dim]); | ||
yy1 = fmaxf(iy1, boxes_flat[j * boxes_dim + 1]); | ||
xx2 = fminf(ix2, boxes_flat[j * boxes_dim + 2]); | ||
yy2 = fminf(iy2, boxes_flat[j * boxes_dim + 3]); | ||
w = fmaxf(0.0, xx2 - xx1 + 1); | ||
h = fmaxf(0.0, yy2 - yy1 + 1); | ||
inter = w * h; | ||
ovr = inter / (iarea + areas_flat[j] - inter); | ||
if (ovr >= nms_overlap_thresh) { | ||
suppressed_flat[j] = 1; | ||
} | ||
} | ||
} | ||
|
||
long *num_out_flat = THLongTensor_data(num_out); | ||
*num_out_flat = num_to_keep; | ||
THByteTensor_free(suppressed); | ||
return 1; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
int cpu_nms(THLongTensor * keep_out, THLongTensor * num_out, THFloatTensor * boxes, THLongTensor * order, THFloatTensor * areas, float nms_overlap_thresh); |
Oops, something went wrong.