In [None]:
!pip install keras keras-hub matplotlib --upgrade -q

In [None]:
import os
os.environ["KERAS_BACKEND"] = "jax"

In [None]:
# @title
import os
from IPython.core.magic import register_cell_magic

@register_cell_magic
def backend(line, cell):
    current, required = os.environ.get("KERAS_BACKEND", ""), line.split()[-1]
    if current == required:
        get_ipython().run_cell(cell)
    else:
        print(
            f"This cell requires the {required} backend. To run it, change KERAS_BACKEND to "
            f"\"{required}\" at the top of the notebook, restart the runtime, and rerun the notebook."
        )

## Chapter 3 - Introduction to TensorFlow, PyTorch, JAX, and Keras
In this chapter, we will explore the most popular deep learning frameworks: TensorFlow, PyTorch, JAX, and Keras. We will discuss their features, advantages, and disadvantages, and provide examples of how to use each framework for building and training neural networks.
What does a deep learning framework provide?
1. **Tensor operations**: Efficient implementations of tensor operations for CPU (and optionally GPU).
1. **Gradient computation**: Automatic differentiation to compute gradients for backpropagation.

Those frameworks may also provide higher-level functionalities, such as:
1. **Neural network layers**: Abstractions for common neural network layers (e.g., Dense, Convolutional).
1. **Model abstractions**: High-level APIs for defining, training, and evaluating models.
1. **Loss functions and optimizers**: Predefined loss functions and optimization algorithms.

Keras, TensorFlow, PyTorch, and JAX do not have all the same features. Keras is a high-level API that can run on top of the other three frameworks, providing a user-friendly interface for building and training models. TensorFlow and PyTorch are more comprehensive frameworks that provide both low-level tensor operations and high-level model abstractions.

### TensorFlow
TensorFlow is an open-source deep learning framework developed by Google. It provides a comprehensive ecosystem for building and deploying machine learning models. TensorFlow supports both CPU and GPU computations and offers a wide range of tools and libraries for various machine learning tasks.
Let's see some syntax examples in TensorFlow.

#### Constant tensors
Constant tensors are immutable tensors whose values cannot be changed after they are created. In TensorFlow, you can create constant tensors using the functions below:

In [None]:
# Constant tensors
import tensorflow as tf

# Useful for initializations
t = tf.ones(shape=(2, 1))
print(t, "\n")

t = tf.zeros(shape=(2, 1))
print(t, "\n")

t = tf.constant([1, 2, 3], dtype="float32")
print(t, "\n")


# Random tensors
# Normal distribution with mean and standard deviation
t = tf.random.normal(shape=(3, 1), mean=0., stddev=1.)
print(t, "\n")


# Uniform distribution with min and max values
t = tf.random.uniform(shape=(3, 1))
print(t, "\n")

#### Variable tensors
Constant tensors are immutable, so they are not suitable for representing model parameters that need to be updated during training. For this purpose, TensorFlow provides variable tensors, which are mutable and can be updated using optimization algorithms. You can create variable tensors using the `tf.Variable` class. An initial value must be provided when creating a variable tensor.

In [None]:
# If uncommented, this code will result in an exception: the constant tensors are not assignable
# t = tf.ones(shape=(1, 3), dtype=tf.dtypes.float32)
# t[0, 1] = 0.0

v = tf.Variable(initial_value=tf.random.normal(shape=(3, 1)))
print(v, "\n")
v.assign(tf.ones((3, 1)))
print(v, "\n")

v[0, 0].assign(3.)
print(v, "\n")

v.assign_add(tf.ones((3, 1)))
print(v, "\n")


#### Tensor operations
Here are some examples of basic tensor operations in TensorFlow:

In [None]:
a = tf.constant(shape=(2, 2), value=[2.0, 2.0, 2.0, 2.0])
print(a, "\n")
b = tf.square(a)
print(b, "\n")
c = tf.sqrt(a)
print(c, "\n")
d = b + c
print(d, "\n")
e = a @ b # Equivalent to tf.matmul(a, b)
print(e, "\n")
f = tf.concat((a, b), axis=0)
print(f, "\n")
g = tf.concat((a, b), axis=1)
print(g, "\n")

With these operations, we can implement the Dense layer from scratch.

