In [None]:
import numpy as np
import matplotlib.pyplot as plt

N = 100
x = np.sqrt(20)*np.random.randn(N)
eps = (1/10)*np.random.randn(N)

b1_star = -2.0
b2_star = -4.0

def relu(z):
    return np.maximum(0, z)

y = relu(x + b1_star) - relu(x + b2_star) + eps

x_grid = np.linspace(-10, 10, 100)
y_grid = relu(x_grid + b1_star) - relu(x_grid + b2_star)

plt.plot(x_grid, y_grid, color='r', label='Model')
plt.scatter(x, y, label='Training samples')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()

plt.show()


In [None]:
b1 = 5.0
b2 = 4.0
x_grid = np.linspace(-10, 10, 100)
y_grid = relu(x_grid + b1) - relu(x_grid + b2)


# x_grid = np.linspace(-2,2)
plt.plot(x_grid, y_grid, color='r', label='Model')
plt.scatter(x, y, label='Training samples')
plt.xlabel('x')
plt.ylabel('y')
plt.legend()

plt.show()


In [None]:
import numpy as np
import plotly.graph_objects as go


def J(b1, b2):
    predictions = relu(x + b1) - relu(x - b2)
    mse = (1/N) * np.sum((y - predictions) ** 2)
    return mse

def plot_J_surface(x, y, alpha=1.0):
    b1_values = np.linspace(-6, 6, 100)
    b2_values = np.linspace(-6, 6, 100)
    B1, B2 = np.meshgrid(b1_values, b2_values)

    Z = np.array([[J(b1, b2) for b1 in b1_values] for b2 in b2_values])

    fig = go.Figure(data=[go.Surface(z=Z, x=b1_values, y=b2_values, opacity=alpha)])

    fig.update_layout(title='J(b1, b2) surface plot',
                      scene=dict(xaxis_title='b1',
                                 yaxis_title='b2',
                                 zaxis_title='MSE'),
                      width=800, height=600)

    return fig

def plot_line_segment(fig, b1_1, b2_1, b1_2, b2_2, color='blue'):
    mse1 = J(b1_1, b2_1)
    mse2 = J(b1_2, b2_2)

    fig.add_trace(go.Scatter3d(x=[b1_1, b1_2], y=[b2_1, b2_2], z=[mse1, mse2],
                               mode='lines',
                               line=dict(color=color, width=2),
                               name=f'Line Segment ({b1_1}, {b2_1}) -> ({b1_2}, {b2_2})'))

def plot_point(fig, b1, b2, color='red'):
    mse = J(b1, b2)
    fig.add_trace(go.Scatter3d(x=[b1], y=[b2], z=[mse],
                               mode='markers',
                               marker=dict(size=5, color=color),
                               name=f'Point (b1={b1}, b2={b2})'))


In [None]:
fig = plot_J_surface(x, y,alpha=0.5)
plot_point(fig,b1,b2)
fig.show()