In [20]:
# parameters
import torch
import torch.nn as nn

BATCH_SIZE = 2
CHANNEL_IN = 6
CHANNEL_OUT = 3
HEIGHT_IN = 8
WIDTH_IN = 8

KERNEL_SIZE = 3
PADDING = 0
STRIDE = 1
GROUP = 3

HEIGHT_OUT = int((HEIGHT_IN - KERNEL_SIZE + 2 * PADDING) / STRIDE + 1)
WIDTH_OUT = int((WIDTH_IN - KERNEL_SIZE + 2 * PADDING) / STRIDE + 1)

In [21]:
# convolution
x = torch.zeros(BATCH_SIZE, CHANNEL_IN, HEIGHT_IN, WIDTH_IN)

for n in range(BATCH_SIZE):
    for c in range(CHANNEL_IN):
        for h in range(HEIGHT_IN):
            for w in range(WIDTH_IN):
                x[n, c, h, w] = h + c

KERNEL_CHANNEL = int(CHANNEL_IN / GROUP)
kernel = torch.zeros((CHANNEL_OUT, KERNEL_CHANNEL, KERNEL_SIZE, KERNEL_SIZE))
for k in range(CHANNEL_OUT):
    for l in range(KERNEL_CHANNEL):
        for i in range(KERNEL_SIZE):
            for j in range(KERNEL_SIZE):
                kernel[k, l, i, j] = j + k

bias = torch.zeros((CHANNEL_OUT,))
for c in range(CHANNEL_OUT):
    bias[c] = c + 10

conv = torch.nn.Conv2d(in_channels=CHANNEL_IN, out_channels=CHANNEL_OUT,
                       kernel_size=KERNEL_SIZE, bias=True, stride=STRIDE, padding=PADDING, groups=GROUP)
conv.weight.data = kernel
conv.bias.data = bias

afterConv = conv(x)

print(afterConv)

