In [None]:
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
import time
import math

torch.manual_seed(0)

In [None]:
class DeepSetNetwork(nn.Module):
    def __init__(self, input_dim):
        super(DeepSetNetwork, self).__init__()
        self.linear1 = nn.Linear(input_dim, input_dim)
        self.init_weights()

    def init_weights(self):
        alpha, gamma = np.random.rand() * 2 - 1, np.random.rand() * 2 - 1
        weights = alpha * torch.eye(self.linear1.weight.shape[0]) + gamma * (torch.ones_like(self.linear1.weight))
        self.linear1.weight.data = weights
        self.linear1.bias.data = torch.zeros_like(self.linear1.bias)

    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        x = torch.sum(x, dim=1)
        return x.view(-1, 1)
    
class MLP(nn.Module):
    def __init__(self, dims,init_fn=None):
        super(MLP, self).__init__()
        layers = []
        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            if i < len(dims) - 2: 
                layers.append(nn.ReLU())
        self.network = nn.Sequential(*layers)
        if init_fn:
            self.apply(init_fn)

    def forward(self, x):
        return self.network(x)
    
def random_mlp_init(m):
    if isinstance(m,nn.Linear):
        weights_size=(m.weight.data).shape
        m.weight.data=torch.normal(0,np.sqrt(2/weights_size[1])*torch.ones(weights_size[0],weights_size[1]))
        bias_length=len(m.bias.data)
        m.bias.data=torch.normal(0,np.sqrt(2/bias_length)*torch.ones(bias_length))

In [None]:
def sample_from_cube(npoints, ndim, side_length):
    vec = np.random.uniform(-side_length, side_length, (npoints, ndim))
    return torch.Tensor(vec)

def get_jacobians_with_outputs(points, model):
    points.requires_grad = True
    output = model(points)
    output.backward(torch.ones_like(output))
    jacobians = points.grad
    return jacobians,output

def remove_duplicates(jacobians,points,outputs,model,midpoint_sampling=2):
    new_jacobians=np.expand_dims(jacobians[0,:],0)
    new_points=np.expand_dims(points[0,:],0)
    new_outputs=np.expand_dims(outputs[0,:],0)

    for k in range(1,jacobians.shape[0]):
        jacobian=jacobians[k,:]
        point=points[k,:]
        output=outputs[k,:]

        differences=new_jacobians-jacobian
        candidates=np.where(np.linalg.norm(differences,axis=1)==0)[0]
        if len(candidates)==0:
            new_jacobians=np.concatenate([new_jacobians,np.expand_dims(jacobian,0)])
            new_points=np.concatenate([new_points,np.expand_dims(point,0)])
            new_outputs=np.concatenate([new_outputs,np.expand_dims(output,0)])
        else:
            for m in range(1,midpoint_sampling):

                midpoints=new_points[candidates,:]+(point-new_points[candidates,:])*m/midpoint_sampling
                model_midpoints=model(torch.tensor(midpoints)).detach().numpy()
                linear_map_midpoints=new_outputs[candidates,:]+(output-new_outputs[candidates,:])*m/midpoint_sampling
                
                updated_candidates=np.where(np.abs(model_midpoints-linear_map_midpoints)<1e-4)[0]
                if len(updated_candidates)==0:
                    new_jacobians=np.concatenate([new_jacobians,np.expand_dims(jacobian,0)])
                    new_points=np.concatenate([new_points,np.expand_dims(point,0)])
                    new_outputs=np.concatenate([new_outputs,np.expand_dims(output,0)])
                    break
    return new_jacobians

def sort_point_components(tensor):
    return torch.sort(tensor, dim=1).values

In [None]:
def estimate_linear_regions(model,radius,n_points,point_dim):
    points = sample_from_cube(npoints=n_points, ndim=point_dim, side_length=radius)
    jacobians,outputs = get_jacobians_with_outputs(points, model)
    jacobians=jacobians.detach().numpy()
    outputs=outputs.detach().numpy()
    points=points.detach().numpy()
    unique_jacobians = remove_duplicates(np.around(jacobians, 10),points,outputs,model)
    return unique_jacobians.shape[0]

def estimate_linear_regions_using_fundamental_domain(model, radius, n_points, point_dim):
    points = sample_from_cube(npoints=n_points, ndim=point_dim, side_length=radius)
    points = sort_point_components(points)
    jacobians,outputs = get_jacobians_with_outputs(points, model)
    jacobians=jacobians.detach().numpy()
    outputs=outputs.detach().numpy()
    points=points.detach().numpy()
    unique_jacobians = remove_duplicates(np.around(jacobians, 10),points,outputs,model)
    total = 0
    for jacobian in unique_jacobians:
        counts = np.unique(jacobian, return_counts=True)[1]
        total += math.factorial(point_dim) / np.prod([math.factorial(count) for count in counts])
    return total

