Skip to content

Commit

Permalink
initial attempt for pytorch 1.4 compatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
suhangpro committed Mar 4, 2020
1 parent 12d52b6 commit 7ca0a3f
Showing 1 changed file with 14 additions and 42 deletions.
56 changes: 14 additions & 42 deletions pac.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from torch.autograd.function import Function, once_differentiable
from torch.nn.parameter import Parameter
from torch.nn.modules.utils import _pair
from torch._thnn import type2backend

try:
import pyinn as P
Expand Down Expand Up @@ -106,7 +105,6 @@ def forward(ctx, input, kernel_size, stride, padding, dilation, channel_wise):
if not channel_wise:
diff_sq = diff_sq.sum(dim=1, keepdim=True)
output = torch.exp(-0.5 * diff_sq)
ctx._backend = type2backend[input.type()]
ctx.save_for_backward(input, output)

return output
Expand All @@ -126,15 +124,9 @@ def backward(ctx, grad_output):
grad_diff = grad.expand_as(cols) * (2 * diff)
grad_diff[:, :, center_y:center_y + 1, center_x:center_x + 1, :, :] -= \
grad_diff.sum(dim=2, keepdim=True).sum(dim=3, keepdim=True)
grad_input = grad_output.new()
ctx._backend.Im2Col_updateGradInput(ctx._backend.library_state,
grad_diff.view(bs, ch * ctx.kernel_size[0] * ctx.kernel_size[1], -1),
grad_input,
in_h, in_w,
ctx.kernel_size[0], ctx.kernel_size[1],
ctx.dilation[0], ctx.dilation[1],
ctx.padding[0], ctx.padding[1],
ctx.stride[0], ctx.stride[1])

grad_input = F.fold(grad_diff.view(bs, ch * ctx.kernel_size[0] * ctx.kernel_size[1], -1),
(in_h, in_w), ctx.kernel_size, ctx.dilation, ctx.padding, ctx.stride)

return grad_input, None, None, None, None, None

Expand All @@ -155,7 +147,6 @@ def forward(ctx, input, kernel, weight, bias=None, stride=1, padding=0, dilation
ctx.save_for_backward(input if (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]) else None,
kernel if (ctx.needs_input_grad[0] or ctx.needs_input_grad[2]) else None,
weight if (ctx.needs_input_grad[0] or ctx.needs_input_grad[1]) else None)
ctx._backend = type2backend[input.type()]

cols = F.unfold(input, ctx.kernel_size, ctx.dilation, ctx.padding, ctx.stride)

Expand Down Expand Up @@ -190,17 +181,11 @@ def backward(ctx, grad_output):
in_cols = F.unfold(input, ctx.kernel_size, ctx.dilation, ctx.padding, ctx.stride)
in_cols = in_cols.view(bs, in_ch, ctx.kernel_size[0], ctx.kernel_size[1], out_sz[0], out_sz[1])
if ctx.needs_input_grad[0]:
grad_input = grad_output.new()
grad_im2col_output = grad_in_mul_k * kernel
grad_im2col_output = grad_im2col_output.view(bs, -1, out_sz[0] * out_sz[1])
ctx._backend.Im2Col_updateGradInput(ctx._backend.library_state,
grad_im2col_output,
grad_input,
ctx.input_size[0], ctx.input_size[1],
ctx.kernel_size[0], ctx.kernel_size[1],
ctx.dilation[0], ctx.dilation[1],
ctx.padding[0], ctx.padding[1],
ctx.stride[0], ctx.stride[1])

