In [9]:
!pip3 install torch torchvision



In [10]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import math
import time

In [11]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [12]:
class LConv(nn.Module):
  def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, padding_mode='zeros'):
    super().__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.kernel_size = kernel_size if type(kernel_size) in [tuple, list] and len(kernel_size) == 2 else (kernel_size, kernel_size)
    self.stride = stride if type(stride) in [tuple, list] and len(stride) == 2 else (stride, stride)
    self.padding = padding if type(padding) in [tuple, list] and len(padding) == 2 else (padding, padding)
    self.dilation = dilation if type(dilation) in [tuple, list] and len(dilation) == 2 else (dilation, dilation)
    self.groups = groups
    self.padding_mode = 'constant' if padding_mode=='zeros' else padding_mode
    abc = torch.Tensor(3 * in_channels // groups, out_channels)
    self.abc = nn.Parameter(abc)
    if bias:
      bias_ = torch.Tensor(out_channels)
      self.bias = nn.Parameter(bias_)
    else:
      self.bias = None
    self.reset_parameters()
  
  def reset_parameters(self):
    nn.init.kaiming_uniform_(self.abc, a=math.sqrt(5))
    if self.bias is not None:
      fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.abc)
      bound = 1 / math.sqrt(fan_in)
      nn.init.uniform_(self.bias, -bound, bound)

  def convert_to_conv(self):
    device = self.abc.device
    conv = nn.Conv2d(self.in_channels, self.out_channels, self.kernel_size, self.stride, self.padding, self.dilation, self.groups, (self.bias is not None), self.padding_mode if self.padding_mode!='constant' else 'zeros')
    a, b, c = self.abc.view(3, self.in_channels // self.groups, self.out_channels, 1, 1).transpose(1, 2).repeat(1, 1, 1, self.kernel_size[0], self.kernel_size[1])
    xi = torch.arange(self.kernel_size[0], device=device).float().sub_((self.kernel_size[0]-1)/2.0).view(self.kernel_size[0], 1).repeat(1, self.kernel_size[1])
    yi = torch.arange(self.kernel_size[1], device=device).float().sub_((self.kernel_size[1]-1)/2.0).view(1, self.kernel_size[1]).repeat(self.kernel_size[0], 1)
    conv.weight = nn.Parameter(a * xi + b * yi + c)
    conv.bias = self.bias
    return conv
  
  def convert_from_conv(self, conv):
    device = self.abc.device
    w = conv.weight.transpose(0, 1).to(device)
    xi = torch.arange(self.kernel_size[0], device=device).float().sub_((self.kernel_size[0]-1)/2.0).view(self.kernel_size[0], 1).repeat(1, self.kernel_size[1])
    yi = torch.arange(self.kernel_size[1], device=device).float().sub_((self.kernel_size[1]-1)/2.0).view(1, self.kernel_size[1]).repeat(self.kernel_size[0], 1)
    oi = torch.ones_like(xi, device=device)
    a = (xi * w).sum((2, 3)) / ((xi * xi).sum() + 1e-9)
    b = (yi * w).sum((2, 3)) / ((yi * yi).sum() + 1e-9)
    c = (oi * w).sum((2, 3)) / ((oi * oi).sum() + 1e-9)
    self.abc = nn.Parameter(torch.cat((a, b, c), 0))
    if conv.bias is None:
      self.bias = None
    else:
      self.bias = conv.bias.to(device)

  def forward(self, data):
    k0, k1 = self.kernel_size
    s0, s1 = self.stride
    p0, p1 = self.padding
    d0, d1 = self.dilation
    g = self.groups
    dk0, dk1 = k0 * d0, k1 * d1
    device = data.device
    ph, pw = data.size(2) + 2 * p0 + d0, data.size(3) + 2 * p1 + d1
    data_p = F.pad(F.pad(data, (p1, p1, p0, p0), self.padding_mode), (d1, (d1 - pw % d1) % d1, d0, (d0 - ph % d0) % d0))
    b, c, h, w = data_p.size()
    xi = torch.arange(h, device=device).float().div_(d0).view(h, 1).repeat(1, w)
    yi = torch.arange(w, device=device).float().div_(d1).view(1, w).repeat(h, 1)
    data_c = torch.cat((data_p * xi, data_p * yi, data_p), 1)
    data_c = data_c.view(b, 3 * c, h // d0, d0, w).cumsum_(2).view(b, 3 * c, h, w)[:, :, :ph]
    data_c = data_c[:, :, dk0:ph:s0] - data_c[:, :, 0:ph-dk0:s0]
    data_c = data_c.view(b, 3 * c, -1, w // d1, d1).cumsum_(3).view(b, 3 * c, -1, w)[:, :, :, :pw]
    data_c = data_c[:, :, :, dk1:pw:s1] - data_c[:, :, :, 0:pw-dk1:s1]

    data_c[:, 0:c].sub_(data_c[:, 2*c:] * (xi[0:ph-dk0:s0, 0:pw-dk1:s1] + (k0+1)/2))
    data_c[:, c:2*c].sub_(data_c[:, 2*c:] * (yi[0:ph-dk0:s0, 0:pw-dk1:s1] + (k1+1)/2))
    _, _, h, w = data_c.size()
    data_c = data_c.view(b, 3, g, c // g, h, w)
    data_c = torch.cat((data_c[:, 0], data_c[:, 1], data_c[:, 2]), 2).view(b, 3 * c, h, w)
    data_c.transpose_(1, 3)

    output_list = []
    ni, no = 3 * self.in_channels // g, self.out_channels // g
    for i in range(g):
      output_list.append(torch.matmul(data_c[:, :, :, i*ni:(i+1)*ni], self.abc[:, i*no:(i+1)*no]))
    output = torch.cat(output_list, 3)
    return (output.transpose(1, 3) if self.bias is None else output.add_(self.bias).transpose(1, 3))

In [13]:
def valid_test(data, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode):
  conv = nn.Conv2d(data.size()[1], out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode)
  lconv = LConv(data.size()[1], out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode)
  conv_from_lconv = lconv.convert_to_conv()
  conv_out = conv(data)
  lconv_out = lconv(data)
  conv_from_lconv_out = conv_from_lconv(data)
  lconv_from_conv_from_lconv = LConv(data.size()[1], out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode)
  lconv_from_conv_from_lconv.convert_from_conv(conv_from_lconv)
  lconv_from_conv_from_lconv_out = lconv_from_conv_from_lconv(data)
  assert conv_out.size() == lconv_out.size()
  assert lconv_out.size() == conv_from_lconv_out.size()
  assert conv_from_lconv_out.size() == lconv_from_conv_from_lconv_out.size()
  assert torch.abs(conv_from_lconv_out - lconv_out).max() < 1e-3
  assert torch.abs(lconv_from_conv_from_lconv_out - lconv_out).max() < 1e-3

with torch.no_grad():
  valid_test(data=torch.randn(1, 1, 32, 32), out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True, padding_mode='zeros')
  valid_test(data=torch.randn(2, 2, 31, 33), out_channels=33, kernel_size=5, stride=(2, 3), padding=2, dilation=(1, 1), groups=1, bias=True, padding_mode='reflect')
  valid_test(data=torch.randn(3, 3, 23, 45), out_channels=64, kernel_size=7, stride=1, padding=(2,5), dilation=(2, 3), groups=1, bias=False, padding_mode='replicate')
  valid_test(data=torch.randn(4, 4, 16, 16), out_channels=128, kernel_size=(3,3), stride=4, padding=(3,2), dilation=(2, 2), groups=2, bias=True, padding_mode='circular')
  valid_test(data=torch.randn(5, 64, 16, 48), out_channels=16, kernel_size=(3,5), stride=(1, 5), padding=3, dilation=2, groups=2, bias=True, padding_mode='zeros')
  valid_test(data=torch.randn(6, 16, 64, 32), out_channels=64, kernel_size=(5,7), stride=2, padding=(3,3), dilation=3, groups=1, bias=False, padding_mode='reflect')
  valid_test(data=torch.randn(7, 8, 64, 32), out_channels=6, kernel_size=(3,7), stride=1, padding=1, dilation=(1, 5), groups=2, bias=True, padding_mode='replicate')
  valid_test(data=torch.randn(8, 4, 64, 32), out_channels=32, kernel_size=(1,10), stride=(4, 2), padding=(0,5), dilation=(3, 3), groups=4, bias=True, padding_mode='circular')
  valid_test(data=torch.randn(9, 2, 64, 32), out_channels=64, kernel_size=(1,1), stride=1, padding=1, dilation=1, groups=2, bias=False, padding_mode='zeros')
  valid_test(data=torch.randn(10, 12, 64, 32), out_channels=24, kernel_size=1, stride=2, padding=1, dilation=(2, 2), groups=6, bias=True, padding_mode='reflect')
  valid_test(data=torch.randn(11, 16, 64, 32), out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1, groups=4, bias=True, padding_mode='replicate')
  valid_test(data=torch.randn(12, 16, 64, 32), out_channels=4, kernel_size=2, stride=(1, 2), padding=1, dilation=(2, 3), groups=4, bias=False, padding_mode='circular')
  valid_test(data=torch.randn(13, 12, 64, 32), out_channels=16, kernel_size=(1, 1), stride=(3, 2), padding=1, dilation=(1, 2), groups=2, bias=True, padding_mode='zeros')
  valid_test(data=torch.randn(14, 24, 64, 32), out_channels=32, kernel_size=(1, 5), stride=(2, 3), padding=1, dilation=3, groups=8, bias=True, padding_mode='reflect')
  valid_test(data=torch.randn(15, 23, 64, 32), out_channels=1, kernel_size=(2, 3), stride=(3, 4), padding=1, dilation=2, groups=1, bias=False, padding_mode='replicate')
  valid_test(data=torch.randn(16, 3, 64, 32), out_channels=7, kernel_size=(3, 4), stride=3, padding=1, dilation=1, groups=1, bias=True, padding_mode='circular')

In [14]:
kernel_size = 3
dim = 64
class Net(nn.Module):
  def __init__(self, dim=[], LConv_size=0, dataset='MNIST'):
    super(Net, self).__init__()
    if dataset == 'MNIST':
      shape = (1, 28, 28)
    elif dataset == 'CIFAR10':
      shape = (3, 32, 32)
    layers = []
    lconvs = 0
    for i in range(len(dim)):
      if dim[i] == 'M' or dim[i] == 'm':
        layers.append(nn.MaxPool2d(kernel_size=2, stride=2))
        shape = (shape[0], shape[1] // 2, shape[2] // 2)
      else:
        if lconvs < LConv_size:
          lconvs += 1
          layers.append(LConv(shape[0], dim[i], kernel_size=3, padding=1))
        else:
          layers.append(nn.Conv2d(shape[0], dim[i], kernel_size=3, padding=1))
        shape = (dim[i], shape[1], shape[2])
        layers.append(nn.ReLU())
    layers.append(nn.Flatten())
    layers.append(nn.Linear(shape[0] * shape[1] * shape[2], 10, bias=True))
    self.model = nn.Sequential(*layers)

  def forward(self, x):
    out = self.model(x)
    return out

In [15]:
batch_size = 1000
transform = torchvision.transforms.ToTensor()

mnist_train = torchvision.datasets.MNIST(root='MNIST/',
                          train=True,
                          transform=transform,
                          download=True)
mnist_test = torchvision.datasets.MNIST(root='MNIST/',
                         train=False,
                         transform=transform,
                         download=True)
mnist_train_loader = torch.utils.data.DataLoader(dataset=mnist_train,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          drop_last=True)
mnist_test_loader = torch.utils.data.DataLoader(dataset=mnist_test,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          drop_last=True)

cifar10_train = torchvision.datasets.CIFAR10(root='CIFAR10/',
                          train=True,
                          transform=transform,
                          download=True)
cifar10_test = torchvision.datasets.CIFAR10(root='CIFAR10/',
                         train=False,
                         transform=transform,
                         download=True)
cifar10_train_loader = torch.utils.data.DataLoader(dataset=cifar10_train,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          drop_last=True)
cifar10_test_loader = torch.utils.data.DataLoader(dataset=cifar10_test,
                                          batch_size=batch_size,
                                          shuffle=True,
                                          drop_last=True)

Files already downloaded and verified
Files already downloaded and verified


In [16]:
for s in range(1, 29):
  conv = nn.Conv2d(1, 64, kernel_size=s, padding=0)
  tic = time.time()
  for data, label in mnist_train_loader:
    logit = conv(data)
  toc = time.time()
  print(s, toc - tic)

1 9.69036340713501
2 8.642049312591553
3 8.397260189056396
4 8.449869155883789
5 8.430787086486816
6 8.688027143478394
7 8.984625101089478
8 9.111706495285034
9 9.069044351577759
10 9.467499017715454
11 9.639098405838013
12 9.502684831619263
13 9.636601448059082
14 10.002632856369019
15 9.275612592697144
16 9.249860286712646
17 8.795132875442505
18 8.667381763458252
19 8.538036108016968
20 7.543653964996338
21 7.179687976837158
22 6.82577919960022
23 5.761156797409058
24 5.408653497695923
25 5.049642324447632
26 4.350834846496582
27 4.044261455535889
28 3.841463088989258


In [17]:
for s in range(1, 29):
  conv = LConv(1, 64, kernel_size=s, padding=0)
  tic = time.time()
  for data, label in mnist_train_loader:
    logit = conv(data)
  toc = time.time()
  print(s, toc - tic)

1 12.75806999206543
2 12.14669418334961
3 11.63218641281128
4 11.082746744155884
5 10.438928127288818
6 10.21586298942566
7 9.676838636398315
8 9.138279676437378
9 8.686416149139404
10 8.224821329116821
11 7.823349714279175
12 7.270512342453003
13 7.137585878372192
14 6.844459056854248
15 6.441086769104004
16 6.2011260986328125
17 5.858244180679321
18 5.648073196411133
19 5.408952713012695
20 5.1887712478637695
21 5.044458627700806
22 4.920837163925171
23 4.742007732391357
24 4.5799500942230225
25 4.483244895935059
26 4.449730157852173
27 4.4391560554504395
28 4.39461088180542


In [23]:
datas = []
for data, label in mnist_train_loader:
  datas.append(data.to(device))
for s in range(1, 29):
  conv = nn.Conv2d(1, 64, kernel_size=s, padding=0).to(device)
  tic = time.time()
  for data in datas:
    logit = conv(data)
  toc = time.time()
  print(s, toc - tic)

1 0.003435850143432617
2 0.003361225128173828
3 0.005074977874755859
4 0.004013776779174805
5 0.0033636093139648438
6 0.0033385753631591797
7 0.0033464431762695312
8 0.003398418426513672
9 0.0033898353576660156
10 0.0033440589904785156
11 0.004087686538696289
12 0.003576040267944336
13 0.0034589767456054688
14 0.003522157669067383
15 0.0035355091094970703
16 0.003570556640625
17 0.0037605762481689453
18 0.003644704818725586
19 0.0039522647857666016
20 0.0038242340087890625
21 0.004037141799926758
22 0.0038001537322998047
23 0.004098176956176758
24 0.0039882659912109375
25 0.003922700881958008
26 0.003940105438232422
27 0.003376483917236328
28 0.003069162368774414


In [24]:
datas = []
for data, label in mnist_train_loader:
  datas.append(data.to(device))
for s in range(1, 29):
  conv = LConv(1, 64, kernel_size=s, padding=0).to(device)
  tic = time.time()
  for data in datas:
    logit = conv(data)
  toc = time.time()
  print(s, toc - tic)

1 0.17536544799804688
2 0.15608477592468262
3 0.14111757278442383
4 0.13695955276489258
5 0.1241922378540039
6 0.11946487426757812
7 0.10746645927429199
8 0.10353803634643555
9 0.0931847095489502
10 0.08848309516906738
11 0.07819652557373047
12 0.07440519332885742
13 0.06546640396118164
14 0.06223893165588379
15 0.05530810356140137
16 0.05131840705871582
17 0.04411816596984863
18 0.0406949520111084
19 0.03696894645690918
20 0.03330254554748535
21 0.0333712100982666
22 0.03294038772583008
23 0.032996416091918945
24 0.03665447235107422
25 0.035073041915893555
26 0.033084869384765625
27 0.03610706329345703
28 0.031485557556152344


In [25]:
def get_n_params(model):
  pp = 0
  for p in list(model.parameters()):
    nn = 1
    for s in list(p.size()):
      nn = nn*s
    pp += nn
  return pp
criterion = torch.nn.CrossEntropyLoss().to(device)

In [26]:
def test(dim, lr = 1e-5, dataset='MNIST'):
  EPOCH = 32
  print(dim)
  lconvs = 0
  for d in dim:
    if type(d) == int:
      lconvs += 1
  models = []
  optimizers = []
  for i in range(lconvs + 1):
    models.append(Net(dim=dim, LConv_size=i, dataset=dataset).to(device))
    optimizers.append(torch.optim.Adam(models[i].parameters(), lr=lr))
  
  text = ''
  time_list = []
  for i in range(lconvs + 1):
    time_list.append(0)
    text += str(get_n_params(models[i])) + ' '
  print(text)

  if dataset == 'MNIST':
    train_loader = mnist_train_loader
    test_loader = mnist_test_loader
    shape = (-1, 1, 28, 28)
  elif dataset == 'CIFAR10':
    train_loader = cifar10_train_loader
    test_loader = cifar10_test_loader
    shape = (-1, 3, 32, 32)

  for epoch in range(EPOCH):
    text = 'L' + str(epoch) + ' '
    for i in range(lconvs + 1):
      tic = time.time()
      avg_loss = 0
      for data, label in train_loader:
        data = data.to(device)
        label = label.to(device)

        optimizers[i].zero_grad()
        logit = models[i](data)
        loss = criterion(logit, label)
        loss.backward()
        optimizers[i].step()

        avg_loss += loss / len(train_loader)
      text += str(avg_loss.item()) + ' '
      toc = time.time()
      time_list[i] += toc - tic
    print(text)


    text = 'A' + str(epoch) + ' '
    with torch.no_grad():
      for i in range(lconvs + 1):
        acc = 0
        for data, label in test_loader:
          data = data.to(device)
          label = label.to(device)

          logit = models[i](data)
          acc += (torch.argmax(logit, 1) == label).float().mean() / len(test_loader)
        text += str(acc.item()) + ' '
    print(text)
  text = 'T '
  for i in range(lconvs + 1):
    text += str(time_list[i]) + ' '
  print(text)

In [28]:
test(dim=[32, 'M', 64], lr=1e-5, dataset='MNIST')

[32, 'M', 64]
144266 144074 131786 
L0 2.2644388675689697 2.265523910522461 2.070502519607544 
A0 0.5679000020027161 0.5642000436782837 0.6644999980926514 
L1 2.1703145503997803 2.143772602081299 1.6017491817474365 
A1 0.7264999747276306 0.7265999913215637 0.7778000831604004 
L2 2.028402328491211 1.962406873703003 1.246524691581726 
A2 0.7833000421524048 0.7736000418663025 0.8187000751495361 
L3 1.8403475284576416 1.7267169952392578 0.9935271739959717 
A3 0.8022000193595886 0.7966001033782959 0.841400146484375 
L4 1.6272786855697632 1.4736323356628418 0.8178498148918152 
A4 0.8152999877929688 0.8106000423431396 0.8633999824523926 
L5 1.4151701927185059 1.242247223854065 0.6940924525260925 
A5 0.82669997215271 0.8184999823570251 0.8758000731468201 
L6 1.2238329648971558 1.052269697189331 0.604385256767273 
A6 0.8313000202178955 0.8303000330924988 0.8842000365257263 
L7 1.0631545782089233 0.9052274823188782 0.5372681021690369 
A7 0.8431999683380127 0.8418000340461731 0.8913000822067261 


In [29]:
test(dim=[32, 'M', 64, 'M', 128], lr=1e-5, dataset='MNIST')

[32, 'M', 64, 'M', 128]
155402 155210 142922 93770 
L0 2.291227340698242 2.288919687271118 2.2134029865264893 2.1313703060150146 
A0 0.39569997787475586 0.40050002932548523 0.5562000870704651 0.5597000122070312 
L1 2.2554612159729004 2.244903326034546 1.9876396656036377 1.683115005493164 
A1 0.5772000551223755 0.5997000336647034 0.7512000203132629 0.7308000326156616 
L2 2.1885275840759277 2.165449857711792 1.694122552871704 1.3106306791305542 
A2 0.6762000322341919 0.67330002784729 0.7795001268386841 0.791400134563446 
L3 2.0735769271850586 2.027067184448242 1.3672256469726562 1.021250605583191 
A3 0.7099000215530396 0.7371001243591309 0.8050000071525574 0.8212000131607056 
L4 1.9004786014556885 1.8150393962860107 1.0818815231323242 0.8183654546737671 
A4 0.7376999855041504 0.7708000540733337 0.8285000920295715 0.8468001484870911 
L5 1.6742795705795288 1.5498476028442383 0.873160183429718 0.6783624291419983 
A5 0.762700080871582 0.7788000702857971 0.8462000489234924 0.8692000508308411 

In [30]:
test(dim=[64, 'M', 64, 'M', 64], lr=1e-5, dataset='MNIST')

[64, 'M', 64, 'M', 64]
105866 105482 80906 56330 
L0 2.29548716545105 2.2907986640930176 2.2586421966552734 2.593106746673584 
A0 0.2973000407218933 0.37389999628067017 0.22600002586841583 0.2524000108242035 
L1 2.266211986541748 2.2622673511505127 2.13334059715271 1.9902207851409912 
A1 0.3775000274181366 0.6514000296592712 0.6011000275611877 0.4189000129699707 
L2 2.2115397453308105 2.20575213432312 1.9506161212921143 1.7273613214492798 
A2 0.5446000099182129 0.6976000070571899 0.7198000550270081 0.5913000106811523 
L3 2.1098239421844482 2.101944923400879 1.6975197792053223 1.4915103912353516 
A3 0.6783000230789185 0.7208001017570496 0.7768000960350037 0.6744000315666199 
L4 1.9399456977844238 1.935772180557251 1.4075963497161865 1.2786290645599365 
A4 0.7358000874519348 0.7478000521659851 0.8026000261306763 0.7398000359535217 
L5 1.699922800064087 1.6979312896728516 1.138649821281433 1.089029312133789 
A5 0.773099958896637 0.7730000019073486 0.822700023651123 0.7838000655174255 
L6 

In [31]:
test(dim=[32, 'M', 64], lr=1e-5, dataset='CIFAR10')

[32, 'M', 64]
183242 182666 170378 
L0 2.2881786823272705 2.2688047885894775 2.2136547565460205 
A0 0.15870000422000885 0.19110000133514404 0.2371000200510025 
L1 2.253120183944702 2.1791539192199707 2.059377908706665 
A1 0.24230001866817474 0.2827000021934509 0.29670003056526184 
L2 2.206246852874756 2.087877035140991 1.9800856113433838 
A2 0.2784999907016754 0.3092000186443329 0.3313000202178955 
L3 2.146627902984619 2.0110530853271484 1.9284279346466064 
A3 0.3125000298023224 0.32420000433921814 0.3461000323295593 
L4 2.0837948322296143 1.9555678367614746 1.8881633281707764 
A4 0.32670003175735474 0.3319999873638153 0.35770002007484436 
L5 2.026672124862671 1.9175012111663818 1.8533142805099487 
A5 0.3338000178337097 0.3451000154018402 0.36980006098747253 
L6 1.9803508520126343 1.8873735666275024 1.8209021091461182 
A6 0.3425000011920929 0.3606000542640686 0.3790000379085541 
L7 1.9445195198059082 1.8609123229980469 1.7912633419036865 
A7 0.3514000177383423 0.364300012588501 0.38820

In [32]:
test(dim=[32, 'M', 64, 'M', 128], lr=1e-5, dataset='CIFAR10')

[32, 'M', 64, 'M', 128]
175178 174602 162314 113162 
L0 2.298668622970581 2.292186975479126 2.3043882846832275 2.6031157970428467 
A0 0.10020001232624054 0.12130001932382584 0.16050000488758087 0.125900000333786 
L1 2.2888619899749756 2.270289659500122 2.244802474975586 2.187079429626465 
A1 0.11970000714063644 0.17810001969337463 0.2223999947309494 0.21660000085830688 
L2 2.2731080055236816 2.236983060836792 2.175191640853882 2.0600218772888184 
A2 0.1924000084400177 0.22310000658035278 0.2825999855995178 0.27970001101493835 
L3 2.246833086013794 2.1848068237304688 2.0932505130767822 1.9701156616210938 
A3 0.22350001335144043 0.26740002632141113 0.3224000334739685 0.31690001487731934 
L4 2.2060885429382324 2.1186444759368896 2.017594337463379 1.9015218019485474 
A4 0.2681000232696533 0.29840004444122314 0.3255999982357025 0.34870001673698425 
L5 2.1499669551849365 2.0521085262298584 1.9561223983764648 1.8475371599197388 
A5 0.2833000123500824 0.30949997901916504 0.3451000452041626 0.3

In [33]:
test(dim=[64, 'M', 64, 'M', 64], lr=1e-5, dataset='CIFAR10')

[64, 'M', 64, 'M', 64]
116618 115466 90890 66314 
L0 2.2988245487213135 2.298510789871216 2.316396713256836 3.6952052116394043 
A0 0.10120001435279846 0.11730001121759415 0.15320000052452087 0.08879999816417694 
L1 2.290137767791748 2.2894480228424072 2.258042812347412 2.5749168395996094 
A1 0.12480000406503677 0.12360002100467682 0.19949999451637268 0.12240000069141388 
L2 2.2764639854431152 2.2750096321105957 2.2103395462036133 2.36326265335083 
A2 0.16740001738071442 0.1421000063419342 0.23450002074241638 0.1648000031709671 
L3 2.2526392936706543 2.247962474822998 2.1464483737945557 2.220160484313965 
A3 0.21489998698234558 0.21150001883506775 0.2671999931335449 0.20910002291202545 
L4 2.2120718955993652 2.201694965362549 2.081286907196045 2.119041681289673 
A4 0.258400022983551 0.24480001628398895 0.2973000109195709 0.2429000288248062 
L5 2.153547525405884 2.1337413787841797 2.024365186691284 2.045530080795288 
A5 0.2775000333786011 0.2760000228881836 0.3111000061035156 0.274500012

In [34]:
test(dim=[16, 'M', 32, 'M', 64], lr=1e-5, dataset='MNIST')

[16, 'M', 32, 'M', 64]
54666 54570 51498 39210 
L0 2.295424461364746 2.2874581813812256 2.2737529277801514 3.163780450820923 
A0 0.21180002391338348 0.3372000455856323 0.22040000557899475 0.12030000239610672 
L1 2.2794501781463623 2.2577221393585205 2.1458733081817627 2.14357590675354 
A1 0.5146000385284424 0.5499000549316406 0.5760000348091125 0.34690001606941223 
L2 2.2555205821990967 2.212529182434082 2.0055830478668213 1.8880336284637451 
A2 0.6375000476837158 0.6592000126838684 0.7080000638961792 0.5261000394821167 
L3 2.2194745540618896 2.143247127532959 1.8376859426498413 1.6878423690795898 
A3 0.6778000593185425 0.7138000130653381 0.7590000629425049 0.6177999973297119 
L4 2.1691391468048096 2.0461201667785645 1.6435965299606323 1.4966814517974854 
A4 0.7031000256538391 0.7414000630378723 0.7873001098632812 0.6782000660896301 
L5 2.10271954536438 1.9222577810287476 1.4423431158065796 1.317352294921875 
A5 0.732900083065033 0.7651000022888184 0.8059000372886658 0.7380000352859497

In [35]:
test(dim=[16, 'M', 32, 'M', 64], lr=1e-5, dataset='CIFAR10')

[16, 'M', 32, 'M', 64]
64554 64266 61194 48906 
L0 2.3018016815185547 2.3024485111236572 2.2887487411499023 3.0361685752868652 
A0 0.12950001657009125 0.1210000067949295 0.17820000648498535 0.13330000638961792 
L1 2.2980902194976807 2.288473129272461 2.227820873260498 2.3880581855773926 
A1 0.1680000126361847 0.14900001883506775 0.2370000183582306 0.16750000417232513 
L2 2.2932910919189453 2.2755069732666016 2.1688125133514404 2.28360652923584 
A2 0.16680000722408295 0.19350001215934753 0.26110002398490906 0.1931000053882599 
L3 2.2862727642059326 2.259021043777466 2.1050140857696533 2.202667474746704 
A3 0.19440001249313354 0.21950002014636993 0.28630003333091736 0.21530000865459442 
L4 2.276947498321533 2.237271785736084 2.0477821826934814 2.1403467655181885 
A4 0.20930001139640808 0.24780000746250153 0.3019000291824341 0.2345000058412552 
L5 2.2645230293273926 2.2107346057891846 2.0037834644317627 2.0916316509246826 
A5 0.2289000153541565 0.2639000117778778 0.31550002098083496 0.256