In [15]:
# parameters
import torch
import torch.nn as nn

BATCH_SIZE = 2
CHANNEL_IN = 6
CHANNEL_OUT = 3
HEIGHT_IN = 8
WIDTH_IN = 8

KERNEL_SIZE = 3
PADDING = 0
STRIDE = 1
GROUP = 3

HEIGHT_OUT = int((HEIGHT_IN - KERNEL_SIZE + 2 * PADDING) / STRIDE + 1)
WIDTH_OUT = int((WIDTH_IN - KERNEL_SIZE + 2 * PADDING) / STRIDE + 1)

# shape of running_mean and running_var should be same as channel_outs
# RUNNING_MEAN = [0.2, 0.2, 0.2]
# RUNNING_VAR = [0.25, 0.25, 0.25]
RUNNING_MEAN = [ 82, 227, 444]
RUNNING_VAR = [ 945, 3780, 8505]
gamma = 0.5
beta = 0.2
momentum = 0

In [16]:
# convolution
x = torch.zeros(BATCH_SIZE, CHANNEL_IN, HEIGHT_IN, WIDTH_IN)

for n in range(BATCH_SIZE):
    for c in range(CHANNEL_IN):
        for h in range(HEIGHT_IN):
            for w in range(WIDTH_IN):
                x[n, c, h, w] = h + c

KERNEL_CHANNEL = int(CHANNEL_IN / GROUP)
kernel = torch.zeros((CHANNEL_OUT, KERNEL_CHANNEL, KERNEL_SIZE, KERNEL_SIZE))
for k in range(CHANNEL_OUT):
    for l in range(KERNEL_CHANNEL):
        for i in range(KERNEL_SIZE):
            for j in range(KERNEL_SIZE):
                kernel[k, l, i, j] = j + k

bias = torch.zeros((CHANNEL_OUT,))
for c in range(CHANNEL_OUT):
    bias[c] = c + 10

conv = torch.nn.Conv2d(in_channels=CHANNEL_IN, out_channels=CHANNEL_OUT,
                       kernel_size=KERNEL_SIZE, bias=True, stride=STRIDE, padding=PADDING, groups=GROUP)
conv.weight.data = kernel
conv.bias.data = bias

afterConv = conv(x)

print(afterConv)

