In [1]:
import matplotlib.pyplot as plt
import numpy as np
from numpy.linalg import norm
import itertools
import copy

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
%matplotlib inline

In [2]:
batch_size_train = 4 * 64
batch_size_test = 6 * 32

In [3]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('data', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)
train_loader.set_size = 4

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('data', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)
test_loader.set_size = 6

In [4]:
examples = enumerate(train_loader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_data.shape)
print(example_targets.shape)

torch.Size([256, 1, 28, 28])
torch.Size([256])


In [5]:
examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_data.shape)
print(example_targets.shape)

torch.Size([192, 1, 28, 28])
torch.Size([192])


In [6]:
class LittleBlock(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(LittleBlock, self).__init__()
        self.fc = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return F.relu(self.fc(x))

In [7]:
class Block(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(Block, self).__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [8]:
#F1: NN + NN
class Symmetric(nn.Module):
    def __init__(self, input_dim, hidden_dim_phi, hidden_dim_rho):
        super(Symmetric, self).__init__()
        
        self.hidden_dim_phi = hidden_dim_phi
        self.hidden_dim_rho = hidden_dim_rho
        self.input_dim = input_dim
        
        self.rho = None
        self.phi = None
        self.reinit()
    
    def reinit(self):
        self.rho = Block(self.hidden_dim_phi, self.hidden_dim_rho, 1)
        self.phi = LittleBlock(self.input_dim, self.hidden_dim_phi)
    
    def forward(self, x):        
        batch_size, input_set_dim, input_dim = x.shape
        
        x = x.view(-1, input_dim)
        z = self.phi(x)
        z = z.view(batch_size, input_set_dim, -1)
        z = torch.mean(z, 1)
        return self.rho(z)
    
    def regularize(self, lamb):
        reg_loss = 0.
        W1 = self.phi.fc.weight
        W2 = self.rho.fc1.weight
        w = self.rho.fc2.weight
        
        W1 = torch.norm(W1, dim = 1, keepdim = True)
        W2 = torch.abs(W2)
        w = torch.abs(w)
        
        reg_loss = torch.matmul(w, torch.matmul(W2, W1)).item()

        return lamb * reg_loss

In [9]:
#F2: K + NN
class KNN(Symmetric):
    def __init__(self, input_dim, hidden_dim_phi, hidden_dim_rho):
        super(KNN, self).__init__(input_dim, hidden_dim_phi, hidden_dim_rho)

    def reinit(self):
        super(KNN, self).reinit()
        
        self.phi.fc.weight.requires_grad = False
        self.phi.fc.bias.requires_grad = False
        
        self.phi.fc.weight.div_(torch.norm(self.phi.fc.weight, dim = 1, keepdim = True))
        
    def regularize(self, lamb):
        reg_loss = 0.

        W2 = self.rho.fc1.weight
        w = self.rho.fc2.weight
        
        W2 = torch.norm(W2, dim = 1, keepdim = True)
        w = torch.abs(w)
        
        reg_loss = torch.matmul(w, W2).item()
        
        return lamb * reg_loss

In [10]:
#F3: K + K
class KK(KNN):
    def __init__(self, input_dim, hidden_dim_phi, hidden_dim_rho):
        super(KK, self).__init__(input_dim, hidden_dim_phi, hidden_dim_rho)

    def reinit(self):
        super(KK, self).reinit()
        
        self.rho.fc1.weight.requires_grad = False
        self.rho.fc1.bias.requires_grad = False
        
        self.rho.fc1.weight.div_(torch.norm(self.rho.fc1.weight, dim = 1, keepdim = True))

        
    def regularize(self, lamb):
        reg_loss = 0.
        
        w = self.rho.fc2.weight

        reg_loss = torch.norm(w)

        return lamb * reg_loss

In [11]:
def forward(model, dataloader, iterations, lamb = 0.1, train = True):
    criterion = nn.MSELoss()
    
    if train:
        model.train()
        optimizer = optim.Adam(model.parameters(), lr=0.01)
    else:
        model.eval()

    losses = []
    for i in range(iterations):
        print("iter", i)
        for batch_idx, (x, y) in enumerate(dataloader):
                    
            if x.shape[0] != dataloader.batch_size:
                continue
                
            set_size = dataloader.set_size
            x = x.view(-1, set_size, 28*28)
            y = y.view(-1, set_size).float()
#             y = torch.sum(y, dim = 1, keepdim = True).float()
            y = torch.mean(y, dim = 1, keepdim = True).float()


            outputs = model(x)
            loss = criterion(outputs, y)

            if train:
                optimizer.zero_grad()
                loss += model.regularize(lamb)
                loss.backward()
                optimizer.step()

            losses.append(loss.item())
            
        if not train:
            break
    
    model.eval()
    return losses

In [12]:
def cross_validate(model, dataloader, iterations, lambs, verbose):
    models = []
    for lamb in lambs:
        model_copy = copy.deepcopy(model)
        losses = forward(model_copy, dataloader, iterations, lamb, train = True)
        models.append(model_copy)
        if verbose:
            print(losses[::10])
    return models

In [13]:
def compare_models(hidden_dim, iterations, input_dim = 28*28, verbose = False):
        
    f1 = Symmetric(input_dim, 1000, hidden_dim)
    f2 = KNN(input_dim, 1000, hidden_dim)
    f3 = KK(input_dim, 1000, 1000)

    f1.__name__ = "S1"
    f2.__name__ = "S2"
    f3.__name__ = "S3"

    models = [f1, f2, f3]
    
    lambs = [0., 1e-6, 1e-4, 1e-2]

    for model in models:
        print("model", model.__name__)
        cv_models = cross_validate(model, train_loader, iterations, lambs, verbose)
        
        validation_errors = np.zeros_like(lambs)
        for i, cv_model in enumerate(cv_models):
            errors = forward(cv_model, train_loader, iterations, lamb = 0., train = False)
            validation_errors[i] = np.mean(np.array(errors))
        
        i = np.argmin(validation_errors)
        lamb = lambs[i]
            
        runs = 3
        run_errors = np.zeros(runs)
        for i in range(runs):
            print("run", i)
            model_copy = copy.deepcopy(model)
            model_copy.reinit()
            forward(model_copy, train_loader, iterations, lamb, train = True)
            errors = forward(model_copy, test_loader, iterations, lamb = 0., train = False)
            run_errors[i] = np.mean(np.array(errors))
        
        mean_error = np.mean(run_errors)
        std_error = np.std(run_errors)
        
        print("mean: {}, std: {}".format(mean_error, std_error))
        
#         if log_plot:
#             plt.semilogy(N_list, mean_error, label = model.__name__)
#         else:
#             plt.plot(N_list, mean_error, label = model.__name__)
#         plt.fill_between(N_list, mean_error - std_error, mean_error + std_error, alpha = 0.2)

    
#     plt.legend()
#     plt.ylim([1e-5, 1e-1]) 
#     plt.xlabel("N")
#     plt.ylabel("Mean Square Error")
#     narrow_str = "Narrow" if narrow else "Wide"
#     plt.title(narrow_str + " generalization for " + objective.__name__)
#     scale_str = "" if not scaleup else "scaled"
#     plt.savefig("plots_high_dim/" + objective.__name__ + "_" + narrow_str + "_" + str(input_dim) + scale_str)
# #     plt.show()
#     plt.close()

In [14]:
compare_models(64, 3, verbose = True)

model S1
iter 0
iter 1
iter 2
[22.364660263061523, 17.521867752075195, 3.2143421173095703, 2.0942001342773438, 2.023754835128784, 1.5800632238388062, 1.2091484069824219, 1.3987340927124023, 1.1901233196258545, 1.358591914176941, 1.4799902439117432, 0.930798351764679, 1.2376352548599243, 1.0660595893859863, 1.017780065536499, 1.158918857574463, 0.7857360243797302, 1.0668962001800537, 0.5916444659233093, 0.6670721769332886, 0.6000826358795166, 0.7016699314117432, 0.731391191482544, 0.6356322169303894, 0.7438897490501404, 0.6143529415130615, 0.7291343808174133, 0.5926221609115601, 0.5311410427093506, 0.7640275359153748, 0.4832967221736908, 0.37477344274520874, 0.803891122341156, 0.628643274307251, 0.5441405177116394, 0.459367036819458, 0.5027850270271301, 0.363884299993515, 0.42216956615448, 0.6339213848114014, 0.6465443968772888, 0.4159659445285797, 0.692392110824585, 0.5154706239700317, 0.6862219572067261, 0.7934747934341431, 0.562089741230011, 0.44370678067207336, 0.4584224820137024, 0

iter 1
iter 2
[22.355693817138672, 1.5483862161636353, 1.0250723361968994, 1.143798589706421, 0.8113694190979004, 0.8473948836326599, 1.120110273361206, 1.1338553428649902, 0.8362168669700623, 0.9086422920227051, 0.6761785745620728, 1.000824213027954, 0.6184148192405701, 1.0730845928192139, 0.7609792947769165, 0.579908013343811, 1.0693010091781616, 1.156562089920044, 0.8650690913200378, 0.8240710496902466, 0.940274715423584, 0.5919563174247742, 0.5979586839675903, 0.8340192437171936, 0.6314544677734375, 0.6717199087142944, 0.5857135653495789, 0.6092777848243713, 0.42694365978240967, 0.5124323964118958, 0.6395403146743774, 0.5280913710594177, 0.6009917259216309, 0.49465349316596985, 0.5481878519058228, 0.5793842077255249, 0.42572152614593506, 0.49376145005226135, 0.7026071548461914, 0.5003964304924011, 0.5426134467124939, 0.6919183135032654, 0.9539603590965271, 1.3839397430419922, 0.6258089542388916, 0.6195499897003174, 0.5209871530532837, 0.5972344279289246, 0.8775133490562439, 0.66833

KeyboardInterrupt: 