# Set-Based Neural Network Training with PyTorch

In this notebook, I will show you how to train a neural network using set-based computation. This is a simple example to show you how to use set-based data in PyTorch.

In [None]:
# Importing the libraries
from copy import deepcopy
import torch    
import numpy as np
import matplotlib.pyplot as plt
import sys
sys.path.append('..')
from SBML import ZonoTorch as zt

In [None]:
# Generating the input data
np.random.seed(1)
torch.manual_seed(1)

num_samples = 100

x = np.random.rand(num_samples, 1) * 10 - 5
y = (np.sign(x) + 1)/2
x += np.random.randn(num_samples, 1) * 0.5

# Plotting the data
plt.scatter(x, y, label='Data')
plt.legend()
plt.xlabel('x')
plt.ylabel('y')
plt.title('Data')
plt.show()


In [None]:
# Initializing the model

nn = torch.nn.Sequential(
    torch.nn.Linear(1, 400),
    torch.nn.ReLU(),
    torch.nn.Linear(400, 300),
    torch.nn.ReLU(),
    torch.nn.Linear(300, 2),
    torch.nn.Softmax()
)

# Initialize weights using Xavier initialization
for layer in nn:
    if isinstance(layer, torch.nn.Linear):
        torch.nn.init.xavier_uniform_(layer.weight)
        torch.nn.init.zeros_(layer.bias)

In [None]:
# Training the model
# Use cuda if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Duplicating model
nn_point = deepcopy(nn)
nn_point = nn_point.to(device)
nn_set = deepcopy(nn)
nn_set = nn_set.to(device)

# Load the data
x = torch.tensor(x, dtype=torch.float32).to(device)
y = torch.tensor(y.squeeze(), dtype=torch.long).to(device)

In [None]:
# Loss function
loss_point = torch.nn.CrossEntropyLoss()
loss_set = zt.core.ZonotopeClassificationLoss(1e-1,1e-1)

# Optimizer
optimizer_point = torch.optim.Adam(nn_point.parameters(), lr=0.01)
optimizer_set = torch.optim.Adam(nn_set.parameters(), lr=0.01)

# Training the model
nn_point_trained = zt.train(nn_point, loss_point, optimizer_point, x, y, 100, batchsize=64)
nn_set_trained = zt.train(nn_set, loss_set, optimizer_set, x, y, 100, batchsize=64, noise=1e-1)

In [None]:
# Plotting the results

x_test = torch.linspace(-5, 5, 100).view(-1, 1).to(device)
y_pred_point = nn_point_trained(x_test)
y_pred_set = nn_set_trained(x_test)

# Calculate accuracy
y_pred_point_class = torch.argmax(y_pred_point, dim=1)
y_pred_set_class = torch.argmax(y_pred_set, dim=1)
accuracy_point = torch.sum(y_pred_point_class == y).item() / y.size(0)
accuracy_set = torch.sum(y_pred_set_class == y).item() / y.size(0)
print('Point Prediction Accuracy: {:.2f}%'.format(accuracy_point * 100))
print('Set Prediction Accuracy: {:.2f}%'.format(accuracy_set * 100))

plt.scatter(x.detach().cpu(), y.detach().cpu(), label='Data')
plt.plot(x_test.detach().cpu(), y_pred_point.detach().cpu(), color ='green',label='Point Prediction')
plt.plot(x_test.detach().cpu(), y_pred_set.detach().cpu(), color = 'orange', label='Set Prediction')
plt.plot(x_test.detach().cpu(), torch.argmax(y_pred_point, dim=1).detach().cpu(), color = 'red', label='Point Prediction Class')
plt.plot(x_test.detach().cpu(), torch.argmax(y_pred_set, dim=1).detach().cpu(), color = 'blue', label='Set Prediction Class')
plt.legend()
plt.xlabel('x')
plt.ylabel('y')
plt.title('Data')
plt.show()