
*This notebook will guide through a toy example, explaing the forward and the backward propagation through the rolling convolution filters.*






# Rolling Convolution Filters : **Back Propagation Using PyTorch**



In [1]:
# restarting runtime

import torch
import torch.nn as nn
import torch.nn.functional as F 

# define a sample input x
x = torch.tensor([[[1,2,3],[4,5,6],[7,8,9]],[[9,-8,7],[9,8,7],[3,-2,-1]]]).unsqueeze(0)
x = x.float()
x.requires_grad = True
print("x=\n\n", x)
print("\n")
print("Shape of x : ",x.shape)
print('\n\n')
# define a sample weight w
w = nn.Parameter(torch.tensor([[[1,-1], [-1,1]],[[1,2], [-1,-2]]]).unsqueeze(0).float())
print("w=\n\n",w)
print("\n")
print("Shape of w : ",w.shape)


x=

 tensor([[[[ 1.,  2.,  3.],
          [ 4.,  5.,  6.],
          [ 7.,  8.,  9.]],

         [[ 9., -8.,  7.],
          [ 9.,  8.,  7.],
          [ 3., -2., -1.]]]], requires_grad=True)


Shape of x :  torch.Size([1, 2, 3, 3])



w=

 Parameter containing:
tensor([[[[ 1., -1.],
          [-1.,  1.]],

         [[ 1.,  2.],
          [-1., -2.]]]], requires_grad=True)


Shape of w :  torch.Size([1, 2, 2, 2])


In [2]:
# define the rolling convolution

def roll(input, weight, stride = 1, padding=0):   
    w = weight
    s = stride
    p = padding
    rolls = w.size()[1]    
    for i in range(1,rolls+1):
      if i==1:
        filter = w
      else:
        w = torch.roll(w, shifts = 1,  dims=1)
        filter = torch.cat((filter, w), dim=0)
    return F.relu(F.conv2d(input,filter,None,s,p))

# perform the rolling convolution operation

out1 = roll(x,w)
print("out1=\n\n",out1)
print("\n")
print("Shape of out1 : ",out1.shape)

# get a scalar output by using mean operator

print("\n")
out2 = out1.mean()
print("out2=\n\n",out2)
print("\n")

# let y = 10 be the groundtruth, assume mse loss

y =10

# compute loss and then perform back propagation
((y - out2)**2).backward()

# print gradients wrt w
print("Grad wrt w :\n\n ",w.grad.data)

print("\n")

# print gradients wrt x
print("Grad wrt x :\n\n ",x.grad.data)


out1=

 tensor([[[[ 0.,  0.],
          [26., 26.]],

         [[ 7.,  0.],
          [ 0.,  0.]]]], grad_fn=<ReluBackward0>)


Shape of out1 :  torch.Size([1, 2, 2, 2])


out2=

 tensor(7.3750, grad_fn=<MeanBackward0>)


Grad wrt w :

  tensor([[[[-11.8125,  -1.9688],
          [-15.7500, -16.4062]],

         [[-11.8125, -11.1562],
          [ -3.2812,  -1.3125]]]])


Grad wrt x :

  tensor([[[[-0.6562, -1.3125,  0.0000],
          [ 0.0000,  1.3125,  0.6562],
          [ 0.6562,  0.0000, -0.6562]],

         [[-0.6562,  0.6562,  0.0000],
          [ 0.0000, -2.6250, -1.3125],
          [ 0.6562,  1.9688,  1.3125]]]])


# Rolling Convolution Filters : **User Defined Back Propagation**



In [13]:
import torch
import torch.nn as nn
import torch.nn.functional as F 

# define a sample input x
x = torch.tensor([[[1,2,3],[4,5,6],[7,8,9]],[[9,-8,7],[9,8,7],[3,-2,-1]]]).unsqueeze(0)
x = x.float()
x.requires_grad = True
print("x=\n\n", x)
print("\n")
print("Shape of x : ",x.shape)
print('\n\n')
# define a sample weight w
w = nn.Parameter(torch.tensor([[[1,-1], [-1,1]],[[1,2], [-1,-2]]]).unsqueeze(0).float())
print("w=\n\n",w)
print("\n")
print("Shape of w : ",w.shape)


x=

 tensor([[[[ 1.,  2.,  3.],
          [ 4.,  5.,  6.],
          [ 7.,  8.,  9.]],

         [[ 9., -8.,  7.],
          [ 9.,  8.,  7.],
          [ 3., -2., -1.]]]], requires_grad=True)


Shape of x :  torch.Size([1, 2, 3, 3])



w=

 Parameter containing:
tensor([[[[ 1., -1.],
          [-1.,  1.]],

         [[ 1.,  2.],
          [-1., -2.]]]], requires_grad=True)


Shape of w :  torch.Size([1, 2, 2, 2])


In [14]:
# define the rolling convolution

def roll(input, weight, stride = 1, padding=0):   
    w = weight
    s = stride
    p = padding
    rolls = w.size()[1]    
    for i in range(1,rolls+1):
      if i==1:
        filter = w
      else:
        w = torch.roll(w, shifts = 1,  dims=1)
        filter = torch.cat((filter, w), dim=0)
    return F.relu(F.conv2d(input,filter,None,s,p))

