Until now, we have discussed binary classification. Now, it's time to talk about multiclass classification, where we categorize the output into different $I$ categories. Recall the deep neural architecture:

\begin{align*}
h_1 & = \text{ReLU}(\beta_0 + \mathbf{\Omega}_0x), \\
h_2 & = \text{ReLU}(\beta_1 + \mathbf{\Omega}_1h_1), \\
h_3 & = \text{ReLU}(\beta_2 + \mathbf{\Omega}_2h_2), \\
& \vdots \\
h_K & = \text{ReLU}(\beta_{K-1} + \mathbf{\Omega}_{K-1}h_{K-1}), \\
y & = \beta_K + \mathbf{\Omega}_K h_K.
\end{align*}

Firstly, we need $y$ to have $I$ elements, where each $y_j \in \mathbb{R}$ for $j= 1, \cdots, I$.

Secondly, we want $y_j$ to represent the probability that the output belongs to class $j$.

To this end, we use the softmax activation function.

The softmax function is used in multiclass classification problems where we want to classify an input into one of multiple possible categories. It converts a vector of real numbers into a probability distribution.

Each output $y_j$ of the softmax function represents the probability that the input belongs to class $j$.
Given an input vector $y = [y_1, y_2, ..., y_I]$, the softmax function computes the exponential of each element $y_j$, and then normalizes these values by dividing by the sum of all these exponentials. This results in a vector of the same length as $y$, but with all elements between 0 and 1, and the entire vector sums to 1.

Mathematically, this is represented as:

$$
\text{softmax}_j(y) = \frac{\exp\left[y_j\right]}{\Sigma_{i=1}^{I}\exp\left[y_i\right]}
$$


In [None]:
import numpy as np

def softmax(y):
    # Compute the exponentials of each element in the input vector
    exps = np.exp(y)
    # Normalize the exponentials by dividing by the sum of all exponentials
    softmax_output = exps / np.sum(exps)
    return softmax_output

y = np.array([1.0, 2.0, 3.0])
print(softmax(y))

Let's tackle a real-world multiclass classification problem. Our goal is to recognize images of handwritten digits ranging from 0 to 9.

In [None]:
import matplotlib.pyplot as plt

from sklearn import datasets
from sklearn.model_selection import train_test_split

digits = datasets.load_digits()

_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))
for ax, image, label in zip(axes, digits.images, digits.target):
    ax.set_axis_off()
    ax.imshow(image, cmap=plt.cm.gray_r, interpolation="nearest")
    ax.set_title("Training: %i" % label)

In [None]:
# flatten the images
n_samples = len(digits.images)
data = digits.images.reshape((n_samples, -1))

# Split data into 90% train and 10% test subsets
X_train, X_test, y_train, y_test = train_test_split(
    data, digits.target, test_size=0.1, shuffle=True
)

In [None]:
import tensorflow as tf

# Define the model
model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(128, activation='relu'),
    tf.keras.layers.Dense(64, activation='relu'),
    tf.keras.layers.Dense(32, activation='relu'),
    tf.keras.layers.Dense(10, activation='softmax')
])

# Compile the model
model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.SparseCategoricalCrossentropy(), metrics=['accuracy'])

# Train the model
history = model.fit(X_train, y_train, epochs=10, validation_data=(X_test, y_test))

# Evaluate the model
loss, accuracy = model.evaluate(X_test, y_test)
print(f'Test Accuracy: {accuracy*100:.2f}%')

In [None]:
import numpy as np

# Predict the probabilities for each class
y_probs = model.predict(X_test)

# Get the class with the highest probability for each sample
y_pred = np.argmax(y_probs, axis=-1)

# Select four random indices from the test data
indices = np.random.choice(len(X_test), size=4, replace=False)

# Plot the images with the actual and predicted labels
_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))
for ax, image, true_label, pred_label in zip(axes, X_test[indices].reshape(-1, 8, 8), y_test[indices], y_pred[indices]):
    ax.set_axis_off()
    ax.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    title = f'True: {true_label}, Pred: {pred_label}'
    ax.set_title(title, color='green' if true_label == pred_label else 'red')

plt.show()