# A simple ANN example with pytorch

## Problem description 
We want to define a simple 2-layer ANN for a regression problem:
- The problem is to get two real values as inputs and subtract one from the other as the output.
- We define a simple ANN with 2 input nodes, 8 hidden nodes in a hidden layer, and 1 output node

In [None]:
# import libraries
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch
from torch.optim import SGD

import matplotlib.pyplot as plt
import numpy as np


In [None]:
# Training set
x = [[7,5],[6,3],[5,2],[4,1],[10,5]]
y = [[2],[3],[3],[3],[5]]

#device = 'cuda' if torch.cuda.is_available() else 'cpu'
#X=torch.tensor(x).float().to(device)
#Y=torch.tensor(y).float().to(device)

X=torch.tensor(x).float()
Y=torch.tensor(y).float()

In [None]:
# to convert the dataset X,Y into a subclass of torch DataSet so we can easily use dataloader
class ModelDataset(Dataset):
    def __init__(self,x,y):
        self.x = x
        self.y = y
    def __getitem__(self,idx):
        return self.x[idx], self.y[idx]
    
    def __len__(self):
        return len(self.x)

In [None]:
ds = ModelDataset(X,Y)

# set the batch size here
dataloader = DataLoader(ds, batch_size=2, shuffle=True)

Here we define the ANN model. We can use Sequential to define the network one layer after the other. 
For every layer 

We use nn.Sequential to define our model as a sequence of layers. 
Sequential is a Module which contains other Modules, and applies them in sequence to produce its output. 
The Linear Module computes output from input using a linear function, and holds internal Tensors for its weight and bias (acts like $z(X)=\sum_i w_i x_i + b  $).
After a Linear Module, we might need to add the activation function, e.g., ReLU

In [None]:

# define the model with 2 input nodes, 8 hidden nodes (in one hidden layer), and 1 node in the output layer

model = nn.Sequential(nn.Linear(2,8), nn.ReLU(), nn.Linear(8,1))
#model = nn.Sequential(nn.Linear(2,8), nn.ReLU(), nn.Linear(8,1)).to(device)


# define the loss function (there are many loss predefined loss functions that we ca use)

loss = nn.MSELoss()

# define the optimization method (here mini-batch SGD), learning rate, to use for optimizing the learnable 
# model parameters (wieghts and biases)

opt = SGD(model.parameters(), lr=0.01)

# define a list to collect the loss values in every GD step (we will plot it later)
loss_history = []

# specify the number of epochs
num_epochs = 1000
for _ in range(num_epochs):
    for data in dataloader:
        x,y = data
        
        opt.zero_grad()  # to flush out the previous gradients
        
        # Forward pass: compute predicted y by passing x to the model. When
        # doing so you pass a Tensor of input data to the Module and it produces
        # a Tensor of output data.
        outputs = model(x)
        
        # Compute and print loss. We pass Tensors containing the predicted and true
        # values of y, and the loss function returns a Tensor containing the loss
        loss_value = loss(outputs,y)
        
        # Backward pass: compute gradient of the loss with respect to all the learnable
        # parameters of the model. Internally, the parameters of each Module are stored
        # in Tensors with requires_grad=True, so this call will compute gradients for
        # all learnable parameters in the model.
        loss_value.backward() 
        
        # Update the weights using gradient descent.
        opt.step() 
        
        #  record the history of the loss values
        loss_history.append(loss_value.detach()) # .detach() detaches the pytorch tensor from the gradient 


In [None]:
# Print the model structure
print(model.state_dict())

In [None]:
# Plot the history of loss during training
plt.plot(loss_history)
plt.xlabel('epochs')
plt.ylabel('loss values')
plt.show()


In [None]:
# Use the model for prediction
test_vals = [[10,2], [4,1],[5,3]]

#test_vals = torch.tensor(test_vals).float().to(device)
test_vals = torch.tensor(test_vals).float()

print(model(test_vals).detach().numpy())