# INITIALISATION

In [7]:
from time import time
from pathlib import Path
from IPython.display import Image, display
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable
from torch.utils import data
from torchvision import datasets, transforms
from torchvision.transforms.functional import to_pil_image, resize, to_tensor
from torchvision.transforms.functional import normalize
import os
import shutil
import random
from copy import deepcopy
from zipfile import ZipFile, ZIP_DEFLATED
from IPython.core.interactiveshell import InteractiveShell

InteractiveShell.ast_node_interactivity = "all"

In [48]:
SEEDS = [0,1,2,3,4]
FORCE = False
INTENS = 0.4
DEFAULTS = {
            "w0": 0.2,  # float >= 0, regularisation parameter
            "w": 0.2,   # float >= 0, harmonisation parameter
            "lr_gen": 0.02,     # float > 0, learning rate of global model
            "lr_node": 0.02,    # float > 0, learning rate of local models
            "NN" : "base",     # "base" or "conv", neural network architecture
            "opt": optim.Adam,    # any torch otpimizer
            "gen_freq": 1,         # int >= 1, number of global steps for 1 local step

            "nbn": 1000,    # int >= 1, number of nodes
            "nbd": 60_000,  # int >= 1, nbd/nbn must be in [1, 60_000], total data
            "fracdish": 0,   # float in [0,1]
            "typ_dish": "zeros",# in ["honest", "zeros", "jokers", "one_evil", 
                                   # "byzantine", "trolls", "strats"]
            "heter": 0,        # int >= 0, heterogeneity of data repartition
            "nb_epochs": 100 # int >= 1, number of training epochs
            }
def defaults_help():
    ''' Structure of DEFAULTS dictionnary :
    
        "w0": 0.2,  # float >= 0, regularisation parameter
        "w": 0.2,   # float >= 0, harmonisation parameter
        "lr_gen": 0.02,     # float > 0, learning rate of global model
        "lr_node": 0.02,    # float > 0, learning rate of local models
        "NN" : "base",     # "base" or "conv", neural network architecture
        "opt": optim.Adam,    # any torch otpimizer
        "gen_freq": 1,     # int >= 1, number of global steps for 
                                                        1 local step

        "nbn": 1000,    # int >= 1, number of nodes
        "nbd": 60_000,  # int >= 1,  total data
                                    - nbd/nbn must be in [1, 60_000]
        "fracdish": 0,   # float in [0,1]
        "typ_dish": "zeros",# in ["honest", "zeros", "jokers", "one_evil", 
                                "byzantine", "trolls", "strats"]
        "heter": 0,        # int >= 0, heterogeneity of data repartition
        "nb_epochs": 100, # int >= 1, number of training epochs
    '''
    None

METRICS = ({"lab":"fit", "ord": "Training Loss", "f_name": "loss"}, 
           {"lab":"gen", "ord": "Training Loss", "f_name": "loss"}, 
           {"lab":"reg", "ord": "Training Loss", "f_name": "loss"}, 
           {"lab":"acc", "ord": "Accuracy", "f_name": "acc"}, 
           {"lab":"l2_dist", "ord": "l2 norm", "f_name": "l2dist"}, 
           {"lab":"l2_norm", "ord": "l2 norm", "f_name": "l2dist"}, 
           {"lab":"grad_sp", "ord": "Scalar Product", "f_name": "grad"}, 
           {"lab":"grad_norm", "ord": "Scalar Product", "f_name": "grad"}
           )


In [9]:
os.chdir("/content")
os.makedirs("distribution", exist_ok=True)
os.chdir("/content/distribution")

# DATA

## functions

In [10]:
# data import and management

def load_mnist(img_size=32):
    ''' return data and labels for train and test mnist dataset '''
    #---------------- train data -------------------
    mnist_train = datasets.MNIST('data', train=True, download=True)
    data_train = mnist_train.data
    labels_train = [mnist_train[i][1] for i in range(len(data_train))]

    pics = []
    for pic in data_train:
        pic = to_pil_image(pic)
        if img_size != 28:
            pic = resize(pic, img_size) # Resize image if needed
        pic = to_tensor(pic)            # Tensor conversion normalizes in [0,1]
        pics.append(pic)
    data_train = torch.stack(pics)

    #------------------  test data -----------------------
    mnist_test = datasets.MNIST('data', train=False, download=True)
    data_test = mnist_test.data
    labels_test = [mnist_test[i][1] for i in range(len(data_test))]

    pics = []
    for pic in data_test:
        pic = to_pil_image(pic)
        if img_size != 28:
            pic = resize(pic, img_size)   # Resize image if needed
        pic = to_tensor(pic)             # Tensor conversion normalizes in [0,1]
        pics.append(pic)
    data_test = torch.stack(pics)

    return (data_train,labels_train), (data_test,labels_test)

def query(datafull, nb, bias=0, fav=0):
    ''' return -nb random samples of -datafull '''
    data, labels = datafull
    idxs = list(range(len(data)))
    l = []
    h, w = data[0][0].shape
    d = torch.empty(nb, 1, h, w)
    if bias == 0:
        indexes = random.sample(idxs, nb) # drawing nb random indexes
    else :
        indexes = []
        for i in range(nb):
            idx = one_query(labels, idxs, bias, fav)
            indexes.append(idx)
            idxs.remove(idx) # to draw only once each index max
    for k, i in enumerate(indexes): # filling our query
        d[k] = data[i]
        l.append(labels[i])
    return d, l

