In [None]:
from utils import make_worm
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
from torch.optim import Adam
import torch
plt.rcParams["animation.html"] = "jshtml"

In [None]:
D,T = 4, 200
X,y = make_worm(D=D, T=T, sigma=0.01)
plt.scatter(X[:,0], X[:,1], c=y)

In [None]:
class MyMLP(nn.Module):
    def __init__(self, input_dimension, hidden_layers, activation):
        super(MyMLP, self).__init__()

        # Define activation function
        activations = {"relu": nn.ReLU(), "sigmoid": nn.Sigmoid()}
        self.activation = activations.get(activation, nn.ReLU())  # Default to ReLU


        layers = []
        prev_dim = input_dimension

        for hidden_dim in hidden_layers:
            layers.append(nn.Linear(prev_dim, hidden_dim))
            layers.append(self.activation)
            prev_dim = hidden_dim


        # Output layer
        layers.append(nn.Linear(prev_dim, 2))

        self.model = nn.Sequential(*layers)
        


    def forward(self,x):
        return self.model(x)


In [None]:
res = 50
grid_1d = np.linspace(0, 1.0, res)

[xx,yy] = np.meshgrid(grid_1d,grid_1d)

X_grid = np.column_stack((yy.flatten(), xx.flatten()))

if D>2:
    X_grid = np.hstack((X_grid, 0.5*np.ones((X_grid.shape[0],D-2))))

X_grid_tensor = torch.from_numpy(X_grid).float()

In [None]:
epochs = 1000

mod = MyMLP(D, [20,20], activation="relu")
optimizer = Adam(mod.parameters(), lr=0.001)
loss_fun = nn.CrossEntropyLoss()

X_tensor = torch.from_numpy(X).float()
y_tensor = torch.from_numpy(y).long()

hist_pred = []
hist_loss = []
hist_dec_bound = []

for epoch in range(epochs):
    pred = mod(X_tensor)
    loss = loss_fun(pred, y_tensor)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


    hist_pred.append(mod(X_tensor).detach().numpy())
    hist_loss.append(loss.item())

    hist_dec_bound.append(torch.softmax(mod(X_grid_tensor), dim=1).detach().numpy())
    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")

hist_loss = np.array(hist_loss)

In [None]:
from matplotlib.animation import FuncAnimation
frames = np.linspace(1, len(hist_pred)-1, dtype=np.int64)

mosaic = """
BB
BB
AA
"""
fig, ax = plt.subplot_mosaic(mosaic, figsize=(12,8))

line_objs, = ax["A"].plot(0, hist_loss[0])
ax["A"].set_ylim(0, hist_loss.max()*1.2)
ax["A"].set_xlabel("iteration")
ax["A"].set_ylabel(r"$\mathcal{L}$")

preds_scatter = ax["B"].scatter(X[:,0], X[:,1], c=hist_pred[0][:,0], vmin=0., vmax=1.)

fig.tight_layout()

def update(i):
    line_objs.set_data(frames[:i].flatten(), hist_loss[frames[:i]])
    ax["A"].set_xlim(0, i+1)

    preds_scatter.set_array(hist_pred[frames[i]][:,0])
    return line_objs, preds_scatter

ani = FuncAnimation(
    fig, update, frames=len(frames)
)
plt.close()
ani

In [None]:
from matplotlib.animation import FuncAnimation
frames = np.linspace(1, len(hist_pred)-1, 100, dtype=np.int64)

mosaic = """
BBB
BBB
BBB
AAA
"""
fig, ax = plt.subplot_mosaic(mosaic, figsize=(8,8), per_subplot_kw={"B":{"projection": "3d"}})

# Plot points
ax["B"].scatter(X[:,0],X[:,1],1-y, marker='o', s=20, c="goldenrod", alpha=0.6)
dec_bound = ax["B"].plot_surface(X_grid[:,0].reshape((res,res)), X_grid[:,1].reshape((res,res)), hist_dec_bound[0][:,0].reshape((res,res)),vmin=0.,vmax=1.)
dec_bound_proj = ax["B"].contourf(X_grid[:,0].reshape((res,res)), X_grid[:,1].reshape((res,res)), hist_dec_bound[0][:,0].reshape((res,res)), zdir='z', offset=-0.5, cmap='coolwarm',vmin=0.,vmax=1.)
ax["B"].set_xlabel(r"$x_1$")
ax["B"].set_ylabel(r"$x_2$")
ax["B"].set_zlabel("Class")

ax["B"].set(zlim=(-0.6,1))
ax["B"].view_init(elev=10., azim=20)

# Plot obj function
line_objs, = ax["A"].plot(0, hist_loss[0])
ax["A"].set_ylim(0, hist_loss.max()*1.2)
ax["A"].set_xlabel("iteration")
ax["A"].set_ylabel(r"$\mathcal{L}$")

fig.tight_layout()

def update(i):
    global dec_bound, dec_bound_proj
    # Update obj function plot
    line_objs.set_data(range(frames[i]), hist_loss[:frames[i]])
    ax["A"].set_xlim(0, frames[i])

    dec_bound.remove()
    dec_bound = ax["B"].plot_surface(X_grid[:,0].reshape((res,res)), X_grid[:,1].reshape((res,res)), hist_dec_bound[frames[i]][:,0].reshape((res,res)),cmap='viridis', vmin=0., vmax=1.)

    dec_bound_proj.remove()
    dec_bound_proj = ax["B"].contourf(X_grid[:,0].reshape((res,res)), X_grid[:,1].reshape((res,res)), hist_dec_bound[frames[i]][:,0].reshape((res,res)), zdir='z', offset=-0.5, cmap='coolwarm', vmin=0., vmax=1.)

    ax["B"].view_init(elev=10., azim=20+3*i)

    return line_objs, dec_bound

ani = FuncAnimation(
    fig, update, frames=len(frames)
)
plt.close()
ani

In [None]:
ani.save("NN_dec.mp4")