In [None]:
def dense(inputs, W, b):
    return tf.nn.relu((inputs @ W) + b)

#### Gradient computation
The automatic differentiation feature in TensorFlow is exposed via the `tf.GradientTape` API. It allows you to record operations for automatic differentiation. It is necessary to open a `tf.GradientTape` context to record the operations for which you want to compute gradients. Beware that only operations executed within the context on variables watched by the `GradientTape` will be recorded for differentiation. By default, all `tf.Variable` objects are being watched. You can also manually watch tensors using the `watch` method of the `GradientTape` object. Here is an example of how to use `tf.GradientTape` to compute gradients:

In [None]:
input_var = tf.Variable(initial_value=3.0)
with tf.GradientTape() as tape:
    # function: y = x^2
    result = tf.square(input_var)
gradient = tape.gradient(result, input_var)
# gradient (in this case derivative) of the function: y' = 2 * x
# Evaluated in x = 3.0 --> y' = 2 * 3 = 6
print(gradient)

In [None]:
# Same as above, but with constants
input_const = tf.constant(3.0)
with tf.GradientTape() as tape:
    tape.watch(input_const) # required to compute the gradient w.r.t. constants
    result = tf.square(input_const)
gradient = tape.gradient(result, input_const)
print(gradient)

It is possible to chain two GradientTape contexts to compute higher-order derivatives. Here is an example of how to compute the second derivative of a function:

In [None]:
time = tf.Variable([4.0])
with tf.GradientTape() as second_order_tape:
    with tf.GradientTape() as first_order_tape:
        z = -(9.81 / 2) * time ** 2
    z_dot = first_order_tape.gradient(z, time)
z_ddot = second_order_tape.gradient(z_dot, time)
print("Acceleration:", z_ddot)
print("Speed:", z_dot)
print("Position:", z)

#### Compilation with tf.function
TensorFlow provides the `tf.function` decorator to compile Python functions into TensorFlow graphs. This can significantly improve performance by optimizing the execution of the function. It is also possible to compile with XLA (Accelerated Linear Algebra) to further optimize performance on supported hardware.

In [None]:
import numpy as np

input = np.ones(shape=(1, 3), dtype="float32")
W = tf.Variable(tf.ones(shape=(3, 1), dtype=tf.dtypes.float32))
b = tf.Variable(tf.ones(shape=(3, 1), dtype=tf.dtypes.float32))

@tf.function
def dense(inputs, W, b):
    return tf.nn.relu(tf.matmul(inputs, W) + b)
print(dense(input, W, b), "\n")

@tf.function(jit_compile=True)
def dense_jit(inputs, W, b):
    return tf.nn.relu(tf.matmul(inputs, W) + b)
print(dense_jit(input, W, b))


As you saw in the previous examples, it is possible to send `np.ndarray` to a TensorFlow function, and it will be automatically converted to a TensorFlow tensor.
#### Example: Linear regression with TensorFlow
One example that can be done end-to-end with the current knowledge is a simple linear regression model using TensorFlow. We will classify two classes of points in 2D space using a linear model. Let's first create our synthetic dataset:

In [None]:
import numpy as np

num_samples_per_class = 1000
# Using a 2D gaussian for the dataset
negative_samples = np.random.multivariate_normal(
    mean=[-2.5, 0], cov=[[1, 0.5], [0.5, 1]], size=num_samples_per_class
)
positive_samples = np.random.multivariate_normal(
    mean=[2.5, 0], cov=[[1, 0.5], [0.5, 1]], size=num_samples_per_class
)

dataset = np.concat((negative_samples, positive_samples), axis=0, dtype=np.float32)
print("Shape dataset:", dataset.shape)

labels = np.concat((np.zeros(shape=(1000,1)), np.ones(shape=(1000,1))), dtype=np.float32)
print("Shape labels:", labels.shape)


We can use Matplotlib to visualize our dataset:

In [None]:
import matplotlib.pyplot as plt

plt.scatter(dataset[:, 0], dataset[:, 1], c=labels[:, 0])
plt.show()

Now let's implement a simple linear regression model using TensorFlow. We will define the model, loss function, and optimization algorithm, and then train the model on our dataset.