def one_query(labels, idxs, redraws, fav):
    ''' labels : list of labels
        idxs : list of available indexes
        draws an index with a favorite label choice 
        fav : favorite label
        redraws : max nb of random redraws while fav not found
    '''
    lab = -1 
    while lab != fav and redraws >= 0:
        idx = idxs[random.randint(0, len(idxs)-1)]
        lab = labels[idx]
        redraws -= 1
    return idx

def list_to_longtens(l):
    ''' change a list into torch.long tensor '''
    tens = torch.empty(len(l), dtype=torch.long)  
    for i, lab in enumerate(l):                       
        tens[i] = lab
    return tens

def swap(l, n, m):
    ''' swap n and m values in l list '''
    return [m if (v==n) else n if (v==m) else v for v in l]


def distribute_data_rd(datafull, distrib, fav_lab=(0,0), 
                       dish=False, dish_lab=0, gpu=True): 
    '''draw random data on N nodes following distrib
        data, labels : raw data and labels
        distrib : int list, list of nb of data points for each node
        pref_lab : (prefered label, strength of preference (int))
        dish : boolean, if nodes are dishonest 
        dish_lab : 0 to 4, labelisation method

        returns : (list of batches of images, list of batches of labels)
    '''
    global FORCING1
    global FORCING2
    global FORCE
    data, labels = datafull
    N = len(distrib)
    data_dist = []      # list of len N
    labels_dist = []    # list of len N
    fav, strength = fav_lab

    for n, number in enumerate(distrib): #for each node
        # if strength == 0:  # if no preference
        d, l = query(datafull, number, strength, fav)
        # else:
        #     d, l = query(datafull, number, strength, fav)
        if gpu:
            data_dist.append(torch.FloatTensor(d).cuda())
        else:
            data_dist.append(torch.FloatTensor(d))
        if dish:                # if dishonest node

            # labels modification
            if dish_lab == 0: # random
                tens = torch.randint(10, (number,), dtype=torch.long)
            elif dish_lab == 1: # zeros
                tens = torch.zeros(number, dtype=torch.long)
            elif dish_lab == 2: # swap 1-7
                l = swap(l, 1, 7)
                tens = list_to_longtens(l)
            elif dish_lab == 3: # swap 2 random (maybe same)
                if FORCE: # to force same swap multiple times
                    if FORCING1 == -1:
                        FORCING1, FORCING2 = random.randint(0,9), random.randint(0,9)
                    l = swap(l, FORCING1, FORCING2)    
                else:         
                    n, m = random.randint(0,9), random.randint(0,9)
                    l = swap(l, n, m)
                tens = list_to_longtens(l)
              
            elif dish_lab == 4: # label +1
                tens = (list_to_longtens(l) + 1) % 10

        else:           # if honest node 
            tens = list_to_longtens(l) # needed for CrossEntropy later
        if gpu:
            tens = tens.cuda()
        labels_dist.append(tens)

    return data_dist, labels_dist

def zipping(dir_name):
    '''zip a local folder to local directory'''
    f = ZipFile(dir_name +'.zip', mode='w', compression=ZIP_DEFLATED)
    for fil in os.listdir(dir_name):
        if fil[0] != ".":
            f.write(dir_name +'/' + fil)
    f.close()

## get data

In [None]:
# downloading data
if 'train' not in globals(): # to avoid loading data every time
    train, test = load_mnist()
    if torch.cuda.is_available():
        test_gpu = torch.tensor(test[0]).cuda(), torch.tensor(test[1]).cuda()

# MODEL

In [12]:
#model structure

def get_base_classifier(gpu=True):
    ''' returns linear baseline classifier '''
    model = nn.Sequential( 
        nn.Flatten(),
        nn.Linear(1024, 10),
        )
    if gpu:
        return model.cuda()
    return model

class classifier(nn.Module):
    '''CNN Model'''
    def __init__(self):
        super(classifier, self).__init__()
        
        # Convolution 1
        self.cnn1 = nn.Conv2d(in_channels=1, out_channels=16,
                              kernel_size=3, stride=1, padding=0)
        self.relu1 = nn.ReLU()
        # Max pool 1
        self.maxpool1 = nn.MaxPool2d(kernel_size=2)
        # Convolution 2
        self.cnn2 = nn.Conv2d(in_channels=16, out_channels=32, 
                              kernel_size=3, stride=1, padding=0)
        self.relu2 = nn.ReLU()      
        # Max pool 2
        self.maxpool2 = nn.MaxPool2d(kernel_size=2)
        # Fully connected 1
        self.fc1 = nn.Linear(32 * 6 * 6, 10) 
    
    def forward(self, x):
        # Set 1
        out = self.cnn1(x)
        out = self.relu1(out)
        out = self.maxpool1(out)  
        # Set 2
        out = self.cnn2(out)
        out = self.relu2(out)
        out = self.maxpool2(out)
        #Flatten
        out = out.view(out.size(0), -1)
        #Dense
        out = self.fc1(out)
        return out
def get_conv_classifier(gpu=True):
    if gpu:
        return classifier().cuda()
    return classifier()

MODELS = {"base": get_base_classifier, "conv": get_conv_classifier}

# TRAINING STRUCTURE

