In [1]:
import os
import numpy as np
import torch
import itertools
import matplotlib.pyplot as plt
%matplotlib inline
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import math

from tqdm import tqdm
import time

In [2]:
# Corollary 2.4 in Mohammadi 2014 - for 1d
def alpha_estimator_one(m, X):
    N = len(X)
    n = int(N/m) # must be an integer
    
    X = X[0:n*m]
    
    Y = np.sum(X.reshape(n, m),1)
    eps = np.spacing(1)

    Y_log_norm =  np.log(np.abs(Y) + eps).mean()
    X_log_norm =  np.log(np.abs(X) + eps).mean()
    diff = (Y_log_norm - X_log_norm) / math.log(m)
    return 1 / diff

In [3]:
# Corollary 2.4 in Mohammadi 2014 - for multi-d
def alpha_estimator_multi(m, X):
    # X is N by d matrix
    N = X.size()[0]   
    n = int(N/m) # must be an integer
#     print(N,n)
    X = X[0:n*m,:]
#     print(X.size())
    Y = torch.sum(X.view(n, m, -1), 1)
    eps = np.spacing(1)
    Y_log_norm = torch.log(Y.norm(dim=1) + eps).mean()
    X_log_norm = torch.log(X.norm(dim=1) + eps).mean()
    diff = (Y_log_norm - X_log_norm) / math.log(m)
    return 1 / diff.item()

In [4]:
# A simple FCN
class simpleNet(nn.Module):

    def __init__(self, input_dim=28*28 , width=128, depth=3, num_classes=10):
        super(simpleNet, self).__init__()
        self.input_dim = input_dim 
        self.width = width
        self.depth = depth
        self.num_classes = num_classes
        
        layers = self.get_layers()

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, self.width, bias=False),
            nn.ReLU(inplace=True),
            *layers,
            nn.Linear(self.width, self.num_classes, bias=False),
        )

    def get_layers(self):
        layers = []
        for i in range(self.depth - 2):
            layers.append(nn.Linear(self.width, self.width, bias=False))
            layers.append(nn.ReLU())
        return layers

    def forward(self, x):
        x = x.view(x.size(0), self.input_dim)
        x = self.fc(x)
        return x

In [5]:
PATH = './3FCN-MNIST/'
# lr_list = np.linspace(0.001,0.1,20)
lr_list = [0.05,0.06,0.07,0.08,0.09,0.1,0.11,0.12]

depth = 3
num_nets = 1000
nets = []
        

In [4]:
def get_layerWise(net):
    w = []
    flag=0
    for p in net.parameters():    
        if p.requires_grad:
            if flag==1:
#                 w[-1]=torch.cat((w[-1],p.view(-1,1)),1)
                flag=0
            else:
                flag=1
                w.append(p)
    for i in range(len(w)):
        w[i]=w[i].detach().numpy()
    res=np.array(w)
    return res

In [6]:
def compute_alphas_centralized(etas, PATH, depth):
#     alphas_mc    = np.zeros((len(etas), depth))-1
    alphas_multi = np.zeros((len(etas), depth))-1
#     alphas_single= np.zeros(len(etas))-1
#     alphas_haus    = np.zeros((len(etas), depth))-1
#     print(num_nets)
    
    
    
    for ei, eta in tqdm(enumerate(etas)):
        
        tmp_path = PATH + 'LR{}/'.format(ei)
        print(tmp_path)
        
#         weights = []
        weights_unfold = []
        weights_unfold_merge = []
        for i in range(depth):
#             weights.append([])
            weights_unfold.append([])

        # record the layers in different arrays
        for i in range(num_nets):
            tmp_path_mod = tmp_path + 'model{}'.format(i+1) +'.pth'
            tmp_net = simpleNet()
            tmp_net = torch.load(tmp_path_mod,map_location='cpu')
