## Section 4

### Extensions with numpy and scipy: parameterized example

In [2]:
import torch
import numpy as np
import scipy

from matplotlib import pyplot as plt
%matplotlib inline

In [3]:
from numpy import flip
from scipy.signal import convolve2d, correlate2d
from torch.nn.modules.module import Module
from torch.nn.parameter import Parameter

Now let's create the custom layer with parameters, for which the gradients could be retained. 

Special thanks to `Adam Paszke <https://github.com/apaszke>` for reasonable example used in this notebook.

In [4]:
class ScipyConv2dFunction(torch.autograd.Function):
    @staticmethod #
    def forward(ctx, my_input, my_filter, bias):
        raise NotImplementedError
    
    @staticmethod
    def backward(ctx, grad_output):
        raise NotImplementedError

In [None]:
class ScipyConv2dFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, my_input, my_filter, bias):
        # detach so we can cast to NumPy
        my_input, my_filter, bias = my_input.detach(), my_filter.detach(), bias.detach()
        result = correlate2d(my_input.numpy(), my_filter.numpy(), mode='valid')
        result += bias.numpy()
        ctx.save_for_backward(my_input, my_filter, bias)
        return torch.as_tensor(result, dtype=my_input.dtype)

    @staticmethod
    def backward(ctx, grad_output):
        raise NotImplementedError

In [5]:
class ScipyConv2dFunction(torch.autograd.Function):
    @staticmethod
    def forward(ctx, my_input, my_filter, bias):
        # detach so we can cast to NumPy
        my_input, my_filter, bias = my_input.detach(), my_filter.detach(), bias.detach()
        result = correlate2d(my_input.numpy(), my_filter.numpy(), mode='valid')
        result += bias.numpy()
        ctx.save_for_backward(my_input, my_filter, bias)
        return torch.as_tensor(result, dtype=my_input.dtype)

    @staticmethod
    def backward(ctx, grad_output):
        grad_output = grad_output.detach()
        my_input, my_filter, bias = ctx.saved_tensors
        grad_output = grad_output.numpy()
        grad_bias = np.sum(grad_output, keepdims=True)
        grad_my_input = convolve2d(grad_output, my_filter.numpy(), mode='full')
        # the previous line can be expressed equivalently as:
        # grad_my_input = correlate2d(grad_output, flip(flip(my_filter.numpy(), axis=0), axis=1), mode='full')
        grad_my_filter = correlate2d(my_input.numpy(), grad_output, mode='valid')
        return torch.from_numpy(grad_my_input), 
               torch.from_numpy(grad_my_filter).to(torch.float), 
               torch.from_numpy(grad_bias).to(torch.float)

In [7]:
class ScipyConv2d(Module):
    def __init__(self, my_filter_width, my_filter_height):
        super(ScipyConv2d, self).__init__()
        self.my_filter = Parameter(torch.randn(my_filter_width, my_filter_height))
        self.bias = Parameter(torch.randn(1, 1))

    def forward(self, my_input):
        return ScipyConv2dFunction.apply(my_input, self.my_filter, self.bias)

**Check the gradients:**



In [10]:
from torch.autograd.gradcheck import gradcheck

moduleConv = ScipyConv2d(3, 3)

my_input = [torch.randn(20, 20, dtype=torch.double, requires_grad=True)]
test = gradcheck(moduleConv, my_input, eps=1e-6, atol=1e-4)
print("Are the gradients correct: ", test)

Are the gradients correct:  True
