Skip to content

Commit

Permalink
Readme
Browse files Browse the repository at this point in the history
  • Loading branch information
HKervadec committed Jun 15, 2018
1 parent 853fdb6 commit 6f95384
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 6 deletions.
7 changes: 5 additions & 2 deletions losses.py
Expand Up @@ -9,6 +9,11 @@
class Partial_CE(torch.autograd.Function):
def forward(self, input, target, weakLabels):
self.save_for_backward(input, target, weakLabels)
# b, c, w, h = input.shape
# assert target.shape == input.shape
# assert weakLabels.shape == (b, 1, w, h)

# assert np.allclose(input[:, 0, ...].cpu().numpy(), 1 - input[:, 1, ...].cpu().numpy(), atol=1e-2)

eps = 1e-20

Expand All @@ -25,8 +30,6 @@ def forward(self, input, target, weakLabels):
lossT = torch.FloatTensor(1)
lossT.fill_(np.float32(loss).item())

lossT = lossT.cuda()

return lossT.cuda() # a single number (averaged loss over batch samples)

def backward(self, grad_output):
Expand Down
3 changes: 1 addition & 2 deletions main_MIDL.py 100644 → 100755
Expand Up @@ -4,13 +4,12 @@

import torch
import numpy as np
import medicalDataLoader
from torch import nn
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import transforms


import medicalDataLoader
from ENet import ENet
from utils import to_var
from utils import computeDiceOneHotBinary, predToSegmentation, inference, DicesToDice, printProgressBar
Expand Down
4 changes: 2 additions & 2 deletions medicalDataLoader.py
Expand Up @@ -28,7 +28,7 @@ def make_dataset(root, mode):
labels.sort()
labels_weak.sort()

for it_im, it_gt, it_w in zip(images, labels,labels_weak):
for it_im, it_gt, it_w in zip(images, labels, labels_weak):
item = (os.path.join(train_img_path, it_im), os.path.join(train_mask_path, it_gt), os.path.join(train_mask_weak_path, it_w))
items.append(item)

Expand All @@ -45,7 +45,7 @@ def make_dataset(root, mode):
labels.sort()
labels_weak.sort()

for it_im, it_gt, it_w in zip(images, labels,labels_weak):
for it_im, it_gt, it_w in zip(images, labels, labels_weak):
item = (os.path.join(train_img_path, it_im), os.path.join(train_mask_path, it_gt), os.path.join(train_mask_weak_path, it_w))
items.append(item)
else:
Expand Down
23 changes: 23 additions & 0 deletions readme.md
@@ -0,0 +1,23 @@
# Constrained-CNN losses for weakly supervised segmentation
Code of our submission https://openreview.net/forum?id=BkIBHb2sG at MIDL 2018

To run it, simply run `main_MIDL.py` (python3.6+, the requirements are specified in the `requirements.txt` file).

The partial ground truth that we used are provided, but not the original dataset: https://www.creatis.insa-lyon.fr/Challenge/acdc/databases.html
You will need to download and pre-process it yourselves first.

The code was developed for PyTorch 0.3.1, and has been modified slightly to work with PyTorch 0.4. However, a lot of cleanup (removing the variables for instance) still need to be done.

## Loss functions
The loss functions are located in `losses.py`, and are defined as autograd functions. We implemented manually both the forward and the backward passes with numpy. We use a batch size of 1, and the code might need to be modified before working for more.

The inputs (predictions, labels and weak labels) are all represented as 4-D tensors:
```python
b, c, w, h = input.shape
assert target.shape == input.shape
assert weakLabels.shape == (b, 1, w, h)
```
`b` is the batch size, `c` the number of classes (2), and `w, h` the image size. Since this is a binary problem, the two classes are complementary (minus the rounding errors), both for the predictions and labels:
```python
assert np.allclose(input[:, 0, ...].cpu().numpy(), 1 - input[:, 1, ...].cpu().numpy(), atol=1e-2)
```
3 changes: 3 additions & 0 deletions requirements.txt
@@ -0,0 +1,3 @@
numpy
pytorch>=0.4
torchvision

0 comments on commit 6f95384

Please sign in to comment.