In [None]:
import tensorflow as tf
input_dim = 2
output_dim = 1
training_rate = 1e-1

W = tf.Variable(tf.random.uniform(shape=(input_dim, output_dim), dtype=tf.dtypes.float32))
b = tf.Variable(tf.zeros(shape=(output_dim, ), dtype=tf.dtypes.float32))

def model(input, W, b):
    return tf.matmul(input, W) + b

def lossFunction(output, label):
    return tf.reduce_mean(tf.square(label - output))
    
# @tf.function(jit_compile=True)
# @tf.function
def batchTrainingStep(input_batch, labels, W, b):
    with tf.GradientTape() as tape:
        prediction = model(input_batch, W, b)
        loss = lossFunction(prediction, labels)
    dW, db = tape.gradient(loss, [W, b])
    W.assign_sub(training_rate * dW)
    b.assign_sub(training_rate * db)
    return loss
    
for i in range(40):
    loss = batchTrainingStep(dataset, labels, W, b)
    print(f"Loss at step {i}: {loss:.4f}")


In [None]:
predictions = model(dataset, W, b)
x = np.linspace(-2, 2, 100)
y = -W[0] / W[1] * x + (0.5 - b) / W[1]
plt.plot(x, y, "-r")
plt.scatter(dataset[:, 0], dataset[:, 1], c=predictions[:, 0] > 0.5)
plt.show()

### PyTorch
PyTorch is an open-source deep learning framework developed by Facebook's AI Research lab. It has a first-class support for the famous model-sharing platform Hugging Face. It supports both CPU and GPU computations and offers a wide range of tools and libraries for various machine learning tasks. PyTorch API is higher level, including abstractions for layers, models and optimizers, making it easier to build and train neural networks.

Let's see some syntax examples in PyTorch.

#### Constant tensors


In [None]:
# Constant tensors
import torch

# Useful for initializations
t = torch.ones(size=(2, 1)) # In other platforms, the parameter it is called shape
print(t, "\n")

t = torch.zeros(size=(2, 1))
print(t, "\n")

t = torch.tensor([1, 2, 3], dtype=torch.float32)
print(t, "\n")

# Random tensors
t = torch.normal(
        mean=torch.zeros(size=(3, 1)), # Mean and std deviation have to be of the same
        std=torch.ones(size=(3, 1))    # size as the tensor itself
    )
print(t, "\n")

# Random tensor with uniform distribution
t = torch.rand(3, 1) # Argument separated, different from TensorFlow and NumPy
print(t, "\n")

# Another difference from TensorFlow: elements of tensors are assignable
t[0, 0] = 1.
print(t)

Similar to `Variable`s in TensorFlow, PyTorch provides `Parameter`s that are tensors that can be optimized during training.

In [None]:
t = torch.zeros(size=(2, 1))
p = torch.nn.parameter.Parameter(data=t)

#### Tensor operations
Here are some examples of basic tensor operations in PyTorch:

In [None]:
a = torch.tensor([[2.0, 2.0], [2.0, 2.0]], dtype=torch.float32)
print(a, "\n")
b = torch.square(a)
print(b, "\n")
c = torch.sqrt(a)
print(c, "\n")
d = b + c
print(d, "\n")
e = a @ b # Equivalent to torch.matmul(a, b)
print(e, "\n")
f = torch.cat((a, b), dim=0)
print(f, "\n")
g = torch.cat((a, b), dim=1)
print(g, "\n")


And that's how a dense layer can be implemented from scratch in PyTorch.

In [None]:
def dense(inputs, W, b):
    return torch.relu(torch.matmul(inputs, W) + b)

input = torch.tensor([1.0, 1.0], dtype=torch.float32)
W = torch.tensor([[2.0], [2.0]], dtype=torch.float32)
b = torch.tensor([[3.0]], dtype=torch.float32)

print(dense(input, W, b))

#### Gradient computation
Behind the scenes, PyTorch uses a dynamic computation graph to compute gradients. This means that the graph is built on-the-fly while we perform operations on tensors. To enable gradient computation for a tensor, we need to set the `requires_grad` attribute to `True` when creating the tensor. PyTorch will then automatically track all operations performed on that tensor and build the computation graph accordingly. Here is an example of how to compute gradients in PyTorch:

