# Double Descent visualization

By [Danya Merkulov](https://t.me/@fminxyz). 

The code below was inspired by:
* [link1](https://x.com/adad8m/status/1582231644223987712)
* [link2](https://fleuret.org/git-extract/pytorch/ddpol.py)

In [24]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

# Set random seed for reproducibility
np.random.seed(8)

# Define parameters
D_max = 100
nb_train_samples = 50
train_noise_std = 0.25

def f(x):
    return np.sin(4.5*x)

# Generate data
a = -1
b = 1
x_train = np.linspace(a, b, nb_train_samples)
y_train = f(x_train) + np.random.normal(0, train_noise_std, nb_train_samples)
x_test = np.linspace(a, b, 100)
y_test = f(x_test)

# Lists to store MSE values
mse_train_list = []
mse_test_list = []
polynomials = []

# Fit polynomials using Chebyshev basis and compute MSE
for D in range(D_max + 1):
    X_train = np.polynomial.chebyshev.chebvander(x_train, D)
    X_test = np.polynomial.chebyshev.chebvander(x_test, D)
    
    beta = np.linalg.pinv(X_train) @ y_train
    polynomials.append(beta)
    
    y_train_pred = X_train @ beta
    y_test_pred = X_test @ beta
    
    mse_train = np.mean((y_train_pred - y_train) ** 2)
    mse_test = np.mean((y_test_pred - y_test) ** 2)
    
    mse_train_list.append(mse_train)
    mse_test_list.append(mse_test)

# Set up the figure, the axis, and the plot elements we want to animate
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(9, 4))

ax1.set_ylim(-1.1, 1.1)
line1, = ax1.plot([], [], lw=2, color='red')
scatter = ax1.scatter(x_train, y_train, color='blue')
cos_line, = ax1.plot(x_test, y_test, color='black', linestyle="--")
ax1.set_title('Polynomial Fitting')
ax1.grid(linestyle=":")
ax1.set_axisbelow(True)

ax2.set_yscale('log')
ax2.set_ylim(1e-5, 1e18)
ax2.set_xlim(0, D_max)
ax2.set_xlabel('Polynomial degree')
ax2.set_ylabel('MSE')
line2, = ax2.plot([], [], lw=2, color='red', label='Test')
line3, = ax2.plot([], [], lw=2, color='blue', label='Train')
ax2.legend(frameon=False)
ax2.grid(linestyle=":")
ax2.set_axisbelow(True)
ax2.annotate('@fminxyz',
                fontsize=16, c='grey', zorder=20,
                xy=(1.1, -0.043), xytext=(0, 20),
                xycoords=('axes fraction', 'figure fraction'),
                textcoords='offset points',
                ha='right', va='bottom')

fig.tight_layout()
fig.subplots_adjust()

# Initialization function: plot the background of each frame
def init():
    line1.set_data([], [])
    line2.set_data([], [])
    line3.set_data([], [])
    return line1, line2, line3

# Animation function which updates figure data. This is called sequentially
def animate(i):
    beta = polynomials[i]
    X_test = np.polynomial.chebyshev.chebvander(x_test, i)
    y_fit = X_test @ beta

    line1.set_data(x_test, y_fit)
    line2.set_data(range(i+1), mse_test_list[:i+1])
    line3.set_data(range(i+1), mse_train_list[:i+1])

    return line1, line2, line3

# Call the animator. blit=True means only re-draw the parts that have changed.
anim = FuncAnimation(fig, animate, init_func=init, frames=D_max+1, interval=100, blit=True)

# This took about 0.5 minute on Macbook Pro with M1 Pro
video = anim.to_html5_video()

from IPython import display
html = display.HTML(video)
display.display(html)
plt.close(fig)

# Save the animation as an MP4 file with high DPI
# anim.save('DD_pro.mp4', dpi=300, writer='ffmpeg')

# plt.close(fig)