#             layerwise_list = get_layerWise(tmp_net)
            for ix, p in enumerate(tmp_net.parameters()):
                if not (ix % 2 == 0):
                    continue
                layer = p.detach().numpy()#.astype(np.float16)
                if(i == 0):
                    weights_unfold[ix//2] = layer / (num_nets * 1.0)
                else:
                    weights_unfold[ix//2] += layer / (num_nets * 1.0)


                layer = layer.reshape(-1,1)
#                 weights[ix].append(layer)

#         for i in range(depth):
#             weights[i] = np.concatenate(weights[i], axis = 1).astype(np.float16)





        for i in range(depth):
#             print(weights_unfold[i].shape)
#             print(i)
            tmp_mean    = np.mean(weights_unfold[i], axis=0)
            
#             tmp_mean    = tmp_mean[..., np.newaxis]
            tmp_mean = tmp_mean[np.newaxis,...]
#             print(tmp_mean.shape)
#             tmp_weights = weights_unfold[i] - tmp_mean.T
            tmp_weights = weights_unfold[i] - tmp_mean
#             print(tmp_weights.shape)
#             print(len(tmp_weights.shape))
            if len(tmp_weights.shape) == 4:
#                 print('yes')
                tmp_weights = np.reshape(tmp_weights, (tmp_weights.shape[0] * tmp_weights.shape[1], -1))
#                 print(tmp_weights.shape)
            
            alphas_multi[ei,i] = np.median([alpha_estimator_multi(mm, torch.from_numpy(tmp_weights)) for mm in (2, 5, 10)])





    return alphas_multi


In [41]:
tmp_mod = simpleNet()
idx = 0
for i in tmp_mod.parameters():
    print(i.shape)
#     print(idx)

torch.Size([6, 3, 5, 5])
torch.Size([6])
torch.Size([16, 6, 5, 5])
torch.Size([16])
torch.Size([120, 400])
torch.Size([120])
torch.Size([84, 120])
torch.Size([84])
torch.Size([10, 84])
torch.Size([10])


In [11]:
PATH = './3FCN-MNIST-unif/'
lr_list = [0.05,0.06,0.07,0.08,0.09,0.1,0.11,0.12]

depth = 3
num_nets = 1000
nets = []
alphas_mc_cent = compute_alphas_centralized(lr_list, PATH, depth)


0it [00:00, ?it/s]

./CNN-CIFAR10-unif/LR0/


1it [00:07,  7.41s/it]

./CNN-CIFAR10-unif/LR1/


2it [00:14,  7.27s/it]

./CNN-CIFAR10-unif/LR2/


3it [00:21,  7.20s/it]

./CNN-CIFAR10-unif/LR3/


4it [00:28,  7.21s/it]

./CNN-CIFAR10-unif/LR4/


5it [00:36,  7.28s/it]

./CNN-CIFAR10-unif/LR5/


6it [00:43,  7.30s/it]

./CNN-CIFAR10-unif/LR6/


7it [00:50,  7.34s/it]

./CNN-CIFAR10-unif/LR7/


8it [00:58,  7.30s/it]


In [12]:
alphas_mc_cent

array([[1.75965741, 2.10764968, 2.08125705, 2.10009752, 2.32603469],
       [1.72196667, 1.88239159, 2.0448582 , 2.09197958, 1.95549839],
       [1.66763325, 1.97760457, 2.05219056, 2.06174552, 1.94725793],
       [1.52074925, 1.9689747 , 2.0669305 , 1.98348258, 2.01566   ],
       [1.69568467, 2.12406174, 2.04512429, 2.1079554 , 2.05447616],
       [1.69549531, 2.12672936, 2.04408338, 2.03716302, 1.84752138],
       [1.76361753, 2.07184698, 2.03102426, 2.10495943, 1.96493566],
       [1.63770094, 2.1224194 , 2.03361919, 1.97968842, 1.82101094]])

In [13]:
np.mean(alphas_mc_cent, axis = 1)

array([2.07493927, 1.93933889, 1.94128636, 1.91115941, 2.00546045,
       1.95019849, 1.98727677, 1.91888778])