In [1]:
import torch

In [2]:
# Create a 4D tensor with dimensions: [batch_size, channels, height, width]
# batch_size=2, channels=3, height=5, width=5
input_tensor = torch.ones(2, 3, 5, 5, requires_grad=True)
input_tensor = input_tensor * torch.arange(1, input_tensor.numel()+1).reshape(input_tensor.shape)
input_tensor.retain_grad()
print("Input tensor shape:", input_tensor.shape)
print("First batch, first channel slice:")
print(input_tensor)

Input tensor shape: torch.Size([2, 3, 5, 5])
First batch, first channel slice:
tensor([[[[  1.,   2.,   3.,   4.,   5.],
          [  6.,   7.,   8.,   9.,  10.],
          [ 11.,  12.,  13.,  14.,  15.],
          [ 16.,  17.,  18.,  19.,  20.],
          [ 21.,  22.,  23.,  24.,  25.]],

         [[ 26.,  27.,  28.,  29.,  30.],
          [ 31.,  32.,  33.,  34.,  35.],
          [ 36.,  37.,  38.,  39.,  40.],
          [ 41.,  42.,  43.,  44.,  45.],
          [ 46.,  47.,  48.,  49.,  50.]],

         [[ 51.,  52.,  53.,  54.,  55.],
          [ 56.,  57.,  58.,  59.,  60.],
          [ 61.,  62.,  63.,  64.,  65.],
          [ 66.,  67.,  68.,  69.,  70.],
          [ 71.,  72.,  73.,  74.,  75.]]],


        [[[ 76.,  77.,  78.,  79.,  80.],
          [ 81.,  82.,  83.,  84.,  85.],
          [ 86.,  87.,  88.,  89.,  90.],
          [ 91.,  92.,  93.,  94.,  95.],
          [ 96.,  97.,  98.,  99., 100.]],

         [[101., 102., 103., 104., 105.],
          [106., 107., 108., 

In [3]:
# Create a convolution kernel with shape [out_channels, in_channels, kernel_height, kernel_width]
# out_channels=1, in_channels=3, kernel_size=3x3
# Create a convolution kernel with shape [out_channels, in_channels, kernel_height, kernel_width]
# out_channels=1, in_channels=3, kernel_size=3x3
conv_kernel = torch.nn.Parameter(torch.tensor([[
    [[-1.0, 0.0, 1.0], [-2.0, 0.0, 2.0], [-1.0, 0.0, 1.0]],  # Channel 1 - horizontal Sobel
    [[1.0, 2.0, 1.0], [0.0, 0.0, 0.0], [-1.0, -2.0, -1.0]],  # Channel 2 - vertical Sobel
    [[0.0, 1.0, 0.0], [1.0, -4.0, 1.0], [0.0, 1.0, 0.0]]     # Channel 3 - Laplacian
]]).float())

# Apply convolution using the existing conv_kernel
# The kernel has 1 output channels, 3 input channels, and 3x3 spatial dimensions
# Create a convolution layer
conv = torch.nn.Conv2d(in_channels=3, out_channels=1, kernel_size=3, bias=False)

# Set the weights to our existing kernel
conv.weight = conv_kernel

print(conv)

Conv2d(3, 1, kernel_size=(3, 3), stride=(1, 1), bias=False)


In [4]:
conv.weight.shape

torch.Size([1, 3, 3, 3])

In [5]:
# Apply the convolution to the input tensor
output = conv(input_tensor)

print("Convolution output shape:", output.shape)
print("First batch, first channel of the output:")
print(output)

Convolution output shape: torch.Size([2, 1, 3, 3])
First batch, first channel of the output:
tensor([[[[-32., -32., -32.],
          [-32., -32., -32.],
          [-32., -32., -32.]]],


        [[[-32., -32., -32.],
          [-32., -32., -32.],
          [-32., -32., -32.]]]], grad_fn=<ConvolutionBackward0>)


In [6]:
grad_tensor = torch.ones_like(output)
output.backward(grad_tensor)

In [7]:
conv.weight.grad

tensor([[[[ 801.,  819.,  837.],
          [ 891.,  909.,  927.],
          [ 981.,  999., 1017.]],

         [[1251., 1269., 1287.],
          [1341., 1359., 1377.],
          [1431., 1449., 1467.]],

         [[1701., 1719., 1737.],
          [1791., 1809., 1827.],
          [1881., 1899., 1917.]]]])

In [8]:
input_tensor.grad

tensor([[[[-1., -1.,  0.,  1.,  1.],
          [-3., -3.,  0.,  3.,  3.],
          [-4., -4.,  0.,  4.,  4.],
          [-3., -3.,  0.,  3.,  3.],
          [-1., -1.,  0.,  1.,  1.]],

         [[ 1.,  3.,  4.,  3.,  1.],
          [ 1.,  3.,  4.,  3.,  1.],
          [ 0.,  0.,  0.,  0.,  0.],
          [-1., -3., -4., -3., -1.],
          [-1., -3., -4., -3., -1.]],

         [[ 0.,  1.,  1.,  1.,  0.],
          [ 1., -2., -1., -2.,  1.],
          [ 1., -1.,  0., -1.,  1.],
          [ 1., -2., -1., -2.,  1.],
          [ 0.,  1.,  1.,  1.,  0.]]],


        [[[-1., -1.,  0.,  1.,  1.],
          [-3., -3.,  0.,  3.,  3.],
          [-4., -4.,  0.,  4.,  4.],
          [-3., -3.,  0.,  3.,  3.],
          [-1., -1.,  0.,  1.,  1.]],

         [[ 1.,  3.,  4.,  3.,  1.],
          [ 1.,  3.,  4.,  3.,  1.],
          [ 0.,  0.,  0.,  0.,  0.],
          [-1., -3., -4., -3., -1.],
          [-1., -3., -4., -3., -1.]],

         [[ 0.,  1.,  1.,  1.,  0.],
          [ 1., -2., -1., 