Employing gradient descent to locate the global optimum of a loss function with many dimensions can be difficult.

It's possible to identify a minimum, but determining whether it's the global minimum or even a satisfactory one is impossible.

A significant issue is that the end point of a gradient descent algorithm is entirely dependent on its initial point.

**Stochastic gradient descent (SGD)** tries to address this by introducing some randomness into the gradient at each step.


The method for incorporating randomness is straightforward.

In each iteration, the algorithm selects a random subset from the training data and calculates the gradient solely based on these examples.

This subset is commonly referred to as a **minibatch**, or simply a **batch**.

Consequently, the update rule for the model parameters $\boldsymbol{\phi}_t$ at the t-th iteration is as follows:


$$
\boldsymbol{\phi}_{t+1} \leftarrow \boldsymbol{\phi}_t - \alpha \cdot \sum_{i \in B_t} \frac{\partial l_i(\boldsymbol{\phi}_t)}{\partial \boldsymbol{\phi}}
$$

Typically, the batches are selected from the dataset without replacement.

A complete pass through the entire training dataset is known as an **epoch**.

In [None]:
import tensorflow as tf

# Data
x = tf.constant([0.03, 0.19, 0.34, 0.46, 0.78, 0.81, 1.08, 1.18, 1.39, 1.60, 1.65, 1.90])
y = tf.constant([0.67, 0.85, 1.05, 1.0, 1.40, 1.5, 1.3, 1.54, 1.55, 1.68, 1.73, 1.6 ])

# Model
model = tf.keras.Sequential(
    [
        tf.keras.layers.Dense(1, input_shape=(1,))
    ]
)
# Compile
model.compile(loss=tf.keras.losses.mse,
              optimizer=tf.keras.optimizers.SGD())
# Train
model.fit(x, y, epochs=30, batch_size=3)

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Predict
x_values = np.linspace(0, 2, 100)
y_pred = model.predict(x_values)

# Plot
plt.figure(figsize=(8, 6))
plt.scatter(x, y, label='Actual')
plt.plot(x_values, y_pred, color='red', label='Predicted')
plt.xlim([0, 2])
plt.ylim([0, 2])
plt.xlabel('x')
plt.ylabel('y')
plt.title('Actual vs Predicted')
plt.legend()
plt.show()