## Losses

In [13]:
#loss and scoring functions 

def local_loss(model_loc, x, y):  
    ''' classification loss '''
    loss = nn.CrossEntropyLoss()
    predicted = model_loc(x)
    local = loss(predicted,y)
    return local

def models_dist(model_loc, model_glob, pow=(1,1)):  
    ''' l1 distance between global and local parameter
        will be mutliplied by w_n 
        pow : (internal power, external power)
    '''
    q, p = pow
    dist = sum(((theta - rho)**q).abs().sum() for theta, rho in 
                  zip(model_loc.parameters(), model_glob.parameters()))**p
    return dist

def model_norm(model_glob, pow=(2,1)): 
    ''' l2 squared regularisation of global parameter
     will be multiplied by w_0 
     pow : (internal power, external power)
     '''
    q, p = pow
    norm = sum((param**q).abs().sum() for param in model_glob.parameters())**p
    return norm

def round_loss(tens, dec=0): 
    '''from an input scalar tensor returns rounded integer'''
    if type(tens)==int or type(tens)==float:
        return round(tens, dec)
    else:
        return round(tens.item(), dec)

def tens_count(tens, val):
    ''' counts nb of -val in tensor -tens '''
    return len(tens) - round_loss(torch.count_nonzero(tens-val))

def score(model, datafull):
    ''' returns accuracy provided models, images and GTs '''
    out = model(datafull[0])
    predictions = torch.max(out, 1)[1]
    c=0
    for a, b in zip(predictions, datafull[1]):
        c += int(a==b)
    return c/len(datafull[0])

## Flower

### flower class

In [14]:
# nodes repartition

