## Section 4

### Extensions with numpy and scipy: parameter-less example

In [1]:
import torch
import numpy as np
import scipy

from matplotlib import pyplot as plt
%matplotlib inline

Let's create dummy function that does nothing but mutlipying the passed argument by $2.5$.

First, we need to create a new class for our function. It should be based on `torch.autograd.Function` class.

In [18]:
class NoParameterNumpyFunction(torch.autograd.Function):
    def forward(self, my_input):
        raise NotImplementedError
    
    def backward(self, grad_output):
        raise NotImplementedError

It should posess `forward` and `backward` methods.

Both methods should take as input and return `torch.Tensor` type variables.

Let's implement the __forward function__:

In [19]:
class NoParameterNumpyFunction(torch.autograd.Function):
    def forward(self, my_input):
        numpy_input = my_input.detach().numpy()
        return my_input.new(numpy_input * 2.5)
        
    def backward(self, grad_output):
        raise NotImplementedError
        

To work with numpy smoothly the tensor should be detached and translated to `numpy.ndarray` type.

Now let's define the __backward function__.

In [20]:
class NoParameterNumpyFunction(torch.autograd.Function):
    def forward(self, my_input):
        numpy_input = my_input.detach().numpy()
        return my_input.new(numpy_input * 2.5)
    
    def backward(self, grad_output):
        numpy_go = grad_output.numpy()
        return grad_output.new(2.5 * numpy_go)

Backward pass provides gradient values, so no need of detaching them.

Finally, let's define the function that applies our `NoParameterNumpyFunction` to the input data and check it.

In [23]:
def my_function(my_input):
    return NoParameterNumpyFunction()(my_input)

In [24]:
my_input = torch.randn(5, 5, requires_grad=True)
result = my_function(my_input)
print('Result: \n', result)
result.backward(torch.randn(result.size()))
print('Input: \n', my_input)

Result: 
 tensor([[ 1.5014,  2.0003, -1.9374, -2.5834, -2.2760],
        [-4.2454, -1.9620,  2.7827,  2.4674, -1.7583],
        [-3.3536,  3.2906,  0.1479, -4.6234, -0.8852],
        [ 4.0107,  0.1398,  3.8121,  1.5927, -1.5370],
        [-1.0974, -1.5101,  2.7156,  1.2644, -2.6461]],
       grad_fn=<NoParameterNumpyFunction>)
Input: 
 tensor([[ 0.6006,  0.8001, -0.7750, -1.0334, -0.9104],
        [-1.6982, -0.7848,  1.1131,  0.9869, -0.7033],
        [-1.3415,  1.3162,  0.0591, -1.8494, -0.3541],
        [ 1.6043,  0.0559,  1.5249,  0.6371, -0.6148],
        [-0.4390, -0.6040,  1.0863,  0.5058, -1.0584]], requires_grad=True)


Same example with 'incorrect FFT' written by `Adam Paszke <https://github.com/apaszke>` is available below for your consideration.

In [25]:
from numpy.fft import rfft2, irfft2


class BadFFTFunction(torch.autograd.Function):

    def forward(self, my_input):
        numpy_input = my_input.detach().numpy()
        result = abs(rfft2(numpy_input))
        return my_input.new(result)

    def backward(self, grad_output):
        numpy_go = grad_output.numpy()
        result = irfft2(numpy_go)
        return grad_output.new(result)

# since this layer does not have any parameters, we can
# simply declare this as a function, rather than as an nn.Module class


def incorrect_fft(my_input):
    return BadFFTFunction()(my_input)

input = torch.randn(8, 8, requires_grad=True)
result = incorrect_fft(input)
print(result)
result.backward(torch.randn(result.size()))
print(input)

tensor([[ 2.4602,  4.0459,  9.1712, 11.2801,  2.8586],
        [ 9.1331, 10.2702,  2.3852,  9.6039,  6.9081],
        [ 8.9842,  5.9188,  3.4215,  1.7207,  5.4507],
        [ 6.7364,  3.4633, 17.2621,  8.7835,  4.2332],
        [11.9287, 12.4571,  9.6719,  7.2652,  0.3746],
        [ 6.7364,  7.7286,  6.8314,  7.5754,  4.2332],
        [ 8.9842,  4.8460,  4.1552,  4.7081,  5.4507],
        [ 9.1331,  6.1811, 11.5793,  3.3977,  6.9081]],
       grad_fn=<BadFFTFunction>)
tensor([[ 1.3022,  0.1690,  0.9718, -0.2892,  0.1146,  0.8799,  0.4793, -0.0734],
        [-0.7867, -0.9935, -0.3931,  0.9037,  1.3205, -1.4621, -0.6775, -0.9220],
        [-0.4453, -1.0873, -0.2218,  0.1006, -0.3520, -0.2610,  0.0036, -1.5339],
        [ 1.2376, -0.1043, -1.9211,  1.0928,  0.4114,  0.2686, -1.0042, -2.1474],
        [ 0.8602,  1.1203,  0.9459,  0.5719, -0.4241, -1.8078,  0.3367,  1.0221],
        [ 1.7610,  0.5820, -1.1527, -0.4963,  0.3826,  2.6382,  0.0099, -1.9355],
        [ 1.0803,  0.7274, -0.1693