<a href="https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2025/blob/ml-foundations-prac-3/practicals/ML_Foundation/Part_3/Machine_learning_evaluation_and_generalisation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Introduction to DL -- Evaluation, Generalization, and Optimization Algorithms**

<img src="https://incubator.ucf.edu/wp-content/uploads/2023/07/artificial-intelligence-new-technology-science-futuristic-abstract-human-brain-ai-technology-cpu-central-processor-unit-chipset-big-data-machine-learning-cyber-mind-domination-generative-ai-scaled-1-1500x1000.jpg" width="600"/>


<a href="https://colab.research.google.com/github/deep-learning-indaba/indaba-pracs-2025/blob/main/practicals/ML_Foundation/Part_3/Machine_learning_evaluation_and_generalisation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


© Deep Learning Indaba 2025. Apache License 2.0.

**Authors:** Ulrich Mbou Sob, Geraud Nangue Tasse

**Reviewers:**

**Introduction:**
In machine learning, our main goal is to **train a model on data** so that it can perform well at a specific task such as **classification** or **regression**.

When training a model, we usually minimize a **loss function**.  
The loss helps guide the learning process, but it doesn’t always match how we actually measure performance in practice. In this tutorial, we will focus on machine learning model evaluation and introduce different optimization techniques that can be leverage to improve machine learning model's performance.

**Topics:**

Content: <font color='green'>`Supervised Learning, Evaluation, Optimization`</font>

Level: <font color='grey'>`Beginner`</font>

**Aims/Learning Objectives:**

In this tutorial we will learn the following key concepts:

- **Model Evaluation** → How do we measure how good a model really is?  
- **Generalization** → How well does a model perform on data it has never seen before?  
- **Optimization & Regularization** → What algorithms and techniques can we use to train models more effectively and prevent overfitting?

**Prerequisites:**

- Practical 1
  - Regression
  - Basic knowledge of Jax
- Practical 2
  - Machine learning classification

**Outline:**