grad_input = F.fold(grad_im2col_output,
ctx.input_size[:2], ctx.kernel_size, ctx.dilation, ctx.padding, ctx.stride)
if ctx.needs_input_grad[1]:
grad_kernel = in_cols * grad_in_mul_k
grad_kernel = grad_kernel.sum(dim=1, keepdim=True)
Expand Down Expand Up @@ -234,7 +219,6 @@ def forward(ctx, input, kernel, weight, bias=None, stride=1, padding=0, output_p
ctx.save_for_backward(input if (ctx.needs_input_grad[1] or ctx.needs_input_grad[2]) else None,
kernel if (ctx.needs_input_grad[0] or ctx.needs_input_grad[2]) else None,
weight if (ctx.needs_input_grad[0] or ctx.needs_input_grad[1]) else None)
ctx._backend = type2backend[input.type()]

w = input.new_ones((ch, 1, 1, 1))
x = F.conv_transpose2d(input, w, stride=stride, groups=ch)
Expand Down Expand Up @@ -279,18 +263,12 @@ def backward(ctx, grad_output):
in_cols = F.unfold(x, ctx.kernel_size, ctx.dilation, _pair(0), _pair(1))
in_cols = in_cols.view(bs, in_ch, ctx.kernel_size[0], ctx.kernel_size[1], out_sz[0], out_sz[1])
if ctx.needs_input_grad[0]:
grad_input = grad_output.new()
grad_im2col_output = grad_in_mul_k * kernel
grad_im2col_output = grad_im2col_output.view(bs, -1, out_sz[0] * out_sz[1])
im2col_input_sz = [o + (k - 1) * d for (o, k, d) in zip(out_sz, ctx.kernel_size, ctx.dilation)]
ctx._backend.Im2Col_updateGradInput(ctx._backend.library_state,
grad_im2col_output,
grad_input,
im2col_input_sz[0], im2col_input_sz[1],
ctx.kernel_size[0], ctx.kernel_size[1],
ctx.dilation[0], ctx.dilation[1],
0, 0,
1, 1)

grad_input = F.fold(grad_im2col_output,
im2col_input_sz[:2], ctx.kernel_size, ctx.dilation, 0, 1)
grad_input = grad_input[:, :, pad[0][0]:-pad[0][1]:ctx.stride[0], pad[1][0]:-pad[1][1]:ctx.stride[1]]
if ctx.needs_input_grad[1]:
grad_kernel = in_cols * grad_in_mul_k
Expand Down Expand Up @@ -321,7 +299,6 @@ def forward(ctx, input, kernel, kernel_size, stride=1, padding=0, dilation=1):
ctx.stride = _pair(stride)
ctx.save_for_backward(input if ctx.needs_input_grad[1] else None,
kernel if ctx.needs_input_grad[0] else None)
ctx._backend = type2backend[input.type()]

cols = F.unfold(input, ctx.kernel_size, ctx.dilation, ctx.padding, ctx.stride)

Expand All @@ -337,17 +314,11 @@ def backward(ctx, grad_output):
grad_input = grad_kernel = None
(bs, ch), out_sz = grad_output.shape[:2], grad_output.shape[2:]
if ctx.needs_input_grad[0]:
grad_input = grad_output.new()
grad_im2col_output = torch.einsum('ijmn,izklmn->ijklmn', (grad_output, kernel))
grad_im2col_output = grad_im2col_output.view(bs, -1, out_sz[0] * out_sz[1])
ctx._backend.Im2Col_updateGradInput(ctx._backend.library_state,
grad_im2col_output,
grad_input,
ctx.input_size[0], ctx.input_size[1],
ctx.kernel_size[0], ctx.kernel_size[1],
ctx.dilation[0], ctx.dilation[1],
ctx.padding[0], ctx.padding[1],
ctx.stride[0], ctx.stride[1])

grad_input = F.fold(grad_im2col_output,
ctx.input_size[:2], ctx.kernel_size, ctx.dilation, ctx.padding, ctx.stride)
if ctx.needs_input_grad[1]:
cols = F.unfold(input, ctx.kernel_size, ctx.dilation, ctx.padding, ctx.stride)
cols = cols.view(bs, ch, ctx.kernel_size[0], ctx.kernel_size[1], out_sz[0], out_sz[1])
Expand Down Expand Up @@ -444,7 +415,8 @@ def packernel2d(input, mask=None, kernel_size=0, stride=1, padding=0, output_pad

if norm is not None:
empty_mask = (norm == 0)
output = output / (norm + torch.tensor(empty_mask, dtype=input.dtype, device=input.device))
# output = output / (norm + torch.tensor(empty_mask, dtype=input.dtype, device=input.device))
output = output / (norm + empty_mask.clone().detach())
output_mask = (1 - empty_mask) if output_mask else None
else:
output_mask = None
Expand Down

0 comments on commit 7ca0a3f

Please sign in to comment.