Skip to content

Commit

Permalink
add partialconv3d
Browse files Browse the repository at this point in the history
  • Loading branch information
liuguilin1225 committed Feb 14, 2019
1 parent 58048b4 commit 5318f86
Show file tree
Hide file tree
Showing 2 changed files with 95 additions and 7 deletions.
16 changes: 9 additions & 7 deletions models/partialconv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,23 @@ def __init__(self, *args, **kwargs):
self.update_mask = None
self.mask_ratio = None

def forward(self, input, mask=None):
def forward(self, input, mask_in=None):

if mask is not None or self.last_size != (input.data.shape[2], input.data.shape[3]):
if mask_in is not None or self.last_size != (input.data.shape[2], input.data.shape[3]):
self.last_size = (input.data.shape[2], input.data.shape[3])

with torch.no_grad():
if self.weight_maskUpdater.type() != input.type():
self.weight_maskUpdater = self.weight_maskUpdater.to(input)

if mask is None:
if mask_in is None:
# if mask is not provided, create a mask
if self.multi_channel:
mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], input.data.shape[3]).to(input)
else:
mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3]).to(input)
else:
mask = mask_in

self.update_mask = F.conv2d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)

Expand All @@ -63,11 +65,11 @@ def forward(self, input, mask=None):
self.update_mask = torch.clamp(self.update_mask, 0, 1)
self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)

if self.update_mask.type() != input.type() or self.mask_ratio.type() != input.type():
self.update_mask.to(input)
self.mask_ratio.to(input)
# if self.update_mask.type() != input.type() or self.mask_ratio.type() != input.type():
# self.update_mask.to(input)
# self.mask_ratio.to(input)

raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask is not None else input)
raw_out = super(PartialConv2d, self).forward(torch.mul(input, mask) if mask_in is not None else input)

if self.bias is not None:
bias_view = self.bias.view(1, self.out_channels, 1, 1)
Expand Down
86 changes: 86 additions & 0 deletions models/partialconv3d.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
###############################################################################
# BSD 3-Clause License
#
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
#
# Author & Contact: Guilin Liu (guilinl@nvidia.com)
###############################################################################

import torch
import torch.nn.functional as F
from torch import nn, cuda
from torch.autograd import Variable

class PartialConv3d(nn.Conv3d):
def __init__(self, *args, **kwargs):

# whether the mask is multi-channel or not
if 'multi_channel' in kwargs:
self.multi_channel = kwargs['multi_channel']
kwargs.pop('multi_channel')
else:
self.multi_channel = False

if 'return_mask' in kwargs:
self.return_mask = kwargs['return_mask']
kwargs.pop('return_mask')
else:
self.return_mask = False

super(PartialConv3d, self).__init__(*args, **kwargs)

if self.multi_channel:
self.weight_maskUpdater = torch.ones(self.out_channels, self.in_channels, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2])
else:
self.weight_maskUpdater = torch.ones(1, 1, self.kernel_size[0], self.kernel_size[1], self.kernel_size[2])

self.slide_winsize = self.weight_maskUpdater.shape[1] * self.weight_maskUpdater.shape[2] * self.weight_maskUpdater.shape[3] * self.weight_maskUpdater.shape[4]

self.last_size = (None, None, None)
self.update_mask = None
self.mask_ratio = None

def forward(self, input, mask_input=None):

if mask_input is not None or self.last_size != (input.data.shape[2], input.data.shape[3], input.data.shape[4]):
self.last_size = (input.data.shape[2], input.data.shape[3], input.data.shape[4])

with torch.no_grad():
if self.weight_maskUpdater.type() != input.type():
self.weight_maskUpdater = self.weight_maskUpdater.to(input)

if mask_input is None:
# if mask is not provided, create a mask
if self.multi_channel:
mask = torch.ones(input.data.shape[0], input.data.shape[1], input.data.shape[2], input.data.shape[3], input.data.shape[4]).to(input)
else:
mask = torch.ones(1, 1, input.data.shape[2], input.data.shape[3], input.data.shape[4]).to(input)
else:
mask = mask_input

self.update_mask = F.conv3d(mask, self.weight_maskUpdater, bias=None, stride=self.stride, padding=self.padding, dilation=self.dilation, groups=1)

self.mask_ratio = self.slide_winsize/(self.update_mask + 1e-8)
# self.mask_ratio = torch.max(self.update_mask)/(self.update_mask + 1e-8)
self.update_mask = torch.clamp(self.update_mask, 0, 1)
self.mask_ratio = torch.mul(self.mask_ratio, self.update_mask)

# if self.update_mask.type() != input.type() or self.mask_ratio.type() != input.type():
# self.update_mask.to(input)
# self.mask_ratio.to(input)

raw_out = super(PartialConv3d, self).forward(torch.mul(input, mask_input) if mask_input is not None else input)

if self.bias is not None:
bias_view = self.bias.view(1, self.out_channels, 1, 1, 1)
output = torch.mul(raw_out - bias_view, self.mask_ratio) + bias_view
if mask_input is not None:
output = torch.mul(output, self.update_mask)
else:
output = torch.mul(raw_out, self.mask_ratio)


if self.return_mask:
return output, self.update_mask
else:
return output

0 comments on commit 5318f86

Please sign in to comment.