>[Part 1 - Evaluation and Generalization](#scrollTo=zHc7_PbomVIN)

>>[Model Evaluation](#scrollTo=nHCa0Tj-0_ZC)

>>>[Breast cancer classification](#scrollTo=VnKyBQeS7auc)

>>>[Evaluation metrics](#scrollTo=S8cWX6wEe0rN)

>>>[✅ Accuracy](#scrollTo=nw5RpOOekkII)

>>>[🎯 Precision](#scrollTo=41S2rXHulrai)

>>>[🚨 Recall (Sensitivity)](#scrollTo=BshBAzvElsyw)

>>>[📊 Aggregate Metrics](#scrollTo=aUlfhmD9tED-)

>>>[Cross validation and Generalisation](#scrollTo=KpjY8k_64kjT)

>[Part 2 - Optimization algorithms, learning rate schedulers, and hyperparameter tunning](#scrollTo=5HNpEM4DnMNe)

>[Appendix](#scrollTo=8dmPgHGhH8oU)

>>[References](#scrollTo=d6YYbpyXpqib)

>[Feedback](#scrollTo=o1ndpYE50BpG)

**Note:** To get the most out of this tutorial, try answering the questions, quizzes, and code tasks on your own before checking the solutions. Actively working through them is the most effective way to learn.

**Before you start:**

Run the "Installation and Imports" cell below.

### Installation and Imports

In [None]:
!pip install jax flax optax clu --quiet

import numpy as np
import random
import matplotlib.pyplot as plt
from flax import nnx
import jax
import jax.numpy as jnp
import pandas as pd
import copy
import math
from matplotlib import cm
import tensorflow as tf
import optax
import flax
from clu import metrics
from flax import struct
import flax.linen as nn
import tensorflow_datasets as tfds
from sklearn.model_selection import train_test_split

### Helper functions (Run Cell)

In [None]:
from IPython.display import clear_output
from flax.core import freeze, unfreeze

@struct.dataclass
class TrainMetrics(metrics.Collection):
  loss: metrics.Average.from_output('loss')

@struct.dataclass
class EvalMetrics(metrics.Collection):
  loss: metrics.Average.from_output('loss')

def train_step(params, model, optimizer, opt_state, loss_grad_fn, metrics, batch):
  """Train for a single step."""
  (loss, logits), grads = loss_grad_fn(params, model, batch)
  updates, opt_state = optimizer.update(grads, opt_state)
  params = optax.apply_updates(params, updates)
  labels = batch[1].astype(jnp.int32)
  metric_updates = TrainMetrics.single_from_model_output(
    logits=logits, labels=labels, loss=loss)

  metrics = metrics.merge(metric_updates)
  return params, opt_state, metrics

def eval_step(params, model, loss_fn, metrics, batch):
  loss, logits = loss_fn(params, model, batch)
  labels = batch[1].astype(jnp.int32)
  metric_updates = EvalMetrics.single_from_model_output(
    logits=logits, labels=labels, loss=loss)

  metrics = metrics.merge(metric_updates)
  return metrics


def train(
    epochs, params, model, optimizer, opt_state, loss_grad_fn,
    loss_fn, train_ds, test_ds, metrics_history,
  ):

  for i in range(epochs):
    train_metrics = TrainMetrics.empty()
    for step, batch in enumerate(train_ds.as_numpy_iterator()):
      params, opt_state, train_metrics = train_step(params, model, optimizer, opt_state, loss_grad_fn, train_metrics, batch)

    for metric, value in train_metrics.compute().items():
      metrics_history[f"train_{metric}"].append(value)

    eval_metrics = EvalMetrics.empty()
    for step, batch in enumerate(test_ds.as_numpy_iterator()):
      eval_metrics = eval_step(params, model, loss_fn, eval_metrics, batch)

    for metric, value in eval_metrics.compute().items():
      metrics_history[f"test_{metric}"].append(value)

  clear_output(wait=True)
  # Plot loss and accuracy in subplots
  fig, ax = plt.subplots(figsize=(7, 5))
  ax.set_title('Loss')
  for dataset in ('train', 'test'):
    ax.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}')
  ax.legend()
  plt.show()

  return params, opt_state, metrics_history


# Recursively print all parameter names and their shapes
def print_param_shapes(params, prefix=""):
    for key, val in params.items():
        if isinstance(val, dict):
            print_param_shapes(val, prefix=f"{prefix}{key}/")
        else:
            print(f"{prefix}{key}: shape={val.shape}")

## Evaluation and Generalization

### Model Evaluation

In simple terms:  
- **Evaluation** tells us how well our model is performing on a given dataset.  
- **Generalization** tells us how well the model can perform on data it has never seen before.  

Our goal is not just to build a model that performs well on the **training data**, but one that learns the **underlying patterns in the data**.  
This way, it can make good predictions on **unseen examples** — and even handle slightly different (out-of-distribution) cases.


#### **Breast cancer classification**
For this section, we will revisit the breast cancer classification task from Part 2 of the practicals. Here will pay more attention and focus on the performance of our model.

Let's load the data

In [None]:
from sklearn.datasets import load_breast_cancer

# Load breast cancer dataset from sklearn
data = load_breast_cancer()

# Convert the dataset to a pandas DataFrame
df = pd.DataFrame(data.data, columns=data.feature_names)

# Add the target variable to the DataFrame
# We reverse here because sklearn stores "Malignant" which is cancerous as 0
# and Benign non cancerous as 1
df['target'] = 1 - data.target
df.head()

In [None]:
df.info()

In [None]:
# Check the proportion of 1s and 0s
proportion = df["target"].value_counts(normalize=True)

print("Counts:\n", df["target"].value_counts())
print("\nProportions:\n", proportion)

🤔 **Pause and reflect:** If we look at the proportions of our labels more then 62% belong to one class. How do you think this can affect the performance of our model?

In [None]:
# split dataset into test and train
train_set, test_set = train_test_split(df, test_size=0.2, random_state=42, stratify=df["target"])

# split each set into input and target
y_train = train_set.pop('target').astype(np.int32)
x_train = train_set

y_test = test_set.pop('target').astype(np.int32)
x_test = test_set

print(f"training data input shape {x_train.shape}")
print(f"training target shape {y_train.shape}")

train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
test_dataset = tf.data.Dataset.from_tensor_slices((x_test, y_test))

Before diving into different evaluation metrics let implement and train a model with out datasets. The [helper function](#scrollTo=QeUWtf7PtE-K) contain the training loop function we will use.

💻 Code Task: Train a Multi-Layer Perceptron (MLP) Binary Classifier

Your goal is to complete the definition of an MLP model using Flax’s Linen API. This model will take input features from a tumor dataset and output a single logit representing whether the tumor is malignant or benign (binary classification).

You'll implement the model by completing the __call__ method of the MLP class.

1. Architecture

   - Design your own architecture using activation functions we learned previously.
   - The input to your MLP should be a list of hidden layers. Previously we passed each hidden layer separately

2. Implement the loss function.

   - Use the binary cross entropy loss.

3.  Call the training loop.

In [None]:
import flax.linen as nn

class MLP(nn.Module):

  layer_sizes: list  # e.g., [64, 32, 10, 1] (hidden layers + output layer)

  @nn.compact
  def __call__(self, x):

    # Implement the various hidden layers with your choosen activation function
    # Hint: loop through all the hidden layers first. Then implement the output layer out of the loop.
    for size ... # update me
      x =  ... # update me



    # Output layer
    x = ... # update me


    return x

In [None]:
# @title 🔓Solution - MLP in Jax(Try not to peek until you've given it a good try!')
class MLP(nn.Module):
  layer_sizes: list  # e.g., [64, 32, 10, 1] (hidden layers + output layer)

  @nn.compact
  def __call__(self, x):
      # Apply all hidden layers with GeLU
      for size in self.layer_sizes[:-1]:
          x = nn.Dense(size)(x)
          x = nn.gelu(x)

      # Output layer (no activation)
      x = nn.Dense(self.layer_sizes[-1])(x)
      return x

In [None]:

def get_model_and_optimizer(input_size, output_sizes=None, seed=32, lr=1e-3):
  # Helper function to quickly initialise MLP models
  # output sizes is list of different output layer e.g. [10,10,1]
  # the final value in the output size should be 1 since this what we want for our MLP

  model = MLP(output_sizes)

  key = jax.random.PRNGKey(seed)

  dummy_data = jnp.zeros((1, input_size), dtype=float)


  params = model.init(key, dummy_data)

  # Print model parameters
  print_param_shapes(params['params'])

  optimizer = optax.adam(learning_rate=lr)
  opt_state = optimizer.init(params)

  return model, params, optimizer, opt_state


In [None]:
def loss_func(params, model, batch):
  """Compute the sigmoid binary cross-entropy loss and return logits."""

  # Extract inputs and labels from batch

  # Calculate the logits

  labels = jnp.reshape(labels, logits.shape) # Reshape the labels to match the shape of the logits.

  # Compute binary cross-entropy loss

  return loss, logits

# Calculate gradients on loss here
loss_grad_fn = ...

In [None]:
# @title 🔓Solution - loss and grads computations (Try not to peek until you've given it a good try!')
def loss_func(params, model, batch):
  # Your code here
  inputs = batch[0]
  labels = batch[1]

  logits = model.apply(params, inputs)
  labels = jnp.reshape(labels, logits.shape)
  loss = optax.sigmoid_binary_cross_entropy(
      logits=logits, labels=labels
  ).mean()
  return loss, logits

# Calculate gradients on loss here
loss_grad_fn = jax.value_and_grad(loss_func, has_aux=True)

In [None]:
batch_size = 32 # you can modify this if you wish
epochs = 100 # you can modify this if you wish
seed = 32 # you can modify this if you wish
lr = 1e-3 # you can modify this if you wish

metrics_history = {
    "train_loss": [],
    "test_loss": [],
}

train_ds = train_dataset.shuffle(1000).batch(batch_size)
test_ds = test_dataset.batch(batch_size)

input_size = 30
output_sizes = ... # update this based on your MLP design
model, params, optimizer, opt_state = get_model_and_optimizer(input_size, output_sizes, seed=seed, lr=lr)

params, opt_state, metric_history = train(epochs, params, model, optimizer, opt_state,
                                          loss_grad_fn, loss_func, train_ds, test_ds, metrics_history)


In [None]:
# @title 🔓Solution - calling the training loop (Try not to peek until you've given it a good try!')
batch_size = 32
epochs = 100

metrics_history = {
    "train_loss": [],
    "test_loss": [],
}

train_ds = train_dataset.shuffle(1000).batch(batch_size)
test_ds = test_dataset.batch(batch_size)

input_size = 30
output_sizes = [30,30,20,1]
model, params, optimizer, opt_state = get_model_and_optimizer(input_size, output_sizes, seed=32, lr=1e-3)

params, opt_state, metric_history = train(epochs, params, model, optimizer, opt_state,
                                          loss_grad_fn, loss_func, train_ds, test_ds, metrics_history)


#### Evaluation metrics

Now that we have trained our model, we will use different metrics to guage the performance of our model.

💻 Code Task: Implement a prediction function for our model.
This function should take as input, the model, the paramemters, input features and threshold for classification.

In [None]:
def predict(params, model, x, threshold=0.5):
    """
    Apply model and return class predictions based on threshold.

    Args:
        params: trained model parameters
        model: Flax MLP model
        x: input array
        threshold: decision threshold (default=0.5)

    Returns:
        jnp.array of predictions (0 or 1)
    """
    logits = model.apply... # update me
    probs = ... # update me
    preds = ... # update me using threshold


    return preds.squeeze()

In [None]:
# @title 🔓Solution - calling the training loop (Try not to peek until you've given it a good try!')
def predict(params, model, x, threshold=0.5):
    """
    Apply model and return class predictions based on threshold.

    Args:
        params: trained model parameters
        model: Flax MLP model
        x: input array
        threshold: decision threshold (default=0.5)

    Returns:
        jnp.array of predictions (0 or 1)
    """
    logits = model.apply(params, x)

    probs = nn.sigmoid(logits)
    preds = (probs >= threshold).astype(jnp.int32).squeeze()

    return preds

Let define the most common metrics used in classification tasks.

#### ✅ Accuracy

Accuracy measures the proportion of correct predictions out of total predictions:

$$
\text{Accuracy} = \frac{\text{Number of correct predictions}}{\text{Total number of predictions}}
$$

While useful, accuracy can be **misleading**, especially on **imbalanced datasets**. Imagine if only 5% of tumors are malignant. A model that always predicts "benign" will still have 95% accuracy, but be completely useless.

#### 🎯 Precision

Precision tells us **how many predicted positives were actually correct**:

$$
\text{Precision} = \frac{TP}{TP + FP}
$$

Useful when **false positives are costly**, e.g., incorrectly diagnosing a healthy patient as having cancer.

#### 🚨 Recall (Sensitivity)

Recall tells us **how many actual positives were predicted**:

$$
\text{Recall} = \frac{TP}{TP + FN}
$$

Important when **missing a true positive case is dangerous** e.g. failing to identify a malignant tumor.

💻 Code Task: Complete the functions below to implement the above 3 metrics, accuracy, precision and recall.

In [None]:
def accuracy(y_true, y_pred):
    """Compute accuracy = correct predictions / total"""
    acc = ... # update me
    return acc


def precision(y_true, y_pred):
    """Compute precision = TP / (TP + FP)"""
    tp = ... # update me -- true positve
    fp = ... # update me -- false positive
    result = ... # update me

    return result


def recall(y_true, y_pred):
    """Compute recall = TP / (TP + FN)"""
    tp = ... # update me -- true positve
    fn = ... # update me -- false negative
    result = ... # update me

    return result


In [None]:
# @title 🔓Solution - accuracy, precision, recall (Try not to peek until you've given it a good try!')
def accuracy(y_true, y_pred):
    """Compute accuracy = correct predictions / total"""
    return jnp.mean(y_true == y_pred)

def precision(y_true, y_pred):
    """Compute precision = TP / (TP + FP)"""
    tp = jnp.sum((y_true == 1) & (y_pred == 1))
    fp = jnp.sum((y_true == 0) & (y_pred == 1))
    return tp / (tp + fp + 1e-8)  # add epsilon to avoid division by zero


def recall(y_true, y_pred):
    """Compute recall = TP / (TP + FN)"""
    tp = jnp.sum((y_true == 1) & (y_pred == 1))
    fn = jnp.sum((y_true == 1) & (y_pred == 0))
    return tp / (tp + fn + 1e-8) # add epsilon to avoid division by zero

Let compute these metrics on test and training sets.

In [None]:
ypreds_train = predict(params, model, x_train, threshold=0.5)
ypreds_test = predict(params, model, x_test, threshold=0.5)

# convert this to numpy from dataframes
y_train = jnp.array(y_train.to_numpy())
y_test = jnp.array(y_test.to_numpy())

train_acc = accuracy(y_train, ypreds_train)
test_acc = accuracy(y_test, ypreds_test)

train_precision = precision(y_train, ypreds_train)
train_recall = recall(y_train, ypreds_train)

test_precision = precision(y_test, ypreds_test)
test_recall = recall(y_test, ypreds_test)

print(f"Train Accuracy: {train_acc:.4f}, Test Accuracy {test_acc:.4f}")
print(f"Train Precision: {train_precision:.4f}, Test Precision {test_precision:.4f}")
print(f"Train Recall: {train_recall:.4f}, Test Recall {test_recall:.4f}")

🤔 Pause and reflect: What can you say about your model performance?

Is your model performing similarly on the training set and the test set?

Are you satisified with your precision and recall scores given their implications?

How can we improve our model?

#### 📊 Aggregate Metrics
In the preciding section, we define various metrics which can be analysed independently. Most often in machine learning and science studies, we want to summarise everthing in single number that can tell the full story. Now we look at few other metrics that try exactly to do that.

**ROC Curve**

The **Receiver Operating Characteristic (ROC) curve** shows how well a binary classifier can separate the two classes across different thresholds.  
- The x-axis is the **False Positive Rate (FPR)**  
- The y-axis is the **True Positive Rate (TPR / Recall)**  
- The closer the curve is to the **top-left corner**, the better the model.  
- The **Area Under the Curve (AUC)** summarizes performance:  
  - AUC = 1 → perfect classifier  
  - AUC = 0.5 → random guessing

**F1-Score**

The **F1-score** balances **precision** and **recall** in one number.  
It is useful when classes are imbalanced.  
$$
F1 = \frac{2 \cdot (\text{Precision} \cdot \text{Recall})}{\text{Precision} + \text{Recall}}
$$

- High F1 means the model has both **good precision** (few false positives) and **good recall** (few false negatives).  
- Useful for medical tasks like cancer detection where both errors matter.

**Matthews Correlation Coefficient (MCC)**

The **MCC** is a more balanced evaluation metric that uses all four confusion matrix values i.e. TP, TN, FP, and FN.  

$$
\text{MCC} = \frac{TP \cdot TN - FP \cdot FN}{\sqrt{(TP+FP)(TP+FN)(TN+FP)(TN+FN)}}
$$

- MCC = +1 → perfect prediction  
- MCC = 0 → random prediction  
- MCC = -1 → total disagreement  

MCC is especially good for **imbalanced datasets**, where accuracy alone can be misleading.

💻 Code Task: Complete the functions below to implement the F1 score and the MCC.

In [None]:
def f1_score(y_true, y_pred):
    p = ... # update me -- precision
    r = ... # update me -- recall
    result = ... # update me
    return result

def matthews_corrcoef(y_true, y_pred):
    tp = np.sum((y_true == 1) & (y_pred == 1))
    tn = np.sum((y_true == 0) & (y_pred == 0))
    fp = ... # update me --- false positive
    fn = ... # update me --- false negative

    numerator = (tp * tn) - (fp * fn)
    denominator = ... # update me
    return numerator / denominator

In [None]:
def f1_score(y_true, y_pred):
    p = precision(y_true, y_pred)
    r = recall(y_true, y_pred)
    return 2 * (p * r) / (p + r + 1e-8)

def matthews_corrcoef(y_true, y_pred):
    tp = float(np.sum((y_true == 1) & (y_pred == 1)))
    tn = float(np.sum((y_true == 0) & (y_pred == 0)))
    fp = float(np.sum((y_true == 0) & (y_pred == 1)))
    fn = float(np.sum((y_true == 1) & (y_pred == 0)))

    numerator = (tp * tn) - (fp * fn)
    denominator = np.sqrt((tp+fp) * (tp+fn) * (tn+fp) * (tn+fn) + 1e-8)
    return numerator / denominator

Let's compute the F1 score and MCC of our model

In [None]:
train_f1_score = f1_score(y_train, ypreds_train)
test_f1_score = f1_score(y_test, ypreds_test)

train_mcc = matthews_corrcoef(y_train, ypreds_train)
test_mcc = matthews_corrcoef(y_test, ypreds_test)

print(f"Train f1 score: {train_f1_score:.4f}, Test f1 score {test_f1_score:.4f}")
print(f"Train MCC: {train_mcc:.4f}, Test MCC {test_mcc:.4f}")

🤔 Pause and reflect: What can you say about your model performance based on these metrics?


### Cross validation and Generalisation

In the previous sections we discussed different metrics which can be use to measure the performance of machine learning models in classifications tasks.

So far we've used an approach where we split the data into a training and test set. This implies we are are only measuring the generalisation ability of the model using a fixed training and test set.

🤔 Pause and reflect: What approach can we use in our train/test splitting strategy to improve model generalisation?


**Cross-validation** is a popular machine learning technique used to test the **generalization capability** of models.  

In this approach, we split the dataset into several parts (called *folds*). The model is trained on some of these folds and tested on the remaining ones. This process is repeated so that each fold serves as the test set once. Finally, we evaluate the model’s performance on each test set and take the **average score** as the overall performance.  

📖 You can read more about different cross-validation techniques [here](https://www.geeksforgeeks.org/machine-learning/cross-validation-machine-learning/).


## Optimization algorithms, learning rate schedulers, and hyperparameter tunning

## Appendix


### References
1. Flax module documentation: https://flax-linen.readthedocs.io/en/latest/api_reference/flax.linen/module.html


## Feedback

Please provide feedback that we can use to improve our practicals in the future.

In [None]:
# @title Generate Feedback Form. (Run Cell)
from IPython.display import HTML

HTML(
    """
<iframe
	src="https://forms.gle/CJCNwwcLW9Y3jZDG7",
  width="80%"
	height="1200px" >
	Loading...
</iframe>
"""
)

<img src="https://baobab.deeplearningindaba.com/static/media/indaba-logo-dark.d5a6196d.png" width="50%" />