In [1]:
import matplotlib.pyplot as plt
import numpy as np
from numpy.linalg import norm
import itertools
import copy

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
%matplotlib inline

In [2]:
class LittleBlock(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LittleBlock, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim, bias = False)

    def forward(self, x):
        return F.relu(self.fc(x))

In [3]:
class Block(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Block, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim, bias = False)
        self.fc2 = nn.Linear(hidden_dim, output_dim, bias = False)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [4]:
#F1: NN + NN
class Symmetric(nn.Module):
    def __init__(self, input_dim, hidden_dim_phi, hidden_dim_rho):
        super(Symmetric, self).__init__()
        
        self.hidden_dim_phi = hidden_dim_phi
        self.hidden_dim_rho = hidden_dim_rho
        self.input_dim = input_dim
        
        self.rho = None
        self.phi = None
        self.reinit()
    
    def reinit(self):
        self.rho = Block(self.hidden_dim_phi, self.hidden_dim_rho, 1)
        self.phi = LittleBlock(self.input_dim, self.hidden_dim_phi)
    
    def forward(self, x):        
        batch_size, input_set_dim, input_dim = x.shape
        
        x = x.view(-1, input_dim)
        z = self.phi(x)
        z = z.view(batch_size, input_set_dim, -1)
        z = torch.mean(z, 1)
        return self.rho(z)
    
    def regularize(self, lamb):
        reg_loss = 0.
        W1 = self.phi.fc.weight
        W2 = self.rho.fc1.weight
        w = self.rho.fc2.weight
        
        W1 = torch.norm(W1, dim = 1, keepdim = True)
        W2 = torch.abs(W2)
        w = torch.abs(w)
        
        reg_loss = torch.matmul(w, torch.matmul(W2, W1)).item()

        return lamb * reg_loss

In [5]:
class DeepSets(Symmetric):
    def __init__(self, input_dim, hidden_dim_phi, hidden_dim_rho):
        super(DeepSets, self).__init__(input_dim, hidden_dim_phi, hidden_dim_rho)

    def forward(self, x):        
        batch_size, input_set_dim, input_dim = x.shape
        
        x = x.view(-1, input_dim)
        z = self.phi(x)
        z = z.view(batch_size, input_set_dim, -1)
        z = torch.sum(z, 1)
        return self.rho(z)

In [6]:
#F2: K + NN
class KNN(Symmetric):
    def __init__(self, input_dim, hidden_dim_phi, hidden_dim_rho):
        super(KNN, self).__init__(input_dim, hidden_dim_phi, hidden_dim_rho)

    def reinit(self):
        super(KNN, self).reinit()
        
        self.phi.fc.weight.requires_grad = False
        self.phi.fc.bias.requires_grad = False
        
        self.phi.fc.weight.div_(torch.norm(self.phi.fc.weight, dim = 1, keepdim = True))
        
    def regularize(self, lamb):
        reg_loss = 0.

        W2 = self.rho.fc1.weight
        w = self.rho.fc2.weight
        
        W2 = torch.norm(W2, dim = 1, keepdim = True)
        w = torch.abs(w)
        
        reg_loss = torch.matmul(w, W2).item()
        
        return lamb * reg_loss

In [7]:
#F3: K + K
class KK(KNN):
    def __init__(self, input_dim, hidden_dim_phi, hidden_dim_rho):
        super(KK, self).__init__(input_dim, hidden_dim_phi, hidden_dim_rho)

    def reinit(self):
        super(KK, self).reinit()
        
        self.rho.fc1.weight.requires_grad = False
        self.rho.fc1.bias.requires_grad = False
        
        self.rho.fc1.weight.div_(torch.norm(self.rho.fc1.weight, dim = 1, keepdim = True))

        
    def regularize(self, lamb):
        reg_loss = 0.
        
        w = self.rho.fc2.weight

        reg_loss = torch.norm(w)

        return lamb * reg_loss

In [8]:
def generate_narrow_data(N, batch_size, input_dim, objective):
    x = np.random.uniform(low = -1, high = 1, size = (batch_size, N, input_dim))
    y = objective(x)
    
    x = torch.from_numpy(x).float()
    y = torch.from_numpy(y).float()
    return (x,y)

