In [1]:
import numpy as np
import pandas as pd
import torch
import os
from tqdm import tqdm

import plotly.express as px
import plotly.graph_objects as go

In [2]:
from torch import nn

class ShallowNN(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, d_m: int):
        super().__init__()
        
        self.layer_in = nn.Linear(in_dim, d_m)
        self.relu = nn.ReLU()
        self.layer_out = nn.Linear(d_m, out_dim)
    
    def forward(self, x):
        hidden_state = self.relu(self.layer_in(x))
        
        return self.layer_out(hidden_state)


In [3]:
from torch.optim import SGD
#* Here we'll train a 1-layer function to learn sin(x) and see how it gets better as d_m increases

RANGE = 5
N = 10000

# X = torch.rand(N, dtype = torch.float32) * RANGE * 2 - RANGE / 2
X = torch.linspace(-RANGE, RANGE, N)
Y = torch.sin(X)

N_STEPS = 10_000
LOG_FREQ = N_STEPS / 10

D_M = 10
wide_model = ShallowNN(1, 1, D_M)

optimizer = SGD(wide_model.parameters(), lr = 0.001)
criterion = nn.MSELoss()
for i in range(N_STEPS):
    
    # data_idx = np.random.randint(N)
    # y_hat = wide_model(X[data_idx].unsqueeze(dim = 0))
    # loss = (y_hat - Y[data_idx])**2
    
    y_hat = wide_model(X.unsqueeze(0).T)
    loss = criterion(y_hat.squeeze(dim = 1), Y)
    
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    if i % LOG_FREQ == 0:
        print(loss.item())

1.207705020904541
0.2668953239917755
0.14975234866142273
0.0833459198474884
0.04902288690209389
0.032367512583732605
0.024362854659557343
0.02027054689824581
0.017894281074404716
0.016299648210406303


In [7]:
X_sin = torch.linspace(-RANGE, RANGE, N)

y_hats = wide_model(X_sin.unsqueeze(0).T).squeeze(dim= 1).detach().numpy()
y_sin = torch.sin(X_sin)

joints_x = (-1 *wide_model.layer_in.bias /  wide_model.layer_in.weight.reshape(-1))
joints_y = wide_model(joints_x.unsqueeze(0).T).squeeze(dim  =1).detach().numpy()
joints_x = joints_x.detach().numpy()
scatter = go.Scatter(x=X_sin.numpy(), y=y_hats, mode='lines', name='Model Predictions')

joint_plot = go.Scatter(x = joints_x, y = joints_y, mode = "markers", name = "Joints")

# Create the line plot for the true sin(x) function
line = go.Scatter(x=X_sin.numpy(), y=y_sin.numpy(), mode='lines', name='True sin(x)')

# Combine the plots
fig = go.Figure(data=[scatter,joint_plot, line] )

# Update layout
fig.update_layout(title='Model Predictions vs True sin(x)',
                  xaxis_title='x',
                  yaxis_title='y',
                  legend_title='Legend',
                xaxis=dict(range=[-RANGE - 1, RANGE + 1]),
                yaxis = dict(range = [-2, 2]))

# Show the plot
fig.show()

In [115]:
px.scatter(x = joints_x, y = joints_y )

In [112]:
joints_y.shape

(100, 1)

In [22]:
np.random.randint(100)

59