In [1]:
import matplotlib.pyplot as plt
import numpy as np
from numpy.linalg import norm
import itertools
import copy

import torch
import torchvision
import torchvision.transforms as transforms

import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

from model import Symmetric, DeepSets, KNN, KK

%matplotlib inline

In [2]:
batch_size_train = 32
batch_size_test = 32

In [22]:
class PointCloud(object):

    def __init__(self, cloud_size):
        self.cloud_size = cloud_size

    def __call__(self, image):

        flat = image.flatten()
        args = torch.argsort(flat)[-self.cloud_size:].int()
        args = args[torch.randperm(self.cloud_size)]
        rows = torch.floor_divide(args, 28)
        cols = torch.fmod(args, 28)
        
        image = torch.zeros(self.cloud_size, 3)
        image[:,0] = (rows - 14) / 28.
        image[:,1] = (cols - 14) / 28.
        image[:,2] = 1 #bias term

        return image

In [23]:
cloud_size = 100
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('data', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                                 PointCloud(cloud_size)
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('data', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                                 PointCloud(cloud_size)
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [24]:
examples = enumerate(train_loader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_data.shape)
print(example_targets.shape)

torch.Size([32, 100, 3])
torch.Size([32])


In [39]:
def forward(model, dataloader, iterations, lamb = 0.1, train = True):
    criterion = nn.CrossEntropyLoss()
    
    if train:
        model.train()
        optimizer = optim.Adam(model.parameters(), lr=0.01)
    else:
        model.eval()

    losses = []
    for i in range(iterations):
        print("iter", i)
        for batch_idx, (x, y) in enumerate(dataloader):
            
            outputs = model(x)
            loss = criterion(outputs, y)

            if train:
                optimizer.zero_grad()
                loss += model.regularize(lamb)
                loss.backward()
                optimizer.step()

            losses.append(loss.item())
            
        if not train:
            break
    
    model.eval()
    return losses

In [40]:
def cross_validate(model, dataloader, iterations, lambs, verbose):
    models = []
    for lamb in lambs:
        model_copy = copy.deepcopy(model)
        losses = forward(model_copy, dataloader, iterations, lamb, train = True)
        models.append(model_copy)
        if verbose:
            print(losses[::100])
    return models

In [41]:
def compare_models(hidden_dim, iterations, input_dim = 2, verbose = False):
        
    f1 = Symmetric(input_dim, hidden_dim, hidden_dim, 10)
    f2 = KNN(input_dim, hidden_dim, hidden_dim, 10)
    f3 = KK(input_dim, hidden_dim, hidden_dim, 10)

    f1.__name__ = "S1"
    f2.__name__ = "S2"
    f3.__name__ = "S3"

    models = [f1, f2, f3]
    
    lambs = [0., 1e-6, 1e-4, 1e-2]

    for model in models:
        print("model", model.__name__)
        cv_models = cross_validate(model, train_loader, iterations, lambs, verbose)
        
        validation_errors = np.zeros_like(lambs)
        for i, cv_model in enumerate(cv_models):
            errors = forward(cv_model, train_loader, iterations, lamb = 0., train = False)
            validation_errors[i] = np.mean(np.array(errors))
        
        i = np.argmin(validation_errors)
        lamb = lambs[i]
            
        runs = 3
        run_errors = np.zeros(runs)
        for i in range(runs):
            print("run", i)
            model_copy = copy.deepcopy(model)
            model_copy.reinit()
            forward(model_copy, train_loader, iterations, lamb, train = True)
            errors = forward(model_copy, test_loader, iterations, lamb = 0., train = False)
            run_errors[i] = np.mean(np.array(errors))
        
        mean_error = np.mean(run_errors)
        std_error = np.std(run_errors)
        
        print("mean: {}, std: {}".format(mean_error, std_error))
        
#         if log_plot:
#             plt.semilogy(N_list, mean_error, label = model.__name__)
#         else:
#             plt.plot(N_list, mean_error, label = model.__name__)
#         plt.fill_between(N_list, mean_error - std_error, mean_error + std_error, alpha = 0.2)

    
#     plt.legend()
#     plt.ylim([1e-5, 1e-1]) 
#     plt.xlabel("N")
#     plt.ylabel("Mean Square Error")
#     narrow_str = "Narrow" if narrow else "Wide"
#     plt.title(narrow_str + " generalization for " + objective.__name__)
#     scale_str = "" if not scaleup else "scaled"
#     plt.savefig("plots_high_dim/" + objective.__name__ + "_" + narrow_str + "_" + str(input_dim) + scale_str)
# #     plt.show()
#     plt.close()

In [42]:
compare_models(100, 3, verbose = True)

model S1
iter 0
iter 1
iter 2
[2.3260128498077393, 2.2904090881347656, 2.314302921295166, 2.3051066398620605, 2.2975456714630127, 2.3014228343963623, 2.300121307373047, 2.313122034072876, 2.3045859336853027, 2.299238443374634, 2.2810049057006836, 2.3037593364715576, 2.2685256004333496, 2.2843332290649414, 2.224290370941162, 2.1974143981933594, 2.1876871585845947, 2.211489200592041, 2.1255316734313965, 2.082728624343872, 2.1830146312713623, 2.168833017349243, 2.0730292797088623, 2.1426496505737305, 2.1896181106567383, 2.091404438018799, 2.0159718990325928, 1.9270375967025757, 1.8788323402404785, 1.8394209146499634, 1.9410144090652466, 1.8235831260681152, 1.7708690166473389, 1.8570257425308228, 1.7368743419647217, 1.8891913890838623, 1.9322465658187866, 1.855883240699768, 1.7349711656570435, 1.9124987125396729, 1.841839075088501, 1.91781747341156, 1.9753161668777466, 1.8703199625015259, 1.9146056175231934, 1.6711887121200562, 1.7270101308822632, 1.6787742376327515, 1.568129539489746, 1.4

iter 1
iter 2
[2.2965550422668457, 2.307572603225708, 2.306950807571411, 2.2971725463867188, 2.3018996715545654, 2.304514169692993, 2.298502206802368, 2.280287504196167, 2.2515783309936523, 2.2303237915039062, 2.2784242630004883, 2.2585246562957764, 1.9348798990249634, 2.1723244190216064, 2.1151492595672607, 2.022503614425659, 1.8198447227478027, 2.127082109451294, 1.982092261314392, 1.9131855964660645, 1.674591064453125, 1.711733341217041, 1.8298388719558716, 1.9202839136123657, 1.8331726789474487, 1.894999384880066, 1.8572266101837158, 1.815189242362976, 1.8048105239868164, 1.6539119482040405, 1.6646924018859863, 1.8906500339508057, 1.7344752550125122, 1.866538643836975, 1.8592941761016846, 1.6230101585388184, 1.914730191230774, 1.7426155805587769, 1.574817180633545, 1.9678140878677368, 1.8349010944366455, 1.4782272577285767, 1.97422456741333, 1.5495644807815552, 1.670025110244751, 1.681551218032837, 2.0760433673858643, 1.827675700187683, 1.55072021484375, 1.739856243133545, 1.659596

iter 1
iter 2
[2.30230712890625, 2.3177826404571533, 2.3164405822753906, 2.3054769039154053, 2.3064627647399902, 2.3097152709960938, 2.3181209564208984, 2.3018195629119873, 2.3058154582977295, 2.2974841594696045, 2.3002614974975586, 2.3038980960845947, 2.336743116378784, 2.2806529998779297, 2.287659168243408, 2.123671770095825, 2.296574115753174, 2.242584228515625, 2.246840715408325, 2.169581890106201, 2.153048276901245, 2.2576513290405273, 2.099327564239502, 2.031433343887329, 2.265861749649048, 2.1298084259033203, 2.0714938640594482, 1.9667214155197144, 1.993407130241394, 1.8997167348861694, 2.0241687297821045, 1.9730218648910522, 2.002216339111328, 2.0087740421295166, 2.153886556625366, 1.9857635498046875, 2.3617770671844482, 1.8768131732940674, 2.2156004905700684, 1.9079346656799316, 1.86032235622406, 1.8729581832885742, 1.8459786176681519, 1.8390023708343506, 1.9515169858932495, 1.9082850217819214, 1.7770074605941772, 1.617967963218689, 2.096160411834717, 1.9240235090255737, 1.945

iter 1
iter 2
[2.7997372150421143, 2.8210229873657227, 2.803724765777588, 2.825474262237549, 2.828339099884033, 2.828144073486328, 2.8264389038085938, 2.853926181793213, 2.845416307449341, 2.8496549129486084, 2.819216251373291, 2.7974112033843994, 2.803910255432129, 2.819061279296875, 2.8906514644622803, 2.825824499130249, 2.689833164215088, 2.724248170852661, 2.6116597652435303, 2.5523569583892822, 2.594130754470825, 2.6172592639923096, 2.797297954559326, 2.7084786891937256, 2.749371290206909, 2.6487767696380615, 2.7036218643188477, 2.69423246383667, 2.6077163219451904, 2.9839024543762207, 2.6342434883117676, 2.6921756267547607, 2.5309159755706787, 2.448484182357788, 2.5829062461853027, 2.7546091079711914, 2.718390464782715, 2.542471408843994, 2.721461057662964, 2.4852209091186523, 2.671170234680176, 2.561317205429077, 2.5006051063537598, 2.580552577972412, 2.4535813331604004, 2.5236923694610596, 2.5233347415924072, 2.507620334625244, 2.73752760887146, 2.758772373199463, 2.63538932800

iter 0
iter 0
iter 0
run 0
iter 0
iter 1
iter 2
iter 0
run 1
iter 0
iter 1
iter 2
iter 0
run 2
iter 0


KeyboardInterrupt: 