In [1]:
import json
import os
import random
from tqdm import tqdm
import numpy as np
import torch, torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data.dataset import Dataset, TensorDataset  
torch.backends.cudnn.benchmark=True
from pyhessian import hessian # Hessian computation



In [2]:
class Network(nn.Module):
    def __init__(self):
        super().__init__()
        
        # Inputs to hidden layer linear transformation
        self.hidden = nn.Linear(784, 784)
        self.hidden2 = nn.Linear(784, 600)
        self.hidden3 = nn.Linear(600, 400)
        self.hidden4 = nn.Linear(400, 200)
        # Output layer, 62 units 
        self.output = nn.Linear(200, 62)
        
        # Define sigmoid activation and softmax output 
        self.ReLu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)
        
    def forward(self, x):
        # Pass the input tensor through each of our operations
        x = torch.reshape(x, (-1, 784))
        x = self.hidden(x)
        x = self.ReLu(x)
        x = self.hidden2(x)
        x = self.ReLu(x)
        x = self.hidden3(x)
        x = self.ReLu(x)
        x = self.hidden4(x)
        x = self.ReLu(x)
        x = self.output(x)
        x = self.softmax(x)
        
        return x

In [3]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=5)
        self.conv2 = nn.Conv2d(32, 128, kernel_size=5)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(2048, 1024)
        self.fc3 = nn.Linear(1024, 512)
        self.fc5 = nn.Linear(512, 62)
        self.softmax = nn.Softmax(dim=1)
    def forward(self, x):
        
        x = torch.unsqueeze(x,1)
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        #print(x.size())
        x = x.view(-1, 2048)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = F.relu(self.fc3(x))
        x = F.dropout(x, training=self.training)
        x = self.fc5(x)
        x = self.softmax(x)
        return x

In [4]:
#################################
##### Neural Network model #####
#################################

