In [1]:
import matplotlib.pyplot as plt
import numpy as np
from numpy.linalg import norm
import itertools
import copy

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
%matplotlib inline

In [2]:
class LittleBlock(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LittleBlock, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim, bias = False)

    def forward(self, x):
        return F.relu(self.fc(x))

In [3]:
class Block(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Block, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim, bias = False)
        self.fc2 = nn.Linear(hidden_dim, output_dim, bias = False)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [4]:
#F1: NN + NN
class Symmetric(nn.Module):
    def __init__(self, input_dim, hidden_dim_phi, hidden_dim_rho):
        super(Symmetric, self).__init__()
        
        self.hidden_dim_phi = hidden_dim_phi
        self.hidden_dim_rho = hidden_dim_rho
        self.input_dim = input_dim
        
        self.rho = None
        self.phi = None
        self.reinit()
    
    def reinit(self):
        self.rho = Block(self.hidden_dim_phi, self.hidden_dim_rho, 1)
        self.phi = LittleBlock(self.input_dim, self.hidden_dim_phi)
    
    def forward(self, x):        
        batch_size, input_set_dim, input_dim = x.shape
        
        x = x.view(-1, input_dim)
        z = self.phi(x)
        z = z.view(batch_size, input_set_dim, -1)
        z = torch.mean(z, 1)
        return self.rho(z)
    
    def regularize(self, lamb):
        reg_loss = 0.
        W1 = self.phi.fc.weight
        W2 = self.rho.fc1.weight
        w = self.rho.fc2.weight
        
        W1 = torch.norm(W1, dim = 1, keepdim = True)
        W2 = torch.abs(W2)
        w = torch.abs(w)
        
        reg_loss = torch.matmul(w, torch.matmul(W2, W1)).item()

        return lamb * reg_loss

In [5]:
class DeepSets(Symmetric):
    def __init__(self, input_dim, hidden_dim_phi, hidden_dim_rho):
        super(DeepSets, self).__init__(input_dim, hidden_dim_phi, hidden_dim_rho)

    def forward(self, x):        
        batch_size, input_set_dim, input_dim = x.shape
        
        x = x.view(-1, input_dim)
        z = self.phi(x)
        z = z.view(batch_size, input_set_dim, -1)
        z = torch.sum(z, 1)
        return self.rho(z)

In [6]:
#F2: K + NN
class KNN(Symmetric):
    def __init__(self, input_dim, hidden_dim_phi, hidden_dim_rho):
        super(KNN, self).__init__(input_dim, hidden_dim_phi, hidden_dim_rho)

    def reinit(self):
        super(KNN, self).reinit()
        
        self.phi.fc.weight.requires_grad = False
        self.phi.fc.weight.div_(torch.norm(self.phi.fc.weight, dim = 1, keepdim = True))
        
    def regularize(self, lamb):
        reg_loss = 0.

        W2 = self.rho.fc1.weight
        w = self.rho.fc2.weight
        
        W2 = torch.norm(W2, dim = 1, keepdim = True)
        w = torch.abs(w)
        
        reg_loss = torch.matmul(w, W2).item()
        
        return lamb * reg_loss

In [7]:
#F3: K + K
class KK(KNN):
    def __init__(self, input_dim, hidden_dim_phi, hidden_dim_rho):
        super(KK, self).__init__(input_dim, hidden_dim_phi, hidden_dim_rho)

    def reinit(self):
        super(KK, self).reinit()
        
        self.rho.fc1.weight.requires_grad = False
        self.rho.fc1.weight.div_(torch.norm(self.rho.fc1.weight, dim = 1, keepdim = True))
        
    def regularize(self, lamb):
        reg_loss = 0.
        
        w = self.rho.fc2.weight

        reg_loss = torch.norm(w)

        return lamb * reg_loss

In [8]:
def generate_narrow_data(N, batch_size, input_dim, objective):
    x = np.random.uniform(low = -1, high = 1, size = (batch_size, N, input_dim))
    y = objective(x)
    
    #bias term
    x[:,:,-1] = 1
    
    x = torch.from_numpy(x).float()
    y = torch.from_numpy(y).float()
    return (x,y)

