In [1]:
import torch.nn as nn
import torch
import MST
import numpy as np

In [2]:
np.random.seed(42)

In [3]:
in_C = 16
out_C = 128
stride = 1
padding = 1
dilation = 1
kernel = 3

params = [in_C, out_C, kernel, stride, padding, dilation]

print(f"{(16 - kernel + padding*2 - (kernel-1)*(dilation-1))/stride + 1}")

16.0


In [4]:
with torch.no_grad():
    myConv = MST.Conv2d(*params)
    torchConv = nn.Conv2d(*params)

    torchConv.weight = nn.Parameter(torch.tensor(myConv._w))
    torchConv.bias = nn.Parameter(torch.tensor(myConv._bias).flatten())


In [5]:
image = np.ones((1, in_C, 16, 16))

In [12]:
with torch.no_grad():
    myres = myConv(image)
    nnres = torchConv(torch.Tensor(image))
    print(myres.shape)
    print(nnres.shape)

    nnres = nnres.numpy().round(3).flatten()
    myres = myres.round(3).flatten()
    bads = np.where(abs(nnres - myres) > 0.001)
    print(myres[:20])
    print(nnres[:20])
    print(len(bads[0]))
    for pos in bads[0][:10]:
        print(f"[{pos}]: {myres[pos]:.3f} {nnres[pos]:.3f}")

(1, 128, 16, 16)
torch.Size([1, 128, 16, 16])
[-0.412 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521
 -0.521 -0.521 -0.521 -0.521 -0.521 -0.583 -1.004 -1.153 -1.153 -1.153]
[-0.412 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521
 -0.521 -0.521 -0.521 -0.521 -0.521 -0.583 -1.004 -1.153 -1.153 -1.153]
0


In [14]:
%%timeit
myres = myConv(image)

2.11 ms ± 729 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
%%timeit
nnres = torchConv(torch.Tensor(image))

130 µs ± 3.07 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [1]:

import torch.nn as nn
import torch
import MST
import numpy as np
np.random.seed(42)
in_C = 16
out_C = 128
stride = 1
padding = 1
dilation = 1
kernel = 3

params = [in_C, out_C, kernel, stride, padding, dilation]

print(f"{(16 - kernel + padding*2 - (kernel-1)*(dilation-1))/stride + 1}")
myConv = MST.Conv2d(*params)
torchConv = nn.Conv2d(*params)

torchConv.weight = nn.Parameter(torch.tensor(myConv._w))
torchConv.bias = nn.Parameter(torch.tensor(myConv._bias).flatten())
image = np.ones((1, in_C, 16, 16))

16.0


In [4]:
torchImage = torch.ones((1, in_C, 16, 16), requires_grad=True)

myConv(image).backward(np.ones((1, 128, 16, 16)))
torchConv(torchImage).backward(torch.ones(1, 128, 16, 16))


nngrad = torchImage.grad.numpy().round(3).flatten()
mygrad = myConv._dinX.round(3).flatten()
bads = np.where(abs(mygrad - nngrad) > 0.001)
print(len(bads[0]))
for pos in bads[0][:10]:
    print(f"[{pos}]: {mygrad[pos]:.3f} {nngrad[pos]:.3f}")


nngrad = torchConv.weight.grad.numpy().round(3).flatten()
mygrad = myConv._dw.round(3).flatten()
bads = np.where(abs(mygrad - nngrad) > 0.001)
print(len(bads[0]))
for pos in bads[0][:10]:
    print(f"[{pos}]: {mygrad[pos]:.3f} {nngrad[pos]:.3f}")

nngrad = torchConv.bias.grad.numpy().round(3).flatten()
mygrad = myConv._dbias.round(3).flatten()
bads = np.where(abs(mygrad - nngrad) > 0.001)
print(len(bads[0]))
for pos in bads[0][:10]:
    print(f"[{pos}]: {mygrad[pos]:.3f} {nngrad[pos]:.3f}")

4095
[0]: 3.557 -0.260
[1]: 1.435 2.414
[2]: -2.466 2.414
[3]: -5.622 2.414
[4]: -4.877 2.414
[5]: 1.165 2.414
[6]: 6.463 2.414
[7]: 7.750 2.414
[8]: 4.135 2.414
[9]: 0.521 2.414
18432
[0]: 225.000 450.000
[1]: 240.000 480.000
[2]: 225.000 450.000
[3]: 240.000 480.000
[4]: 256.000 512.000
[5]: 240.000 480.000
[6]: 225.000 450.000
[7]: 240.000 480.000
[8]: 225.000 450.000
[9]: 225.000 450.000
128
[0]: 256.000 512.000
[1]: 256.000 512.000
[2]: 256.000 512.000
[3]: 256.000 512.000
[4]: 256.000 512.000
[5]: 256.000 512.000
[6]: 256.000 512.000
[7]: 256.000 512.000
[8]: 256.000 512.000
[9]: 256.000 512.000


In [5]:
with torch.no_grad():
    myres = myConv(image)
    nnres = torchConv(torch.Tensor(image))
    print(myres.shape)
    print(nnres.shape)

    nnres = nnres.numpy().round(3).flatten()
    myres = myres.round(3).flatten()
    bads = np.where(abs(nnres - myres) > 0.001)
    print(myres[:20])
    print(nnres[:20])
    print(len(bads[0]))
    for pos in bads[0][:10]:
        print(f"[{pos}]: {myres[pos]:.3f} {nnres[pos]:.3f}")

(1, 128, 16, 16)
torch.Size([1, 128, 16, 16])
[-0.412 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521
 -0.521 -0.521 -0.521 -0.521 -0.521 -0.583 -1.004 -1.153 -1.153 -1.153]
[-0.412 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521 -0.521
 -0.521 -0.521 -0.521 -0.521 -0.521 -0.583 -1.004 -1.153 -1.153 -1.153]
0