cfg = {
    'VGG11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'VGG16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'VGG19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

class VGG(nn.Module):
    def __init__(self, vgg_name):
        super(VGG, self).__init__()
        self.features = self._make_layers(cfg[vgg_name])
        self.classifier = nn.Sequential(
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Linear(512, 512),
            nn.ReLU(True),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        out = self.features(x)
        out = out.view(out.size(0), -1)
        out = self.classifier(out)
        output = F.log_softmax(out, dim=1)
        return output

    def _make_layers(self, cfg):
        layers = []
        in_channels = 3
        for x in cfg:
            if x == 'M':
                layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
            else:
                layers += [nn.Conv2d(in_channels, x, kernel_size=3, padding=1),
                           nn.BatchNorm2d(x),
                           nn.ReLU(inplace=True)]
                in_channels = x
        layers += [nn.AvgPool2d(kernel_size=1, stride=1)]
        return nn.Sequential(*layers)

In [5]:
def client_update(client_model, optimizer, train_loader, mode, epoch=5):
    """
    This function updates/trains client model on client data
    """
    client_model.train()
    for e in range(epoch):
        for batch_idx, (inputs, target) in enumerate(train_loader):
            inputs, target = inputs.cuda(), target.cuda()
            optimizer.zero_grad()
            output = client_model(inputs)
#             target = torch.nn.functional.one_hot(target, 62)
#             print(target.shape)
            loss = nn.functional.nll_loss(output, target)
            loss.backward()
            optimizer.step()
    if mode == 'Average':
        return loss.item()
    if mode == 'HessFuse':
        client_model.eval()
        for batch_idx, (inputs, target) in enumerate(train_loader):
                inputs, target = inputs.cuda(), target.cuda()
                target = torch.nn.functional.one_hot(target)
                loss2 = torch.nn.CrossEntropyLoss()
                hessian_comp = hessian(client_model, loss2, data=(inputs, target), cuda=True)
                top_eigenvalues, top_eigenvector = hessian_comp.eigenvalues()
                break

        return loss.item(), top_eigenvalues

In [6]:
def server_aggregate(global_model, client_models, weights):
    """
    This function has aggregation method 'mean'
    """
    ### This will take simple mean of the weights of models ###
    global_dict = global_model.state_dict()
    for k in global_dict.keys():
        global_dict[k] = torch.stack([(weights[0]*(client_models[0].state_dict()[k].float())) for i in range(len(client_models))], 0).mean(0)
            
    global_model.load_state_dict(global_dict)
    for model in client_models:
        model.load_state_dict(global_model.state_dict())

In [7]:
def test(global_model, test_loader):
    """This function test the global model on test data and returns test loss and test accuracy """
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        count = 0
        for data, target in test_loader:
            data, target = data.cuda(), target.cuda()
#             target = torch.nn.functional.one_hot(target)
            output = global_model(data)
            test_loss += nn.functional.nll_loss(output, target, reduction='sum').item()  # sum up batch loss
            pred = output.argmax(dim=1, keepdim=True)  # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()
            count += 1
        
    test_loss /= len(test_loader.dataset)
    acc = correct / len(test_loader.dataset)

    return test_loss, acc

In [8]:
num_clients = 1
num_selected = 1
num_rounds = 150
epochs = 5
batch_size = 100
# random.seed(13)
# np.random.seed(13)
# torch.manual_seed(13)

#############################################################
##### Creating desired data distribution among clients  #####
#############################################################


# # Image augmentation 
# transform_train = transforms.Compose([
#     transforms.RandomCrop(32, padding=4),
#     transforms.RandomHorizontalFlip(),
#     transforms.ToTensor(),
#     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
# ])

# # Loading CIFAR10 using torchvision.datasets
# traindata = datasets.CIFAR10('./data', train=True, download=True,
#                        transform= transform_train)

# # Dividing the training data into num_clients, with each client having equal number of images
# # Normalizing the test images
# transform_test = transforms.Compose([
#     transforms.ToTensor(),
#     transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
# ])

# # Loading the test iamges and thus converting them into a test_loader
# test_loader = torch.utils.data.DataLoader(
#         datasets.CIFAR10('./data', train=False, transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
#         ), batch_size=batch_size, shuffle=True)
import matplotlib.pyplot as plt
with open('/home/asahebpa/leaf/data/femnist/data/train/all_data_5_niid_0_keep_0_train_9.json') as f:
  data = json.load(f)
Usersdata = data['user_data']
UsersID = data['users']
train_loader = [torch.utils.data.DataLoader(TensorDataset(torch.FloatTensor(np.reshape(Usersdata[UsersID[i]]['x'],(-1,28,28))),torch.tensor(Usersdata[UsersID[i]]['y'])), batch_size=batch_size, shuffle=True) for i in range(num_clients)]

with open('/home/asahebpa/leaf/data/femnist/data/train/all_data_5_niid_0_keep_0_train_9.json') as f:
  data = json.load(f)
Usersdata = data['user_data']
UsersID = data['users']
testx = torch.FloatTensor(np.reshape(Usersdata[UsersID[0]]['x'],(-1,28,28)))
testy = torch.tensor(Usersdata[UsersID[0]]['y'])
for i in range(99):
    testx = torch.cat((testx,torch.FloatTensor(np.reshape(Usersdata[UsersID[i+1]]['x'],(-1,28,28)))),dim=0)
    testy = torch.hstack((testy, torch.tensor(Usersdata[UsersID[i+1]]['y'])))

test_loader = torch.utils.data.DataLoader(TensorDataset(testx, testy), batch_size=batch_size, shuffle=True)
    

In [9]:
############################################
#### Initializing models and optimizer  ####
############################################

# #### global model ##########
# global_model =  VGG('VGG19').cuda()
global_model =  Network().cuda()
# ############## client models ##############
# client_models = [ VGG('VGG19').cuda() for _ in range(num_selected)]
client_models = [ Network().cuda() for _ in range(num_selected)]
for model in client_models:
    model.load_state_dict(global_model.state_dict()) ### initial synchronizing with global model 

############### optimizers ################
opt = [optim.SGD(model.parameters(), lr=0.1) for model in client_models]


###### List containing info about learning #########
losses_train = []
losses_test = []
acc_train = []
acc_test = []
# Runnining FL
mode = 'Average'
for r in range(num_rounds):
    # select random clients
    client_idx = np.random.permutation(num_clients)[:num_selected]
    # client update
    losstot = 0
    eigs = np.ones(num_selected)
    for i in tqdm(range(num_selected)):
        if mode == 'HessFuse':
            loss, eigss = client_update(client_models[0], opt[0], train_loader[client_idx[0]], mode, epoch=epochs)
            eigs[i] = eigss[0]
        if mode == 'Average':
            loss = client_update(client_models[0], opt[0], test_loader, mode, epoch=epochs)
        losstot += loss
        
    weights = eigs/(np.sum(eigs))
    losses_train.append(loss)
    # server aggregate
    server_aggregate(global_model, client_models, weights*num_selected)
    
    test_loss, acc = test(client_models[0], test_loader)
    losses_test.append(test_loss)
    acc_test.append(acc)
    print('%d-th round' % r)
    print('average train loss %0.3g | test loss %0.3g | test acc: %0.3f' % (loss / num_selected, test_loss, acc))

100%|██████████| 1/1 [00:01<00:00,  1.26s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

0-th round
average train loss -0.0185 | test loss -0.0182 | test acc: 0.067


100%|██████████| 1/1 [00:01<00:00,  1.27s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

1-th round
average train loss -0.115 | test loss -0.078 | test acc: 0.078


100%|██████████| 1/1 [00:01<00:00,  1.25s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

2-th round
average train loss -0.096 | test loss -0.078 | test acc: 0.078


100%|██████████| 1/1 [00:01<00:00,  1.20s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

3-th round
average train loss -0.134 | test loss -0.078 | test acc: 0.078


100%|██████████| 1/1 [00:01<00:00,  1.18s/it]
  0%|          | 0/1 [00:00<?, ?it/s]

4-th round
average train loss -0.0209 | test loss -0.0781 | test acc: 0.078


  0%|          | 0/1 [00:01<?, ?it/s]


KeyboardInterrupt: 

In [None]:
# import random
# choose = (random.randrange(len(datareally['y'])))

# datareally = (dd['f1816_24'])
# x1 = datareally['x'][choose]
# print(datareally['y'][choose])

# import numpy as np
# x1 = np.zeros((784,63))
# counters = np.zeros((63))
# print(x1.shape)
# for i in range(len(datareally['y'])):
#     print(np.array(datareally['x'][i]).shape)
#     label = datareally['y'][i]
#     counters[label] += 1 
#     x1[:, label] += np.array(datareally['x'][i])

# import matplotlib.pyplot as plt
# fig, axes = plt.subplots(8,8, figsize=(8,8))
# for i,ax in enumerate(axes.flat):
#     tempol = x1[:, i]/counters[i]
#     xplot = np.reshape(np.ravel(tempol), (28, 28))
#     ax.imshow(xplot)