In [9]:
def generate_wide_data(N, batch_size, input_dim, objective):
    x = np.zeros((batch_size, N, input_dim))
    for i in range(input_dim):
        
        a = np.random.uniform(low = -1, high = 1, size = batch_size)
        b = np.random.uniform(low = -1, high = 1, size = batch_size)
        a, b = np.minimum(a,b), np.maximum(a,b)

        x_fill = np.random.uniform(low = np.tile(a, (N,1)), high = np.tile(b, (N,1)))
        x[:,:,i] = x_fill.T
    
    y = objective(x)
    
    #bias term
    x[:,:,-1] = 1
    
    x = torch.from_numpy(x).float()
    y = torch.from_numpy(y).float()
    return (x,y)

In [10]:
def generate_data(N, batch_size, input_dim, objective, narrow):
    if narrow:
        return generate_narrow_data(N, batch_size, input_dim, objective)
    else:
        return generate_wide_data(N, batch_size, input_dim, objective)

In [11]:
def train(model, x, y, iterations, lamb = 0.1):
    model.train()
    criterion = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=0.003)
#     scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=10)

    indices = np.array_split(np.arange(x.shape[0]), x.shape[0]/20)

    losses = []
    for i in range(iterations):
        index = indices[np.random.randint(len(indices))]
        outputs = model(x[index])

        optimizer.zero_grad()
        loss = criterion(outputs, y[index])
        loss += model.regularize(lamb)
        loss.backward()
        optimizer.step()

        losses.append(loss.item())
    
    model.eval()
    return losses

In [12]:
def generalization_error(N_list, batch_size, model, objective, narrow):
    errors = []
    for N in N_list:
        input_dim = model.input_dim
        x,y = generate_data(N, batch_size, input_dim, objective, narrow)
        outputs = model(x)
        error = nn.MSELoss()(outputs, y).item()
        errors.append(error)
    return np.array(errors)

In [13]:
def relu(x):
    return np.maximum(x, 0)

In [14]:
mean = lambda x: np.mean(norm(x, axis = 2), axis = 1, keepdims = True)

median = lambda x: np.median(norm(x, axis = 2), axis = 1, keepdims = True)

maximum = lambda x: np.max(norm(x, axis = 2), axis = 1, keepdims = True)

lamb = 0.1
softmax = lambda x: lamb * np.log(np.mean(np.exp(norm(x, axis = 2) / lamb), axis = 1, keepdims = True))

second = lambda x: np.sort(norm(x, axis = 2), axis = 1)[:,-2].reshape(-1,1)

In [237]:
### May need to sample several neurons to find one that isn't degenerate on the domain [-1,1]


input_dim = 10
teacher = Symmetric(input_dim, 50, 1)

# torch.nn.init.normal_(teacher.phi.fc.weight,std = 1.)
# torch.nn.init.normal_(teacher.rho.fc1.weight,std = 1.)
# torch.nn.init.normal_(teacher.rho.fc2.weight,std = 1.)

torch.nn.init.uniform_(teacher.phi.fc.weight, a = -10., b = 10.)
torch.nn.init.uniform_(teacher.rho.fc1.weight,a = -10., b = 10.)
# torch.nn.init.uniform_(teacher.rho.fc2.weight,a = -10., b = 10.)

with torch.no_grad():
    teacher.phi.fc.weight.div_(torch.mean(torch.norm(teacher.phi.fc.weight, dim = 1)))
    teacher.rho.fc1.weight.div_(torch.mean(torch.norm(teacher.rho.fc1.weight, dim = 1)))

teacher.eval()
def neuron(x):
    x = torch.from_numpy(x).float()
    y = teacher(x)
    return y.data.numpy().reshape(-1, 1)

x, y = generate_narrow_data(3, 6, input_dim, neuron)
print(y)

tensor([[-0.1110],
        [ 0.0000],
        [-0.1389],
        [ 0.0000],
        [-0.2638],
        [ 0.0000]])


In [213]:
### May need to sample several neurons to find one that isn't degenerate on the domain [-1,1]