# perform the rolling convolution operation

out1 = roll(x,w)
print("out1=\n\n",out1)
print("\n")
print("Shape of out1 : ",out1.shape)

# get a scalar output by using mean operator

print("\n")
out2 = out1.mean()
print("out2=\n\n",out2)
print("\n")

# let y = 10 be the groundtruth, assume mse loss

y =10

# grad wrt out2

grad_out2 = -1 * 2 * (y - out2)
#print(grad_out2)

temp = (grad_out2/8).item()
#print(temp)

# grad wrt out1

grad_out1 = torch.tensor([[[[temp, temp], [temp,temp]], [[temp,temp],[temp, temp]]]])
# print(grad_out1.shape)

# grad after crossing relu 

grad_relu = ((out1>0) * grad_out1).float()

# grad wrt W (channel 1)

# grad_w1 = x_channel_1 * grad_y1 + x_channel_2 * grad_y2 (note that x_channel_2 can be seen as the first channel of phi(x), where phi is the channel roll operator)

a = nn.functional.conv2d((x[0,0,:,:].unsqueeze(0)).unsqueeze(0),(grad_relu[0,0,:,:].unsqueeze(0)).unsqueeze(0),None,1,0)
b = nn.functional.conv2d((x[0,1,:,:].unsqueeze(0)).unsqueeze(0),(grad_relu[0,1,:,:].unsqueeze(0)).unsqueeze(0),None,1,0)

grad_w1 = a + b

# grad wrt W (channel 2)

# grad_w2 = x_channel_2 * grad_y1 + x_channel_1 * grad_y2 (note that x_channel_1 can be seen as the second channel of phi(x), where phi is the channel roll operator)

a = nn.functional.conv2d((x[0,1,:,:].unsqueeze(0)).unsqueeze(0),(grad_relu[0,0,:,:].unsqueeze(0)).unsqueeze(0),None,1,0)
b = nn.functional.conv2d((x[0,0,:,:].unsqueeze(0)).unsqueeze(0),(grad_relu[0,1,:,:].unsqueeze(0)).unsqueeze(0),None,1,0)

grad_w2 = a + b

print("Grad wrt w :\n\n ", torch.cat((grad_w1.data, grad_w2.data),dim=0))

print("\n")

# grad wrt x (channel 1)

# grad_x1 = flip(w_channel_1) * grad_y1 + flip(w_channel_2) * grad_y2 (note that w_channel_2 can be seen as the first channel of phi(w), phi is the channel roll operator)

a = nn.functional.conv2d((grad_relu[0,0,:,:].unsqueeze(0)).unsqueeze(0), (torch.flip(w[0,0,:,:],[0,1]).unsqueeze(0)).unsqueeze(0),None,1,2)
b = nn.functional.conv2d((grad_relu[0,1,:,:].unsqueeze(0)).unsqueeze(0), (torch.flip(w[0,1,:,:],[0,1]).unsqueeze(0)).unsqueeze(0),None,1,2)

grad_x1 =  a + b 
grad_x1 = grad_x1[:,:,1:4,1:4]

# grad wrt x (channel 2)

# grad_x2 = flip(w_channel_2) * grad_y1 + flip(w_channel_1) * grad_y2 (note that w_channel_1 can be seen as the second channel of phi(w), phi is the channel roll operator)

a = nn.functional.conv2d((grad_relu[0,0,:,:].unsqueeze(0)).unsqueeze(0), (torch.flip(w[0,1,:,:],[0,1]).unsqueeze(0)).unsqueeze(0),None,1,2)
b = nn.functional.conv2d((grad_relu[0,1,:,:].unsqueeze(0)).unsqueeze(0), (torch.flip(w[0,0,:,:],[0,1]).unsqueeze(0)).unsqueeze(0),None,1,2)

grad_x2 =  a + b  
grad_x2 = grad_x2[:,:,1:4,1:4]

print("Grad wrt x :\n\n ", torch.cat((grad_x1.data, grad_x2.data),dim=0))

print("\n")

out1=

 tensor([[[[ 0.,  0.],
          [26., 26.]],

         [[ 7.,  0.],
          [ 0.,  0.]]]], grad_fn=<ReluBackward0>)


Shape of out1 :  torch.Size([1, 2, 2, 2])


out2=

 tensor(7.3750, grad_fn=<MeanBackward0>)


Grad wrt w :

  tensor([[[[-11.8125,  -1.9688],
          [-15.7500, -16.4062]]],


        [[[-11.8125, -11.1562],
          [ -3.2812,  -1.3125]]]])


Grad wrt x :

  tensor([[[[-0.6562, -1.3125,  0.0000],
          [ 0.0000,  1.3125,  0.6562],
          [ 0.6562,  0.0000, -0.6562]]],


        [[[-0.6562,  0.6562,  0.0000],
          [ 0.0000, -2.6250, -1.3125],
          [ 0.6562,  1.9688,  1.3125]]]])