class Flower():
    ''' Training structure including local models and general one 
        Allowing to add and remove nodes at will
        .pop
        .add_nodes
        .rem_nodes
        .train
        .display
        .check
    '''

    def __init__(self, test, gpu=True, **kwargs):
        ''' opt : optimizer
            test : test data couple (imgs,labels)
            w0 : regularisation strength
        '''
        self.d_test = test
        self.w0 = kwargs["w0"]
        self.gpu = gpu

        self.opt = kwargs["opt"]
        self.lr_node = kwargs["lr_node"]
        self.lr_gen = kwargs["lr_gen"]
        self.gen_freq = kwargs["gen_freq"]  # generalisation frequency (>=1)

        self.get_classifier = MODELS[kwargs["NN"]]
        self.general_model = self.get_classifier(gpu)
        self.init_model = deepcopy(self.general_model)
        self.last_grad = None
        self.opt_gen = self.opt(self.general_model.parameters(), lr=self.lr_gen)
        self.pow_gen = (1,1)  # choice of norms for Licchavi loss 
        self.pow_reg = (2,1)  # (internal power, external power)
        self.data = []
        self.labels = [] 
        self.typ = []
        self.models = []
        self.weights = []
        self.age = []
        self.opt_nodes = []
        self.nb_nodes = 0
        self.dic = {"honest" : -1, "trolls" : 0, "zeros" : 1, 
                    "one_evil" : 2, "strats" : 3, "jokers" : 4, "byzantine" : -1}
        self.history = ([], [], [], [], [], [], [], []) 
        # self.h_legend = ("fit", "gen", "reg", "acc", "l2_dist", "l2_norm", "grad_sp", "grad_norm")
        self.localtest = ([], []) # (which to pick for each node, list of (data,labels) pairs)
        self.size = nb_params(self.general_model) / 10_000

    # ------------ population methods --------------------
    def set_localtest(self, datafull, size, nodes, fav_lab=(0,0), typ="honest"):
        ''' create a local data for some nodes
            datafull : source data
            size : size of test sample
            fav_labs : (label, strength)
            nodes : list of nodes which use this data           
        '''
        id = self.dic[typ]
        dish = (id != -1) # boolean for dishonesty
        dt, lb = distribute_data_rd(datafull, [size], fav_lab,
                                    dish, dish_lab=id, gpu=self.gpu)
        dtloc = (dt[0], lb[0])
        self.localtest[1].append(dtloc)
        id = len(self.localtest[1]) - 1
        for n in nodes:
            self.localtest[0][n] = id

    def add_nodes(self, datafull, pop, typ, fav_lab=(0,0), verb=1, **kwargs):
        ''' add nodes to the Flower 
            datafull : data to put on node (sampled from it)
            pop : (nb of nodes, size of nodes)
            typ : type of nodes (str keywords)
            fav_lab : (favorite label, strength)
            w : int, weight of new nodes
        '''
        w = kwargs["w"] # taking global variable if -w not provided
        nb, size = pop
        id = self.dic[typ]
        dish = (id != -1) # boolean for dishonesty
        dt, lb = distribute_data_rd(datafull, [size] * nb, fav_lab,
                                    dish, dish_lab=id, gpu=self.gpu)
        self.data += dt
        self.labels += lb
        self.typ += [typ] * nb

        self.models += [self.get_classifier(self.gpu) for i in range(nb)]
        self.weights += [w] * nb
        self.age += [0] * nb
        for i in range(nb):
            self.localtest[0].append(-1)
        self.nb_nodes += nb
        self.opt_nodes += [self.opt(self.models[n].parameters(), lr=self.lr_node) 
            for n in range(self.nb_nodes - nb, self.nb_nodes) 
            ]
        if verb:
            print("Added {} {} nodes of {} data points".format(nb, typ, size))
            print("Total number of nodes : {}".format(self.nb_nodes))

    def rem_nodes(self, first, last, verb=1):
        ''' remove nodes of indexes -first (included) to -last (excluded) '''
        nb = last - first
        if last > self.nb_nodes:
            print("-last is out of range, remove canceled")
        else:
            del self.data[first : last]
            del self.labels[first : last] 
            del self.typ[first : last]
            del self.models[first : last]
            del self.weights[first : last]
            del self.age[first : last]
            del self.opt_nodes[first : last]
            del self.localtest[0][first : last]
            self.nb_nodes -= nb
            if verb: print("Removed {} nodes".format(nb))
        
    def hm(self, ty):
        ''' count nb of nodes of this type '''
        return self.typ.count(ty)
    
    def pop(self):
        ''' return dictionnary of population '''
        c = {}
        for ty in self.dic.keys():
            c[ty] = self.hm(ty)
        return c

    # ------------- scoring methods -----------
    def score_glob(self, datafull): 
        ''' return accuracy provided images and GTs '''
        return score(self.general_model, datafull)
    
    def test_loc(self, node):
        ''' score of node on local test data '''
        id_data = self.localtest[0][node]
        if id_data == -1:
            # print("No local test data")
            return None
        else:
            nodetest = score(self.models[node], self.localtest[1][id_data])
            return nodetest

    def test_full(self, node):
        ''' score of node on global test data '''
        return score(self.models[node], self.d_test)

    def test_train(self, node):
        ''' score of node on its train data '''
        return score(self.models[node], (self.data[node], self.labels[node]))

    def display(self, node):
        ''' display accuracy for selected node
            node = -1 for global model
        '''
        if node == -1: # global model
            print("global model")
            print("accuracy on test data :", 
                  self.score_glob(self.d_test))
        else: # we asked for a node
            loc_train = self.test_train(node)
            loc_test = self.test_loc(node)
            full_test = self.test_full(node)
            print("node number :", node, ", dataset size :",
                len(self.labels[node]), ", type :", self.typ[node], 
                ", age :", self.age[node])
            print("accuracy on local train data :", loc_train)
            print("accuracy on local test data :", loc_test)
            print("accuracy on global test data :", full_test)
            repart = {str(k) : tens_count(self.labels[node], k) 
                for k in range(10)}
            print("labels repartition :", repart)
    
    # ---------- methods for training ------------

    def _set_lr(self):
        '''set learning rates of optimizers according to Flower setting'''
        for n in range(self.nb_nodes):  # updating lr in optimizers
            self.opt_nodes[n].param_groups[0]['lr'] = self.lr_node
        self.opt_gen.param_groups[0]['lr'] = self.lr_gen

    def _zero_opt(self):
        '''reset gradients of all models'''
        for n in range(self.nb_nodes):
            self.opt_nodes[n].zero_grad()      
        self.opt_gen.zero_grad()

    def _update_hist(self, epoch, test_freq, fit, gen, reg, verb=1):
        ''' update history '''
        if epoch  % test_freq == 0:   # printing accuracy on test data
            acc = self.score_glob(self.d_test)
            if verb: print("TEST ACCURACY : ", acc)
            for i in range(test_freq):
                self.history[3].append(acc) 
        self.history[0].append(round_loss(fit))
        self.history[1].append(round_loss(gen))
        self.history[2].append(round_loss(reg))

        dist = models_dist(self.init_model, self.general_model, pow=(2,0.5)) 
        norm = model_norm(self.general_model, pow=(2,0.5))
        self.history[4].append(round_loss(dist, 1))
        self.history[5].append(round_loss(norm, 1))
        grad_gen = extract_grad(self.general_model)
        if epoch > 1: # no last model for first epoch
            scal_grad = sp(self.last_grad, grad_gen)
            self.history[6].append(scal_grad)
        else:
            self.history[6].append(0) # default value for first epoch
        self.last_grad = deepcopy(extract_grad(self.general_model)) 
        grad_norm = sp(grad_gen, grad_gen)  # use sqrt ?
        self.history[7].append(grad_norm)

    def _old(self, years):
        ''' increment age (after training) '''
        for i in range(self.nb_nodes):
            self.age[i] += years

    def _counters(self, c_gen, c_fit):
        '''update internal training counters'''
        fit_step = (c_fit >= c_gen) 
        if fit_step:
            c_gen += self.gen_freq
        else:
            c_fit += 1 
        return fit_step, c_gen, c_fit

    def _do_step(self, fit_step):
        '''step for appropriate optimizer(s)'''
        if fit_step:       # updating local or global alternatively
            for n in range(self.nb_nodes): 
                self.opt_nodes[n].step()      
        else:
            self.opt_gen.step()  

    def _print_losses(self, tot, fit, gen, reg):
        '''print losses'''
        print("total loss : ", tot) 
        print("fitting : ", round_loss(fit),
                ', generalisation : ', round_loss(gen),
                ', regularisation : ', round_loss(reg))

    # ====================  TRAINING ================== 

    def train(self, nb_epochs=None, test_freq=1, verb=1):   
        '''training loop'''
        nb_epochs = EPOCHS if nb_epochs is None else nb_epochs
        time_train = time()
        self._set_lr()

        # initialisation to avoid undefined variables at epoch 1
        loss, fit_loss, gen_loss, reg_loss = 0, 0, 0, 0
        c_fit, c_gen = 0, 0

        fit_scale = 20 / self.nb_nodes
        gen_scale = 1 / self.nb_nodes / self.size
        reg_scale = self.w0 / self.size

        reg_loss = reg_scale * model_norm(self.general_model, self.pow_reg)  

        # training loop 
        nb_steps = self.gen_freq + 1
        for epoch in range(1, nb_epochs + 1):
            if verb: print("\nepoch {}/{}".format(epoch, nb_epochs))
            time_ep = time()

            for step in range(1, nb_steps + 1):
                fit_step, c_gen, c_fit = self._counters(c_gen, c_fit)
                if verb >= 2: 
                    txt = "(fit)" if fit_step else "(gen)" 
                    print("step :", step, '/', nb_steps, txt)
                self._zero_opt() # resetting gradients


                #----------------    Licchavi loss  -------------------------
                 # only first 2 terms of loss updated
                if fit_step:
                    fit_loss, gen_loss, diff = 0, 0, 0
                    for n in range(self.nb_nodes):   # for each node
                        if self.typ[n] == "byzantine":
                            fit = local_loss(self.models[n], 
                                             self.data[n], self.labels[n])
                            fit_loss -= fit
                            diff += 2 * fit # dirty trick CHANGE
                        else:
                            fit_loss += local_loss(self.models[n], 
                                                self.data[n], self.labels[n])
                        g = models_dist(self.models[n], 
                                        self.general_model, self.pow_gen)
                        gen_loss +=  self.weights[n] * g  # generalisation term
                    fit_loss *= fit_scale
                    gen_loss *= gen_scale
                    loss = fit_loss + gen_loss 
                          
                # only last 2 terms of loss updated 
                else:        
                    gen_loss, reg_loss = 0, 0
                    for n in range(self.nb_nodes):   # for each node
                        g = models_dist(self.models[n], 
                                        self.general_model, self.pow_gen)
                        gen_loss += self.weights[n] * g  # generalisation term    
                    reg_loss = model_norm(self.general_model, self.pow_reg) 
                    gen_loss *= gen_scale
                    reg_loss *= reg_scale
                    loss = gen_loss + reg_loss

                total_out = round_loss(fit_loss + diff
                    + gen_loss + reg_loss)
                if verb >= 2:
                    self._print_losses(total_out, fit_loss + diff,
                                       gen_loss, reg_loss)
                # Gradient descent 
                loss.backward() 
                self._do_step(fit_step)   
 
            if verb: print("epoch time :", round(time() - time_ep, 2)) 
            self._update_hist(epoch, test_freq, fit_loss, gen_loss, reg_loss, verb)
            self._old(1)  # aging all nodes
             
        # ----------------- end of training -------------------------------  
        for i in range(nb_epochs % test_freq): # to maintain same history length
            self.history[3].append(acc)
        print("training time :", round(time() - time_train, 2)) 
        return self.history


    # ------------ to check for problems --------------------------
    def check(self):
        ''' perform some tests on internal parameters adequation '''
        # population check
        b1 =  (self.nb_nodes == len(self.data) == len(self.labels) 
            == len(self.typ) == len(self.models) == len(self.opt_nodes) 
            == len(self.weights) == len(self.age) == len(self.localtest[0]))
        # history check
        b2 = True
        for l in self.history:
            b2 = b2 and (len(l) == len(self.history[0]) >= max(self.age))
        # local test data check
        b3 = (max(self.localtest[0]) + 1 <= len(self.localtest[1]) )
        if (b1 and b2 and b3):
            print("No Problem")
        else:
            print("OULALA non ça va pas là")

