In [1]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
import torchinfo

# Implementing convolution from scratch
Ref. https://stackoverflow.com/questions/43086557/convolve2d-just-by-using-numpy

## Convolution (N-batch=1, N-inchannels=1, N-outchannels=1)

In [108]:
img= np.round(np.random.uniform(0,1,size=(6,9)),2)
w= np.round(np.random.uniform(-1,1,size=(3,3)),2)
print(f'img= {img}')
print(f'w= {w}')

img= [[0.08 0.93 0.67 0.68 0.33 0.15 0.16 0.76 0.28]
 [0.48 0.57 0.15 0.68 0.49 0.69 0.39 0.99 0.54]
 [0.8  0.42 0.61 0.07 0.7  0.22 0.45 0.26 0.88]
 [0.11 0.35 0.65 0.08 0.65 0.52 0.23 0.86 0.9 ]
 [0.06 0.68 0.63 0.34 0.46 0.96 0.3  0.82 0.78]
 [0.22 0.88 0.9  0.02 0.26 0.14 0.57 0.45 0.23]]
w= [[-0.92  0.33 -0.08]
 [-0.4   0.85 -0.37]
 [-0.91  0.31  0.42]]


### Ground truth: torch.conv2d

In [109]:
img_torch= torch.from_numpy(img).reshape((1,1,)+img.shape).float()
w_torch= torch.from_numpy(w).reshape((1,1,)+w.shape).float()
torch.nn.functional.conv2d(img_torch, w_torch)

tensor([[[[ 0.0751, -1.2047, -0.3211, -0.3938, -0.4005, -0.4083,  0.6076],
          [-0.1728, -0.2881, -0.6900,  0.3136, -0.8475, -0.4277,  0.0754],
          [-0.2124, -0.0885, -1.3013,  0.7135, -0.5055, -0.8416,  0.2162],
          [ 0.7338, -0.4896, -1.4604,  0.1203,  0.1224, -0.6654,  0.0060]]]])

### With numpy stride_tricks.as_strided and einsum

In [110]:
view_shape= tuple(np.subtract(img.shape, w.shape)+1)+w.shape
print(f'view_shape= {view_shape}')
strides= img.strides+img.strides
print(f'strides= {strides}')
sub_matrices= np.lib.stride_tricks.as_strided(img, view_shape, strides)
# print(f'sub_matrices= {sub_matrices}')
conv1= (sub_matrices*w).sum(axis=(2,3))
print(f'conv1= {conv1}')
conv2= np.einsum('ij,klij->kl', w, sub_matrices)
print(f'conv2= {conv2}')

view_shape= (4, 7, 3, 3)
strides= (72, 8, 72, 8)
conv1= [[ 0.0751 -1.2047 -0.3211 -0.3938 -0.4005 -0.4083  0.6076]
 [-0.1728 -0.2881 -0.69    0.3136 -0.8475 -0.4277  0.0754]
 [-0.2124 -0.0885 -1.3013  0.7135 -0.5055 -0.8416  0.2162]
 [ 0.7338 -0.4896 -1.4604  0.1203  0.1224 -0.6654  0.006 ]]
conv2= [[ 0.0751 -1.2047 -0.3211 -0.3938 -0.4005 -0.4083  0.6076]
 [-0.1728 -0.2881 -0.69    0.3136 -0.8475 -0.4277  0.0754]
 [-0.2124 -0.0885 -1.3013  0.7135 -0.5055 -0.8416  0.2162]
 [ 0.7338 -0.4896 -1.4604  0.1203  0.1224 -0.6654  0.006 ]]


### With Torch

In [111]:
img_torch= torch.from_numpy(img).reshape((1,1,)+img.shape).float()
w_torch= torch.from_numpy(w).reshape((1,1,)+w.shape).float()
view_shape= img_torch.shape[:2]+tuple(np.subtract(img_torch.shape[2:],w_torch.shape[2:])+1)+w_torch.shape[2:]
print(f'view_shape= {view_shape}')
strides= img_torch.stride()+img_torch.stride()[2:]
print(f'strides= {strides}')
sub_matrices= torch.as_strided(img_torch, view_shape, strides)
# print(f'sub_matrices= {sub_matrices}')
conv1= (sub_matrices*w_torch).sum(axis=(4,5))
print(f'conv1= {conv1}')
conv2= torch.einsum('klij,klmnij->klmn', w_torch, sub_matrices)
print(f'conv2= {conv2}')

view_shape= torch.Size([1, 1, 4, 7, 3, 3])
strides= (54, 54, 9, 1, 9, 1)
conv1= tensor([[[[ 0.0751, -1.2047, -0.3211, -0.3938, -0.4005, -0.4083,  0.6076],
          [-0.1728, -0.2881, -0.6900,  0.3136, -0.8475, -0.4277,  0.0754],
          [-0.2124, -0.0885, -1.3013,  0.7135, -0.5055, -0.8416,  0.2162],
          [ 0.7338, -0.4896, -1.4604,  0.1203,  0.1224, -0.6654,  0.0060]]]])
conv2= tensor([[[[ 0.0751, -1.2047, -0.3211, -0.3938, -0.4005, -0.4083,  0.6076],
          [-0.1728, -0.2881, -0.6900,  0.3136, -0.8475, -0.4277,  0.0754],
          [-0.2124, -0.0885, -1.3013,  0.7135, -0.5055, -0.8416,  0.2162],
          [ 0.7338, -0.4896, -1.4604,  0.1203,  0.1224, -0.6654,  0.0060]]]])


## Convolution (N-batch=1, N-inchannels=2, N-outchannels=3)
Note: The following code works with n_bch>=2.

In [153]:
n_bch,in_ch,out_ch,ks= 1,2,3,3
img= np.round(np.random.uniform(0,1,size=(n_bch,in_ch,6,9)),2)
w= np.round(np.random.uniform(-1,1,size=(out_ch,in_ch,ks,ks)),2)
print(f'img= {img}')
print(f'w= {w}')

