In [31]:
import tinygrad
import numpy as np
from tinygrad.extra.datasets import fetch_cifar
import time

start = time.monotonic()
X_train, Y_train = fetch_cifar(train=True)
X_train = X_train.reshape(50000, -1)
print(X_train.shape)
print(X_train[2,:])



(50000, 3072)
[1030.2646  1022.1686  1022.1686  ...  315.58603  315.58603  319.40884]


In [32]:
from tinygrad.nn import Linear, optim
from tinygrad.tensor import Tensor

class TiniestCIFAR:
    def __init__(self):
        self.l1 = Linear(3072, 3072, bias=False)
        self.l2 = Linear (3072, 512, bias=False)
        self.l3 = Linear(512,10, bias=False)

    def __call__(self, x):
        x=self.l1(x).leakyrelu()
        x=self.l2(x).leakyrelu()
        return x.log_softmax()
    
net = TiniestCIFAR()


In [35]:
from tinygrad.nn.optim import SGD
opt = SGD([net.l1.weight, net.l2.weight], lr=3e-4)
from extra.training import sparse_categorical_crossentropy


def cross_entropy(out, Y):
  num_classes = out.shape[-1]
  YY = Y.flatten().astype(np.int32)
  y = np.zeros((YY.shape[0], num_classes), np.float32)
  y[range(y.shape[0]),YY] = -1.0*num_classes
  y = y.reshape(list(Y.shape)+[num_classes])
  y = Tensor(y)
  return out.mul(y).mean()

In [49]:
#run the model
BS=5
for step in range (5):
    samp = np.random.randint(0, X_train.shape[0], BS)
    batch = Tensor(X_train[samp], requires_grad=False)
    labels = Y_train[samp]
    out = net(batch)
    print(out.numpy())

    loss = cross_entropy(out, labels)
    opt.zero_grad()
    loss.backward()
    opt.step()

    pred = np.argmax(out.numpy(), axis=-1)
    print(pred)
    acc = (pred == labels).mean()
    #print (acc)
    if step % 100 == 0:
        print(f'step{step} | Loss: {loss.numpy()} | Accuracy: {acc}')    



[[-2.16672500e+04 -9.75807500e+04 -2.81981000e+05 ... -1.30394344e+08
  -1.28366750e+05 -2.00229580e+07]
 [-6.65920000e+05 -5.83253750e+05 -7.80757000e+05 ... -7.70481200e+07
  -7.52712750e+05 -1.24046540e+07]
 [-8.25180000e+04 -5.09902000e+05  0.00000000e+00 ... -5.32649320e+07
  -3.50760250e+05 -8.43354600e+06]
 [-1.39162000e+05 -1.46896750e+05 -3.53686000e+05 ... -7.47527120e+07
  -1.79480250e+05 -1.15683230e+07]
 [-2.56251000e+05 -5.27521375e+05 -2.26377375e+05 ... -6.70873960e+07
  -5.53306000e+05 -1.07157570e+07]]
[8 9 2 5 9]
step0 | Loss: 641617.375 | Accuracy: 0.0
[[-2.82252250e+05  0.00000000e+00 -2.88326250e+05 ... -9.29222560e+07
  -2.90263000e+05 -1.44368230e+07]
 [-1.82451750e+05 -7.46698750e+04 -1.04624500e+05 ... -5.67318160e+07
  -2.07165875e+05 -8.83959100e+06]
 [-3.34326875e+05 -8.34433750e+04 -3.47758500e+05 ... -5.78031120e+07
  -2.54673375e+05 -9.04291300e+06]
 [-7.54928750e+04 -8.99122500e+04  0.00000000e+00 ... -5.02520520e+07
  -1.54905750e+05 -7.80617900e+06]
 