In [1]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import math

from torch.nn import Conv2d, BatchNorm2d, MaxPool2d, Linear

In [2]:
from my import Conv2d as myConv2d
from my import BatchNorm2d as myBatchNorm2d
from my import MaxPool2d as myMaxPool2d
from my import Linear as myLinear

## Conv2d Test

In [3]:
t1 = torch.rand((10,3,64,64)) 
in_channels = 3
out_channels = 7
kernel_size = 9
stride = (1,2)
padding = (1,2)
bias = True

In [4]:
conv2d = Conv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
correct = conv2d(t1)
my_conv2d = myConv2d(in_channels, out_channels, kernel_size, stride, padding, bias)
my_conv2d.w = conv2d.weight
my_conv2d.b = conv2d.bias
my_ans = my_conv2d(t1)

In [5]:
torch.allclose(my_ans, correct)

False

In [6]:
(correct - my_ans).abs().max()

tensor(3.8743e-07, grad_fn=<MaxBackward1>)

## MaxPool2d Test

In [7]:
t2 = torch.rand((10,3,64,64))
kernel_size = 9

In [8]:
maxpol2d = MaxPool2d(kernel_size)
my_maxpol2d = myMaxPool2d(kernel_size)
correct = maxpol2d(t2)
my_ans = my_maxpol2d(t2)

In [9]:
torch.allclose(my_ans, correct)

True

In [10]:
(correct - my_ans).abs().max()

tensor(0.)

## Linear Test

In [11]:
t3 = torch.rand((10,3))
in_channels = 3
out_channels = 5
bias = True

In [13]:
linear = Linear(in_features=in_channels, out_features=out_channels, bias=bias)
my_linear = myLinear(in_channels, out_channels, bias)
my_linear.w = linear.weight
my_linear.b = linear.bias
correct = linear(t3)
my_ans = my_linear(t3)

In [14]:
torch.allclose(my_ans, correct)

True

In [15]:
(correct - my_ans).abs().max()

tensor(0., grad_fn=<MaxBackward1>)

## BatchNorm2d Test

In [16]:
t4 = torch.rand((10,3,64,64))
channels = 3
eps = 1e-5
momentum = 0.1

In [19]:
bn2d = BatchNorm2d(channels, eps, momentum)
my_bn2d = myBatchNorm2d(channels, eps, momentum)
my_bn2d.running_mean = bn2d.running_mean
my_bn2d.running_var = bn2d.running_var
my_bn2d.gamma = bn2d.weight
my_bn2d.beta = bn2d.bias
correct = bn2d(t4)
my_ans = my_bn2d(t4)

In [20]:
torch.allclose(my_ans, correct)

False

In [21]:
(correct - my_ans).abs().max()

tensor(2.1458e-05, grad_fn=<MaxBackward1>)