In [1]:
# an implementation of Self-Compressing Neural Networks
# https://arxiv.org/pdf/2301.13142
import os
import tqdm
#os.environ["DEBUG"] = '2'
#os.environ["JITBEAM"] = '2'   # make tinygrad fast, first run is slow but then it's fast
from tinygrad.nn.datasets import mnist
X_train, Y_train, X_test, Y_test = mnist()

In [2]:
from tinygrad import Tensor, nn, TinyJit, dtypes
from tinygrad.helpers import prod
import math, functools

class QConv2d:
  def __init__(self, in_channels, out_channels, kernel_size):
    self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
    scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
    self.weight = Tensor.uniform(out_channels, in_channels, *self.kernel_size, low=-scale, high=scale)
    self.e = Tensor.full((out_channels, 1, 1, 1), -8.)
    self.b = Tensor.full((out_channels, 1, 1, 1), 2.)  # start with 2 bits per weight

  def qbits(self):
    return self.b.relu().sum() * prod(self.weight.shape[1:])

  def qweight(self):
        return Tensor.minimum(Tensor.maximum(2**-self.e * self.weight, -2**(self.b.relu()-1)), 2**(self.b.relu()-1) - 1)
  
  def __call__(self, x:Tensor):
    qw = self.qweight()
    print(self.weight.shape)
    print(qw.shape)
    w = (qw.round() - qw).detach() + qw  # straight through estimator
    print((2**self.e * w).shape)
    return x.conv2d(2**self.e * w)

class Model:
  def __init__(self):
    self.layers: List[Callable[[Tensor], Tensor]] = [ 
      QConv2d(1, 32, 5), Tensor.relu,
      QConv2d(32, 32, 5), Tensor.relu,
      nn.BatchNorm(32, affine=False, track_running_stats=False),
      Tensor.max_pool2d,
      QConv2d(32, 64, 3), Tensor.relu,
      QConv2d(64, 64, 3), Tensor.relu,
      nn.BatchNorm(64, affine=False, track_running_stats=False), 
      Tensor.max_pool2d,
      # TODO: do we really need this reshape?
      lambda x: x.flatten(1).reshape(-1, 576, 1, 1),
      QConv2d(576, 10, 1), lambda x: x.flatten(1)]

  def __call__(self, x:Tensor) -> Tensor:
      return x.sequential(self.layers)

In [3]:
model = Model()
opt = nn.optim.Adam(nn.state.get_parameters(model))
test_accs, bytes_used = [], []
weight_count = sum(t.numel() for t in opt.params)
len(opt.params), weight_count

AttributeError: module 'tinygrad.nn' has no attribute 'BatchNorm'

In [None]:

def train_step() -> Tensor:
  with Tensor.train():
    samples = Tensor.randint(512, high=X_train.shape[0])
    loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples])
    Q = functools.reduce(lambda x,y: x+y, [l.qbits() for l in model.layers if isinstance(l, QConv2d)]) / weight_count
    loss = loss + 0.05*Q  
    loss.backward()
  return loss, Q


def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean()*100

Tensor.training = True
from tqdm import trange
test_acc = float('nan')
for i in (t:=trange(20000)):
  loss, Q = train_step()
  print(loss, Q)
  model_bytes = Q.item()/8*weight_count
  print(model_bytes)
  if i%10 == 9:
    test_acc = get_test_acc().item()
  test_accs.append(test_acc)
  bytes_used.append(model_bytes)
  t.set_description(f"loss: {loss.item():6.2f}  bytes: {model_bytes:.1f}  acc: {test_acc:5.2f}%")

  0%|          | 0/20000 [00:00<?, ?it/s]

(32, 1, 5, 5)
(32, 1, 5, 5)
(32, 1, 5, 5)
(32, 32, 5, 5)
(32, 32, 5, 5)
(32, 32, 5, 5)
(64, 32, 3, 3)
(64, 32, 3, 3)
(64, 32, 3, 3)
(64, 64, 3, 3)
(64, 64, 3, 3)
(64, 64, 3, 3)
(10, 576, 1, 1)
(10, 576, 1, 1)
(10, 576, 1, 1)
<Tensor <LB NV () float (<BinaryOps.ADD: 1>, None)> on NV with grad <LB NV () float (<MetaOps.CONST: 2>, None)>> <Tensor <LB NV () float (<MetaOps.CONST: 2>, None)> on NV with grad <LB NV () float (<MetaOps.CONST: 2>, None)>>


  0%|          | 0/20000 [00:10<?, ?it/s]


RuntimeError: wait_result: 10000 ms TIMEOUT!