img= [[[[0.58 0.38 0.12 0.17 0.97 0.18 0.95 0.95 0.27]
   [0.09 0.86 0.63 0.99 0.84 0.24 0.21 0.26 0.21]
   [0.89 0.43 0.9  0.75 0.87 0.16 0.41 0.08 0.11]
   [0.16 0.81 0.15 0.56 0.35 0.11 0.41 0.9  0.09]
   [0.73 0.43 0.98 0.79 0.52 0.54 0.15 0.25 0.19]
   [0.99 0.19 0.04 0.77 0.36 0.47 0.14 0.52 0.15]]

  [[0.89 0.98 0.89 0.14 0.78 0.14 0.95 0.95 0.73]
   [0.12 0.13 0.87 0.72 0.27 0.96 0.91 0.84 0.14]
   [0.37 0.56 0.06 0.5  0.27 0.07 0.57 0.86 0.44]
   [0.46 0.13 0.26 0.31 0.9  0.86 0.93 0.24 0.58]
   [0.84 0.62 0.58 0.91 0.86 0.59 0.16 0.31 0.32]
   [0.39 0.44 0.67 0.37 0.96 0.83 0.26 0.91 0.58]]]]
w= [[[[-1.   -0.12 -0.24]
   [ 0.5   0.92  0.81]
   [ 0.47  0.18 -0.04]]

  [[-0.11  0.6   0.16]
   [-0.19  0.49 -0.37]
   [-0.09  0.44  0.28]]]


 [[[ 0.5  -0.2   0.45]
   [-0.49 -0.34  0.54]
   [-0.6   0.16 -0.82]]

  [[ 0.42 -0.53 -0.91]
   [-0.79  0.3   0.42]
   [-0.93 -0.9  -0.54]]]


 [[[-0.69 -0.89  0.13]
   [-0.94 -0.65  0.37]
   [-0.46  0.03 -0.45]]

  [[-0.3  -0.93  0.54]
   [ 

### Ground truth: torch.conv2d

In [154]:
# print(torch.nn.Conv2d(2,3,bias=False,kernel_size=(3,3)).weight.shape)
img_torch= torch.from_numpy(img).float()
w_torch= torch.from_numpy(w).float()
conv_gt= torch.nn.functional.conv2d(img_torch, w_torch)
print(f'conv_gt= {conv_gt}')

conv_gt= tensor([[[[ 1.7332,  2.4102,  2.5449,  1.8440,  0.4113,  1.2922,  0.8077],
          [ 1.9147,  1.4130,  1.9298,  1.3351,  1.1848,  1.2055,  1.1881],
          [ 0.7150,  1.1221,  0.9520,  0.8854,  0.1931,  1.9831,  1.2294],
          [ 2.1038,  1.2396,  2.1700,  2.3855,  1.6488,  1.1085,  0.7497]],

         [[-2.4579, -1.1591, -2.2823, -2.4235, -1.1120, -2.3014, -2.7134],
          [-1.6237, -2.0818, -1.0713, -3.0737, -3.6110, -2.8510, -1.7890],
          [-2.7452, -2.2581, -2.3141, -2.4440, -1.6706, -2.3933, -2.0621],
          [-2.0524, -1.7402, -2.7161, -3.8014, -3.4150, -2.6239, -1.6516]],

         [[-2.0643, -2.1252, -0.7881, -2.5149, -1.4129, -0.2671, -1.1814],
          [-0.6381, -2.0971, -2.4688, -1.4834, -0.8247, -0.7548, -0.6436],
          [-1.5008, -0.9920, -1.2218, -1.5451,  0.3733,  0.1928, -1.1368],
          [-0.6915, -0.7645, -0.4683, -0.9015, -0.1823, -0.8607, -0.6381]]]])


### With numpy stride_tricks.as_strided and einsum

In [155]:
view_shape= img.shape[:2]+tuple(np.subtract(img.shape[2:], w.shape[-2:])+1)+w.shape[-2:]
print(f'view_shape= {view_shape}')
strides= img.strides+img.strides[2:]
print(f'strides= {strides}')
sub_matrices= np.lib.stride_tricks.as_strided(img, view_shape, strides)
print(f'sub_matrices.shape= {sub_matrices.shape}')
print(f'w.shape= {w.shape}')
# print(f'sub_matrices= {sub_matrices}')
conv_np= np.einsum('ijkl,bjmnkl->bimn', w, sub_matrices)
print(f'conv_np= {conv_np}, {torch.all(torch.isclose(conv_gt,torch.from_numpy(conv_np).float()))}')

view_shape= (1, 2, 4, 7, 3, 3)
strides= (864, 432, 72, 8, 72, 8)
sub_matrices.shape= (1, 2, 4, 7, 3, 3)
w.shape= (3, 2, 3, 3)
conv_np= [[[[ 1.7332  2.4102  2.5449  1.844   0.4113  1.2922  0.8077]
   [ 1.9147  1.413   1.9298  1.3351  1.1848  1.2055  1.1881]
   [ 0.715   1.1221  0.952   0.8854  0.1931  1.9831  1.2294]
   [ 2.1038  1.2396  2.17    2.3855  1.6488  1.1085  0.7497]]

  [[-2.4579 -1.1591 -2.2823 -2.4235 -1.112  -2.3014 -2.7134]
   [-1.6237 -2.0818 -1.0713 -3.0737 -3.611  -2.851  -1.789 ]
   [-2.7452 -2.2581 -2.3141 -2.444  -1.6706 -2.3933 -2.0621]
   [-2.0524 -1.7402 -2.7161 -3.8014 -3.415  -2.6239 -1.6516]]

  [[-2.0643 -2.1252 -0.7881 -2.5149 -1.4129 -0.2671 -1.1814]
   [-0.6381 -2.0971 -2.4688 -1.4834 -0.8247 -0.7548 -0.6436]
   [-1.5008 -0.992  -1.2218 -1.5451  0.3733  0.1928 -1.1368]
   [-0.6915 -0.7645 -0.4683 -0.9015 -0.1823 -0.8607 -0.6381]]]], True


### With Torch

In [156]:
img_torch= torch.from_numpy(img).float()
w_torch= torch.from_numpy(w).float()
view_shape= img_torch.shape[:2]+tuple(np.subtract(img_torch.shape[2:],w_torch.shape[-2:])+1)+w_torch.shape[-2:]
print(f'view_shape= {view_shape}')
strides= img_torch.stride()+img_torch.stride()[2:]
print(f'strides= {strides}')
sub_matrices= torch.as_strided(img_torch, view_shape, strides)
print(f'sub_matrices.shape= {sub_matrices.shape}')
# print(f'sub_matrices= {sub_matrices}')
conv_tch= torch.einsum('ijkl,bjmnkl->bimn', w_torch, sub_matrices)
print(f'conv_tch= {conv_tch}, {torch.all(torch.isclose(conv_gt,conv_tch))}')

view_shape= torch.Size([1, 2, 4, 7, 3, 3])
strides= (108, 54, 9, 1, 9, 1)
sub_matrices.shape= torch.Size([1, 2, 4, 7, 3, 3])
conv_tch= tensor([[[[ 1.7332,  2.4102,  2.5449,  1.8440,  0.4113,  1.2922,  0.8077],
          [ 1.9147,  1.4130,  1.9298,  1.3351,  1.1848,  1.2055,  1.1881],
          [ 0.7150,  1.1221,  0.9520,  0.8854,  0.1931,  1.9831,  1.2294],
          [ 2.1038,  1.2396,  2.1700,  2.3855,  1.6488,  1.1085,  0.7497]],

         [[-2.4579, -1.1591, -2.2823, -2.4235, -1.1120, -2.3014, -2.7134],
          [-1.6237, -2.0818, -1.0713, -3.0737, -3.6110, -2.8510, -1.7890],
          [-2.7452, -2.2581, -2.3141, -2.4440, -1.6706, -2.3933, -2.0621],
          [-2.0524, -1.7402, -2.7161, -3.8014, -3.4150, -2.6239, -1.6516]],

         [[-2.0643, -2.1252, -0.7881, -2.5149, -1.4129, -0.2671, -1.1814],
          [-0.6381, -2.0971, -2.4688, -1.4834, -0.8247, -0.7548, -0.6436],
          [-1.5008, -0.9920, -1.2218, -1.5451,  0.3733,  0.1928, -1.1368],
          [-0.6915, -0.7645, -0.4683

## Pixel-variant kernel (N-batch=1, N-inchannels=1, N-outchannels=1)

In [112]:
kernel_shape= (3,3)
img= np.round(np.random.uniform(0,1,size=(6,9)),2)
out_shape= tuple(np.subtract(img.shape, kernel_shape)+1)
w= np.round(np.random.uniform(-1,1,size=out_shape+kernel_shape),2)
print(f'img= {img}')
print(f'w= {w[0:1,0:1,:,:]}..., {w.shape}')

img= [[0.78 0.79 0.32 0.44 0.79 0.26 0.59 0.34 0.38]
 [0.94 0.62 0.43 0.03 0.82 0.45 0.63 0.64 0.43]
 [0.67 0.69 0.81 0.63 0.37 0.91 0.35 0.36 0.56]
 [0.02 0.74 0.88 0.68 0.51 0.15 0.41 0.68 0.29]
 [0.58 0.9  0.23 0.1  0.26 0.06 0.35 0.13 0.84]
 [0.57 0.67 0.11 0.6  0.3  0.58 0.75 0.15 0.06]]
w= [[[[ 0.12  0.07  0.34]
   [ 0.07  0.27  0.67]
   [ 0.55 -0.11  0.77]]]]..., (4, 7, 3, 3)


### With numpy stride_tricks.as_strided and einsum

In [113]:
view_shape= out_shape+kernel_shape
print(f'view_shape= {view_shape}')
strides= img.strides+img.strides
print(f'strides= {strides}')
sub_matrices= np.lib.stride_tricks.as_strided(img, view_shape, strides)
print(f'sub_matrices.shape= {sub_matrices.shape}')
# print(f'sub_matrices= {sub_matrices}')
# print(f'sub_matrices*w= {sub_matrices*w}')
res1= (sub_matrices*w).sum(axis=(2,3))
print(f'res1= {res1}')
res2= np.einsum('ijkl,ijkl->ij', w, sub_matrices)
print(f'res2= {res2}')

view_shape= (4, 7, 3, 3)
strides= (72, 8, 72, 8)
sub_matrices.shape= (4, 7, 3, 3)
res1= [[ 1.6953 -2.4214  0.5561  0.3529 -0.6678  0.8172  0.7055]
 [-1.5766 -1.3174 -1.0877  1.7223  1.1279 -1.3055  1.6701]
 [ 0.2433 -1.561   0.8805 -0.0295  0.9864  0.448   0.4326]
 [-1.2691 -0.5238 -0.5407 -1.1373 -0.2999  1.304  -1.7477]]
res2= [[ 1.6953 -2.4214  0.5561  0.3529 -0.6678  0.8172  0.7055]
 [-1.5766 -1.3174 -1.0877  1.7223  1.1279 -1.3055  1.6701]
 [ 0.2433 -1.561   0.8805 -0.0295  0.9864  0.448   0.4326]
 [-1.2691 -0.5238 -0.5407 -1.1373 -0.2999  1.304  -1.7477]]


### With Torch

In [114]:
img_torch= torch.from_numpy(img).reshape((1,1,)+img.shape).float()
w_torch= torch.from_numpy(w).reshape((1,1,)+w.shape).float()
view_shape= (1,1,)+out_shape+kernel_shape
print(f'view_shape= {view_shape}')
strides= img_torch.stride()+img_torch.stride()[2:]
print(f'strides= {strides}')
sub_matrices= torch.as_strided(img_torch, view_shape, strides)
print(f'sub_matrices.shape= {sub_matrices.shape}')
# print(f'sub_matrices= {sub_matrices}')
res1= (sub_matrices*w_torch).sum(axis=(4,5))
print(f'res1= {res1}')
res2= torch.einsum('ijklmn,ijklmn->ijkl', w_torch, sub_matrices)
print(f'res2= {res2}')

view_shape= (1, 1, 4, 7, 3, 3)
strides= (54, 54, 9, 1, 9, 1)
sub_matrices.shape= torch.Size([1, 1, 4, 7, 3, 3])
res1= tensor([[[[ 1.6953, -2.4214,  0.5561,  0.3529, -0.6678,  0.8172,  0.7055],
          [-1.5766, -1.3174, -1.0877,  1.7223,  1.1279, -1.3055,  1.6701],
          [ 0.2433, -1.5610,  0.8805, -0.0295,  0.9864,  0.4480,  0.4326],
          [-1.2691, -0.5238, -0.5407, -1.1373, -0.2999,  1.3040, -1.7477]]]])
res2= tensor([[[[ 1.6953, -2.4214,  0.5561,  0.3529, -0.6678,  0.8172,  0.7055],
          [-1.5766, -1.3174, -1.0877,  1.7223,  1.1279, -1.3055,  1.6701],
          [ 0.2433, -1.5610,  0.8805, -0.0295,  0.9864,  0.4480,  0.4326],
          [-1.2691, -0.5238, -0.5407, -1.1373, -0.2999,  1.3040, -1.7477]]]])


## Pixel-variant kernel (N-batch=1, N-inchannels=2, N-outchannels=3)
Note: The following code works with N-batch>=2.0.

In [192]:
n_bch,in_ch,out_ch,ks= 1,2,3,3
img= np.round(np.random.uniform(0,1,size=(n_bch,in_ch,6,9)),2)
kernel_shape= (ks,ks)
out_shape= tuple(np.subtract(img.shape[2:], kernel_shape)+1)
w= np.round(np.random.uniform(-1,1,size=(out_ch,in_ch)+out_shape+kernel_shape),2)
print(f'img= {img}, {img.shape}')
print(f'w= {w[0:1,0:1,0:1,0:1,:,:]}..., {w.shape}')

img= [[[[0.98 0.13 0.47 0.79 0.96 0.87 0.65 0.45 0.27]
   [0.22 0.94 1.   0.39 0.41 0.7  0.98 0.95 0.64]
   [0.82 0.94 0.75 0.48 0.7  0.42 0.91 0.3  0.82]
   [0.94 0.42 0.02 0.8  0.63 0.36 0.68 0.71 0.93]
   [0.44 0.55 0.94 0.22 0.06 0.9  0.69 0.86 0.24]
   [0.86 0.46 0.65 0.13 0.25 0.69 0.26 0.54 0.14]]

  [[0.11 0.43 0.89 0.6  0.   0.24 0.92 0.36 0.05]
   [0.78 0.69 0.43 0.87 0.3  0.19 0.06 0.84 0.6 ]
   [0.09 0.88 0.89 0.15 0.66 0.97 0.47 0.97 0.4 ]
   [0.52 0.31 0.89 0.1  0.22 0.37 0.22 0.92 0.85]
   [0.65 0.21 0.31 0.53 0.29 0.66 0.13 0.67 0.48]
   [0.33 0.31 0.56 0.84 0.41 0.45 0.82 0.6  0.61]]]], (1, 2, 6, 9)
w= [[[[[[ 0.36  0.26  0.67]
     [-0.35  0.67  0.49]
     [-0.13  0.26  0.62]]]]]]..., (3, 2, 4, 7, 3, 3)


### Test: comparison with torch.conv2d

In [189]:
w_base= np.round(np.random.uniform(-1,1,size=(out_ch,in_ch)+(1,1)+kernel_shape),2)
w= np.repeat(np.repeat(w_base,out_shape[1],axis=3),out_shape[0],axis=2)
print(f'img= {img}, {img.shape}')
print(f'w= {w[0:1,0:1,0:1,0:1,:,:]}..., {w.shape}, {w_base.shape}')
img_torch= torch.from_numpy(img).float()
w_base_torch= torch.from_numpy(w_base).reshape(w_base.shape[:2]+w_base.shape[4:]).float()
# print(f'w_base_torch= {w_base_torch}..., {w_base_torch.shape}')
res_gt= torch.nn.functional.conv2d(img_torch, w_base_torch)
print(f'res_gt= {res_gt}')

img= [[[[0.72 0.12 0.27 0.56 0.25 0.86 0.9  0.18 0.46]
   [0.56 0.12 0.72 0.28 0.79 0.21 0.61 0.75 0.96]
   [0.3  0.67 0.49 0.28 0.84 0.06 0.42 0.7  0.43]
   [0.13 0.62 0.24 0.1  0.05 0.59 0.43 0.41 0.62]
   [0.69 0.75 0.12 0.09 0.62 0.95 0.78 0.6  0.88]
   [0.15 0.86 0.43 0.78 0.6  0.18 0.25 0.34 0.63]]

  [[0.34 0.19 0.52 0.19 0.25 0.97 0.71 0.98 0.97]
   [0.73 0.56 0.74 0.81 0.37 0.48 0.4  0.39 0.32]
   [0.78 0.37 0.78 0.4  0.39 0.67 0.27 0.3  0.17]
   [0.15 0.81 0.4  0.11 0.38 0.61 0.78 0.37 0.58]
   [0.56 0.47 0.54 0.11 0.39 0.94 0.16 0.75 0.12]
   [0.66 0.43 0.54 0.02 0.41 0.49 0.41 0.23 0.15]]]], (1, 2, 6, 9)
w= [[[[[[ 0.9   0.67  0.26]
     [ 0.59  0.43 -0.64]
     [-0.07  0.92 -0.77]]]]]]..., (3, 2, 4, 7, 3, 3), (3, 2, 1, 1, 3, 3)
res_gt= tensor([[[[-0.3094, -0.5318, -0.6258,  0.2221, -1.0388, -1.1220, -1.0099],
          [-1.3240, -0.4891, -0.8321, -1.0456,  0.1132, -0.8761,  0.5982],
          [-0.1455, -0.1262, -0.7849, -1.3490, -0.5358,  0.0394, -0.5605],
          [ 0.378

### With numpy stride_tricks.as_strided and einsum

In [193]:
view_shape= img.shape[:2]+out_shape+kernel_shape
print(f'view_shape= {view_shape}')
strides= img.strides+img.strides[2:]
print(f'strides= {strides}')
sub_matrices= np.lib.stride_tricks.as_strided(img, view_shape, strides)
print(f'sub_matrices.shape= {sub_matrices.shape}')
# print(f'sub_matrices= {sub_matrices}')
# print(f'sub_matrices*w= {sub_matrices*w}')
res_np= np.einsum('ijmnkl,bjmnkl->bimn', w, sub_matrices)
print(f'res_np= {res_np}')
# print(f'res_np= {res_np}, {torch.all(torch.isclose(res_gt,torch.from_numpy(res_np).float()))}')

view_shape= (1, 2, 4, 7, 3, 3)
strides= (864, 432, 72, 8, 72, 8)
sub_matrices.shape= (1, 2, 4, 7, 3, 3)
res_np= [[[[ 3.0778  1.7253 -0.5113 -0.0817 -0.581  -1.9058  0.7106]
   [-1.9801  1.1653  0.3585 -1.7316  1.3211  0.6694  0.5069]
   [ 0.6119  0.563  -0.7654  2.5973 -1.971   0.6882  0.8126]
   [ 0.7336  0.9963  1.6773 -1.0295  1.2635  2.6903 -0.332 ]]

  [[-1.364  -0.0372  1.4103 -0.1755  1.194  -0.2636  0.3186]
   [-1.9607  0.4     1.9175 -1.7362  0.4124  1.0521 -0.605 ]
   [-1.7676  2.4814  2.3001  0.1844 -0.6243  2.2404  0.7621]
   [-1.7687 -0.2826  0.0179 -0.3561  0.365   0.8226  0.6107]]

  [[-0.4571  1.1342  0.5213 -0.5604 -1.0119  0.4766  0.1784]
   [ 0.0692 -0.2373 -1.4442 -0.3252 -2.1491  2.8538  0.116 ]
   [-2.1918  0.3397  1.0173 -0.2993 -2.663   0.7206  0.1622]
   [-0.5263  1.0427  0.5387  0.6553 -1.3975  0.8994 -1.7948]]]]


### With Torch

In [194]:
img_torch= torch.from_numpy(img).float()
w_torch= torch.from_numpy(w).float()
view_shape= img_torch.shape[:2]+out_shape+kernel_shape
print(f'view_shape= {view_shape}')
strides= img_torch.stride()+img_torch.stride()[2:]
print(f'strides= {strides}')
sub_matrices= torch.as_strided(img_torch, view_shape, strides)
print(f'sub_matrices.shape= {sub_matrices.shape}')
# print(f'sub_matrices= {sub_matrices}')
res_tch= torch.einsum('ijmnkl,bjmnkl->bimn', w_torch, sub_matrices)
print(f'res_tch= {res_tch}, {torch.all(torch.isclose(res_tch,torch.from_numpy(res_np).float()))}')

view_shape= torch.Size([1, 2, 4, 7, 3, 3])
strides= (108, 54, 9, 1, 9, 1)
sub_matrices.shape= torch.Size([1, 2, 4, 7, 3, 3])
res_tch= tensor([[[[ 3.0778,  1.7253, -0.5113, -0.0817, -0.5810, -1.9058,  0.7106],
          [-1.9801,  1.1653,  0.3585, -1.7316,  1.3211,  0.6694,  0.5069],
          [ 0.6119,  0.5630, -0.7654,  2.5973, -1.9710,  0.6882,  0.8126],
          [ 0.7336,  0.9963,  1.6773, -1.0295,  1.2635,  2.6903, -0.3320]],

         [[-1.3640, -0.0372,  1.4103, -0.1755,  1.1940, -0.2636,  0.3186],
          [-1.9607,  0.4000,  1.9175, -1.7362,  0.4124,  1.0521, -0.6050],
          [-1.7676,  2.4814,  2.3001,  0.1844, -0.6243,  2.2404,  0.7621],
          [-1.7687, -0.2826,  0.0179, -0.3561,  0.3650,  0.8226,  0.6107]],

         [[-0.4571,  1.1342,  0.5213, -0.5604, -1.0119,  0.4766,  0.1784],
          [ 0.0692, -0.2373, -1.4442, -0.3252, -2.1491,  2.8538,  0.1160],
          [-2.1918,  0.3397,  1.0173, -0.2993, -2.6630,  0.7206,  0.1622],
          [-0.5263,  1.0427,  0.5387,

# Advanced version (with padding, stride, bias)

## Convolution (N-batch=1, N-inchannels=2, N-outchannels=3, padding=(ks-1)//2, stride=2)
Note: The following code works with n_bch>=2.

In [56]:
n_bch,in_ch,out_ch,ks,stride= 1,2,3,3,2
padding=(ks-1)//2
padding_mode='zeros'
valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
factory_kwargs= {'device': None, 'dtype': None}
img= np.round(np.random.uniform(0,1,size=(n_bch,in_ch,6,10)),2)
w= np.round(np.random.uniform(-1,1,size=(out_ch,in_ch,ks,ks)),2)
b= np.round(np.random.uniform(-1,1,size=(out_ch,)),2)
# self.bias= Parameter(torch.empty(out_channels, **factory_kwargs))
# self.register_parameter('bias', None)
print(f'img= {img} ({img.shape})')
print(f'w= {w} ({w.shape})')
print(f'b= {b} ({b.shape})')

img= [[[[0.97 0.88 0.33 0.56 0.96 0.28 0.76 0.06 0.02 0.62]
   [0.1  0.5  0.85 0.32 0.25 0.31 0.1  0.96 0.6  0.75]
   [0.03 0.72 0.6  0.81 0.95 0.67 0.24 0.73 0.74 0.78]
   [0.9  0.55 0.08 0.91 0.22 0.16 0.16 0.73 0.02 0.11]
   [0.4  0.27 0.73 0.86 0.25 0.26 0.38 0.37 0.11 0.71]
   [0.84 0.44 0.03 0.76 0.25 0.95 0.65 0.8  0.15 0.78]]

  [[0.52 0.01 0.   0.05 0.77 0.86 1.   0.71 0.43 0.3 ]
   [0.36 0.89 0.47 0.79 0.23 0.06 0.84 0.39 0.83 0.08]
   [0.27 0.96 0.15 0.31 0.67 0.51 0.91 0.14 0.09 0.41]
   [0.01 0.32 0.14 0.55 0.49 0.29 0.43 0.35 0.59 0.34]
   [0.47 0.98 0.97 0.4  0.58 0.25 0.47 0.82 0.89 0.77]
   [0.54 0.44 0.2  0.2  0.72 0.37 0.75 0.91 0.65 0.64]]]] ((1, 2, 6, 10))
w= [[[[-0.25 -0.42  0.4 ]
   [-0.49  0.75 -0.85]
   [ 0.45  0.22 -0.62]]

  [[ 0.58  0.37  0.56]
   [-0.02  0.58  0.63]
   [-0.13 -0.48  0.81]]]


 [[[ 0.4  -0.87  0.65]
   [ 0.39  0.09 -0.22]
   [-0.89 -0.11 -0.4 ]]

  [[-0.18 -0.13  0.32]
   [ 0.26 -0.46  0.47]
   [-0.82 -0.16 -0.8 ]]]


 [[[ 0.77 -0.78 -0.97]


### Ground truth: torch.conv2d

In [57]:
# print(torch.nn.Conv2d(in_ch,out_ch,padding=padding,stride=stride,kernel_size=(ks,ks),bias=True).weight.shape)
img_torch= torch.from_numpy(img).float()
w_torch= torch.from_numpy(w).float()
b_torch= torch.from_numpy(b).float()
stride_= torch.nn.modules.utils._pair(stride)
padding_= torch.nn.modules.utils._pair(padding)
assert(padding_mode in valid_padding_modes)
# valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
# factory_kwargs= {'device': None, 'dtype': None}
if padding_mode=='zeros':
  conv_gt= torch.nn.functional.conv2d(img_torch, w_torch, b_torch, stride_, padding_)
else:
  img_torch= torch.nn.functional.pad(img_torch, tuple(x for x in reversed(padding_) for _ in range(2)), mode=padding_mode)
  conv_gt= torch.nn.functional.conv2d(img_torch, w_torch, b_torch, stride_, torch.nn.modules.utils._pair(0))
print(f'conv_gt= {conv_gt} ({conv_gt.shape})')

conv_gt= tensor([[[[ 1.2575,  0.5938,  1.7473,  1.5731,  0.3075],
          [ 1.7829,  1.1979,  1.9626,  1.0695,  1.1495],
          [ 1.7046,  1.6697,  0.8839,  2.4489,  1.4549]],

         [[-0.8814, -1.3877, -0.4227, -0.4648, -1.2804],
          [ 0.5107, -0.9642, -0.9943, -0.0307, -0.3449],
          [-0.3702,  0.0088, -0.7570, -1.2568, -1.4838]],

         [[-0.2464, -0.9405,  0.6403,  1.7610,  0.8318],
          [ 1.3089, -0.3375, -0.3856, -0.3134, -0.8664],
          [ 0.5074,  0.1467,  1.5369,  0.8652,  1.0076]]]]) (torch.Size([1, 3, 3, 5]))


### With Torch (wo torch.conv2d)

In [60]:
# print(torch.nn.Conv2d(2,3,bias=False,kernel_size=(3,3)).weight.shape)
img_torch= torch.from_numpy(img).float()
w_torch= torch.from_numpy(w).float()
b_torch= torch.from_numpy(b).float()
stride_= torch.nn.modules.utils._pair(stride)
padding_= torch.nn.modules.utils._pair(padding)
assert(padding_mode in valid_padding_modes)
# valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
view_h= int((img_torch.shape[2]+2*padding_[0]-1*(w_torch.shape[-2]-1)-1)/stride_[0]+1)
view_w= int((img_torch.shape[3]+2*padding_[1]-1*(w_torch.shape[-1]-1)-1)/stride_[1]+1)
view_shape= img_torch.shape[:2]+(view_h,view_w)+w_torch.shape[-2:]
print(f'view_shape= {view_shape}')
if padding_mode=='zeros':
  img_torch= torch.nn.functional.pad(img_torch, tuple(x for x in reversed(padding_) for _ in range(2)), mode='constant', value=0.0)
else:
  img_torch= torch.nn.functional.pad(img_torch, tuple(x for x in reversed(padding_) for _ in range(2)), mode=padding_mode)

strides= img_torch.stride()[:2]+tuple(np.array(img_torch.stride()[2:])*stride_)+img_torch.stride()[2:]
print(f'strides= {strides}')
sub_matrices= torch.as_strided(img_torch, view_shape, strides)
print(f'sub_matrices.shape= {sub_matrices.shape}')
# print(f'sub_matrices= {sub_matrices}')
conv_tch= torch.einsum('ijkl,bjmnkl->bimn', w_torch, sub_matrices)+b_torch.reshape((-1,1,1))
print(f'conv_tch= {conv_tch} ({conv_tch.shape}), {torch.all(torch.isclose(conv_gt,conv_tch))}')

view_shape= torch.Size([1, 2, 3, 5, 3, 3])
strides= (192, 96, 24, 2, 12, 1)
sub_matrices.shape= torch.Size([1, 2, 3, 5, 3, 3])
conv_tch= tensor([[[[ 1.2575,  0.5938,  1.7473,  1.5731,  0.3075],
          [ 1.7829,  1.1979,  1.9626,  1.0695,  1.1495],
          [ 1.7046,  1.6697,  0.8839,  2.4489,  1.4549]],

         [[-0.8814, -1.3877, -0.4227, -0.4648, -1.2804],
          [ 0.5107, -0.9642, -0.9943, -0.0307, -0.3449],
          [-0.3702,  0.0088, -0.7570, -1.2568, -1.4838]],

         [[-0.2464, -0.9405,  0.6403,  1.7610,  0.8318],
          [ 1.3089, -0.3375, -0.3856, -0.3134, -0.8664],
          [ 0.5074,  0.1467,  1.5369,  0.8652,  1.0076]]]]) (torch.Size([1, 3, 3, 5])), True


## Pixel-variant kernel (N-batch=1, N-inchannels=2, N-outchannels=3, padding=(ks-1)//2, stride=2)

In [106]:
n_bch,in_ch,out_ch,ks,stride= 1,2,3,3,2
padding=(ks-1)//2
padding_mode='zeros'
valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
factory_kwargs= {'device': None, 'dtype': None}
img= np.round(np.random.uniform(0,1,size=(n_bch,in_ch,6,10)),2)
kernel_shape= (ks,ks)
stride_= torch.nn.modules.utils._pair(stride)
padding_= torch.nn.modules.utils._pair(padding)
out_h= int((img.shape[2]+2*padding_[0]-1*(kernel_shape[0]-1)-1)/stride_[0]+1)
out_w= int((img.shape[3]+2*padding_[1]-1*(kernel_shape[1]-1)-1)/stride_[1]+1)
out_shape= (out_h,out_w)
w= np.round(np.random.uniform(-1,1,size=(out_ch,in_ch)+out_shape+kernel_shape),2)
b= np.round(np.random.uniform(-1,1,size=(out_ch,)+out_shape),2)
print(f'img= {img}, {img.shape}')
print(f'w= {w[0:1,0:1,0:1,0:1,:,:]}..., {w.shape}')
print(f'b= {b}..., {b.shape}')
res_gt= None

img= [[[[0.13 0.78 0.13 0.46 0.61 0.87 0.13 0.55 0.95 0.35]
   [0.1  0.03 0.73 0.81 0.12 0.3  0.51 0.13 0.86 0.95]
   [0.93 0.3  0.85 0.47 0.38 0.36 0.92 0.43 0.63 0.25]
   [0.93 0.04 0.81 0.17 0.29 0.9  0.47 0.48 0.43 0.66]
   [0.8  0.35 0.97 0.59 0.23 0.72 0.81 0.09 0.5  0.51]
   [0.57 0.73 0.67 0.61 0.51 0.57 0.66 0.13 0.61 0.94]]

  [[0.87 0.19 0.36 0.82 0.91 0.79 0.34 0.95 0.36 0.98]
   [1.   0.54 0.34 0.78 0.02 0.78 1.   0.92 0.73 0.6 ]
   [0.21 0.38 0.18 0.91 0.34 0.74 0.37 0.85 0.79 0.83]
   [0.17 0.28 0.27 0.44 0.5  0.81 0.25 0.16 0.97 0.72]
   [0.51 0.54 0.22 0.69 0.71 0.06 0.58 0.19 0.94 0.57]
   [0.85 0.86 0.53 0.22 0.25 0.54 0.26 0.95 0.51 0.46]]]], (1, 2, 6, 10)
w= [[[[[[ 0.16 -0.65  0.45]
     [ 1.    0.82  0.01]
     [ 0.51  0.06  0.14]]]]]]..., (3, 2, 3, 5, 3, 3)
b= [[[ 0.73  0.63  0.63 -0.83  0.54]
  [ 0.3  -0.38 -0.97 -0.35 -0.15]
  [ 0.24 -0.31  0.02  0.65  0.82]]

 [[ 0.65 -0.18  0.84  0.03  0.19]
  [-0.64  0.51 -0.69  0.76 -0.88]
  [ 0.24 -0.21  0.42  0.37  0.6 ]]

### Test: comparison with torch.conv2d

In [104]:
img_torch= torch.from_numpy(img).float()
w_base= np.round(np.random.uniform(-1,1,size=(out_ch,in_ch)+(1,1)+kernel_shape),2)
b_base= np.round(np.random.uniform(-1,1,size=(out_ch,)+(1,1)),2)
w= np.repeat(np.repeat(w_base,out_shape[1],axis=3),out_shape[0],axis=2)
b= np.repeat(np.repeat(b_base,out_shape[1],axis=2),out_shape[0],axis=1)
# print(f'img_torch= {img}, {img.shape}')
print(f'w= {w[0:1,0:1,0:1,0:1,:,:]}..., {w.shape}, {w_base.shape}')
print(f'b= {b}..., {b.shape}, {b_base.shape}')
w_base_torch= torch.from_numpy(w_base).reshape(w_base.shape[:2]+w_base.shape[4:]).float()
b_base_torch= torch.from_numpy(b_base).reshape(-1).float()
# b_base_torch= None
# print(f'w_base_torch= {w_base_torch}..., {w_base_torch.shape}')
# print(f'b_base_torch= {b_base_torch}..., {b_base_torch.shape}')
assert(padding_mode in valid_padding_modes)
# valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
# factory_kwargs= {'device': None, 'dtype': None}
if padding_mode=='zeros':
  res_gt= torch.nn.functional.conv2d(img_torch, w_base_torch, b_base_torch, stride_, padding_)
else:
  img_torch= torch.nn.functional.pad(img_torch, tuple(x for x in reversed(padding_) for _ in range(2)), mode=padding_mode)
  res_gt= torch.nn.functional.conv2d(img_torch, w_base_torch, b_base_torch, stride_, torch.nn.modules.utils._pair(0))
print(f'res_gt= {res_gt} ({res_gt.shape})')

w= [[[[[[ 0.03 -0.9  -0.43]
     [ 0.22 -0.47 -0.62]
     [-0.01 -0.74 -0.1 ]]]]]]..., (3, 2, 3, 5, 3, 3), (3, 2, 1, 1, 3, 3)
b= [[[ 0.22  0.22  0.22  0.22  0.22]
  [ 0.22  0.22  0.22  0.22  0.22]
  [ 0.22  0.22  0.22  0.22  0.22]]

 [[ 0.65  0.65  0.65  0.65  0.65]
  [ 0.65  0.65  0.65  0.65  0.65]
  [ 0.65  0.65  0.65  0.65  0.65]]

 [[-0.58 -0.58 -0.58 -0.58 -0.58]
  [-0.58 -0.58 -0.58 -0.58 -0.58]
  [-0.58 -0.58 -0.58 -0.58 -0.58]]]..., (3, 3, 5), (3, 1, 1)
res_gt= tensor([[[[-1.0330, -1.0288, -1.2870, -1.1536, -1.1331],
          [-1.8166, -3.0955, -0.9112, -2.0081, -3.2835],
          [-1.6533, -2.0332, -2.4586, -2.3034, -3.2587]],

         [[ 0.5351,  1.2488,  2.7257,  0.8186,  1.9153],
          [-0.4047, -0.5823,  0.7149,  0.8679, -0.3799],
          [-0.6785, -1.2071, -0.0506,  0.7773, -1.4025]],

         [[ 0.6793, -0.2399, -0.3794,  0.3111,  0.5520],
          [-0.2498, -1.8309, -1.8522, -0.1971, -2.2337],
          [-0.8209, -2.7957, -0.3283, -1.4841, -2.5652]]]]) (torch

### With Torch (wo torch.conv2d)

In [108]:
img_torch= torch.from_numpy(img).float()
w_torch= torch.from_numpy(w).float()
b_torch= torch.from_numpy(b).float()
view_shape= img_torch.shape[:2]+out_shape+kernel_shape
print(f'view_shape= {view_shape}')
if padding_mode=='zeros':
  img_torch= torch.nn.functional.pad(img_torch, tuple(x for x in reversed(padding_) for _ in range(2)), mode='constant', value=0.0)
else:
  img_torch= torch.nn.functional.pad(img_torch, tuple(x for x in reversed(padding_) for _ in range(2)), mode=padding_mode)
strides= img_torch.stride()[:2]+tuple(np.array(img_torch.stride()[2:])*stride_)+img_torch.stride()[2:]
print(f'strides= {strides}')
sub_matrices= torch.as_strided(img_torch, view_shape, strides)
print(f'sub_matrices.shape= {sub_matrices.shape}')
# print(f'sub_matrices= {sub_matrices}')
res_tch= torch.einsum('ijmnkl,bjmnkl->bimn', w_torch, sub_matrices)+b_torch
print(f'res_tch= {res_tch} ({res_tch.shape}), {torch.all(torch.isclose(res_tch,res_gt)) if res_gt is not None else "N/A"}')

view_shape= torch.Size([1, 2, 3, 5, 3, 3])
strides= (192, 96, 24, 2, 12, 1)
sub_matrices.shape= torch.Size([1, 2, 3, 5, 3, 3])
res_tch= tensor([[[[ 1.9333,  1.3860, -0.3644, -2.6978, -1.1301],
          [ 0.5294, -1.7122, -0.4804, -2.8249, -1.0778],
          [ 2.2135,  0.8713, -2.0962,  0.0470,  0.1462]],

         [[ 0.8562,  0.4286,  0.1503,  0.8111,  1.5066],
          [-0.7512,  1.2271,  2.1082, -2.6052,  1.4239],
          [ 1.3573, -3.5836, -0.0464, -1.4093,  0.1911]],

         [[ 0.0108, -0.4411, -1.1051,  4.2101,  1.4716],
          [-0.4302,  0.0436,  0.8183,  1.9086,  0.6192],
          [-0.9203, -0.4216,  0.5768,  1.5044,  0.8749]]]]) (torch.Size([1, 3, 3, 5])), N/A


### Modular version

In [116]:
'''
in_shape: Tuple or list of (in_channels, in_h, in_w).
'''
class TLocallyConnected2d(torch.nn.Module):
  def __init__(self, in_shape, out_channels, kernel_size, stride=1, padding=0, bias=True, padding_mode='zeros', device=None, dtype=None):
    super(TLocallyConnected2d, self).__init__()
    self.kernel_size= torch.nn.modules.utils._pair(kernel_size)
    self.stride= torch.nn.modules.utils._pair(stride)
    self.padding= torch.nn.modules.utils._pair(padding)
    out_h= int((in_shape[1]+2*self.padding[0]-1*(self.kernel_size[0]-1)-1)/self.stride[0]+1)
    out_w= int((in_shape[2]+2*self.padding[1]-1*(self.kernel_size[1]-1)-1)/self.stride[1]+1)
    self.out_shape= (out_h,out_w)
    valid_padding_modes= {'zeros', 'reflect', 'replicate', 'circular'}
    assert(padding_mode in valid_padding_modes)
    self.padding_mode= padding_mode
    factory_kwargs= dict(device=device, dtype=dtype)
    self.weight= torch.nn.Parameter(torch.empty((out_channels,in_shape[0])+self.out_shape+self.kernel_size, **factory_kwargs))
    if bias:
      self.bias= torch.nn.Parameter(torch.empty((out_channels,)+self.out_shape, **factory_kwargs))
    else:
      self.register_parameter('bias', None)

  def forward(self, x):
    view_shape= x.shape[:2]+self.out_shape+self.kernel_size
    if self.padding_mode=='zeros':
      x= torch.nn.functional.pad(x, tuple(p for p in reversed(self.padding) for _ in range(2)), mode='constant', value=0.0)
    else:
      x= torch.nn.functional.pad(x, tuple(p for p in reversed(self.padding) for _ in range(2)), mode=padding_mode)
    strides= x.stride()[:2]+tuple(np.array(x.stride()[2:])*self.stride)+x.stride()[2:]
    sub_matrices= torch.as_strided(x, view_shape, strides)
    x= torch.einsum('ijmnkl,bjmnkl->bimn', self.weight, sub_matrices)
    return x if self.bias is None else x+self.bias

lc= TLocallyConnected2d(img.shape[1:], out_ch, ks, stride=stride, padding=padding, bias=True, padding_mode=padding_mode)
lc.weight.data= torch.from_numpy(w).float()
lc.bias.data= torch.from_numpy(b).float()
img_torch= torch.from_numpy(img).float()
res_lc= lc(img_torch)
print(f'res_lc= {res_lc} ({res_lc.shape}), {torch.all(torch.isclose(res_tch,res_lc))}')

res_lc= tensor([[[[ 1.9333,  1.3860, -0.3644, -2.6978, -1.1301],
          [ 0.5294, -1.7122, -0.4804, -2.8249, -1.0778],
          [ 2.2135,  0.8713, -2.0962,  0.0470,  0.1462]],

         [[ 0.8562,  0.4286,  0.1503,  0.8111,  1.5066],
          [-0.7512,  1.2271,  2.1082, -2.6052,  1.4239],
          [ 1.3573, -3.5836, -0.0464, -1.4093,  0.1911]],

         [[ 0.0108, -0.4411, -1.1051,  4.2101,  1.4716],
          [-0.4302,  0.0436,  0.8183,  1.9086,  0.6192],
          [-0.9203, -0.4216,  0.5768,  1.5044,  0.8749]]]],
       grad_fn=<AddBackward0>) (torch.Size([1, 3, 3, 5])), True


In [141]:
fc= torch.nn.Linear(torch.flatten(lc(img_torch)).shape[0], 1)
net= torch.nn.Sequential(lc, torch.nn.Flatten(), fc)
# net(img_torch)
torchinfo.summary(net)
opt= torch.optim.Adam(net.parameters())
f_loss= torch.nn.MSELoss()
opt.zero_grad()
pred= net(img_torch)
loss= f_loss(pred, torch.tensor([[0.0]]))
loss.backward()
print(f'pred= {pred}')
print(f'loss= {loss}')
print(f'fc.weight.shape= {fc.weight.shape}')
print(f'fc.weight.grad.shape= {fc.weight.grad.shape}')
print(f'fc.bias.shape= {fc.bias.shape}')
print(f'fc.bias.grad.shape= {fc.bias.grad.shape}')
print(f'lc.weight.shape= {lc.weight.shape}')
print(f'lc.weight.grad.shape= {lc.weight.grad.shape}')
print(f'lc.bias.shape= {lc.bias.shape}')
print(f'lc.bias.grad.shape= {lc.bias.grad.shape}')

pred= tensor([[0.2714]], grad_fn=<AddmmBackward>)
loss= 0.07367976754903793
fc.weight.shape= torch.Size([1, 45])
fc.weight.grad.shape= torch.Size([1, 45])
fc.bias.shape= torch.Size([1])
fc.bias.grad.shape= torch.Size([1])
lc.weight.shape= torch.Size([3, 2, 3, 5, 3, 3])
lc.weight.grad.shape= torch.Size([3, 2, 3, 5, 3, 3])
lc.bias.shape= torch.Size([3, 3, 5])
lc.bias.grad.shape= torch.Size([3, 3, 5])
