In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from itertools import islice
import plotly.graph_objects as go
from torchviz import make_dot
from sklearn.model_selection import train_test_split

np.random.seed(0)
torch.manual_seed(0)

In [2]:
# Generate data
n = 100
x1 = np.random.uniform(-2*np.pi, 2*np.pi, n)
x2 = np.random.uniform(-2*np.pi, 2*np.pi, n)
x1 = np.sort(x1)                              # TODO:Try without sorting
x2 = np.sort(x2)
x1,x2 = np.meshgrid(x1,x2)
x = np.concatenate((x1.reshape(-1, 1), x2.reshape(-1, 1)), axis=1)
y = x[:,0]**2 + x[:,1]**2

x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.2, random_state=0)

x_train = torch.tensor(x_train, dtype=torch.float32)
x_test = torch.tensor(x_test, dtype=torch.float32)
y_train = torch.tensor(y_train, dtype=torch.float32)
y_test = torch.tensor(y_test, dtype=torch.float32)

In [3]:
class MLP(torch.nn.Module):
    def __init__(self):
        super(MLP, self).__init__()
        self.layers = torch.nn.ModuleList()
        self.layers.append(torch.nn.Linear(2, 10))
        self.layers.append(torch.nn.Linear(10, 5))
        self.layers.append(torch.nn.Linear(5, 1))
        self.init_weights_random()

    def init_weights_random(self):
        for layer in self.layers:
            torch.nn.init.normal_(layer.weight)
            torch.nn.init.normal_(layer.bias)

    def forward(self, x):
        activation = []
        for layer in self.layers[:-1]:
            x = torch.relu(layer(x))
            activation.append( (x > 0).squeeze().int() )

        x = self.layers[-1](x)
        activation = torch.cat(activation, dim=1)
        return x, activation
    
model = MLP()

In [4]:
def count_regions(activation, return_inverse=False):
    unique_activation = torch.unique(activation, dim=0, return_inverse=return_inverse)
    if return_inverse:
        return unique_activation[0].shape[0], unique_activation[1]
    return unique_activation.shape[0]

In [None]:
criterion = torch.nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

n_epochs = 1000
n_regions = []
for epoch in range(n_epochs):   #TODO: Compare time without couting regions
    optimizer.zero_grad()
    y_pred, activation = model(x_train)
    regions = count_regions(activation, return_inverse=False)
    n_regions.append(regions)
    loss = criterion(y_pred.squeeze(), y_train)
    loss.backward()
    optimizer.step()
    if epoch % 100 == 0:
        print(f'Epoch: {epoch}, Loss: {loss.item()}')

In [None]:
# Test
model.eval()
y_pred, _ = model(x_test)
loss = criterion(y_pred.squeeze(), y_test)
print(f'Test Loss: {loss.item()}')

# Plot the model function
n_samples = 100
x1 = np.linspace(-2*np.pi, 2*np.pi, n_samples)
x2 = np.linspace(-2*np.pi, 2*np.pi, n_samples)
x1, x2 = np.meshgrid(x1, x2)
x = np.concatenate((x1.reshape(-1, 1), x2.reshape(-1, 1)), axis=1)
x = torch.tensor(x, dtype=torch.float32)
y_pred, activation = model(x)
y_pred = y_pred.detach().numpy().reshape(n_samples, n_samples)
fig = go.Figure(data=[go.Surface(z=y_pred, x=x1, y=x2)])
fig.update_layout(title='Model function', autosize=False, width=500, height=500, margin=dict(l=65, r=50, b=65, t=90))
fig.show()

n_samples = 800
x1 = np.linspace(-2*np.pi, 2*np.pi, n_samples)
x2 = np.linspace(-2*np.pi, 2*np.pi, n_samples)
x1, x2 = np.meshgrid(x1, x2)
x = np.concatenate((x1.reshape(-1, 1), x2.reshape(-1, 1)), axis=1)
x = torch.tensor(x, dtype=torch.float32)
_, activation = model(x)
_, inverse_indices = count_regions(activation, return_inverse=True)
fig = go.Figure(data=[
    go.Contour(
        z=inverse_indices.numpy().reshape(n_samples, n_samples),
        colorscale='Viridis',
        showscale=False,
    )
])
fig.update_layout(title='Input space partition')
fig.show()

# Plot then number of regions as a function of the number of epochs
fig = go.Figure(data=go.Scatter(x=np.arange(n_epochs), y=n_regions, mode='lines'))
fig.update_layout(title='Number of regions', xaxis_title='Epoch', yaxis_title='Number of regions')
fig.show()