### flower utility

In [15]:
def get_flower(gpu=True, **kwargs):
    '''get a Flower using the appropriate test data (gpu or not)'''
    if gpu:
        return Flower(test_gpu, gpu=gpu, **kwargs)
    else:
        return Flower(test, gpu=gpu, **kwargs)

# def grad_sp(m1, m2):
#     ''' scalar product of gradients of 2 models '''
#     s = 0
#     for p1, p2 in zip(m1.parameters(), m2.parameters()):
#         s += (p1.grad * p2.grad).sum()
#     return s

def extract_grad(model):
    '''return list of gradients of a model'''
    l_grad =  [p.grad for p in model.parameters()]
    return l_grad

def sp(l_grad1, l_grad2):
    '''scalar product of 2 lists of gradients'''
    s = 0
    for g1, g2 in zip(l_grad1, l_grad2):
        s += (g1 * g2).sum()
    return round_loss(s, 4)

def nb_params(model):
    '''return number of parameters of a model'''
    return sum(p.numel() for p in model.parameters())

# GETTING PLOTS

## Plotting utilities

In [16]:
def seedall(s):
    '''seed all sources of randomness'''
    reproducible = (s >= 0)
    torch.manual_seed(s)
    random.seed(s)
    np.random.seed(s)
    torch.backends.cudnn.deterministic = reproducible
    torch.backends.cudnn.benchmark     = not reproducible
    print("\nSeeded all to", s)

def replace_dir(path):
    ''' create or replace directory '''
    if os.path.exists(path):
        shutil.rmtree(path)
    os.makedirs(path)