input_dim = 10
smooth_teacher = Symmetric(input_dim, 50, 1)

# torch.nn.init.normal_(teacher.rho.fc1.weight,std = 1.)
# torch.nn.init.normal_(teacher.rho.fc2.weight,std = 1.)

torch.nn.init.uniform_(teacher.rho.fc1.weight,a = -2., b = 2.)
torch.nn.init.uniform_(teacher.rho.fc2.weight,a = -2., b = 2.)

with torch.no_grad():
    teacher.phi.fc.weight.div_(torch.mean(torch.norm(teacher.phi.fc.weight, dim = 1)))
    teacher.rho.fc1.weight.div_(torch.mean(torch.norm(teacher.rho.fc1.weight, dim = 1)))

smooth_teacher.eval()
def smooth_neuron(x):
    x = torch.from_numpy(x).float()
    y = smooth_teacher(x)
    return y.data.numpy().reshape(-1, 1)

x, y = generate_narrow_data(3, 6, input_dim, smooth_neuron)
print(y)

tensor([[ 0.0000],
        [-0.0524],
        [-0.0402],
        [-0.0202],
        [-0.0082],
        [-0.0435]])


In [115]:
neuron.__name__ = "neuron"
smooth_neuron.__name__ = "smooth_neuron"
maximum.__name__ = "maximum"
softmax.__name__ = "softmax"
median.__name__ = "median"
mean.__name__ = "mean"
second.__name__ = "second"

In [93]:
###############################################

In [218]:
def cross_validate(model, x, y, iterations, lambs, verbose):
    models = []
    for lamb in lambs:
        model_copy = copy.deepcopy(model)
        losses = train(model_copy, x, y, iterations, lamb)
        models.append(model_copy)
        if verbose and lamb == 0:
            print("check for overfitting power of", model.__name__)
            print(losses[::int(iterations/10)])
    return models

In [222]:
def compare_models(N_max, hidden_dim, iterations, batch_size, input_dim, objective, narrow, verbose = True, log_plot = False):
#     x, y = generate_data(N_max, batch_size, input_dim, objective, narrow)
    print("currently", objective.__name__)
        
    f1 = Symmetric(input_dim, 1000, hidden_dim)
    f2 = KNN(input_dim, 1000, hidden_dim)
    f3 = KK(input_dim, 1000, 1000)

    f1.__name__ = "S1"
    f2.__name__ = "S2"
    f3.__name__ = "S3"

    models = [f1, f2, f3]
    
    lambs = [0., 1e-6, 1e-4, 1e-2, 1]
    N_list = np.arange(2, N_max + 16)

    for model in models:
        x, y = generate_data(N_max, batch_size, input_dim, objective, narrow)
        cv_models = cross_validate(model, x, y, iterations, lambs, verbose)
        
        validation_errors = np.zeros_like(lambs)
        for i, cv_model in enumerate(cv_models):
            validation_errors[i] = generalization_error([N_max], 1000, cv_model, objective, narrow)[0]
        
        i = np.argmin(validation_errors)
        lamb = lambs[i]
            
        runs = 10
        run_errors = np.zeros((runs, len(N_list)))
        for i in range(runs):
            x, y = generate_data(N_max, batch_size, input_dim, objective, narrow)
            model_copy = copy.deepcopy(model)
            model_copy.reinit()
            train(model_copy, x, y, iterations, lamb)
            errors = generalization_error(N_list, 1000, model_copy, objective, narrow)
            run_errors[i] = np.array(errors)
        
        mean_error = np.mean(run_errors, axis = 0)
        std_error = np.std(run_errors, axis = 0)
        if log_plot:
            plt.semilogy(N_list, mean_error, label = model.__name__)
        else:
            plt.plot(N_list, mean_error, label = model.__name__)
        plt.fill_between(N_list, mean_error - std_error, mean_error + std_error, alpha = 0.2)

    
    plt.legend()
    plt.ylim([1e-5, 1e-1]) 
    plt.xlabel("N")
    plt.ylabel("Mean Square Error")
    narrow_str = "Narrow" if narrow else "Wide"
    plt.title(narrow_str + " generalization for " + objective.__name__)
    plt.savefig("plots_high_dim/" + objective.__name__ + "_" + narrow_str + "_" + str(input_dim))