In [None]:
input_var = torch.tensor(3.0, requires_grad=True)
result = torch.square(input_var)
result.backward() # Performing back propagation on the output
print(input_var.grad) # the 'grad' attribute gets filled

result = torch.square(input_var)
result.backward()
print(input_var.grad) # Beware: it accumulates, it DOES NOT get overwritten

input_var.grad = None # That's the way to reset the gradients


#### Example: Linear regression with PyTorch
We can implement the same linear regression model using the PyTorch tools we introduced up to now.

In [None]:
input_dim = 2
output_dim = 1
learning_rate = 0.1

W = torch.rand(input_dim, output_dim, requires_grad=True)
b = torch.zeros(output_dim, requires_grad=True)

def model(inputs, W, b):
    return inputs @ W + b

def mean_squared_error(targets, predictions):
    per_sample_losses = torch.square(targets - predictions)
    return torch.mean(per_sample_losses)


def training_step(inputs, targets, W, b):
    predictions = model(inputs, W, b)
    loss = mean_squared_error(targets, predictions)
    loss.backward()
    grad_loss_wrt_W, grad_loss_wrt_b = W.grad, b.grad
    with torch.no_grad():
        W -= grad_loss_wrt_W * learning_rate
        b -= grad_loss_wrt_b * learning_rate
    W.grad = None
    b.grad = None
    return loss

dataset_torch = torch.tensor(dataset)
labels_torch = torch.tensor(labels)
for i in range(40):
    loss = training_step(dataset_torch, labels_torch, W, b)
    print(f"Epoch: {i + 1}. Loss: {loss}")
    
predictions = model(dataset_torch, W, b)
W_plot = W.detach().numpy()
b_plot = b.detach().numpy()
x = np.linspace(-2, 2, 100)
y = -W_plot[0] / W_plot[1] * x + (0.5 - b_plot) / W_plot[1]
plt.plot(x, y, "-r")
plt.scatter(dataset[:, 0], dataset[:, 1], c=predictions[:, 0] > 0.5)
plt.show()

We implemented the linear regression model using only the low-level interface of PyTorch, without using any high-level abstractions.
One of the high-level abstractions provided by PyTorch is the `Module` class, which can be used to define neural network layers and models.

In [None]:
class LinearModel(torch.nn.Module):
    # In the constructor, it is necessary to instantiate the 'Parameter's
    def __init__(self):
        super().__init__()
        self.W = torch.nn.Parameter(torch.rand(input_dim, output_dim))
        self.b = torch.nn.Parameter(torch.zeros(output_dim))

    # It is required to override the 'forward' method.
    def forward(self, inputs):
        return torch.matmul(inputs, self.W) + self.b
    
# Let's call the model, the output will be random
model = LinearModel()
model(torch.tensor([2.0, 2.0]))


The other helpful high-level interface is the `torch.optim` module, which provides various optimization algorithms for training neural networks. We will use Stochastic Gradient Descent (SGD) to optimize our linear regression model.

In [None]:
learning_rate = 0.1

model = LinearModel()

optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)

def training_step(inputs, targets):
    predictions = model(inputs)
    loss = mean_squared_error(targets, predictions)
    loss.backward()
    optimizer.step()
    model.zero_grad()
    return loss

Finally, it is possible to use compilation in PyTorch with the decorator `torch.compile` (that can be used also as a function on a model) to optimize the execution of the model.

In [None]:
@torch.compile
def dense(inputs, W, b):
    return torch.relu(torch.matmul(inputs, W) + b)

However, the compilation is still experimental and rarely used in practice. In general, PyTorch is the slowest framework among the ones presented here.

### JAX
JAX is an open-source deep learning framework developed by Google. It is designed for high-performance numerical computing and provides a powerful automatic differentiation system. JAX supports CPU, GPU, and TPU and it is the fastest framework among the ones presented here.

JAX encourages a functional programming style, where functions are pure and do not have side effects. This crucial aspect makes JAX code directly parallelizable, and so optimizable for GPU and TPU hardware.

Let's see some syntax examples in JAX.

In [None]:
from jax import numpy as jnp

# Useful for initializations
t = jnp.ones(shape=(2, 1))
print(t, "\n")

t = jnp.zeros(shape=(2, 1))
print(t, "\n")