In [9]:
def generate_wide_data(N, batch_size, input_dim, objective):
    x = np.zeros((batch_size, N, input_dim))
    for i in range(input_dim):
        
        a = np.random.uniform(low = -1, high = 1, size = batch_size)
        b = np.random.uniform(low = -1, high = 1, size = batch_size)
        a, b = np.minimum(a,b), np.maximum(a,b)

        x_fill = np.random.uniform(low = np.tile(a, (N,1)), high = np.tile(b, (N,1)))
        x[:,:,i] = x_fill.T
    
    y = objective(x)
    
    x = torch.from_numpy(x).float()
    y = torch.from_numpy(y).float()
    return (x,y)

In [10]:
def generate_data(N, batch_size, input_dim, objective, narrow):
    if narrow:
        return generate_narrow_data(N, batch_size, input_dim, objective)
    else:
        return generate_wide_data(N, batch_size, input_dim, objective)

In [11]:
def train(model, x, y, iterations, lamb = 0.1):
    model.train()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.003)
#     scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10)

    indices = np.array_split(np.arange(x.shape[0]), x.shape[0]/20)

    losses = []
    for i in range(iterations):
        index = indices[np.random.randint(len(indices))]
        outputs = model(x[index])

        optimizer.zero_grad()
        loss = criterion(outputs, y[index])
        loss += model.regularize(lamb)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
    
    model.eval()
    return losses

In [12]:
def generalization_error(N_list, batch_size, model, objective, narrow):
    errors = []
    for N in N_list:
        input_dim = model.input_dim
        x,y = generate_data(N, batch_size, input_dim, objective, narrow)
        outputs = model(x)
        error = nn.MSELoss()(outputs, y).item()
        errors.append(error)
    return np.array(errors)

In [13]:
def relu(x):
    return np.maximum(x, 0)

In [14]:
mean = lambda x: np.mean(norm(x, axis = 2), axis = 1, keepdims = True)

median = lambda x: np.median(norm(x, axis = 2), axis = 1, keepdims = True)

maximum = lambda x: np.max(norm(x, axis = 2), axis = 1, keepdims = True)

lamb = 0.1
softmax = lambda x: lamb * np.log(np.mean(np.exp(norm(x, axis = 2) / lamb), axis = 1, keepdims = True))

second = lambda x: np.sort(norm(x, axis = 2), axis = 1)[:,-2].reshape(-1,1)

In [21]:
### May need to sample several neurons to find one that isn't degenerate on the domain [-1,1]

###NOTE: when you come back to this in a few months, you may find S1 isn't doing as well
# on the neuron since we fixed the path norm regularization.  Consider a wackier weight
# distribution to make S2 and S3 even worse

input_dim = 10
teacher = Symmetric(input_dim, 1, 1)

torch.nn.init.normal_(teacher.phi.fc.weight,std = 1.)
torch.nn.init.normal_(teacher.rho.fc1.weight,std = 1.)
torch.nn.init.normal_(teacher.rho.fc2.weight,std = 1.)

# torch.nn.init.normal_(teacher.phi.fc.weight,std = 3.)
# torch.nn.init.normal_(teacher.rho.fc1.weight,std = 3.)
# torch.nn.init.normal_(teacher.rho.fc2.weight,std = 3.)

teacher.eval()
def neuron(x):
    x = torch.from_numpy(x).float()
    y = teacher(x)
    return y.data.numpy().reshape(-1, 1)

x, y = generate_narrow_data(3, 4, input_dim, neuron)
print(y)

tensor([[  0.9779],
        [ -9.3473],
        [-11.8528],
        [ -5.6205]])


In [27]:
### May need to sample several neurons to find one that isn't degenerate on the domain [-1,1]

input_dim = 10
smooth_teacher = Symmetric(input_dim, 1, 1)

smooth_teacher.eval()
def smooth_neuron(x):
    x = torch.from_numpy(x).float()
    y = smooth_teacher(x)
    return y.data.numpy().reshape(-1, 1)

x, y = generate_narrow_data(3, 4, input_dim, smooth_neuron)
print(y)

tensor([[-0.1887],
        [-0.2303],
        [-0.1945],
        [-0.2155]])


