Skip to content

Commit

Permalink
Merge pull request #19 from achaiah/WIP
Browse files Browse the repository at this point in the history
Wip
  • Loading branch information
achaiah committed May 6, 2021
2 parents b7acf4d + 6c19b15 commit 0127458
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 23 deletions.
4 changes: 4 additions & 0 deletions README.md
Expand Up @@ -39,6 +39,10 @@ have [docs](https://pywick.readthedocs.io/en/latest/)! They're still a
work in progress though so apologies for anything that's broken.

## What's New (highlights)
- **May 6, 2021**
- Many SoTA classification and segmentation models added: Swin-Transformer variants, NFNet variants (L0, L1), Halo nets, Lambda nets, ECA variants, Rexnet + others
- Many new loss functions added: RecallLoss, SoftInvDiceLoss, OhemBCEDicePenalizeBorderLoss, RMIBCEDicePenalizeBorderLoss + others
- Bug fixes
- **Jun. 15, 2020**
- 200+ models added from [rwightman's](https://github.com/rwightman/pytorch-image-models) repo via `torch.hub`! See docs for all the variants!
- Some minor bug fixes
Expand Down
4 changes: 4 additions & 0 deletions docs/source/README.md
Expand Up @@ -39,6 +39,10 @@ have [docs](https://pywick.readthedocs.io/en/latest/)! They're still a
work in progress though so apologies for anything that's broken.

## What's New (highlights)
- **May 6, 2021**
- Many SoTA classification and segmentation models added: Swin-Transformer variants, NFNets variants (L0, L1), Halo nets, Lambda nets, ECA variants + others
- Many new loss functions added: RecallLoss, SoftInvDiceLoss, OhemBCEDicePenalizeBorderLoss, RMIBCEDicePenalizeBorderLoss + others
- Bug fixes
- **Jun. 15, 2020**
- 200+ models added from [rwightman's](https://github.com/rwightman/pytorch-image-models) repo via `torch.hub`! See docs for all the variants!
- Some minor bug fixes
Expand Down
82 changes: 63 additions & 19 deletions pywick/losses.py
Expand Up @@ -1826,9 +1826,44 @@ def get_tp_fp_fn(net_output, gt, axes=None, mask=None, square=False):

return tp, fp, fn

# ===================== #
# Boundary Loss
# Source: https://github.com/JunMa11/SegLoss/blob/71b14900e91ea9405d9705c95b451fc819f24c70/test/loss_functions/boundary_loss.py#L102

def compute_sdf(img_gt, out_shape):
"""
compute the signed distance map of binary mask
input: segmentation, shape = (batch_size, x, y, z)
output: the Signed Distance Map (SDM)
sdf(x) = 0; x in segmentation boundary
-inf|x-y|; x in segmentation
+inf|x-y|; x out of segmentation
"""

from scipy.ndimage import distance_transform_edt
from skimage import segmentation as skimage_seg

img_gt = img_gt.astype(np.uint8)

gt_sdf = np.zeros(out_shape)

for b in range(out_shape[0]): # batch size
for c in range(1, out_shape[1]): # channel
posmask = img_gt[b][c].astype(np.bool)
if posmask.any():
negmask = ~posmask
posdis = distance_transform_edt(posmask)
negdis = distance_transform_edt(negmask)
boundary = skimage_seg.find_boundaries(posmask, mode='inner').astype(np.uint8)
sdf = negdis - posdis
sdf[boundary==1] = 0
gt_sdf[b][c] = sdf

return gt_sdf


class BDLoss(nn.Module):
def __init__(self, **kwargs):
def __init__(self):
"""
compute boudary loss
only compute the loss of foreground
Expand All @@ -1837,26 +1872,35 @@ def __init__(self, **kwargs):
super(BDLoss, self).__init__()
# self.do_bg = do_bg

def forward(self, logits, target, bound):
def forward(self, net_output, gt):
"""
Takes 2D or 3D logits.
logits: (batch_size, class, x,y,(z))
target: ground truth, shape: (batch_size, 1, x,y,(z))
bound: precomputed distance map, shape (batch_size, class, x,y,(z))
Torch Eigensum description: https://stackoverflow.com/questions/55894693/understanding-pytorch-einsum
net_output: (batch_size, class, x,y,z)
target: ground truth, shape: (batch_size, 1, x,y,z)
bound: precomputed distance map, shape (batch_size, class, x,y,z)
"""
compute_directive = "bcxy,bcxy->bcxy"
if len(logits) == 5:
compute_directive = "bcxyz,bcxyz->bcxyz"

net_output = softmax_helper(logits)
# print('net_output shape: ', net_output.shape)
pc = net_output[:, 1:, ...].type(torch.float32)
dc = bound[:,1:, ...].type(torch.float32)

multipled = torch.einsum(compute_directive, pc, dc)
net_output = softmax_helper(net_output)
with torch.no_grad():
if len(net_output.shape) != len(gt.shape):
gt = gt.view((gt.shape[0], 1, *gt.shape[1:]))

if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
# if this is the case then gt is probably already a one hot encoding
y_onehot = gt
else:
gt = gt.long()
y_onehot = torch.zeros(net_output.shape)
if net_output.device.type == "cuda":
y_onehot = y_onehot.cuda(net_output.device.index)
y_onehot.scatter_(1, gt, 1)
gt_sdf = compute_sdf(y_onehot.cpu().numpy(), net_output.shape)

phi = torch.from_numpy(gt_sdf)
if phi.device != net_output.device:
phi = phi.to(net_output.device).type(torch.float32)
# pred = net_output[:, 1:, ...].type(torch.float32)
# phi = phi[:,1:, ...].type(torch.float32)

multipled = torch.einsum("bcxyz,bcxyz->bcxyz", net_output[:, 1:, ...], phi[:, 1:, ...])
bd_loss = multipled.mean()

return bd_loss
Expand Down
20 changes: 16 additions & 4 deletions pywick/models/model_utils.py
@@ -1,3 +1,5 @@
from typing import Callable

from . import classification
from .segmentation import *
from . import segmentation
Expand Down Expand Up @@ -54,7 +56,13 @@ def get_fc_names(model_name, model_type=ModelType.CLASSIFICATION):
return [None]


def get_model(model_type, model_name, num_classes, pretrained=True, force_reload=False, **kwargs):
def get_model(model_type: ModelType,
model_name: str,
num_classes: int,
pretrained: bool = True,
force_reload: bool = False,
custom_load_fn: Callable = None,
**kwargs):
"""
:param model_type: (ModelType):
type of model we're trying to obtain (classification or segmentation)
Expand All @@ -72,20 +80,24 @@ def get_model(model_type, model_name, num_classes, pretrained=True, force_reload
:param force_reload: (bool):
Whether to force reloading the list of models from torch.hub. By default, a cache file is used if it is found locally and that can prevent
new or updated models from being found.
:param custom_load_fn: (Callable):
A custom callable function to use for loading models (typically used to load cutting-edge or custom models that are not in the publicly available list)
:return: model
"""

if model_name not in get_supported_models(model_type) and not model_name.startswith('TEST'):
raise ValueError('The supplied model name: {} was not found in the list of acceptable model names.'
' Use get_supported_models() to obtain a list of supported models.'.format(model_name))
if model_name not in get_supported_models(model_type) and not model_name.startswith('TEST') and custom_load_fn is None:
raise ValueError(f'The supplied model name: {model_name} was not found in the list of acceptable model names.'
' Use get_supported_models() to obtain a list of supported models or supply a custom_load_fn')

print("INFO: Loading Model: -- " + model_name + " with number of classes: " + str(num_classes))

if model_type == ModelType.CLASSIFICATION:
torch_hub_names = torch.hub.list(rwightman_repo, force_reload=force_reload)
if model_name in torch_hub_names:
model = torch.hub.load(rwightman_repo, model_name, pretrained=pretrained, num_classes=num_classes)
elif custom_load_fn is not None:
model = custom_load_fn(model_name, pretrained, num_classes, **kwargs)
else:
# 1. Load model (pretrained or vanilla)
import ssl
Expand Down
2 changes: 2 additions & 0 deletions requirements.txt
Expand Up @@ -4,6 +4,8 @@ numpy
pandas
pillow
pyyaml
scipy
scikit-image
six
torch >= 1.4.0
torchvision
Expand Down

0 comments on commit 0127458

Please sign in to comment.