In [17]:
from tinygrad.nn.datasets import mnist
from tinygrad import nn, Tensor
from tinygrad.helpers import getenv, colored, trange

class Model():
    def __init__(self):
       self.layers = [      
        nn.Conv2d(1, 32, 5), Tensor.relu,
        nn.Conv2d(32, 32, 5), Tensor.relu,
        nn.BatchNorm(32), Tensor.max_pool2d,
        nn.Conv2d(32, 64, 3), Tensor.relu,
        nn.Conv2d(64, 64, 3), Tensor.relu,
        nn.BatchNorm(64), Tensor.max_pool2d,
        lambda x: x.flatten(1), nn.Linear(576, 10)]
       
    def __call__(self,x): return x.sequential(self.layers)


X_train, Y_train, X_test, Y_test = mnist()

model = Model()

opt = nn.optim.Adam(nn.state.get_parameters(model))
@Tensor.train()
def trainstep():
    opt.zero_grad()
    samples = Tensor.randint(512, high=X_train.shape[0])
    loss = model(X_train[samples]).sparse_categorical_crossentropy(Y_train[samples]).backward()
    opt.step()
    return loss

def get_test_acc() -> Tensor: return (model(X_test).argmax(axis=1) == Y_test).mean() * 100
test_acc = float("nan")
for i in (t:=trange(70)):
    loss = trainstep()
    if i%10 == 9: test_acc = get_test_acc().item()
    t.set_description(f"loss: {loss.item():6.2f} test_accuracy: {test_acc:5.2f}%")



loss:   0.07 test_accuracy: 97.83%: 100%|███████| 70/70 [00:21<00:00,  3.26it/s]


In [None]:
random = Tensor.randint(high=X_test.shape[0]).item()  # Wähle zufälligen Index
result = model(X_test[random].unsqueeze(0))
  # Modell auf ein Bild anwenden

print(f"Predict: {result.argmax().item()}, actual value: {Y_test[random].item()}")


(1, 28, 28)
[[[[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
      0   0   0   0   0   0   0   0   0   0   0]]

  [[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
      0   0   0   0   0   0   0   0   0   0   0]]

  [[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
      0   0   0   0   0   0   0   0   0   0   0]]

  [[  0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0   0
      0   0   0   0   0   0   0   0   0   0   0]]

  [[  0   0   0   0   0   0   0   0   0   0   0   0  11 150 253 202  31
      0   0   0   0   0   0   0   0   0   0   0]]

  [[  0   0   0   0   0   0   0   0   0   0   0   0  37 251 251 253 107
      0   0   0   0   0   0   0   0   0   0   0]]

  [[  0   0   0   0   0   0   0   0   0   0   0  21 197 251 251 253 107
      0   0   0   0   0   0   0   0   0   0   0]]

  [[  0   0   0   0   0   0   0   0   0   0 110 190 251 251 251 253 169
    109  62   0   0   0   0   0   0   0   0   0]]

  [[