#     plt.show()
    plt.close()

In [223]:
#Run to generate plots in Figure 1:

N_max = 4
hidden_dim = 50

iterations = 1000
batch_size = 100

input_dim = 10

In [224]:
# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim, mean, narrow = False, log_plot = True)
# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim , mean, narrow = True, log_plot = True)

# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim, median, narrow = False, log_plot = True)
# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim , median, narrow = True, log_plot = True)

# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim, maximum, narrow = False, log_plot = True)
# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim , maximum, narrow = True, log_plot = True)

# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim, softmax, narrow = False, log_plot = True)
# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim , softmax, narrow = True, log_plot = True)

# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim, second, narrow = False, log_plot = True)
# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim , second, narrow = True, log_plot = True)

compare_models(N_max, hidden_dim, iterations, batch_size, input_dim, neuron, narrow = False, log_plot = True)
compare_models(N_max, hidden_dim, iterations, batch_size, input_dim , neuron, narrow = True, log_plot = True)

compare_models(N_max, hidden_dim, iterations, batch_size, input_dim, smooth_neuron, narrow = False, log_plot = True)
compare_models(N_max, hidden_dim, iterations, batch_size, input_dim , smooth_neuron, narrow = True, log_plot = True)

currently mean
check for overfitting power
[2.3047752380371094, 0.002783059375360608, 0.00017399639182258397, 0.0002541202702559531, 0.0002493889187462628, 0.00045954054803587496, 3.897277565556578e-05, 1.0654335710569285e-05, 9.75105467659887e-06, 2.39704386331141e-05]
check for overfitting power
[2.2417304515838623, 0.002252849517390132, 0.0009137581218965352, 0.001335066626779735, 0.0011182911694049835, 0.0006375341326929629, 0.0005794171593151987, 0.0006989827379584312, 0.00013785649207420647, 6.958266749279574e-05]
check for overfitting power
[1.1356312036514282, 0.0015773542691022158, 0.0011178527493029833, 0.00048722411156632006, 0.0006756122456863523, 0.00025430653477087617, 5.262292688712478e-05, 4.309754149289802e-05, 1.506218086433364e-05, 0.00028825673507526517]
currently mean
check for overfitting power
[3.406557083129883, 0.0008517092210240662, 0.0010509208077564836, 0.00022463369532488286, 5.358149064704776e-05, 4.3858704884769395e-05, 1.2801954653696157e-05, 1.213838822

check for overfitting power
[0.04113614559173584, 0.0004033070872537792, 0.0002051381452474743, 0.00011936268856516108, 0.00045927875908091664, 0.0049672857858240604, 0.0025650514289736748, 0.00012907215568702668, 1.5132117368921172e-05, 0.00019784728647209704]
currently neuron
check for overfitting power
[0.003397454973310232, 0.0006860502762719989, 0.0004577323270495981, 0.00037512482958845794, 3.286478022346273e-05, 0.00019349031208548695, 1.674845589150209e-05, 3.9130485674832016e-05, 1.106493527913699e-06, 1.0388267668304252e-07]
check for overfitting power
[0.004624880850315094, 0.0025437786243855953, 0.0, 0.010122809559106827, 0.0, 0.0019508779514580965, 0.0, 0.0016945298993960023, 0.010122809559106827, 0.0019508779514580965]
check for overfitting power
[0.03023776412010193, 0.0006069460650905967, 0.00015993285342119634, 3.809927511611022e-05, 9.478800166107249e-06, 0.0009227569098584354, 6.551707338076085e-05, 8.316764024129952e-07, 1.836461524362676e-05, 0.00016059548943303525

In [None]:
#Plots for right half of Figure 2

# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim , smooth_neuron, narrow = True, log_plot = True)
# compare_models(N_max, hidden_dim, iterations, batch_size, input_dim , neuron, narrow = True, log_plot = True)

In [None]:
N_max = 4
hidden_dim = 50
iterations = 2000
batch_size = 100

objective = neuron
narrow = False

input_dim = 10


x, y = generate_data(N_max, batch_size, input_dim, objective, narrow)

for i in range(20):
    model = Symmetric(input_dim, 2000, hidden_dim)
    model.train()
    losses = train(model, x, y, iterations, lamb = 0.00001)
    model.eval()
    print(losses[::int(iterations/10)])
    print("f1", generalization_error([4], 5000, model, objective, narrow))
    
    model = KNN(input_dim, 2000, hidden_dim)
    model.train()
    losses = train(model, x, y, iterations, lamb = 0.00001)
    model.eval()
    print(losses[::int(iterations/10)])
    print("f2", generalization_error([4], 5000, model, objective, narrow))
    
    model = KK(input_dim, 2000, 1000)
    model.train()
    losses = train(model, x, y, iterations, lamb = 0.00001)
    model.eval()
    print(losses[::int(iterations/10)])
    print("f3", generalization_error([4], 5000, model, objective, narrow))

[0.011131564155220985, 0.0014981222338974476, 0.0005335101159289479, 0.0005359654896892607, 0.0005274279974400997, 0.0005361093790270388, 0.0005384979886002839, 0.0005354254972189665, 0.0005588674684986472, 0.0005240787868387997]
f1 [0.00656719]
[0.06285697966814041, 0.013469170778989792, 0.01869463175535202, 0.01869463175535202, 0.013469170778989792, 0.030276844277977943, 0.008029865100979805, 0.013469170778989792, 0.0118960440158844, 0.013469170778989792]
f2 [0.01477099]
[0.009445526637136936, 0.0003498205041978508, 0.001232254900969565, 0.0004458093026187271, 5.454664278659038e-05, 0.00011973125947406515, 2.8050522814737633e-05, 0.00255468743853271, 0.0023875662591308355, 0.0016647777520120144]
f3 [0.00882316]
[0.0073065925389528275, 0.001781738130375743, 0.0006876902189105749, 0.0006504563498310745, 0.0006505273631773889, 0.0006528985104523599, 0.0006745976861566305, 0.0007234070799313486, 0.0006582402274943888, 0.0006518846494145691]
f1 [0.00631934]


In [None]:
###############################################

In [None]:
#Plot for Figure 3

# N_max = 4
# hidden_dim = 50

# iterations = 1000
# batch_size = 100

# input_dim = 5

# objective = mean
# narrow = False

# log_plot = True

# f1 = Symmetric(input_dim, 1000, hidden_dim)
# f2 = DeepSets(input_dim, 1000, hidden_dim)
# f1.__name__ = "S1"
# f2.__name__ = "DeepSets"

# models = [f1, f2]

# N_list = np.arange(2, N_max + 16)

# for model in models:
#     x, y = generate_data(N_max, batch_size, input_dim, objective, narrow)

#     lamb = 0.

#     runs = 10
#     run_errors = np.zeros((runs, len(N_list)))
#     for i in range(runs):
#         x, y = generate_data(N_max, batch_size, input_dim, objective, narrow)
#         model_copy = copy.deepcopy(model)
#         model_copy.reinit()
#         train(model_copy, x, y, iterations, lamb)
#         errors = generalization_error(N_list, 1000, model_copy, objective, narrow)
#         run_errors[i] = np.array(errors)

#     mean_error = np.mean(run_errors, axis = 0)
#     std_error = np.std(run_errors, axis = 0)
#     if log_plot:
#         plt.semilogy(N_list, mean_error, label = model.__name__)
#     else:
#         plt.plot(N_list, mean_error, label = model.__name__)
#     plt.fill_between(N_list, mean_error - std_error, mean_error + std_error, alpha = 0.2)


# plt.legend()
# #     plt.ylim([1e-5, 1e10]) 
# plt.xlabel("N")
# plt.ylabel("Mean Square Error")
# narrow_str = "Narrow" if narrow else "Wide"
# plt.title("Normalized vs. Unnormalized generalization for " + objective.__name__)
# plt.savefig("plots_high_dim/" + "deepsets")
# #     plt.show()
# plt.close()