In [1]:
import torch
import torch.nn as nn

In [2]:
x = torch.randn(3, 4096, 33, 33, dtype=torch.float)

In [3]:
torch.backends.cudnn.deterministic = True
torch.manual_seed(999)

conv1 = nn.Sequential(nn.Conv2d(4096, 32, kernel_size=3, stride=1, padding=1),
                      nn.BatchNorm2d(32),
                      nn.ReLU(),
                      nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
                      nn.BatchNorm2d(32),
                      nn.ReLU())


In [4]:
conv1(x)

tensor([[[[2.3408e-01, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.4055e-01,
           0.0000e+00, 3.1740e-01],
          [6.0529e-02, 0.0000e+00, 1.0834e+00,  ..., 2.5167e-01,
           0.0000e+00, 4.1058e-02],
          ...,
          [4.7487e-02, 0.0000e+00, 0.0000e+00,  ..., 1.7381e-01,
           1.5450e-01, 0.0000e+00],
          [2.6681e-01, 1.8886e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 1.8858e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 1.2271e-02]],

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.3624e-01,
           0.0000e+00, 0.0000e+00],
          [1.9470e-02, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 9.0795e-02, 3.4075e-02,  ..., 4.6195e-02,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 3.1931e-01, 2.3666e-01,  ..., 2.3635

In [5]:
torch.backends.cudnn.deterministic = True
torch.manual_seed(999)

conv2list = nn.ModuleList([nn.Conv2d(4096, 32, kernel_size=3, stride=1, padding=1),
                      nn.BatchNorm2d(32),
                      nn.ReLU(),
                      nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
                      nn.BatchNorm2d(32),
                      nn.ReLU()])

In [6]:
conv2list

ModuleList(
  (0): Conv2d(4096, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (5): ReLU()
)

In [7]:
conv2 = nn.Sequential(*conv2list)

In [10]:
conv2(x)

tensor([[[[2.3408e-01, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.4055e-01,
           0.0000e+00, 3.1740e-01],
          [6.0529e-02, 0.0000e+00, 1.0834e+00,  ..., 2.5167e-01,
           0.0000e+00, 4.1058e-02],
          ...,
          [4.7487e-02, 0.0000e+00, 0.0000e+00,  ..., 1.7381e-01,
           1.5450e-01, 0.0000e+00],
          [2.6681e-01, 1.8886e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 1.8858e-01, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 1.2271e-02]],

         [[0.0000e+00, 0.0000e+00, 0.0000e+00,  ..., 1.3624e-01,
           0.0000e+00, 0.0000e+00],
          [1.9470e-02, 0.0000e+00, 0.0000e+00,  ..., 0.0000e+00,
           0.0000e+00, 0.0000e+00],
          [0.0000e+00, 9.0795e-02, 3.4075e-02,  ..., 4.6195e-02,
           0.0000e+00, 0.0000e+00],
          ...,
          [0.0000e+00, 3.1931e-01, 2.3666e-01,  ..., 2.3635

In [11]:
p = conv2(x)

In [35]:
p.shape

torch.Size([3, 32, 33, 33])

In [12]:
N, _, _ , _ = p.size()

In [14]:
p.view(N, -1).shape

torch.Size([3, 34848])

In [15]:
p2 = torch.flatten(p, 1)

In [16]:
p2.shape

torch.Size([3, 34848])

In [30]:
l = nn.Linear(32*33*33,1)

In [33]:
l2 = nn.Conv2d(32, 32, kernel_size=1)

In [32]:
l(p2).shape

torch.Size([3, 1])

In [36]:
l2(p).shape

torch.Size([3, 32, 33, 33])

In [37]:
l2(p)

tensor([[[[ 2.1639e-01, -1.7975e-02,  1.5701e-02,  ...,  1.2980e-01,
            2.9597e-01, -5.0215e-02],
          [ 1.7488e-01,  9.6976e-02, -2.0132e-01,  ..., -3.7243e-02,
            1.1176e-01, -4.7357e-02],
          [ 1.1665e-01,  7.1419e-04, -2.1494e-01,  ...,  1.7126e-02,
           -3.1513e-01,  2.0884e-01],
          ...,
          [ 7.5671e-02, -2.2924e-01, -5.4103e-02,  ...,  9.0692e-02,
            1.1703e-01, -9.1827e-02],
          [-1.8555e-01,  2.5177e-01, -6.8018e-02,  ..., -1.2534e-01,
           -1.3236e-01,  1.4303e-02],
          [-1.0576e-02, -1.6121e-01, -5.5170e-02,  ...,  1.6286e-01,
            5.7840e-02,  5.8851e-02]],

         [[ 5.3380e-03, -2.9567e-02, -2.1464e-01,  ..., -2.0748e-01,
           -2.9022e-02, -2.3779e-02],
          [-3.4135e-02, -2.3336e-01, -1.9029e-01,  ..., -2.4160e-02,
           -1.2347e-01, -1.9874e-01],
          [ 1.4578e-02, -2.7529e-02,  8.1546e-04,  ..., -1.3702e-01,
           -1.3030e-01, -3.4425e-02],
          ...,
     

In [29]:
p2.view(-1, 32 , 33, 33).shape

torch.Size([3, 32, 33, 33])

In [28]:
torch.eq(p2.view(3, 32 , 33, 33), p).all()

tensor(1, dtype=torch.uint8)