In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from itertools import islice
import plotly.graph_objects as go

np.random.seed(0)

In [2]:
class FeedForwardNN(torch.nn.Module):
    def __init__(self):
        super(FeedForwardNN, self).__init__()
        self.layers = torch.nn.ModuleList()
        self.layers.append(torch.nn.Linear(1, 2))
        self.layers.append(torch.nn.Linear(2, 2))
        self.layers.append(torch.nn.Linear(2, 1))
        self.init_weights()
        self.n_hidden_layers = len(self.layers)-1

    def init_weights(self):
        for layer in self.layers:
            torch.nn.init.uniform_(layer.weight)
            torch.nn.init.uniform_(layer.bias)

    def init_weights_const(self):
        for i, layer in enumerate(self.layers):
            torch.nn.init.constant_(layer.weight, i+1)
            torch.nn.init.constant_(layer.bias, i+1)

    def forward(self, x):
        activation = []
        for layer in self.layers[:-1]:
            x = torch.relu(layer(x))
            activation.append( (x > 0).squeeze().int() )

        x = self.layers[-1](x)
        activation = torch.cat(activation, dim=1)
        return x, activation

model = FeedForwardNN()

In [3]:
n_samples = 50000
X = np.random.uniform(-10, 10, n_samples)
X = np.sort(X)
#X = np.linspace(-10, 10, n_samples)
X = torch.tensor(X).float().view(n_samples, 1)

with torch.no_grad():
    Y, activations = model(X)

unique_activations, inverse_indices = torch.unique(activations, dim=0, return_inverse=True)
print(f"Unique activations: {unique_activations.shape[0]}")

Unique activations: 3


In [4]:
fig = go.Figure()
x = X.squeeze().tolist()
y = Y.squeeze().tolist()

fig.add_trace(go.Scatter(x=x, y=y, mode='lines', name='Model'))
for i, activation in enumerate(unique_activations):
    fig.add_trace(go.Scatter(x=X.squeeze()[inverse_indices == i], y=0*np.zeros_like(X.squeeze()), mode='lines', name=f'Activation {activation}'))
fig.update_layout(title='Model Function')
fig.show()

In [5]:
class FeedForwardNN2D(torch.nn.Module):
    def __init__(self):
        super(FeedForwardNN2D, self).__init__()
        self.layers = torch.nn.ModuleList()
        self.layers.append(torch.nn.Linear(2, 2))
        self.layers.append(torch.nn.Linear(2, 2))
        self.layers.append(torch.nn.Linear(2, 1))
        self.init_weights_random()

    def init_weights(self):
        for layer in self.layers:
            torch.nn.init.ones_(layer.weight)
            torch.nn.init.zeros_(layer.bias)

    def init_weights_random(self):
        for layer in self.layers:
            torch.nn.init.normal_(layer.weight)
            torch.nn.init.normal_(layer.bias)

    def forward(self, x):
        activation = []
        for layer in self.layers[:-1]:
            x = torch.relu(layer(x))
            activation.append( (x > 0).squeeze().int() )

        x = self.layers[-1](x)
        activation = torch.cat(activation, dim=1)
        return x, activation
    
model = FeedForwardNN2D()

In [6]:
n_samples = 800
boundary = 100
#X1 = np.sort(np.random.uniform(-boundary, boundary, n_samples))
#X2 = np.sort(np.random.uniform(-boundary, boundary, n_samples))
X1 = np.linspace(-boundary, boundary, n_samples)
X2 = np.linspace(-boundary, boundary, n_samples)
X1, X2 = np.meshgrid(X1, X2)
X = np.concatenate((X1.reshape(-1, 1), X2.reshape(-1, 1)), axis=1)
X = torch.tensor(X).float().view(n_samples**2, 2)

with torch.no_grad():
    Y, activations = model(X)

unique_activations, inverse_indices = torch.unique(activations, dim=0, return_inverse=True)
print(f"Unique activations: {unique_activations.shape[0]}")

Unique activations: 9


In [13]:
data = []
data.append(
    go.Surface(
        z=Y.squeeze().numpy().reshape(n_samples, n_samples),
        x=X1.tolist(),
        y=X2.tolist(),
        colorscale='Viridis',
        showscale=False,
        opacity=1,
    )
)
data.append(
    go.Surface(
        z=np.max(Y.squeeze().numpy())*np.ones_like(X1)+5,
        x=X1.tolist(),
        y=X2.tolist(),
        colorscale='Viridis',
        showscale=False,
        opacity=0.5,
        name=f'Activation {activation}',
        surfacecolor= inverse_indices.numpy().reshape(n_samples, n_samples)
    )
)
fig = go.Figure(data=data)
fig.update_layout(title='Model Function')
fig.write_html("./Plots/surface+partition.html")

In [17]:
fig = go.Figure()
fig.add_trace(
    go.Contour(
        z=inverse_indices.numpy().reshape(n_samples, n_samples),
        colorscale='Viridis',
        showscale=False,
    )
)
fig.update_layout(title='Input space partition')
fig.write_html("./Plots/partition.html")