t = jnp.array([1, 2, 3], dtype="float32")
print(t, "\n")

JAX implements directly the same API as NumPy. However, there are few differences regarding array assignments and random number generation.

#### Tensor assignments
In JAX, arrays are immutable, meaning that once they are created, their values cannot be changed (it would go against the functional programming paradigm). Instead of modifying an array in place, you create a new array with the desired changes using several functions provided by JAX.

In [None]:
t = jnp.array([1, 2, 3], dtype="float32")
print(t, "\n")
new_t = t.at[0].set(10)
print(new_t, "\n")

#### Pseudo-Random Number Generation
Generating a random number using a Pseudo-Random Number Generator (PRNG), usually requires to hold a state. The state defines the current position of the PRNG in its sequence of random numbers. The stateless design of JAX requires to explicitly pass the state of the PRNG to the random number generation functions, and these functions return a new state along with the generated random numbers. This design choice allows for better reproducibility and easier parallelization of code.

In [None]:
import jax

seed_key = jax.random.key(1337) # A seed is needed to generate a key, which will be used for the PRNGs
t = jax.random.normal(seed_key, shape=(3,))
print(t, "\n")

# Given the stateless design, if we provide the same key, we will get the same value
t = jax.random.normal(seed_key, shape=(3,))
print(t, "\n")

# The function split provides a deterministic way to generate another key from a first one.
# In this way, once we decided a first seed, we can generate more numbers in a deterministic way
new_seed_key = jax.random.split(seed_key, num=1)[0]
t = jax.random.normal(new_seed_key, shape=(3,))
print(t, "\n")

#### Tensor operations
In JAX, you can perform basic tensor operations using the same functions as in NumPy. Here are some examples:

In [None]:
from jax import numpy as jnp
a = 2.0 * jnp.ones((2, 2))
print(a, "\n")
b = jnp.square(a)
print(b, "\n")
c = jnp.sqrt(a)
print(c, "\n")
d = b + c
print(d, "\n")
e = a @ b # Equivalent to jnp.matmul(a, b)
print(e, "\n")
e *= d # Element wise multiplication
print(e)

And that's how a dense layer can be implemented from scratch in JAX.

In [None]:
def dense(inputs, W, b):
    return jax.nn.relu(jnp.matmul(inputs, W) + b)

#### Gradient computation
JAX provides the `jax.grad` function to compute gradients of functions. The `jax.grad` is a meta-function that takes a function as input and returns a new function that computes the gradient of the input function with respect to its arguments. Here is an example of how to use `jax.grad` to compute gradients:

In [None]:
import jax

def compute_loss(input_var):
    return jnp.square(input_var)

# To get the gradient of a function, it should satisfy the following properties:
#   1. It should return a scalar loss value (optionally, can have other outputs)
#   2. Its first argument (that can be an array, list of array or dict of arrays),
#      should contain the array we want to compute the gradient for 
grad_fn = jax.grad(compute_loss)

input_var = jnp.array(3.0)
grad_of_loss_wrt_input_var = grad_fn(input_var)
print(grad_of_loss_wrt_input_var, "\n")


# Thanks to 'jax.value_and_grad', it is possible to get loss and gradient in one pass
grad_fn = jax.value_and_grad(compute_loss)
output, grad_of_loss_wrt_input_var = grad_fn(input_var)
print("loss: ", output)
print("grad: ", grad_of_loss_wrt_input_var, "\n")

# If the function of the gradient has more inputs and/or outputs, the gradient function
# will also have that additional inputs and outputs
def dense(state, inputs):
    (W, b) = state
    return jax.nn.relu(jnp.matmul(inputs, W) + b)[0], inputs

grad_fn = jax.value_and_grad(dense, has_aux=True) # For auxiliary outputs, not needed when
                                                  # more inputs are needed

input = jnp.array([1.0, 0.0])
W = jnp.array([1.0, 1.0])
b = jnp.array([2.0])

# Returns first the output of the differenciated function, then the gradient
(loss, other), grad_of_loss_wrt_input_var = grad_fn((W, b), input)
print("loss:", loss)
print("grad:", grad_of_loss_wrt_input_var)
print("inputs:", other)


