In [83]:
import torch
from torch import nn
import numpy as np

## Grouped Transposed Conv 2d

In [84]:
conv = nn.ConvTranspose2d(2, 4, (3,3), groups=2)
x = torch.from_numpy(np.arange(2*3*3, dtype=np.float32).reshape(1, 2, 3, 3)).to(torch.float32)
x

tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.],
          [ 6.,  7.,  8.]],

         [[ 9., 10., 11.],
          [12., 13., 14.],
          [15., 16., 17.]]]])

In [85]:
x.shape

torch.Size([1, 2, 3, 3])

In [92]:
for par in conv.parameters():
    par.requires_grad = False

In [93]:
with torch.no_grad():
    conv.weight.fill_(1.)

In [94]:
conv.weight.shape

torch.Size([2, 2, 3, 3])

In [95]:
conv.weight

Parameter containing:
tensor([[[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]],


        [[[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]],

         [[1., 1., 1.],
          [1., 1., 1.],
          [1., 1., 1.]]]])

In [96]:
#conv.
y = conv(x)
y.shape

torch.Size([1, 4, 5, 5])

In [116]:
y

tensor([[[[-4.8214e-02,  9.5179e-01,  2.9518e+00,  2.9518e+00,  1.9518e+00],
          [ 2.9518e+00,  7.9518e+00,  1.4952e+01,  1.1952e+01,  6.9518e+00],
          [ 8.9518e+00,  2.0952e+01,  3.5952e+01,  2.6952e+01,  1.4952e+01],
          [ 8.9518e+00,  1.9952e+01,  3.2952e+01,  2.3952e+01,  1.2952e+01],
          [ 5.9518e+00,  1.2952e+01,  2.0952e+01,  1.4952e+01,  7.9518e+00]],

         [[-2.3228e-01,  7.6772e-01,  2.7677e+00,  2.7677e+00,  1.7677e+00],
          [ 2.7677e+00,  7.7677e+00,  1.4768e+01,  1.1768e+01,  6.7677e+00],
          [ 8.7677e+00,  2.0768e+01,  3.5768e+01,  2.6768e+01,  1.4768e+01],
          [ 8.7677e+00,  1.9768e+01,  3.2768e+01,  2.3768e+01,  1.2768e+01],
          [ 5.7677e+00,  1.2768e+01,  2.0768e+01,  1.4768e+01,  7.7677e+00]],

         [[ 9.1800e+00,  1.9180e+01,  3.0180e+01,  2.1180e+01,  1.1180e+01],
          [ 2.1180e+01,  4.4180e+01,  6.9180e+01,  4.8180e+01,  2.5180e+01],
          [ 3.6180e+01,  7.5180e+01,  1.1718e+02,  8.1180e+01,  4.2180e+

## Transposed conv 3d

In [163]:
net = nn.ConvTranspose3d(1, 2, (2, 2, 2), bias=False)
x = torch.from_numpy(np.arange(2*2*2, dtype=np.float32).reshape(1, 1, 2, 2, 2)).to(torch.float32)
x

tensor([[[[[0., 1.],
           [2., 3.]],

          [[4., 5.],
           [6., 7.]]]]])

In [164]:
for par in net.parameters():
    par.requires_grad = False
with torch.no_grad():
    net.weight.fill_(1.)

In [165]:
net.weight.shape

torch.Size([1, 2, 2, 2, 2])

In [166]:
y = net(x)
y.shape

torch.Size([1, 2, 3, 3, 3])

In [167]:
y

tensor([[[[[ 0.,  1.,  1.],
           [ 2.,  6.,  4.],
           [ 2.,  5.,  3.]],

          [[ 4., 10.,  6.],
           [12., 28., 16.],
           [ 8., 18., 10.]],

          [[ 4.,  9.,  5.],
           [10., 22., 12.],
           [ 6., 13.,  7.]]],


         [[[ 0.,  1.,  1.],
           [ 2.,  6.,  4.],
           [ 2.,  5.,  3.]],

          [[ 4., 10.,  6.],
           [12., 28., 16.],
           [ 8., 18., 10.]],

          [[ 4.,  9.,  5.],
           [10., 22., 12.],
           [ 6., 13.,  7.]]]]])

In [168]:
y.flatten()

tensor([ 0.,  1.,  1.,  2.,  6.,  4.,  2.,  5.,  3.,  4., 10.,  6., 12., 28.,
        16.,  8., 18., 10.,  4.,  9.,  5., 10., 22., 12.,  6., 13.,  7.,  0.,
         1.,  1.,  2.,  6.,  4.,  2.,  5.,  3.,  4., 10.,  6., 12., 28., 16.,
         8., 18., 10.,  4.,  9.,  5., 10., 22., 12.,  6., 13.,  7.])

In [154]:
y.dtype

torch.float32

In [158]:
y.round()

tensor([[[[[-0.,  1.,  1.],
           [ 2.,  6.,  4.],
           [ 2.,  5.,  3.]],

          [[ 4., 10.,  6.],
           [12., 28., 16.],
           [ 8., 18., 10.]],

          [[ 4.,  9.,  5.],
           [10., 22., 12.],
           [ 6., 13.,  7.]]],


         [[[-0.,  1.,  1.],
           [ 2.,  6.,  4.],
           [ 2.,  5.,  3.]],

          [[ 4., 10.,  6.],
           [12., 28., 16.],
           [ 8., 18., 10.]],

          [[ 4.,  9.,  5.],
           [10., 22., 12.],
           [ 6., 13.,  7.]]]]])

In [161]:
a = x + 0.13
a

tensor([[[[[0.1300, 1.1300],
           [2.1300, 3.1300]],

          [[4.1300, 5.1300],
           [6.1300, 7.1300]]]]])

In [162]:
net(x)

tensor([[[[[-0.1201,  0.8799,  0.8799],
           [ 1.8799,  5.8799,  3.8799],
           [ 1.8799,  4.8799,  2.8799]],

          [[ 3.8799,  9.8799,  5.8799],
           [11.8799, 27.8799, 15.8799],
           [ 7.8799, 17.8799,  9.8799]],

          [[ 3.8799,  8.8799,  4.8799],
           [ 9.8799, 21.8799, 11.8799],
           [ 5.8799, 12.8799,  6.8799]]],


         [[[-0.0686,  0.9314,  0.9314],
           [ 1.9314,  5.9314,  3.9314],
           [ 1.9314,  4.9314,  2.9314]],

          [[ 3.9314,  9.9314,  5.9314],
           [11.9314, 27.9314, 15.9314],
           [ 7.9314, 17.9314,  9.9314]],

          [[ 3.9314,  8.9314,  4.9314],
           [ 9.9314, 21.9314, 11.9314],
           [ 5.9314, 12.9314,  6.9314]]]]])