# Installation
### To install Trax, you can use pip, the Python package installer. Open a terminal or command prompt and run the following command:

In [None]:
# !pip install trax
print("done!")

done!


### Quick Start
#### Once Trax is installed, let's dive into a quick example to understand the basic workflow.

In [None]:
import trax.layers.base as tl_base
import jax.numpy as np

# Define your custom function
def multiply_by_scalar(x, scalar):
    return x * scalar

# Create a function layer
multiply_layer = tl_base.Fn('MultiplyByScalar', multiply_by_scalar)

# Define some inputs
x = 3
scalar = 2

# Apply the layer to the inputs
result = multiply_layer((x, scalar))

print(result)


6


In this example, we define a simple addition function and convert it into a Trax layer using trax.layers.Fn. We then apply the layer to inputs (a, b) and obtain the result.

### Building Models
#### Trax allows you to build complex models by stacking layers together. Here's an example of building a simple feedforward neural network:

In [None]:
import trax
import jax

# Define the model architecture
model = trax.layers.Serial(
    trax.layers.Dense(128),
    trax.layers.Relu(),
    trax.layers.Dense(10)
)

# Define the input shape
input_shape = trax.shapes.ShapeDtype((1, 784), dtype='float32')

# Initialize the model with the input shape
model.init(input_shape)

# Generate random input
rng_key = jax.random.PRNGKey(0)  # Create a random number generator key
inputs = trax.fastmath.random.normal(key=rng_key, shape=(1, 784))

# Run the input through the model
output = model(inputs)

# Print the output shape
print(output.shape)  # Output: (1, 10)


(1, 10)


In this example, we use trax.models.mlp.MLP to define a multi-layer perceptron (MLP) with a hidden layer of 128 units and an output layer of 10 units. We initialize the model with the input shape (1, 784), generate random inputs, and obtain the output.



### Training a Model
#### Trax provides utilities for training models. Here's an example of training a simple model on the MNIST dataset:

In [None]:
import trax
import numpy as np
from trax import layers as tl
from trax.supervised import training

# Load the MNIST dataset and normalize the pixel values
train_stream = trax.data.TFDS('mnist', keys=('image', 'label'), train=True)()
eval_stream = trax.data.TFDS('mnist', keys=('image', 'label'), train=False)()

def normalize_image(x):
    image = x['image'].astype(np.float32) / 255.0
    return image, x['label']  # Return the normalized image and the label

train_stream = list(train_stream)
eval_stream = list(eval_stream)

# Apply normalization to the dataset
train_stream = trax.data.Dataset(train_stream).map(normalize_image)
eval_stream = trax.data.Dataset(eval_stream).map(normalize_image)

# Define the model architecture
model = tl.Serial(
    tl.Flatten(),  # Flatten the input (28, 28, 1) to (784,)
    tl.Dense(256),
    tl.Relu(),
    tl.Dense(10),
    tl.LogSoftmax()
)

# Specify the input signature
input_signature = trax.shapes.ShapeDtype((1, 28, 28, 1), dtype=np.float32)

# Set the input signature for the model
model.init(input_signature)

# Define the custom loss function
def custom_loss_fn():
    def loss_fn(inputs, targets):
        logits = model(inputs)
        return tl.CrossEntropyLoss()(logits, targets)
    return tl.Fn('CustomLoss', loss_fn)

# Define the optimizer
optimizer = trax.optimizers.Adam(learning_rate=0.01)

# Define the training task
train_task = training.TrainTask(
    labeled_data=train_stream,
    loss_layer=custom_loss_fn(),
    optimizer=optimizer,
    n_steps_per_checkpoint=500
)

# Define the evaluation task
eval_task = training.EvalTask(
    labeled_data=eval_stream,
    metrics=[tl.CrossEntropyLoss(), tl.Accuracy()]
)

# Create the training loop
training_loop = training.Loop(
    model=model,
    tasks=[train_task],
    eval_tasks=[eval_task]
)

# Run the training loop
training_loop.run(n_steps=1000)
print("Done!")


In this example, we load the MNIST dataset using trax.data.TFDS. We define a simple model using tl.Serial and specify the layers and their configuration. We define the loss function, optimizer, and create a TrainTask and EvalTask for training and evaluation, respectively. Finally, we create a training.Loop and run the training loop for a specified number of steps.