In [63]:
import torch, sys
import numpy as np
from torch.autograd import Variable
import torch.nn.functional as F
from tqdm.notebook import tqdm

import plotly.express as px
import plotly.graph_objects as go

In [258]:
# generate random data
def hidden_function(x):
    return np.sin(x)**3

x = np.linspace(-15,15,20000, dtype=np.float32)
y = hidden_function(x)

In [241]:
class SimpleRegression(torch.nn.Module):
    
    def __init__(self, input_dims, output_dims, hidden_dims=250):
        super().__init__()
        self.input_layer  = torch.nn.Linear(input_dims, hidden_dims)
        self.h1           = torch.nn.Linear(hidden_dims, hidden_dims)
        self.output_layer = torch.nn.Linear(hidden_dims, output_dims)
        
    def forward(self, x):
        out = self.input_layer(x)
        out = F.relu(out)
        
        out = self.h1(out)
        out = F.relu(out)
        
        out = self.output_layer(out)
        return out

In [259]:
class Agent():
    
    def __init__(self, dim_shape, lr=0.001):
        self.input_dim, self.output_dim  = dim_shape
        self.output_dim = output_dim
        self.model      = SimpleRegression(input_dim, output_dim)
        
        self.loss_fn   = torch.nn.MSELoss() 
        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=lr)
        
        self.losses = []
    
    def predict(self, x):
        predictions = model(torch.from_numpy(x))
        return predictions.detach().numpy()
        
    def fit(self, x, y, epochs=100, batch_size=None, graph_loss=True):
        
        # shuffle data
        shuffled_idx = np.arange(0, len(x))
        np.random.shuffle(shuffled_idx)
        
        x = x[shuffled_idx]
        y = y[shuffled_idx]
        
        batch_size = batch_size if batch_size != None else len(x) // 10
        
        with tqdm(total=epochs, file=sys.stdout) as pbar:
            for epoch in range(1, epochs+1):
                for batch in range(len(x) // batch_size):
                    # Get training batch and reshape to [[0], [1], ..., [n]]
                    x_train = x[batch * batch_size : (batch+1) * batch_size][np.newaxis].T
                    y_train = y[batch * batch_size : (batch+1) * batch_size][np.newaxis].T
                    # Convert to torch tensor
                    inputs  = Variable(torch.from_numpy(x_train))
                    labels  = Variable(torch.from_numpy(y_train))
                    
                    # Perform training step
                    self.optimizer.zero_grad()
                    predictions = self.model(inputs)
                    loss        = self.loss_fn(predictions, labels)
                    loss.backward()
                    self.optimizer.step()
                    
                    # Save loss
                    self.losses.append(loss.item())
                    
                # Update tqdm bar
                pbar.update(1)
                pbar.set_description(f'epoch loss: {np.mean(self.losses[-len(x)//batch_size:])}')
        
        if graph_loss:
            fig = go.Figure()
            fig.add_trace(go.Scatter(y=self.losses, mode='lines', name='Loss'))
            fig.show()
    
    def validate(self, x, y):
        
        predictions = self.model(torch.from_numpy(x)).detach().numpy()
        mse = np.mean((predictions - y)**2)
        
        fig = go.Figure()
        fig.update_layout(title=f"MSE Loss: {mse}")
        fig.add_trace(go.Scatter(x=x.squeeze(), y=y.squeeze(), mode='lines', name='actual'))
        fig.add_trace(go.Scatter(x=x.squeeze(), y=predictions.squeeze(), mode='markers', name='predicted'))
        fig.show()

In [260]:
agent = Agent(dim_shape=(1, 1), lr=0.001)
agent.fit(x, y, epochs=100, batch_size=100)

HBox(children=(FloatProgress(value=0.0), HTML(value='')))




In [261]:
# Prediction
x_data = np.linspace(-10, 10, 100, dtype=np.float32)[np.newaxis].T
y_data = hidden_function(x_data)

agent.validate(x_data, y_data)