def get_style():
    '''give different line styles for plots'''
    l = ["-","-.",":","--"]
    for i in range(10000):
        yield l[i % 4]

def get_color():
    '''give different line styles for plots'''
    l = ["red","green","blue","grey"]
    for i in range(10000):
        yield l[i % 4]

STYLES = get_style() # generator for looping styles
COLORS = get_color()

def title_save(title=None, path=None, suff=".png"):
    ''' add title and save plot '''
    if title is not None:   
        plt.title(title)
    if path is not None:
        plt.savefig(path + suff)

def legendize(y):
    ''' label axis of plt plot '''
    plt.xlabel("Epochs")
    plt.ylabel(y)
    plt.legend()

def clean_dic(dic):
    ''' replace some values by more readable ones '''
    if "opt" in dic.keys():
        dic = deepcopy(dic)
        op = dic["opt"]
        dic["opt"] = "Adam" if op == optim.Adam else "SGD" if op == optim.SGD else None
    return dic

def get_title(conf, ppl=4):
    ''' converts a dictionnary in str of approriate shape 
        ppl : parameters per line
    '''
    title = ""
    c = 0 # enumerate ?
    for key, val in clean_dic(conf).items(): 
        c += 1
        title += "{}: {}".format(key,val)
        title += " \n" if (c % ppl) == 0 else ', '
    return title[:-2]

## Plotting from history

In [17]:
# functions to display training history 

def means_bounds(arr):
    ''' from array return 1 array of means, 
        1 of (mean - var), 1 of (mean + var)
    '''
    means = np.mean(arr, axis=0)
    var = np.var(arr, axis = 0) 
    low, up = means - var, means + var
    return means, low, up


# ----------- to display multiple accuracy curves on same plot -----------
def add_acc_var(arr, label):
    ''' from array add curve of accuracy '''
    acc = arr[:,3,:]
    means, low, up = means_bounds(acc)
    epochs = range(1, len(means) + 1)
    plt.plot(epochs, means, label=label, linestyle=next(STYLES))
    plt.fill_between(epochs, up, low, alpha=0.4)

def plot_runs_acc(l_runs, title=None, path=None, **kwargs):
    ''' plot several acc_var on one graph '''
    arr = np.asarray(l_runs)
    l_param = get_possibilities(**kwargs) # for legend
    for run, param in zip(arr, l_param): # adding one curve for each parameter combination (run)
        add_acc_var(run, param)
    plt.ylim([0,1])
    plt.grid(True, which='major', linewidth=1, axis='y', alpha=1)
    plt.minorticks_on()
    plt.grid(True, which='minor', linewidth=0.8, axis='y', alpha=0.8)
    legendize("Test Accuracy")
    title_save(title, path, suff=".png")
    plt.show()



# ------------- utility for what follows -------------------------
def plot_var(l_hist, l_idx):
    ''' add curve of asked indexes of history to the plot '''
    arr_hist = np.asarray(l_hist)
    epochs = range(1, arr_hist.shape[2] + 1)
    for idx in l_idx:
        vals = arr_hist[:,idx,:]
        vals_m, vals_l, vals_u = means_bounds(vals)
        style, color = next(STYLES), next(COLORS)
        plt.plot(epochs, vals_m, label=METRICS[idx]["lab"], linestyle=style, color=color)
        plt.fill_between(epochs, vals_u, vals_l, alpha=INTENS, color=color)

def plotfull_var(l_hist, l_idx, title=None, path=None, show=True):
    ''' plot metrics asked in -l_idx and save if -path provided '''
    plot_var(l_hist, l_idx)
    idx = l_idx[0]
    legendize(METRICS[idx]["ord"])
    title_save(title, path, suff=" {}.png".format(METRICS[idx]["f_name"]))
    if show: 
        plt.show()

# ------- groups of metrics on a same plot -----------
def loss_var(l_hist, title=None, path=None):
    ''' plot losses with variance from a list of historys '''
    plotfull_var(l_hist, [0,1,2], title, path)

def acc_var(l_hist, title=None, path=None):
    ''' plot accuracy with variance from a list of historys '''
    plt.ylim([0,1])
    plt.grid(True, which='major', linewidth=1, axis='y', alpha=1)
    plt.minorticks_on()
    plt.grid(True, which='minor', linewidth=0.8, axis='y', alpha=0.8)
    plotfull_var(l_hist, [3], title, path)

def l2_var(l_hist, title=None, path=None):
    '''plot l2 norm of gen model from a list of historys'''
    plotfull_var(l_hist, [4,5], title, path)

def gradsp_var(l_hist, title=None, path=None):
    ''' plot scalar product of gradients between 2 consecutive epochs
        from a list of historys
    '''
    plotfull_var(l_hist, [6,7], title, path)

# plotting all we have
def plot_metrics(l_hist, title=None, path=None):
    '''plot and save the different metrics from list of historys'''
    acc_var(l_hist, title, path)  
    loss_var(l_hist, title, path)
    l2_var(l_hist, title, path)
    gradsp_var(l_hist, title, path)

## Running, plotting, saving

### utilities

In [20]:
def adapt(obj):
    ''' -obj is a parameter or an iterable over values of a parameter
        return generator of values of the parameter (event if only 1)
    '''
    if hasattr(obj, '__iter__') and type(obj) != str:
        for v in obj:
            yield v
    else:
        yield obj