tensor([[[[ 37.,  37.,  37.,  37.,  37.,  37.],
          [ 55.,  55.,  55.,  55.,  55.,  55.],
          [ 73.,  73.,  73.,  73.,  73.,  73.],
          [ 91.,  91.,  91.,  91.,  91.,  91.],
          [109., 109., 109., 109., 109., 109.],
          [127., 127., 127., 127., 127., 127.]],

         [[137., 137., 137., 137., 137., 137.],
          [173., 173., 173., 173., 173., 173.],
          [209., 209., 209., 209., 209., 209.],
          [245., 245., 245., 245., 245., 245.],
          [281., 281., 281., 281., 281., 281.],
          [317., 317., 317., 317., 317., 317.]],

         [[309., 309., 309., 309., 309., 309.],
          [363., 363., 363., 363., 363., 363.],
          [417., 417., 417., 417., 417., 417.],
          [471., 471., 471., 471., 471., 471.],
          [525., 525., 525., 525., 525., 525.],
          [579., 579., 579., 579., 579., 579.]]],


        [[[ 37.,  37.,  37.,  37.,  37.,  37.],
          [ 55.,  55.,  55.,  55.,  55.,  55.],
          [ 73.,  73.,  73.,  73

In [23]:
# batch normalization
# shape of running_mean and running_var should be same as channel_outs
RUNNING_MEAN = [ 82, 227, 444]
RUNNING_VAR = [ 945, 3780, 8505]
weight = [0.5, 0.5, 0.5]
bias = [0.2, 0.2, 0.2]
momentum = 0
eps = 1e-5

batch_norm = nn.BatchNorm2d(num_features=CHANNEL_OUT, momentum=momentum, eps=eps)
batch_norm.running_mean = torch.tensor(RUNNING_MEAN, dtype=torch.float)
batch_norm.running_var = torch.tensor(RUNNING_VAR, dtype=torch.float)
batch_norm.weight.data = torch.tensor(weight, dtype=torch.float)
batch_norm.bias.data = torch.tensor(bias, dtype=torch.float)

afterNorm = batch_norm(afterConv)

print("BN output:")
print(batch_norm.running_mean)
print(batch_norm.running_var)
print(afterNorm.shape)
print(afterNorm)

BN output:
tensor([ 82., 227., 444.])
tensor([ 945., 3780., 8505.])
torch.Size([2, 3, 6, 6])
tensor([[[[-0.5319, -0.5319, -0.5319, -0.5319, -0.5319, -0.5319],
          [-0.2392, -0.2392, -0.2392, -0.2392, -0.2392, -0.2392],
          [ 0.0536,  0.0536,  0.0536,  0.0536,  0.0536,  0.0536],
          [ 0.3464,  0.3464,  0.3464,  0.3464,  0.3464,  0.3464],
          [ 0.6392,  0.6392,  0.6392,  0.6392,  0.6392,  0.6392],
          [ 0.9319,  0.9319,  0.9319,  0.9319,  0.9319,  0.9319]],

         [[-0.5319, -0.5319, -0.5319, -0.5319, -0.5319, -0.5319],
          [-0.2392, -0.2392, -0.2392, -0.2392, -0.2392, -0.2392],
          [ 0.0536,  0.0536,  0.0536,  0.0536,  0.0536,  0.0536],
          [ 0.3464,  0.3464,  0.3464,  0.3464,  0.3464,  0.3464],
          [ 0.6392,  0.6392,  0.6392,  0.6392,  0.6392,  0.6392],
          [ 0.9319,  0.9319,  0.9319,  0.9319,  0.9319,  0.9319]],

         [[-0.5319, -0.5319, -0.5319, -0.5319, -0.5319, -0.5319],
          [-0.2392, -0.2392, -0.2392, -0.2392

In [18]:
normalize_shape = (HEIGHT_OUT, WIDTH_OUT, CHANNEL_OUT)
weight = torch.zeros(normalize_shape)
bias = torch.zeros(normalize_shape)

for h in range(HEIGHT_OUT):
    for w in range(WIDTH_OUT):
        for c in range(CHANNEL_OUT):
            weight[h, w, c] = h
            bias[h, w, c] = w
            
ln_in = afterConv.permute(0, 2, 3, 1)

ln_norm = nn.LayerNorm([HEIGHT_OUT, WIDTH_OUT, CHANNEL_OUT], eps=eps, elementwise_affine=True)
ln_norm.weight.data = weight
ln_norm.bias.data = bias
print("weight shape: ", ln_norm.weight.shape)
print("bias shape: ", ln_norm.bias.shape)

ln_out = ln_norm(ln_in)
afterNorm = ln_out.permute(0, 3, 1, 2)
print("LN Out:")
print(afterNorm)

weight shape:  torch.Size([6, 6, 3])
bias shape:  torch.Size([6, 6, 3])
LN Out:
tensor([[[[ 0.0000,  1.0000,  2.0000,  3.0000,  4.0000,  5.0000],
          [-1.2031, -0.2031,  0.7969,  1.7969,  2.7969,  3.7969],
          [-2.1853, -1.1853, -0.1853,  0.8147,  1.8147,  2.8147],
          [-2.9465, -1.9465, -0.9465,  0.0535,  1.0535,  2.0535],
          [-3.4867, -2.4867, -1.4867, -0.4867,  0.5133,  1.5133],
          [-3.8059, -2.8059, -1.8059, -0.8059,  0.1941,  1.1941]],

         [[ 0.0000,  1.0000,  2.0000,  3.0000,  4.0000,  5.0000],
          [-0.4788,  0.5212,  1.5212,  2.5212,  3.5212,  4.5212],
          [-0.5156,  0.4844,  1.4844,  2.4844,  3.4844,  4.4844],
          [-0.1105,  0.8895,  1.8895,  2.8895,  3.8895,  4.8895],
          [ 0.7366,  1.7366,  2.7366,  3.7366,  4.7366,  5.7366],
          [ 2.0257,  3.0257,  4.0257,  5.0257,  6.0257,  7.0257]],

         [[ 0.0000,  1.0000,  2.0000,  3.0000,  4.0000,  5.0000],
          [ 0.6875,  1.6875,  2.6875,  3.6875,  4.6875,  5

In [4]:
# relu
relu = nn.ReLU()
out = relu(afterNorm)

print(out)

tensor([[[[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0536, 0.0536, 0.0536, 0.0536, 0.0536, 0.0536],
          [0.3464, 0.3464, 0.3464, 0.3464, 0.3464, 0.3464],
          [0.6392, 0.6392, 0.6392, 0.6392, 0.6392, 0.6392],
          [0.9319, 0.9319, 0.9319, 0.9319, 0.9319, 0.9319]],

         [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0536, 0.0536, 0.0536, 0.0536, 0.0536, 0.0536],
          [0.3464, 0.3464, 0.3464, 0.3464, 0.3464, 0.3464],
          [0.6392, 0.6392, 0.6392, 0.6392, 0.6392, 0.6392],
          [0.9319, 0.9319, 0.9319, 0.9319, 0.9319, 0.9319]],

         [[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
          [0.0536, 0.0536, 0.0536, 0.0536, 0.0536, 0.0536],
          [0.3464, 0.3464, 0.3464, 0.3464, 0.3464, 0.3464],
          [0.6392, 0.6392, 0.6392, 0

In [16]:
# silu
silu = nn.SiLU()
out = silu(afterNorm)

print(out)

tensor([[[[ 0.0000,  0.7311,  1.7616,  2.8577,  3.9281,  4.9665],
          [-0.2778, -0.0913,  0.5493,  1.5413,  2.6361,  3.7135],
          [-0.2209, -0.2775, -0.0841,  0.5647,  1.5605,  2.6556],
          [-0.1470, -0.2432, -0.2646,  0.0275,  0.7811,  1.8201],
          [-0.1035, -0.1910, -0.2742, -0.1853,  0.3211,  1.2403],
          [-0.0828, -0.1600, -0.2549, -0.2488,  0.1065,  0.9165]],

         [[ 0.0000,  0.7311,  1.7616,  2.8577,  3.9281,  4.9665],
          [-0.1832,  0.3270,  1.2485,  2.3337,  3.4201,  4.4726],
          [-0.1928,  0.2997,  1.2101,  2.2932,  3.3807,  4.4343],
          [-0.0522,  0.6305,  1.6414,  2.7373,  3.8115,  4.8530],
          [ 0.4981,  1.4766,  2.5701,  3.6496,  4.6954,  5.7182],
          [ 1.7896,  2.8857,  3.9551,  4.9929,  6.0112,  7.0195]],

         [[ 0.0000,  0.7311,  1.7616,  2.8577,  3.9281,  4.9665],
          [ 0.4575,  1.4241,  2.5163,  3.5974,  4.6447,  5.6683],
          [ 1.8030,  2.8990,  3.9680,  5.0055,  6.0236,  7.0318],
      

In [6]:
# gelu
gelu = nn.functional.gelu
out = gelu(afterNorm)

print(out)

tensor([[[[-0.1582, -0.1582, -0.1582, -0.1582, -0.1582, -0.1582],
          [-0.0970, -0.0970, -0.0970, -0.0970, -0.0970, -0.0970],
          [ 0.0280,  0.0280,  0.0280,  0.0280,  0.0280,  0.0280],
          [ 0.2201,  0.2201,  0.2201,  0.2201,  0.2201,  0.2201],
          [ 0.4721,  0.4721,  0.4721,  0.4721,  0.4721,  0.4721],
          [ 0.7682,  0.7682,  0.7682,  0.7682,  0.7682,  0.7682]],

         [[-0.1582, -0.1582, -0.1582, -0.1582, -0.1582, -0.1582],
          [-0.0970, -0.0970, -0.0970, -0.0970, -0.0970, -0.0970],
          [ 0.0280,  0.0280,  0.0280,  0.0280,  0.0280,  0.0280],
          [ 0.2201,  0.2201,  0.2201,  0.2201,  0.2201,  0.2201],
          [ 0.4721,  0.4721,  0.4721,  0.4721,  0.4721,  0.4721],
          [ 0.7682,  0.7682,  0.7682,  0.7682,  0.7682,  0.7682]],

         [[-0.1582, -0.1582, -0.1582, -0.1582, -0.1582, -0.1582],
          [-0.0970, -0.0970, -0.0970, -0.0970, -0.0970, -0.0970],
          [ 0.0280,  0.0280,  0.0280,  0.0280,  0.0280,  0.0280],
      