In [63]:
import torch, sys
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
from tqdm.notebook import tqdm

import plotly.express as px
import plotly.graph_objects as go

In [97]:
# Just to get FF network working in pytorch
# thank you https://towardsdatascience.com/linear-regression-with-pytorch-eb6dedead817
class SimpleRegression(torch.nn.Module):
    
    def __init__(self, input_dims, output_dims, hidden_dims=250):
        super().__init__()
        self.input_layer  = torch.nn.Linear(input_dims, hidden_dims)
        self.h1           = torch.nn.Linear(hidden_dims, hidden_dims)
        self.output_layer = torch.nn.Linear(hidden_dims, output_dims)
        
    def forward(self, x):
        out = self.input_layer(x)
        out = F.relu(out)
        
        out = self.h1(out)
        out = F.relu(out)
        
        out = self.output_layer(out)
        return out

In [163]:
# generate random data
def hidden_function(x):
    return np.sin(x)**3

x = np.linspace(-10,10,2100, dtype=np.float32)
y = hidden_function(x)

In [172]:
# model hyperparams
input_dim  = 1
output_dim = 1
lr = 0.01
epochs = 250
batch_size = 200

In [173]:
model = SimpleRegression(input_dim, output_dim)
criterion = torch.nn.MSELoss() 
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
losses = []

shuffled_idx = np.arange(0, len(x))
np.random.shuffle(shuffled_idx)
x = x[shuffled_idx]
y = y[shuffled_idx]

with tqdm(total=epochs, file=sys.stdout) as pbar:
    for epoch in range(epochs):
        for batch in range(len(x) // batch_size):
            x_train = x[batch * batch_size : (batch+1) * batch_size][np.newaxis].T
            y_train = y[batch * batch_size : (batch+1) * batch_size][np.newaxis].T

            inputs = Variable(torch.from_numpy(x_train))
            labels = Variable(torch.from_numpy(y_train))

            optimizer.zero_grad()
            outputs = model(inputs)
            loss    = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
        
            losses.append(loss.item())
            
        pbar.update(1)
        if epoch % 10 == 0:
            pbar.set_description(f'epoch: {epoch}, loss: {loss.item()}')

HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))




In [174]:
# Loss
px.line(losses)

In [175]:
# Prediction
x_data = np.linspace(-10, 10, 100, dtype=np.float32)[np.newaxis].T
y_data = hidden_function(x_data)
model_data = model(torch.from_numpy(x_data)).detach().numpy()

fig = go.Figure()
fig.add_trace(go.Scatter(x=x_data.squeeze(), y=y_data.squeeze(), mode='lines', name='actual'))
fig.add_trace(go.Scatter(x=x_data.squeeze(), y=model_data.squeeze(), mode='markers', name='predicted'))
fig.show()