def is_end(it, dist=0):
    ''' check if iterator is empty '''
    it2 = deepcopy(it)
    try:
        for a in range(dist + 1):
            a = next(it2)
        return False
    except StopIteration:
        return True

def explore(dic):
    ''' dic is a dictionnary of parameters (some may have multiple values)
        return a list of dictionnarys of all possible combinations 
    '''
    it = iter(dic)
    _LIST = [] 
    def _explo(it, dic, **kwargs): # **kwargs is the output
        '''yield a dictionnary with only one value for each param'''
        if not is_end(it): # if iterator not empty
            key = next(it)
            for par in adapt(dic[key]):
                _explo(deepcopy(it), dic, **kwargs, **{key: par}) 
        else:            # end of recursion
            _LIST.append(kwargs)
    _explo(it, dic)
    return _LIST

def add_defaults(config):
    ''' add default values for non-specified parameters '''
    fullconf = deepcopy(DEFAULTS)        
    for key, val in config.items():
        fullconf[key] = val
    return fullconf

def my_confs(**kwargs):
    ''' return all possible configurations '''
    for config in explore(kwargs):
        fullconf = add_defaults(config)
        yield fullconf

# FUSE THE 2 FUNCTIONS ?
def get_possibilities(**kwargs):
    ''' identify variations of parameters '''
    l_confs = explore(kwargs)
    leg_keys = []  # parameters used for legend
    for key, val in kwargs.items():
        if len(list(adapt(val))) >  1: # if this param is not constant
            leg_keys.append(key)
    legends = []
    for conf in l_confs:
        leg = get_title({k:conf[k] for k in leg_keys})
        legends.append(leg)
    return legends

def get_constants(**kwargs):
    ''' identify constant parameters '''
    l_confs = my_confs(**kwargs)
    leg_keys = []  # parameters used for legend
    for key, val in kwargs.items():
        if len(list(adapt(val))) >  1: # if this param is not constant
            leg_keys.append(key)
    constants = []
    for conf in l_confs:
        cst = get_title({k:conf[k] for k in DEFAULTS.keys() if (k not in leg_keys)})
        constants.append(cst)
    return constants   # NOT CLEAN BEACAUSE CONSTANT LIST

def legend_to_name(legend):
    ''' convert legend text format to filename format '''
    name = legend.replace(': ','_')       # deepcopy ?
    name = name.replace('\n', ' ')
    name = name.replace(',', '')
    return name

### core

In [35]:
def get_custom_flower(verb=1, gpu=True, **kwargs):
    nbn = kwargs["nbn"]
    ppn = kwargs["nbd"] // nbn # points per node
    nbdish = int(kwargs["fracdish"] * nbn)
    nbh = nbn - nbdish
    typ_dish = kwargs["typ_dish"]
    heter = kwargs["heter"]    
    flow = get_flower(gpu=gpu, **kwargs)
    if heter:
        nbh_lab = nbh // 10 # for each label
        nbdish_lab = nbdish // 10
        nb_lab = nbh_lab + nbdish_lab
        for lab in range(10): # for each label
            flow.add_nodes(train, (nbh_lab, ppn), "honest", (lab, heter), verb=verb, **kwargs)
            flow.add_nodes(train, (nbdish_lab, ppn), typ_dish, (lab, heter), verb=verb, **kwargs)
            # if gpu:
            #     flow.set_localtest(test_gpu, 100, range(lab * nb_lab, (lab + 1) * nb_lab), (lab, heter))
            # else:
            #     flow.set_localtest(test, 100, range(lab * nb_lab, (lab + 1) * nb_lab), (lab, heter))
    else:
        # print(kwargs)
        flow.add_nodes(train, (nbh, ppn), "honest", verb=verb, **kwargs)
        flow.add_nodes(train, (nbdish, ppn), typ_dish, verb=verb, **kwargs)
    return flow

def run_whatever(config, path, verb=0, gpu=True):
    '''config is a dictionnary with all parameters'''
    nb_epochs = config["nb_epochs"]
    l_hist = [] # list of historys
    for s in SEEDS:
        seedall(s)
        flow = get_custom_flower(verb=verb, gpu=gpu, **config) 
        h = flow.train(nb_epochs, verb=verb)
        l_hist.append(h)
    title = get_title(config)
    plot_metrics(l_hist, title, path)
    return l_hist

def run_whatever_mult(name="name", verb=0, gpu=True, **kwargs):
    ''' User-friendly running-and-plotting-and-saving interface
        Each parameter of DEFAULTS can be 
        inputted as single value, as an iterable of values or not inputted
        All parameters combinations are computed in a grid fashion 

        name : used for folder name and filenames
        verb : 0, 1 or 2, verbosity level
        gpu : boolean
        **kwargs : structure and training parameters, 
                    see "defaults_help?" for full parameters list
        Return : all training historys
    '''
    l_runs = [] # list of historys for each parameter
    replace_dir(name)
    path = name + "/" + name + " "
    l_legend = get_possibilities(**kwargs)
    l_confs = my_confs(**kwargs)
    for legend, config in zip(l_legend, l_confs): # iterating over all combinations
        curr_path = path + legend_to_name(legend)
        l_hist = run_whatever(config, curr_path, verb, gpu)
        l_runs.append(l_hist)
    title = get_constants(**kwargs)[0]
    plot_runs_acc(l_runs, title, path, **kwargs)
    zipping(name)
    return l_runs