tensor([[[[ 37.,  37.,  37.,  37.,  37.,  37.],
          [ 55.,  55.,  55.,  55.,  55.,  55.],
          [ 73.,  73.,  73.,  73.,  73.,  73.],
          [ 91.,  91.,  91.,  91.,  91.,  91.],
          [109., 109., 109., 109., 109., 109.],
          [127., 127., 127., 127., 127., 127.]],

         [[137., 137., 137., 137., 137., 137.],
          [173., 173., 173., 173., 173., 173.],
          [209., 209., 209., 209., 209., 209.],
          [245., 245., 245., 245., 245., 245.],
          [281., 281., 281., 281., 281., 281.],
          [317., 317., 317., 317., 317., 317.]],

         [[309., 309., 309., 309., 309., 309.],
          [363., 363., 363., 363., 363., 363.],
          [417., 417., 417., 417., 417., 417.],
          [471., 471., 471., 471., 471., 471.],
          [525., 525., 525., 525., 525., 525.],
          [579., 579., 579., 579., 579., 579.]]],


        [[[ 37.,  37.,  37.,  37.,  37.,  37.],
          [ 55.,  55.,  55.,  55.,  55.,  55.],
          [ 73.,  73.,  73.,  73

In [17]:
# batch normalization
batch_norm = nn.BatchNorm2d(num_features=CHANNEL_OUT, momentum=momentum)
batch_norm.running_mean = torch.tensor(RUNNING_MEAN, dtype=torch.float)
batch_norm.running_var = torch.tensor(RUNNING_VAR, dtype=torch.float)
batch_norm.weight.data.fill_(gamma)
batch_norm.bias.data.fill_(beta)

afterBN = batch_norm(afterConv)

print("BN output:")
print(batch_norm.running_mean)
print(batch_norm.running_var)
print(afterBN.shape)
print(afterBN)

BN output:
tensor([ 82., 227., 444.])
tensor([ 945., 3780., 8505.])
torch.Size([2, 3, 6, 6])
tensor([[[[-0.5319, -0.5319, -0.5319, -0.5319, -0.5319, -0.5319],
          [-0.2392, -0.2392, -0.2392, -0.2392, -0.2392, -0.2392],
          [ 0.0536,  0.0536,  0.0536,  0.0536,  0.0536,  0.0536],
          [ 0.3464,  0.3464,  0.3464,  0.3464,  0.3464,  0.3464],
          [ 0.6392,  0.6392,  0.6392,  0.6392,  0.6392,  0.6392],
          [ 0.9319,  0.9319,  0.9319,  0.9319,  0.9319,  0.9319]],

         [[-0.5319, -0.5319, -0.5319, -0.5319, -0.5319, -0.5319],
          [-0.2392, -0.2392, -0.2392, -0.2392, -0.2392, -0.2392],
          [ 0.0536,  0.0536,  0.0536,  0.0536,  0.0536,  0.0536],
          [ 0.3464,  0.3464,  0.3464,  0.3464,  0.3464,  0.3464],
          [ 0.6392,  0.6392,  0.6392,  0.6392,  0.6392,  0.6392],
          [ 0.9319,  0.9319,  0.9319,  0.9319,  0.9319,  0.9319]],

         [[-0.5319, -0.5319, -0.5319, -0.5319, -0.5319, -0.5319],
          [-0.2392, -0.2392, -0.2392, -0.2392

In [18]:
# relu
relu = nn.ReLU()
out = relu(afterBN)

print(out)

tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0536, 0.0536, 0.0536, 0.0536, 0.0536, 0.0536],
          [0.3464, 0.3464, 0.3464, 0.3464, 0.3464, 0.3464],
          [0.6392, 0.6392, 0.6392, 0.6392, 0.6392, 0.6392],
          [0.9319, 0.9319, 0.9319, 0.9319, 0.9319, 0.9319]],

         [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0536, 0.0536, 0.0536, 0.0536, 0.0536, 0.0536],
          [0.3464, 0.3464, 0.3464, 0.3464, 0.3464, 0.3464],
          [0.6392, 0.6392, 0.6392, 0.6392, 0.6392, 0.6392],
          [0.9319, 0.9319, 0.9319, 0.9319, 0.9319, 0.9319]],

         [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0536, 0.0536, 0.0536, 0.0536, 0.0536, 0.0536],
          [0.3464, 0.3464, 0.3464, 0.3464, 0.3464, 0.3464],
          [0.6392, 0.6392, 0.6392, 0

In [19]:
# silu
silu = nn.SiLU()
out = silu(afterBN)

print(out)

tensor([[[[-0.1968, -0.1968, -0.1968, -0.1968, -0.1968, -0.1968],
          [-0.1053, -0.1053, -0.1053, -0.1053, -0.1053, -0.1053],
          [ 0.0275,  0.0275,  0.0275,  0.0275,  0.0275,  0.0275],
          [ 0.2029,  0.2029,  0.2029,  0.2029,  0.2029,  0.2029],
          [ 0.4184,  0.4184,  0.4184,  0.4184,  0.4184,  0.4184],
          [ 0.6686,  0.6686,  0.6686,  0.6686,  0.6686,  0.6686]],

         [[-0.1968, -0.1968, -0.1968, -0.1968, -0.1968, -0.1968],
          [-0.1053, -0.1053, -0.1053, -0.1053, -0.1053, -0.1053],
          [ 0.0275,  0.0275,  0.0275,  0.0275,  0.0275,  0.0275],
          [ 0.2029,  0.2029,  0.2029,  0.2029,  0.2029,  0.2029],
          [ 0.4184,  0.4184,  0.4184,  0.4184,  0.4184,  0.4184],
          [ 0.6686,  0.6686,  0.6686,  0.6686,  0.6686,  0.6686]],

         [[-0.1968, -0.1968, -0.1968, -0.1968, -0.1968, -0.1968],
          [-0.1053, -0.1053, -0.1053, -0.1053, -0.1053, -0.1053],
          [ 0.0275,  0.0275,  0.0275,  0.0275,  0.0275,  0.0275],
      

In [20]:
# gelu
gelu = nn.functional.gelu
out = gelu(afterBN)

print(out)

tensor([[[[-0.1582, -0.1582, -0.1582, -0.1582, -0.1582, -0.1582],
          [-0.0970, -0.0970, -0.0970, -0.0970, -0.0970, -0.0970],
          [ 0.0280,  0.0280,  0.0280,  0.0280,  0.0280,  0.0280],
          [ 0.2201,  0.2201,  0.2201,  0.2201,  0.2201,  0.2201],
          [ 0.4721,  0.4721,  0.4721,  0.4721,  0.4721,  0.4721],
          [ 0.7682,  0.7682,  0.7682,  0.7682,  0.7682,  0.7682]],

         [[-0.1582, -0.1582, -0.1582, -0.1582, -0.1582, -0.1582],
          [-0.0970, -0.0970, -0.0970, -0.0970, -0.0970, -0.0970],
          [ 0.0280,  0.0280,  0.0280,  0.0280,  0.0280,  0.0280],
          [ 0.2201,  0.2201,  0.2201,  0.2201,  0.2201,  0.2201],
          [ 0.4721,  0.4721,  0.4721,  0.4721,  0.4721,  0.4721],
          [ 0.7682,  0.7682,  0.7682,  0.7682,  0.7682,  0.7682]],

         [[-0.1582, -0.1582, -0.1582, -0.1582, -0.1582, -0.1582],
          [-0.0970, -0.0970, -0.0970, -0.0970, -0.0970, -0.0970],
          [ 0.0280,  0.0280,  0.0280,  0.0280,  0.0280,  0.0280],
      