In [None]:
import sys
import os
sys.path.insert(0, os.path.abspath('..')) # Add parent directory to path

import plotly.graph_objects as go
import numpy as np
from matplotlib.animation import FuncAnimation
from nn.nets import Net
from nn.layers import InLayer, Dense
from nn.activation_functions import ReLU, linear
from nn.cost_functions import MSE

In [None]:
l0 = InLayer(units=2)
l1 = Dense(units=8, activation=ReLU)
l2 = Dense(units=8, activation=ReLU)
l3 = Dense(units=1, activation=linear)

In [None]:
net = Net([l0, l1, l2, l3], cost_function=MSE)

In [None]:
x1 = x2 = np.linspace(-5, 5, 300)
x1v, x2v = np.meshgrid(x1, x2)

X_train = np.stack((x1v, x2v), axis=-1).reshape(-1, 2)
Y_train = (np.sin(x1v + x2v)).reshape(-1, 1)
print(X_train.shape, Y_train.shape)

In [None]:
frames_raw = []

In [None]:


for i in range (100):
    net.gradient_descent(X_train, Y_train, alpha=0.01, epochs=1000, verbose=False, batch_size=1000)

    Y_pred = net.predict(X_train)
    m = Y_pred.shape[0]
    n = int(np.sqrt(m))
    z = Y_pred.reshape((n, n))
    frames_raw.append(z)

In [None]:
net.predict(np.array([[0, 0]]))
# np.sin(1)

In [None]:
frames_raw[-1][90,90]

In [None]:
# 2D grid

# Generate z surfaces for each frame
# frames_raw = [np.sin(np.sqrt(x1v**2 + x2v**2) - t * 0.2) for t in range(30)]

# Convert to Plotly frames
frames = [go.Frame(data=[go.Surface(z=z, x=x1v, y=x2v)], name=str(i)) for i, z in enumerate(frames_raw)]

# Create figure with first frame's surface
fig = go.Figure(
    data=frames[0].data,  # ✅ Not wrapped in an extra list
    frames=frames
)

# Lock axes and set fast animation
fig.update_layout(
    scene=dict(
        zaxis=dict(range=[0, 1], autorange=False),
        xaxis=dict(range=[-5, 5], autorange=False),
        yaxis=dict(range=[-5, 5], autorange=False)
    ),
    updatemenus=[{
        'type': 'buttons',
        'buttons': [{
            'label': 'Play',
            'method': 'animate',
            'args': [None, {
                'frame': {'duration': 30, 'redraw': True},
                'transition': {'duration': 0},
                'fromcurrent': True
            }]
        }]
    }]
)

fig.show()