In [None]:
def sampling_fundamental_domain_experiment(num_models, point_dim, search_radius, base_num_points):
    num_points_adjusted_for_dimension = base_num_points ** point_dim
    print(f"Running experiment with {num_models} models, {point_dim} input dimension, search radius {search_radius}, and {num_points_adjusted_for_dimension} points.")
    models = [DeepSetNetwork(point_dim) for _ in range(num_models)]
    all_ratios = []
    time_regular=[]
    time_fundamental=[]
    for i, model in enumerate(models):
        start_time=time.time()
        regular_sampling = estimate_linear_regions(model, search_radius, num_points_adjusted_for_dimension, point_dim)
        time_regular.append(time.time()-start_time)
        start_time=time.time()
        fundamental_domain_sampling = estimate_linear_regions_using_fundamental_domain(model, search_radius, num_points_adjusted_for_dimension//(math.factorial(point_dim)), point_dim)
        time_fundamental.append(time.time()-start_time)
        ratio = fundamental_domain_sampling / regular_sampling
        all_ratios.append(ratio)

    average_ratio = sum(all_ratios) / num_models
    print(f"Average ratio: {average_ratio} Time: {sum(time_regular)/num_models} Time Fundamental: {sum(time_fundamental)/num_models}")

    return all_ratios,time_regular,time_fundamental

def count_linear_regions(architectures, search_radius=20, num_points=2000, nn_samples=5, midpoint_sampling=2):
    architecture_avg_counts=[]
    architecture_avg_times=[]
    architecture_avg_samples=[]
    for architecture in architectures:
        num_lin_regions=[]
        times=[]
        number_of_samples=[]
        point_dim=architecture[0]
        for sample_num in range(nn_samples):
            model=MLP(architecture,init_fn=random_mlp_init)
            points=sample_from_cube(npoints=num_points, ndim=point_dim, side_length=search_radius)
            start_time=time.time()
            jacobians,outputs=get_jacobians_with_outputs(points, model)
            jacobians=jacobians.detach().numpy()
            outputs=outputs.detach().numpy()
            unique_jacobians,counts=remove_duplicates(np.around(jacobians, 10),points.detach().numpy(),outputs,model,midpoint_sampling=midpoint_sampling)
            total_time=time.time()-start_time
            num_lin_regions.append(unique_jacobians)
            times.append(total_time)
            number_of_samples.append(num_points+counts)
        architecture_avg_counts.append(sum(num_lin_regions)/len(num_lin_regions))
        architecture_avg_times.append(sum(times)/len(times))
        architecture_avg_samples.append(sum(number_of_samples)/len(number_of_samples))
        
        print(f"{architecture} num regions = {architecture_avg_counts[-1]}, time = {architecture_avg_times[-1]}, samples = {architecture_avg_samples[-1]}")

In [None]:
NUM_MODELS = 10
point_dim_to_ratios = {}
for point_dim in range(2, 7):
    ratios,times_regular,times_fundamental=sampling_fundamental_domain_experiment(num_models=NUM_MODELS, point_dim=point_dim, search_radius=20, base_num_points=10)
    point_dim_to_ratios[point_dim] = {"ratios":ratios,"times_regular":times_regular,"time_fundamental":times_fundamental}

In [None]:
NN_SAMPLES=25
ARCHITECTURES=[[2,6,1],[3,5,1],[4,4,1],[5,3,1],[6,2,1],[3,2,2,1],[3,3,2,1]]

search_radius=5
num_points=1000
print(f"search radius {search_radius}, num_points {num_points}, nn_samples {NN_SAMPLES}")
count_linear_regions(ARCHITECTURES,search_radius=search_radius,num_points=num_points,nn_samples=NN_SAMPLES)

search_radius=20
num_points=1000
print(f"search radius {search_radius}, num_points {num_points}, nn_samples {NN_SAMPLES}")
count_linear_regions(ARCHITECTURES,search_radius=search_radius,num_points=num_points,nn_samples=NN_SAMPLES)

search_radius=5
num_points=5000
print(f"search radius {search_radius}, num_points {num_points}, nn_samples {NN_SAMPLES}")
count_linear_regions(ARCHITECTURES,search_radius=search_radius,num_points=num_points,nn_samples=NN_SAMPLES)

search_radius=20
num_points=5000
print(f"search radius {search_radius}, num_points {num_points}, nn_samples {NN_SAMPLES}")
count_linear_regions(ARCHITECTURES,search_radius=search_radius,num_points=num_points,nn_samples=NN_SAMPLES)
