In [1]:
%pip install git+https://github.com/tinygrad/tinygrad.git


Collecting git+https://github.com/tinygrad/tinygrad.git
  Cloning https://github.com/tinygrad/tinygrad.git to /private/var/folders/x4/ygyvps6n5rdf76sm9_9t__gw0000gn/T/pip-req-build-cljughaa
  Running command git clone --filter=blob:none --quiet https://github.com/tinygrad/tinygrad.git /private/var/folders/x4/ygyvps6n5rdf76sm9_9t__gw0000gn/T/pip-req-build-cljughaa
  Resolved https://github.com/tinygrad/tinygrad.git to commit e2b380b743a3e938a8505a4df652765c9dae74ce
  Preparing metadata (setup.py) ... [?25ldone
[?25h
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.3.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [2]:
from tinygrad import Device
Device.DEFAULT = "CLANG" 
print(Device.DEFAULT)

CLANG


In [3]:
from tinygrad import Tensor, nn

class Model:
  def __init__(self):
    self.l1 = nn.Conv2d(1, 32, kernel_size=(3,3))
    self.l2 = nn.Conv2d(32, 64, kernel_size=(3,3))
    self.l3 = nn.Linear(1600, 10)

  def __call__(self, x:Tensor) -> Tensor:
    x = self.l1(x).relu().max_pool2d((2,2))
    x = self.l2(x).relu().max_pool2d((2,2))
    return self.l3(x.flatten(1).dropout(0.5))

In [4]:
from tinygrad.nn.datasets import mnist
X_train, Y_train, X_test, Y_test = mnist()
print(X_train.shape, X_train.dtype, Y_train.shape, Y_train.dtype)
# (60000, 1, 28, 28) dtypes.uchar (60000,) dtypes.uchar

(60000, 1, 28, 28) dtypes.uchar (60000,) dtypes.uchar


In [5]:
model = Model()
acc = (model(X_test).argmax(axis=1) == Y_test).mean()
# NOTE: tinygrad is lazy, and hasn't actually run anything by this point
print(acc.item())  # ~10% accuracy, as expected from a random model

0.09709999710321426


In [6]:
optim = nn.optim.Adam(nn.state.get_parameters(model))
batch_size = 128
def step():
  Tensor.training = True  # makes dropout work
  samples = Tensor.randint(batch_size, high=X_train.shape[0])
  X, Y = X_train[samples], Y_train[samples]
  optim.zero_grad()
  loss = model(X).sparse_categorical_crossentropy(Y).backward()
  optim.step()
  return loss

In [7]:
import timeit
timeit.repeat(step, repeat=5, number=1)

[6.141228911001235,
 3.917855568928644,
 3.409235392929986,
 3.7306120509747416,
 4.309540601912886]

In [8]:
from tinygrad import GlobalCounters, Context
GlobalCounters.reset()
with Context(DEBUG=2): step()

scheduled 45 kernels
*** CLANG      1 E_[90mn11[0m                                     arg  1 mem  0.06 GB tm      6.07us/     0.01ms (     0.00 GFLOPS    0.0|0.0     GB/s) ['__imul__']
*** CLANG      2 E_[90mn12[0m                                     arg  1 mem  0.06 GB tm      2.38us/     0.01ms (     0.00 GFLOPS    0.0|0.0     GB/s) ['__imul__']
*** CLANG      3 E_[90mn6[0m                                      arg  1 mem  0.06 GB tm      2.74us/     0.01ms (     0.00 GFLOPS    0.0|0.0     GB/s) ['randint']
*** CLANG      4 r_[34m20000[0m[90m_[0m[31m15000[0m[90m_[0m[33m3[0m[90m_[0m[35m4[0m[90m[0m                         arg  1 mem  0.06 GB tm     28.80us/     0.04ms (     0.00 GFLOPS    8.3|8.3     GB/s) ['__getitem__']
*** CLANG      5 r_[34m10[0m[90m_[0m[35m10[0m[90mn1[0m                                 arg  1 mem  0.06 GB tm      2.55us/     0.04ms (     0.14 GFLOPS    0.0|0.0     GB/s) ['sparse_categorical_crossentropy']
*** CLANG      6 E_[90mn10[

In [9]:
from tinygrad import TinyJit
jit_step = TinyJit(step)

In [10]:
import timeit
timeit.repeat(jit_step, repeat=5, number=1)

[3.9964950480498374,
 3.7686408278532326,
 3.5536490818485618,
 4.27838852093555,
 3.5634236389305443]

In [11]:
for step in range(1500):
  loss = jit_step()
  if step%100 == 0:
    Tensor.training = False
    acc = (model(X_test).argmax(axis=1) == Y_test).mean().item()
    print(f"step {step:4d}, loss {loss.item():.2f}, acc {acc*100.:.2f}%")

step    0, loss 4.21, acc 71.44%
step  100, loss 0.39, acc 94.09%
step  200, loss 0.31, acc 96.27%
step  300, loss 0.13, acc 97.08%
step  400, loss 0.25, acc 97.39%
step  500, loss 0.13, acc 97.62%
step  600, loss 0.14, acc 97.65%
step  700, loss 0.12, acc 98.13%
step  800, loss 0.16, acc 97.71%
step  900, loss 0.27, acc 98.14%
step 1000, loss 0.17, acc 97.92%
step 1100, loss 0.26, acc 98.15%
step 1200, loss 0.18, acc 98.00%
step 1300, loss 0.13, acc 97.88%
step 1400, loss 0.13, acc 98.14%