Finally, JAX also provides the `jax.jit` decorator to compile functions for improved performance. The `jax.jit` decorator compiles a function into a more efficient representation that can be executed faster on supported hardware (CPU, GPU, TPU). Here is an example of how to use `jax.jit` to compile a function:

In [None]:
@jax.jit
def dense(inputs, W, b):
    return jax.nn.relu(jnp.matmul(inputs, W) + b)

#### Example: Linear regression with JAX
We can implement the same linear regression model using JAX. One important aspect to consider when using JAX is that it encourages a functional programming style. This means that we need to avoid side effects and mutable state in our functions.

In [None]:
import jax

input_dim = 2
output_dim = 1
learning_rate = 0.1

def model(inputs, W, b):
    return jnp.matmul(inputs, W) + b

def mean_squared_error(targets, predictions):
    per_sample_losses = jnp.square(targets - predictions)
    return jnp.mean(per_sample_losses)

def compute_loss(state, inputs, targets):
    W, b = state
    predictions = model(inputs, W, b)
    loss = mean_squared_error(targets, predictions)
    return loss

grad_fn = jax.value_and_grad(compute_loss)

@jax.jit
def training_step(inputs, targets, W, b):
    loss, grads = grad_fn((W, b), inputs, targets)
    grad_wrt_W, grad_wrt_b = grads
    W = W - grad_wrt_W * learning_rate # Beware! learning_rate must be constant!
    b = b - grad_wrt_b * learning_rate
    return loss, W, b

W = jax.numpy.array(np.random.uniform(size=(input_dim, output_dim)))
b = jax.numpy.array(np.zeros(shape=(output_dim,)))
state = (W, b)
for step in range(40):
    loss, W, b = training_step(dataset, labels, W, b)
    print(f"Loss at step {step}: {loss:.4f}")
    
predictions = model(dataset, W, b)
x = np.linspace(-2, 2, 100)
y = -W[0] / W[1] * x + (0.5 - b) / W[1]
plt.plot(x, y, "-r")
plt.scatter(dataset[:, 0], dataset[:, 1], c=predictions[:, 0] > 0.5)
plt.show()

### Keras
Keras is a high-level deep learning API that can run on top of TensorFlow, PyTorch, and JAX. It provides a user-friendly interface for building and training neural networks, making it easier to develop deep learning models without needing to deal with low-level details.

Before using Keras, it is necessary to select the backend framework. By default, Keras uses TensorFlow as the backend, but it can be configured with an environment variable called `KERAS_BACKEND`. The supported backends are TensorFlow, PyTorch, and JAX. Alternatively, it is possible to set the backend in the configuration file located at `~/.keras/keras.json`. Here is an example of how to set the backend to PyTorch in the configuration file:

```python
{
    # Default floating-point precision. It should typically not be
    # changed.
    "floatx": "float32",
    # Default numerical fuzzing factor. It should typically not be
    # changed.
    "epsilon": 1e-07,
    # Change "tensorflow" to "jax" or "torch."
    "backend": "torch",
    # This is the default image layout. We'll talk about this in
    # chapter 8.
    "image_data_format": "channels_last",
}
```

#### Layers in Keras
Keras provides as central abstraction the `Layer` class, which can be used to define neural network layers. Keras provides many predefined layers, such as `Dense`, `Conv2D`, and `LSTM`. A layer is an object that stores some state (its parameters) and do some computation (the forward pass). As an example, here is how to create a class for a dense layer in Keras:

In [None]:
import keras

# All Keras layers inherit from the base Layer class.
class SimpleDense(keras.Layer):
    def __init__(self, units, activation=None):
        super().__init__()
        self.units = units
        self.activation = activation

    # Weight creation takes place in the build() method.
    def build(self, input_shape):
        batch_dim, input_dim = input_shape
        # add_weight is a shortcut method for creating weights. It's
        # also possible to create standalone variables and assign them
        # as layer attributes, like self.W = keras.Variable(shape=...,
        # initializer=...).
        self.W = self.add_weight(
            shape=(input_dim, self.units), initializer="random_normal"
        )
        self.b = self.add_weight(shape=(self.units,), initializer="zeros")

    # We define the forward pass computation in the call() method.
    def call(self, inputs):
        y = keras.ops.matmul(inputs, self.W) + self.b
        if self.activation is not None:
            y = self.activation(y)
        return y