In [28]:
neuron.__name__ = "neuron"
smooth_neuron.__name__ = "smooth_neuron"
maximum.__name__ = "maximum"
softmax.__name__ = "softmax"
median.__name__ = "median"
mean.__name__ = "mean"
second.__name__ = "second"

In [29]:
###############################################

In [30]:
def cross_validate(model, x, y, iterations, lambs, verbose):
    models = []
    for lamb in lambs:
        model_copy = copy.deepcopy(model)
        losses = train(model_copy, x, y, iterations, lamb)
        models.append(model_copy)
        if verbose:
            print(losses[::int(iterations/10)])
    return models

In [31]:
def compare_models(N_max, hidden_dim, iterations, batch_size, input_dim, objective, narrow, verbose = False, log_plot = False, scaleup = False):
#     x, y = generate_data(N_max, batch_size, input_dim, objective, narrow)
    print("currently", objective.__name__)
    
    c = 1 if not scaleup else 2
    
    f1 = Symmetric(input_dim, 1000, hidden_dim)
    f2 = KNN(input_dim, c * 1000, hidden_dim)
    f3 = KK(input_dim, c * 1000, 1000)

    f1.__name__ = "S1"
    f2.__name__ = "S2"
    f3.__name__ = "S3"

    models = [f1, f2, f3]
    
    lambs = [0., 1e-6, 1e-4, 1e-2]
    N_list = np.arange(2, N_max + 16)

    for model in models:
        x, y = generate_data(N_max, batch_size, input_dim, objective, narrow)
        cv_models = cross_validate(model, x, y, iterations, lambs, verbose)
        
        validation_errors = np.zeros_like(lambs)
        for i, cv_model in enumerate(cv_models):
            validation_errors[i] = generalization_error([N_max], 1000, cv_model, objective, narrow)[0]
        
        i = np.argmin(validation_errors)
        lamb = lambs[i]
            
        runs = 10
        run_errors = np.zeros((runs, len(N_list)))
        for i in range(runs):
            x, y = generate_data(N_max, batch_size, input_dim, objective, narrow)
            model_copy = copy.deepcopy(model)
            model_copy.reinit()
            train(model_copy, x, y, iterations, lamb)
            errors = generalization_error(N_list, 1000, model_copy, objective, narrow)
            run_errors[i] = np.array(errors)
        
        mean_error = np.mean(run_errors, axis = 0)
        std_error = np.std(run_errors, axis = 0)
        if log_plot:
            plt.semilogy(N_list, mean_error, label = model.__name__)
        else:
            plt.plot(N_list, mean_error, label = model.__name__)
        plt.fill_between(N_list, mean_error - std_error, mean_error + std_error, alpha = 0.2)

    
    plt.legend()
    plt.ylim([1e-5, 1e-1]) 
    plt.xlabel("N")
    plt.ylabel("Mean Square Error")
    narrow_str = "Narrow" if narrow else "Wide"
    plt.title(narrow_str + " generalization for " + objective.__name__)
    scale_str = "" if not scaleup else "scaled"
    plt.savefig("plots_high_dim/" + objective.__name__ + "_" + narrow_str + "_" + str(input_dim) + scale_str)
#     plt.show()
    plt.close()

In [32]:
#Run to generate plots in Figure 1:

N_max = 4
hidden_dim = 50

iterations = 1000
batch_size = 100

input_dim = 10

In [33]:
# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim, mean, narrow = False, log_plot = True)
# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim , mean, narrow = True, log_plot = True)

# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim, median, narrow = False, log_plot = True)
# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim , median, narrow = True, log_plot = True)

# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim, maximum, narrow = False, log_plot = True)
# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim , maximum, narrow = True, log_plot = True)

# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim, softmax, narrow = False, log_plot = True)
# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim , softmax, narrow = True, log_plot = True)

# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim, second, narrow = False, log_plot = True)
# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim , second, narrow = True, log_plot = True)

compare_models(N_max, hidden_dim, iterations, batch_size, input_dim, neuron, narrow = False, log_plot = True)
compare_models(N_max, hidden_dim, iterations, batch_size, input_dim , neuron, narrow = True, log_plot = True)

currently neuron
currently neuron


In [None]:
#Plots for right half of Figure 2

# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim , smooth_neuron, narrow = True, log_plot = True, scaleup = True)
# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim , neuron, narrow = True, log_plot = True, scaleup = True)

