In [4]:
import torch.nn as nn
import torch
import MST
import numpy as np

In [5]:
np.random.seed(42)

In [6]:
in_C = 16
out_C = 128
stride = 1
padding = 1
dilation = 1
kernel = 3

params = [in_C, out_C, kernel, stride, padding, dilation]

print(f"{(16 - kernel + padding*2 - (kernel-1)*(dilation-1))/stride + 1}")

16.0


In [7]:
with torch.no_grad():
    myConv = MST.Conv2d(*params)
    torchConv = nn.Conv2d(*params)

    torchConv.weight = nn.Parameter(torch.tensor(myConv._w))
    torchConv.bias = nn.Parameter(torch.tensor(myConv._bias).flatten())


In [8]:
image = np.ones((1, in_C, 16, 16))

In [9]:
with torch.no_grad():
    myres = myConv(image)
    nnres = torchConv(torch.Tensor(image))
    print(myres.shape)
    print(nnres.shape)

    nnres = nnres.numpy().round(3).flatten()
    myres = myres.round(3).flatten()
    bads = np.where(abs(nnres - myres) > 0.001)
    print(myres[:20])
    print(nnres[:20])
    print(len(bads[0]))
    for pos in bads[0][:10]:
        print(f"[{pos}]: {myres[pos]:.3f} {nnres[pos]:.3f}")

(1, 128, 16, 16)
torch.Size([1, 128, 16, 16])
[-0.412 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521
 -0.521 -0.521 -0.521 -0.521 -0.521 -0.583 -1.004 -1.153 -1.153 -1.153]
[-0.412 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521
 -0.521 -0.521 -0.521 -0.521 -0.521 -0.583 -1.004 -1.153 -1.153 -1.153]
0


In [10]:
%%timeit
myres = myConv(image)

1.41 ms ± 12.4 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [11]:
%%timeit
nnres = torchConv(torch.Tensor(image))

115 µs ± 835 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [12]:

import torch.nn as nn
import torch
import MST
import numpy as np
np.random.seed(42)
in_C = 2
out_C = 128
stride = 2
padding = 12
dilation =1
kernel = 6
inH, InW = 32, 32
outH, outW = inH//stride, InW//stride

params = [in_C, out_C, kernel, stride, padding, dilation]

print(f"{(16 - kernel + padding*2 - (kernel-1)*(dilation-1))/stride + 1}")
myConv = MST.Conv2d(*params)
torchConv = nn.Conv2d(*params)

torchConv.weight = nn.Parameter(torch.tensor(myConv._w))
torchConv.bias = nn.Parameter(torch.tensor(myConv._bias).flatten())
image = np.ones((1, in_C, inH, InW))

18.0


In [13]:
torchImage = torch.ones((1, in_C, inH, InW), requires_grad=True)

out_sample = myConv(image)

myConv(image).backward(np.ones(out_sample.shape))
torchConv(torchImage).backward(torch.ones(out_sample.shape))


nngrad = torchImage.grad.numpy().round(3).flatten()
mygrad = myConv._dinX.round(3).flatten()
bads = np.where(abs(mygrad - nngrad) > 0.001)
print(len(bads[0]))
for pos in bads[0][:10]:
    print(f"[{pos}]: {mygrad[pos]:.3f} {nngrad[pos]:.3f}")


nngrad = torchConv.weight.grad.numpy().round(3).flatten()
mygrad = myConv._dw.round(3).flatten()
bads = np.where(abs(mygrad - nngrad) > 0.001)
print(len(bads[0]))
for pos in bads[0][:10]:
    print(f"[{pos}]: {mygrad[pos]:.3f} {nngrad[pos]:.3f}")

torchConv.weight.grad = torch.zeros_like(torchConv.weight.grad)

nngrad = torchConv.bias.grad.numpy().round(3).flatten()
mygrad = myConv._dbias.round(3).flatten()
bads = np.where(abs(mygrad - nngrad) > 0.001)
print(len(bads[0]))
for pos in bads[0][:10]:
    print(f"[{pos}]: {mygrad[pos]:.3f} {nngrad[pos]:.3f}")

torchConv.bias.grad = torch.zeros_like(torchConv.bias.grad)

0
0
0


In [14]:
with torch.no_grad():
    myres = myConv(image)
    nnres = torchConv(torch.Tensor(image))
    print(myres.shape)
    print(nnres.shape)

    nnres = nnres.numpy().round(3).flatten()
    myres = myres.round(3).flatten()
    bads = np.where(abs(nnres - myres) > 0.001)
    print(myres[:20])
    print(nnres[:20])
    print(len(bads[0]))
    for pos in bads[0][:10]:
        print(f"[{pos}]: {myres[pos]:.3f} {nnres[pos]:.3f}")

(1, 128, 26, 26)
torch.Size([1, 128, 26, 26])
[0.151 0.151 0.151 0.151 0.151 0.151 0.151 0.151 0.151 0.151 0.151 0.151
 0.151 0.151 0.151 0.151 0.151 0.151 0.151 0.151]
[0.151 0.151 0.151 0.151 0.151 0.151 0.151 0.151 0.151 0.151 0.151 0.151
 0.151 0.151 0.151 0.151 0.151 0.151 0.151 0.151]
0
