In [4]:
import numpy as np
from jax import grad, jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp

## Basic NumPy-Like Operations


In [6]:
import jax.numpy as jnp

# Create arrays
a = jnp.array([1, 2, 3])
b = jnp.array([4, 5, 6])

# Element-wise operations
c = a + b
d = a * b

# Mathematical functions
e = jnp.sin(a)
f = jnp.log(b)
print(e, f)

def selu(x, alpha=1.67, lmbda=1.05):
  return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x = jnp.arange(5.0)
print(selu(x))

[0.84147096 0.9092974  0.14112   ] [1.3862944 1.609438  1.7917595]
[0.        1.05      2.1       3.1499999 4.2      ]


## grad

In [7]:
from jax import grad

def f(x):
    return x ** 2

# Compute the gradient of f
df_dx = grad(f)
print(df_dx(3.0))  # Output: 6.0


6.0


## JIT


In [None]:
from jax import jit

@jit
def matmul(a, b):
    return jnp.dot(a, b)

x = jnp.ones((1000, 1000))
y = jnp.ones((1000, 1000))
result = matmul(x, y)


## Train a Simple Neural Net

### Hyperparams


In [10]:
# Install Flax
!pip install flax

import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
import flax.linen as nn
import optax

# Define the neural network using Flax
class MLP(nn.Module):
    hidden_dim: int
    output_dim: int

    def setup(self):
        self.dense1 = nn.Dense(self.hidden_dim)
        self.dense2 = nn.Dense(self.output_dim)

    def __call__(self, x):
        x = jax.nn.relu(self.dense1(x))
        x = self.dense2(x)
        return x

# Define the loss function
def loss_fn(params, batch):
    x, y = batch
    logits = model.apply(params, x)
    loss = jnp.mean(optax.softmax_cross_entropy(logits=logits, labels=y))
    return loss

# Define a function to compute gradients
@jit
def compute_grads(params, batch):
    return grad(loss_fn)(params, batch)

# Initialize model and optimizer
key = jax.random.PRNGKey(0)
model = MLP(hidden_dim=64, output_dim=10)
params = model.init(key, jnp.ones((1, 28*28)))  # Example input shape for MNIST
optimizer = optax.adam(learning_rate=0.001)
opt_state = optimizer.init(params)

# Define a function to apply gradients and update parameters
@jit
def update(params, batch, opt_state):
    grads = compute_grads(params, batch)
    updates, new_opt_state = optimizer.update(grads, opt_state)
    new_params = optax.apply_updates(params, updates)
    return new_params, new_opt_state

# Example synthetic data
def create_batch(batch_size):
    x = jax.random.normal(key, (batch_size, 28*28))
    y = jax.nn.one_hot(jax.random.randint(key, (batch_size,), 0, 10), 10)
    return (x, y)

# Training loop
batch_size = 64
num_epochs = 100

for epoch in range(num_epochs):
    batch = create_batch(batch_size)
    params, opt_state = update(params, batch, opt_state)
    if epoch % 1 == 0:
        current_loss = loss_fn(params, batch)
        print(f"Epoch {epoch}, Loss: {current_loss:.4f}")

# Define a batched version of the model prediction function
batched_predict = vmap(lambda x: model.apply(params, x))

# Example batch prediction
batch = jax.random.normal(key, (10, 28*28))
predictions = batched_predict(batch)
print(predictions)


Epoch 0, Loss: 2.1131
Epoch 1, Loss: 1.7660
Epoch 2, Loss: 1.4649
Epoch 3, Loss: 1.2084
Epoch 4, Loss: 0.9971
Epoch 5, Loss: 0.8239
Epoch 6, Loss: 0.6833
Epoch 7, Loss: 0.5669
Epoch 8, Loss: 0.4714
Epoch 9, Loss: 0.3924
Epoch 10, Loss: 0.3264
Epoch 11, Loss: 0.2719
Epoch 12, Loss: 0.2274
Epoch 13, Loss: 0.1907
Epoch 14, Loss: 0.1602
Epoch 15, Loss: 0.1350
Epoch 16, Loss: 0.1141
Epoch 17, Loss: 0.0967
Epoch 18, Loss: 0.0823
Epoch 19, Loss: 0.0703
Epoch 20, Loss: 0.0604
Epoch 21, Loss: 0.0521
Epoch 22, Loss: 0.0453
Epoch 23, Loss: 0.0395
Epoch 24, Loss: 0.0347
Epoch 25, Loss: 0.0307
Epoch 26, Loss: 0.0273
Epoch 27, Loss: 0.0244
Epoch 28, Loss: 0.0220
Epoch 29, Loss: 0.0199
Epoch 30, Loss: 0.0181
Epoch 31, Loss: 0.0166
Epoch 32, Loss: 0.0153
Epoch 33, Loss: 0.0141
Epoch 34, Loss: 0.0131
Epoch 35, Loss: 0.0122
Epoch 36, Loss: 0.0114
Epoch 37, Loss: 0.0108
Epoch 38, Loss: 0.0101
Epoch 39, Loss: 0.0096
Epoch 40, Loss: 0.0091
Epoch 41, Loss: 0.0087
Epoch 42, Loss: 0.0083
Epoch 43, Loss: 0.007