-
Notifications
You must be signed in to change notification settings - Fork 212
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
58048b4
commit 5318f86
Showing
2 changed files
with
95 additions
and
7 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
############################################################################### | ||
# BSD 3-Clause License | ||
# | ||
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Author & Contact: Guilin Liu (guilinl@nvidia.com) | ||
############################################################################### | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from torch import nn, cuda | ||
from torch.autograd import Variable | ||
|
||
class PartialConv3d(nn.Conv3d): | ||
def __init__(self, *args, **kwargs): | ||
|
||
# whether the mask is multi-channel or not | ||
if 'multi_channel' in kwargs: | ||
self.multi_channel = kwargs['multi_channel'] | ||
kwargs.pop('multi_channel') | ||
else: | ||
self.multi_channel = False | ||
|
||
if 'return_mask' in kwargs: | ||
self.return_mask = kwargs['return_mask'] | ||
kwargs.pop('return_mask') | ||
else: | ||
self.return_mask = False | ||
|
||
super(PartialConv3d, self).__init__(*args, **kwargs) | ||
|
||
if self.multi_channel: | ||
self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2]) | ||
else: | ||
self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2]) | ||
|
||
self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3] * self.weight_maskUpdater.shape[4] | ||
|
||
self.last_size = (None, None, None) | ||
self.update_mask = None | ||
self.mask_ratio = None | ||
|
||
def forward(self, input, mask_input=None): | ||
|
||
if mask_input is not None or self.last_size != (input.data.shape[2], input.data.shape[3], input.data.shape[4]): | ||
self.last_size = (input.data.shape[2], input.data.shape[3], input.data.shape[4]) | ||
|
||
with torch.no_grad(): | ||
if self.weight_maskUpdater.type() != input.type(): | ||
self.weight_maskUpdater = self.weight_maskUpdater.to(input) | ||
|
||
if mask_input is None: | ||
# if mask is not provided, create a mask | ||
if self.multi_channel: | ||
mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], input.data.shape[3], input.data.shape[4]).to(input) | ||
else: | ||
mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3], input.data.shape[4]).to(input) | ||
else: | ||
mask = mask_input | ||
|
||
self.update_mask = F.conv3d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1) | ||
|
||
self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8) | ||
# self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8) | ||
self.update_mask = torch.clamp(self.update_mask, 0, 1) | ||
self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask) | ||
|
||
# if self.update_mask.type() != input.type() or self.mask_ratio.type() != input.type(): | ||
# self.update_mask.to(input) | ||
# self.mask_ratio.to(input) | ||
|
||
raw_out = super(PartialConv3d, self).forward(torch.mul(input, mask_input) if mask_input is not None else input) | ||
|
||
if self.bias is not None: | ||
bias_view = self.bias.view(1, self.out_channels, 1, 1, 1) | ||
output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view | ||
if mask_input is not None: | ||
output = torch.mul(output, self.update_mask) | ||
else: | ||
output = torch.mul(raw_out, self.mask_ratio) | ||
|
||
|
||
if self.return_mask: | ||
return output, self.update_mask | ||
else: | ||
return output |