## Some more

In [18]:
# functions to train and display history at the end




# - heterogeneity of data with different styles of notation depending on nodes -

# def get_flower_heter_strats(heter, verb=0, gpu=True):
#     '''initialize and add nodes according to parameter'''
#     global FORCING1
#     global FORCING2
#     global FORCE    
#     nbn = NBN
#     flow = get_flower(gpu)
#     ppn = 60_000 // nbn # points per node
#     nb_lab = nbn // 10
#     FORCE = True
#     for lab in range(10):
#         for n in range(nb_lab):
#             FORCING1, FORCING2 = -1, -1
#             flow.add_nodes(train, (1, ppn), "strats", (lab, heter), verb=verb)
#             flow.set_localtest(test_gpu, 100, [lab * nb_lab + n], (lab, heter), typ="strats")
#     FORCE = False
#     return flow

# def run_heter_strats(heter, verb=0, gpu=True):
#     ''' create a flower of honest nodes and trains it for 200 eps
#         display graphs of loss and accuracy 
#         heter : heterogeneity of data
#     '''
#     flow = get_flower_heter_strats(heter, verb, gpu)
#     flow.gen_freq = 1
#     h = flow.train(epochs, verb=verb)
#     flow.check()
#     t1 = "heter : {}, nbn : {}, lrnode : {}, lrgen : {}, genfrq : {}" 
#     t2 = "\ntype : only strats" 
#     text = t1 + t2
#     title = text.format(heter, flow.nb_nodes, flow.lr_node, 
#                         flow.lr_gen, flow.gen_freq)
#     plot_metrics([h], title, path)    
#     return flow

# def compare(flow_centr, flow_distr): # for run_heter
#     ''' return average accuracy on local test sets 
#         for both centralized and distributed models
#     '''
#     central, gen, distr = 0, 0, 0
#     N = flow_distr.nb_nodes
#     for lab in range(10):
#         sc = score(flow_centr.models[0], flow_distr.localtest[1][lab])
#         central += sc
#     for lab in range(10):
#         sc = score(flow_distr.general_model, flow_distr.localtest[1][lab])
#         gen += sc
#     for n in range(N):
#         sc = flow_distr.test_loc(n)
#         distr += sc
#     distr = distr / N
#     central = central / 10
#     gen = gen / 10
#     return central, gen, distr

#  def compare2(flow_centr, flow_distr): # for run_heter_strats
#     ''' return average accuracy on local test sets 
#         for both centralized and distributed models
#     '''
#     central, gen, distr = 0, 0, 0
#     N = flow_distr.nb_nodes
#     for n in range(N):
#         sc = score(flow_centr.models[0], flow_distr.localtest[1][n])
#         central += sc
#     for n in range(N):
#         sc = score(flow_distr.general_model, flow_distr.localtest[1][n])
#         gen += sc
#     for n in range(N):
#         sc = flow_distr.test_loc(n)
#         distr += sc
#     distr = distr / N
#     central = central / N
#     gen = gen / N
#     return central, gen, distr 

# THAT'S WHERE YOU RUN STUFF

## Help

In [46]:
help(run_whatever_mult)

Help on function run_whatever_mult in module __main__:

run_whatever_mult(name='name', verb=0, gpu=True, **kwargs)
    User-friendly running-and-plotting-and-saving interface
    Each parameter of DEFAULTS can be 
    inputted as single value, as an iterable of values or not inputted
    All parameters combinations are computed in a grid fashion 
    
    name : used for folder name and filenames
    verb : 0, 1 or 2, verbosity level
    gpu : boolean
    **kwargs : structure and training parameters, 
                see "defaults_help?" for full parameters list
    Return : all training historys



In [49]:
help(defaults_help)

Help on function defaults_help in module __main__:

defaults_help()
    Structure of DEFAULTS dictionnary :
    
    "w0": 0.2,  # float >= 0, regularisation parameter
    "w": 0.2,   # float >= 0, harmonisation parameter
    "lr_gen": 0.02,     # float > 0, learning rate of global model
    "lr_node": 0.02,    # float > 0, learning rate of local models
    "NN" : "base",     # "base" or "conv", neural network architecture
    "opt": optim.Adam,    # any torch otpimizer
    "gen_freq": 1,     # int >= 1, number of global steps for 
                                                    1 local step
    
    "nbn": 1000,    # int >= 1, number of nodes
    "nbd": 60_000,  # int >= 1,  total data
                                - nbd/nbn must be in [1, 60_000]
    "fracdish": 0,   # float in [0,1]
    "typ_dish": "zeros",# in ["honest", "zeros", "jokers", "one_evil", 
                            "byzantine", "randoms", "trolls", "strats"]
    "heter": 0,        # int >= 0, heterogeneity of d

## Run

In [None]:
SEEDS = [1,2,3,4,5]
historys = run_whatever_mult(nb_epochs=10, verb=0)

# MANUAL

In [None]:
seedall(51)
tulip = get_flower(**DEFAULTS)
tulip.add_nodes(train, (1, 60000), "honest")
tulip.check()
# tulip.lr_node = 0.2
# tulip.lr_gen = 0.05
# tulip.w0 = 0
h1 = tulip.train(2, verb=2)
tulip.check()

In [None]:
tulip.display(0)

In [None]:
plot_metrics([h1])