In [1]:
import torch
from torch import nn

In [90]:
class PeriodicPadding2D(nn.Module):
    def __init__(self, pad_width, **kwargs):
        super().__init__(**kwargs)
        self.pad_width = pad_width

    def forward(self, inputs, **kwargs):
        if self.pad_width == 0:
            return inputs
        inputs_padded = torch.cat(
            (
                inputs[:, :, :, -self.pad_width :],
                inputs,
                inputs[:, :, :, : self.pad_width],
            ),
            dim=-1,
        )
        # Zero padding in the lat direction
        inputs_padded = nn.functional.pad(
            inputs_padded, (0, 0, self.pad_width, self.pad_width)
        )
        return inputs_padded

In [91]:
inputs=torch.arange(2*2).reshape(1,1,2,2).float()

In [92]:
print(inputs)

tensor([[[[0., 1.],
          [2., 3.]]]])


In [94]:
test = PeriodicPadding2D(1)
print(test(inputs))

tensor([[[[0., 0., 0., 0.],
          [1., 0., 1., 0.],
          [3., 2., 3., 2.],
          [0., 0., 0., 0.]]]])


In [62]:
m = nn.Conv2d(1, 1, 2, padding=1, stride=1, dilation=2)

In [63]:
print(m)

Conv2d(1, 1, kernel_size=(2, 2), stride=(1, 1), padding=(1, 1), dilation=(2, 2))


In [64]:
m(inputs)

tensor([[[[ 0.5404, -0.3138],
          [-0.0263, -0.2362]]]], grad_fn=<ConvolutionBackward0>)

In [65]:
print(m.weight)

Parameter containing:
tensor([[[[-0.3872,  0.2099],
          [-0.0388,  0.2589]]]], requires_grad=True)


In [66]:
print(m.weight.shape)

torch.Size([1, 1, 2, 2])


In [67]:
print(m.bias)

Parameter containing:
tensor([-0.2362], requires_grad=True)


In [61]:
(inputs*m.weight).sum()+m.bias

tensor([-1.1350], grad_fn=<AddBackward0>)

In [78]:
mg = nn.Conv2d(2, 2, 2, padding=1, stride=1, dilation=1, groups=1)
inputs=torch.arange(2*2*2).reshape(1,2,2,2).float()
print(inputs)

tensor([[[[0., 1.],
          [2., 3.]],

         [[4., 5.],
          [6., 7.]]]])


In [79]:
print("shape: ", mg.weight.shape)
print("values: ", mg.weight)

shape:  torch.Size([2, 2, 2, 2])
values:  Parameter containing:
tensor([[[[-0.0515,  0.2466],
          [-0.1008, -0.1423]],

         [[ 0.3036,  0.3058],
          [ 0.0263, -0.2055]]],


        [[[-0.1418, -0.2214],
          [ 0.0306, -0.1865]],

         [[ 0.1498,  0.0009],
          [ 0.1374, -0.3461]]]], requires_grad=True)