my_dense = SimpleDense(units=32, activation=keras.ops.relu)
input_tensor = keras.ops.ones(shape=(2, 784))

# Call happens through the operator __call__(). We could ovveride this method, but then
# we brake the internal automatic shape inference! The build function is exactly used for
# this. When we run the __call__() function, the Layer will check if it has already been
# built. If not, it will run the build method, that takes the shape as input. In the build
# call, it is then possible to allocate the variable of the right size
output_tensor = my_dense(input_tensor)
print(output_tensor.shape)

#### Models in Keras
Keras also provides the `Model` class, which can be used to define neural network models as a graph of layers. A model is an object that groups layers into an object with training and inference features. Keras offers both a functional API and a object-oriented API for defining models. Keras' functional API will be discussed in further chapters. Here is an example of how to create a simple model using the `Sequential` subclass of `Model`:

In [None]:
from keras import models
from keras import layers

model = models.Sequential(
    [
        layers.Dense(32, activation="relu"),
        layers.Dense(64, activation="relu"),
        layers.Dense(32, activation="relu"),
        layers.Dense(10, activation="softmax"),
    ]
)

The topology of the model is very important. It defines the hypothesis space of the model: choosing a model defines a specific space of functions that the model can represent, and with training, we are searching for the best function parameters within that space. It basically encodes our prior knowledge about the problem we want to solve. Selecting an appropriate model architecture is more an art than a science, and it often requires experimentation and domain knowledge.

Once defined a model, it is possible to compile it using the `compile` method. This method configures the model for training by specifying the optimizer, loss function, and metrics to monitor during training.

In [None]:
model = keras.Sequential([keras.layers.Dense(1)])

# These strings are used to instantiate the right objects
model.compile(
    optimizer="rmsprop",
    loss="mean_squared_error",
    metrics=["accuracy"],
)

# It is possible to provide directly the objects
model.compile(
    optimizer=keras.optimizers.RMSprop(learning_rate=0.1), # Using the object is useful for cases like this one
    loss=keras.losses.MeanSquaredError(),
    metrics=[keras.metrics.BinaryAccuracy()],
)


After compiling the model, it is ready to be trained using the `fit` method, which takes the training data and labels as input and performs the training process.
It requires as input: 
1. The training data.
1. The training labels.
1. The number of epochs (iterations over the entire dataset).
1. The batch size (number of samples per gradient update).

It is possible to also select part of the dataset for evaluation during training using the `validation_split` argument (or providing a validation dataset).

The return type of the `fit` method is a `History` object, which contains information about the training process, such as the loss and metrics values for each epoch. This information can be useful for analyzing the model's performance and diagnosing potential issues during training.

In [None]:
import pprint

# Simple training without validation data
model.compile(
    optimizer="rmsprop",
    loss="mean_squared_error",
    metrics=["accuracy"],
)

history = model.fit(
    dataset,
    labels,
    epochs=5,
    batch_size=128,
)
pprint.pprint(history.history)

# Training with validation data and custom learning rate
model = keras.Sequential([keras.layers.Dense(1)])
model.compile(
    optimizer=keras.optimizers.RMSprop(learning_rate=0.1),
    loss=keras.losses.MeanSquaredError(),
    metrics=[keras.metrics.BinaryAccuracy()],
)

indices_permutation = np.random.permutation(len(dataset))
shuffled_inputs = dataset[indices_permutation]
shuffled_targets = labels[indices_permutation]

num_validation_samples = int(0.3 * len(dataset))
val_inputs = shuffled_inputs[:num_validation_samples]
val_targets = shuffled_targets[:num_validation_samples]
training_inputs = shuffled_inputs[num_validation_samples:]
training_targets = shuffled_targets[num_validation_samples:]
history = model.fit(
    training_inputs,
    training_targets,
    epochs=5,
    batch_size=16,
    validation_data=(val_inputs, val_targets),
)
pprint.pprint(history.history)

Finally, the trained model can be used to make predictions on new data using the `predict` method. This method takes the input data as input and returns the predicted output of the model.

In [None]:
predictions = model.predict(val_inputs, batch_size=128)
print(predictions[:10])