In [34]:
N_max = 4
hidden_dim = 50
iterations = 1000
batch_size = 100

objective = neuron
narrow = False

input_dim = 10


x, y = generate_data(N_max, batch_size, input_dim, objective, narrow)

for i in range(20):
    model = Symmetric(input_dim, 2000, hidden_dim)
    model.train()
    losses = train(model, x, y, iterations, lamb = 0.000)
    model.eval()
    print(losses[::int(iterations/10)])
    print("f1", generalization_error([4], 5000, model, objective, narrow))
    
    model = KNN(input_dim, 2000, hidden_dim)
    model.train()
    losses = train(model, x, y, iterations, lamb = 0.000)
    model.eval()
    print(losses[::int(iterations/10)])
    print("f2", generalization_error([4], 5000, model, objective, narrow))
    
    model = KK(input_dim, 2000, 1000)
    model.train()
    losses = train(model, x, y, iterations, lamb = 0.0001)
    model.eval()
    print(losses[::int(iterations/10)])
    print("f3", generalization_error([4], 5000, model, objective, narrow))

[108.8774642944336, 0.12319805473089218, 0.06409303098917007, 0.031051063910126686, 0.00372758018784225, 0.012634518556296825, 0.011655419133603573, 0.0014871725579723716, 0.010766642168164253, 0.0005567565094679594]
f1 [0.88040781]
[68.0797348022461, 0.2454993724822998, 0.14803007245063782, 0.03867511451244354, 0.03269072622060776, 0.00997860822826624, 0.02678658626973629, 0.014266696758568287, 0.021281693130731583, 0.010137232020497322]
f2 [1.40248609]
[109.46990966796875, 23.109466552734375, 12.243729591369629, 7.640867710113525, 5.95836067199707, 3.8359882831573486, 5.479048728942871, 2.708665370941162, 2.0167436599731445, 1.3769158124923706]
f3 [10.7036314]
[183.14517211914062, 0.0831356942653656, 0.04076262190937996, 0.012397709302604198, 0.004988738335669041, 0.0012710250448435545, 0.0005181847373023629, 0.0004428067186381668, 0.0003207512490916997, 2.5524159354972653e-06]
f1 [0.73567802]
[173.53285217285156, 0.39136749505996704, 0.21716876327991486, 0.3039013743400574, 0.103648

KeyboardInterrupt: 

In [None]:
###############################################

In [None]:
#Plot for Figure 3

# N_max = 4
# hidden_dim = 50

# iterations = 1000
# batch_size = 100

# input_dim = 5

# objective = mean
# narrow = False

# log_plot = True

# f1 = Symmetric(input_dim, 1000, hidden_dim)
# f2 = DeepSets(input_dim, 1000, hidden_dim)
# f1.__name__ = "S1"
# f2.__name__ = "DeepSets"

# models = [f1, f2]

# N_list = np.arange(2, N_max + 16)

# for model in models:
#     x, y = generate_data(N_max, batch_size, input_dim, objective, narrow)

#     lamb = 0.

#     runs = 10
#     run_errors = np.zeros((runs, len(N_list)))
#     for i in range(runs):
#         x, y = generate_data(N_max, batch_size, input_dim, objective, narrow)
#         model_copy = copy.deepcopy(model)
#         model_copy.reinit()
#         train(model_copy, x, y, iterations, lamb)
#         errors = generalization_error(N_list, 1000, model_copy, objective, narrow)
#         run_errors[i] = np.array(errors)

#     mean_error = np.mean(run_errors, axis = 0)
#     std_error = np.std(run_errors, axis = 0)
#     if log_plot:
#         plt.semilogy(N_list, mean_error, label = model.__name__)
#     else:
#         plt.plot(N_list, mean_error, label = model.__name__)
#     plt.fill_between(N_list, mean_error - std_error, mean_error + std_error, alpha = 0.2)


# plt.legend()
# #     plt.ylim([1e-5, 1e10]) 
# plt.xlabel("N")
# plt.ylabel("Mean Square Error")
# narrow_str = "Narrow" if narrow else "Wide"
# plt.title("Normalized vs. Unnormalized generalization for " + objective.__name__)
# plt.savefig("plots_high_dim/" + "deepsets")
# #     plt.show()
# plt.close()