# JIT Compilation

Additionally, it is possible to speed up the computation of certain neural networks by using the JIT.
Currently, this does not support models with varying input sizes and non tinygrad operations.

To use the JIT (Just in Time) compilation we just need to add a function decorator to the forward pass of our neural network and ensure that the input and output are realized tensors.
åOr in this case we will create a wrapper function and decorate the wrapper function to speed up the evaluation of our neural network.

## Tinynet Model

In [1]:
import numpy as np
from fetch_mnist import fetch_mnist
from tinygrad.helpers import Timing
from tinygrad.tensor import Tensor

class Linear:
  def __init__(self, in_features, out_features, bias=True, initialization: str='kaiming_uniform'):
    self.weight = getattr(Tensor, initialization)(out_features, in_features)
    self.bias = Tensor.zeros(out_features) if bias else None

  def __call__(self, x):
    return x.linear(self.weight.transpose(), self.bias)

class TinyNet:
  def __init__(self):
    self.l1 = Linear(784, 128, bias=False)
    self.l2 = Linear(128, 10, bias=False)

  def __call__(self, x):
    x = self.l1(x)
    x = x.leakyrelu()
    x = self.l2(x)
    return x

## Example

In [2]:
from tinygrad.jit import TinyJit

net = TinyNet()
X_train, Y_train, X_test, Y_test = fetch_mnist()

@TinyJit
def jit(x):
  return net(x).realize()

with Timing("Time: "):
  avg_acc = 0
  for step in range(1000):
    # random sample a batch
    samp = np.random.randint(0, X_test.shape[0], size=(64))
    batch = Tensor(X_test[samp], requires_grad=False)
    # get the corresponding labels
    labels = Y_test[samp]

    # forward pass with jit
    out = jit(batch)

    # calculate accuracy
    pred = out.argmax(axis=-1).numpy()
    avg_acc += (pred == labels).mean()
  print(f"Test Accuracy: {avg_acc / 1000}")

Test Accuracy: 0.103640625
Time: 3123.10 ms


You will find that the evaluation time is much faster than before and that your accelerator utilization is much higher.