In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from spn.structure.Base import Sum, Product, Context, get_nodes_by_type
from spn.structure.leaves.parametric.Parametric import Gaussian, Categorical
from spn.algorithms.LearningWrappers import learn_structure, learn_mspn
from sklearn.datasets import make_classification
from sklearn.metrics import f1_score
from spn.algorithms.MPE import mpe
from spn.algorithms.EM import EM_optimization
from spn.structure.leaves.piecewise.PiecewiseLinear import create_histogram_leaf
from spn.algorithms.splitting.RDC import get_split_cols_RDC_py
from spn.algorithms.splitting.Clustering import get_split_rows_KMeans
from sklearn.model_selection import train_test_split
from utils import random_region_graph, region_graph_to_spn, reassign_node_ids
from datasets.utils import get_vertical_train_data, get_test_data
from scipy.special import logsumexp

from spn.algorithms.Gradient import gradient_backward
from spn.algorithms.Inference import log_likelihood
from spn.algorithms.Validity import is_valid

from spn.structure.Base import Sum, get_nodes_by_type, get_number_of_nodes
from itertools import product
import numpy as np
import warnings
from sklearn.cluster import KMeans

from datetime import datetime as dt
import timeit
warnings.filterwarnings('ignore')

In [2]:
def map_scopes(spn, inds):
    nodes = get_nodes_by_type(spn)
    scope_mapping = {i: s for i, s in enumerate(inds)}
    for n in nodes:
        sc = list(n.scope)
        new_sc = [scope_mapping[i] for i in sc]
        n.scope = new_sc
    return spn

def reassign_ids(spn):
    nodes = get_nodes_by_type(spn)
    for i, n in enumerate(nodes):
        n.id = i
    return spn

def make_dataset(num_samples, num_features, n_informative, n_redundant, n_classes, n_clusters, n_repeated):
    x, y = make_classification(num_samples, num_features, n_informative=n_informative, 
                               n_redundant=n_redundant, n_classes=n_classes, n_clusters_per_class=n_clusters,
                               n_repeated=n_repeated, class_sep=2)
    train_x, test_x, train_y, test_y = train_test_split(x, y, test_size=0.3)
    train_data = np.hstack([train_x, train_y.reshape(-1, 1)])
    test_data = np.hstack([test_x, test_y.reshape(-1, 1)])
    return train_data, test_data

split_cols = get_split_cols_RDC_py(0.3, False)
split_rows = get_split_rows_KMeans(2, standardize=False)

def softmax(vec, temperature):
    """
    turn vec into normalized probability
    """
    sum_exp = sum(np.exp(x/temperature) for x in vec)
    return np.array([np.exp(x/temperature)/sum_exp for x in vec])

def cond_sum_em_update(allowed_nodes):
    def sum_em_update(node, node_gradients=None, root_lls=None, all_lls=None, **kwargs):
        if node.id in allowed_nodes:
            RinvGrad = node_gradients - root_lls

            for i, c in enumerate(node.children):
                new_w = RinvGrad + (all_lls[:, c.id] + np.log(node.weights[i]))
                node.weights[i] = logsumexp(new_w)

            assert not np.any(np.isnan(node.weights))

            node.weights = np.exp(node.weights - logsumexp(node.weights)) + np.exp(-100)

            node.weights = node.weights / node.weights.sum()
            #node.weights = softmax(node.weights, 0.1)
            #idx = np.argsort(node.weights)[:-3]
            #node.weights[idx] = 0
            node.weights = node.weights / node.weights.sum()


            if node.weights.sum() > 1:
                node.weights[np.argmax(node.weights)] -= node.weights.sum() - 1

            assert not np.any(np.isnan(node.weights))
            assert np.isclose(np.sum(node.weights), 1)
            assert not np.any(node.weights < 0)
            assert node.weights.sum() <= 1, "sum: {}, node weights: {}".format(node.weights.sum(), node.weights)
    return sum_em_update

_node_updates = {Sum: cond_sum_em_update([0])}

def add_node_em_update(node_type, lambda_func):
    _node_updates[node_type] = lambda_func


def EM_optimization_network(spn, data, iterations=5, node_updates=_node_updates, skip_validation=False, **kwargs):
    if not skip_validation:
        valid, err = is_valid(spn)
        assert valid, "invalid spn: " + err

    lls_per_node = np.zeros((data.shape[0], get_number_of_nodes(spn)))

    for _ in range(iterations):
        # one pass bottom up evaluating the likelihoods
        log_likelihood(spn, data, dtype=data.dtype, lls_matrix=lls_per_node)

        gradients = gradient_backward(spn, lls_per_node)

        R = lls_per_node[:, 0]

        for node_type, func in node_updates.items():
            for node in get_nodes_by_type(spn, node_type):
                func(
                    node,
                    node_lls=lls_per_node[:, node.id],
                    node_gradients=gradients[:, node.id],
                    root_lls=R,
                    all_lls=lls_per_node,
                    all_gradients=gradients,
                    data=data,
                    **kwargs
                )

def build_fedspn_head(client_cluster_spns):
    num_clients = len(client_cluster_spns)
    # assume num clusters is equal on all clients 
    num_clusters = len(client_cluster_spns[0])
    clusters = list(range(num_clusters))
    prods = {}
    for l in range(1, num_clients):
        for comb in product(*[clusters]*num_clients):
            prefix = list(comb)[:l]
            next_node = list(comb)[l]
            prod_id = tuple(prefix + [next_node])
            if l > 1:
                # connect product node of last layer with next_node's SPN of l-the client
                relevant_spns = [prods[tuple(prefix)], client_cluster_spns[l][next_node]]
            else:
                # first product layer -> connect all client SPNs of a certain prod_id
                relevant_spns = [client_cluster_spns[i][j] for i,j in enumerate(prod_id)]
            scopes = [set(s.scope) for s in relevant_spns]
            prod_scope = list(set().union(*scopes))
            prod = Product(relevant_spns)
            prod.scope = prod_scope
            prods[prod_id] = prod

    all_scopes = set()
    for cluster_spns in client_cluster_spns:
        for s in cluster_spns:
            all_scopes = all_scopes.union(set(s.scope))
    
    root_children = [n for prefix, n in prods.items() if len(prefix) == num_clients]
    weights = softmax(np.zeros(len(root_children)), 1)
    #weights = softmax(np.random.normal(0, 0.5, len(root_children)))
    root = Sum(weights, root_children)
    root.scope = list(all_scopes)
    root = reassign_node_ids(root)
    return root

In [3]:
num_features = 20
datasets = {}
for num_corrs in [0, 10, 15]:
    inf = num_features - num_corrs
    train, test = make_dataset(2000, num_features, inf, num_corrs, 2, 1, 0)
    datasets[num_corrs] = (train, test)

In [4]:
def infer_node_type(data):
    types = []
    for i  in range(data.shape[1]):
        unique = len(np.unique(data[:, i]))
        if unique < 100:
            params = {'p': np.repeat(1 / unique, unique)}
            types.append((Categorical, params))
        else:
            params = {'mean': 0, 'stdev': 1}
            types.append((Gaussian, params))
    return types

In [5]:
def vertical_fl_e2e(train_data, test_data, s, class_idx, spn_struct='rat'):
    lls = []
    runtimes = []
    for _ in range(5):
        # train one spn on each client
        spns = []
        client_rt = 0
        for cl_idx in s:
            start = timeit.default_timer()
            client_data = train_data[:, cl_idx]
            client_features = client_data.shape[1] - 1
            kmeans = KMeans(2)
            if class_idx in cl_idx:
                context = Context(parametric_types=[Gaussian]*client_features + [Categorical]).add_domains(client_data)
                clusters = kmeans.fit_predict(client_data)
            else:
                context = Context(parametric_types=[Gaussian]*(client_features + 1)).add_domains(client_data)
                clusters = kmeans.fit_predict(client_data[:, :-1])
            cluster_spns = []
            for c in np.unique(clusters):
                idx = np.argwhere(clusters == c).flatten()
                subset = client_data[idx]
                if spn_struct == 'learned':
                    spn_classification = learn_structure(subset, context, split_rows, split_cols, create_histogram_leaf)
                    spn_classification = map_scopes(spn_classification, cl_idx)
                elif spn_struct == 'rat':
                    rg = random_region_graph(0, list(range(client_features)), [])
                    if class_idx in cl_idx:
                        dists = {i: (Gaussian, {'mean': 0, 'stdev': 1}) for i in range(client_features)}
                        dists[client_features] = (Categorical, {'p': [0.5, 0.5]})
                    else:
                        dists = {i: (Gaussian, {'mean': 0, 'stdev': 1}) for i in range(client_features+1)}
                    curr_layer = [n for n in rg.nodes if len(list(rg.pred[n])) == 0]
                    spn_classification = region_graph_to_spn(rg, curr_layer, dists)
                    spn_classification = map_scopes(spn_classification, cl_idx)
                    spn_classification = reassign_node_ids(spn_classification)
                cluster_spns.append(spn_classification)
            spns.append(cluster_spns)
            t = timeit.default_timer() - start
            client_rt = max(client_rt, t)
        
        start = timeit.default_timer()
        spn = build_fedspn_head(spns)
        # optimize server SPN
        # NOTE: It's legal to put client data in here since we can propagate likelihoods
        #   over the network without sending private information
        EM_optimization(spn, train_data)
        #print(spn.weights)
        # evaluate server model
        server_rt = timeit.default_timer() - start
        overall_rt = client_rt + server_rt
        runtimes.append(overall_rt)
        ll = log_likelihood(spn, test_data)
        lls.append(np.mean(ll))
    return lls, runtimes

In [6]:
def vertical_fl_two_step(train_data, test_data, s, class_idx, spn_struct='rat'):
    lls = []
    runtimes = []
    for _ in range(5):
        # train one spn on each client
        spns = []
        client_rt = 0
        for cl_idx in s:
            start = timeit.default_timer()
            client_data = train_data[:, cl_idx]
            client_features = client_data.shape[1] - 1
            kmeans = KMeans(2)
            node_types = infer_node_type(client_data)
            types = [t for t, _ in node_types]
            context = Context(parametric_types=types).add_domains(client_data)
            if class_idx in cl_idx:
                clusters = kmeans.fit_predict(client_data)
            else:
                clusters = kmeans.fit_predict(client_data[:, :-1])
            cluster_spns = []
            for c in np.unique(clusters):
                idx = np.argwhere(clusters == c).flatten()
                subset = client_data[idx]
                if spn_struct == 'learned':
                    spn_classification = learn_structure(subset, context, split_rows, split_cols, create_histogram_leaf)
                elif spn_struct == 'rat':
                    rg = random_region_graph(0, list(range(client_features)), [])
                    dists = {i: t for i, t in enumerate(node_types)}
                    curr_layer = [n for n in rg.nodes if len(list(rg.pred[n])) == 0]
                    spn_classification = region_graph_to_spn(rg, curr_layer, dists)
                    spn_classification = reassign_node_ids(spn_classification)
                EM_optimization(spn_classification, subset)
                spn_classification = map_scopes(spn_classification, cl_idx)
                cluster_spns.append(spn_classification)
            spns.append(cluster_spns)
            t = timeit.default_timer() - start
            client_rt = max(client_rt, t)
        
        start = timeit.default_timer()
        spn = build_fedspn_head(spns)
        # optimize server SPN
        # NOTE: It's legal to put client data in here since we can propagate likelihoods
        #   over the network without sending private information
        EM_optimization_network(spn, train_data)
        #print(spn.weights)
        server_rt = timeit.default_timer() - start
        overall_rt = client_rt + server_rt
        runtimes.append(overall_rt)
        ll = log_likelihood(spn, test_data)
        lls.append(np.mean(ll))
    return lls, runtimes

In [7]:
def synth_experiment(datasets, n_clients, training='e2e'):
    log_likelihoods = []
    for r in [0]:
        print(f"Train with r={r} redundant features")
        train_data, test_data = datasets[r]
        # split data on n clients vertically
        indices = np.arange(train_data.shape[1])
        s = np.array_split(indices, n_clients)
        class_idx = train_data.shape[1] - 1
        if training == 'e2e':
            lls, rt = vertical_fl_e2e(train_data, test_data, s, class_idx)
        else:
            lls, rt = vertical_fl_two_step(train_data, test_data, s, class_idx)
        log_likelihoods.append(lls)
    return np.array(log_likelihoods), rt

In [8]:
def income_experiment(n_clients, training='e2e'):
    log_likelihoods = []
    train_data, _ = get_vertical_train_data('income', n_clients)
    # due to compatibility first stack train_data columns
    train_data = np.column_stack(train_data)
    split_idx = int(0.7*len(train_data))
    train_data, test_data = train_data[:split_idx], train_data[split_idx:]
    # split data on n clients vertically
    indices = np.arange(train_data.shape[1])
    s = np.array_split(indices, n_clients)
    class_idx = train_data.shape[1] - 1
    if training == 'e2e':
        lls, rt = vertical_fl_e2e(train_data, test_data, s, class_idx)
    else:
        lls, rt = vertical_fl_two_step(train_data, test_data, s, class_idx)
    log_likelihoods.append(lls)
    return np.array(log_likelihoods), rt

In [9]:
def credit_experiment(n_clients, training='e2e'):
    train_data, _ = get_vertical_train_data('credit', n_clients)
    test_data = get_test_data('credit')
    # due to compatibility first stack train_data columns
    train_data = np.column_stack(train_data)
    # split data on n clients vertically
    indices = np.arange(train_data.shape[1])
    s = np.array_split(indices, n_clients)
    class_idx = train_data.shape[1] - 1
    if training == 'e2e':
        lls, rt = vertical_fl_e2e(train_data, test_data, s, class_idx)
    else:
        lls, rt = vertical_fl_two_step(train_data, test_data, s, class_idx)
    return lls, rt

In [33]:
income_experiment(2)

(array([[-17.62337467, -17.62337467, -17.62337467, -17.62337467,
         -17.62337467]]),
 [15.583048567990772,
  15.518587407976156,
  15.651644947007298,
  15.533854084991617,
  15.590369198005646])

In [34]:
income_experiment(2, '2step')

(array([[-34.56542603, -34.5749303 , -34.63581021, -34.58977622,
         -34.57579776]]),
 [15.226520216994686,
  15.522485764988232,
  15.174898539989954,
  15.342063400981715,
  15.12149933699402])

In [37]:
synth_experiment(datasets, 2)

Train with r=0 redundant features


(array([[-46.92523213, -46.92523213, -46.92523213, -46.92523213,
         -46.92523213]]),
 [3.4566459110064898,
  3.490660990035394,
  3.5004191609914415,
  3.465085667994572,
  3.543375704990467])

In [38]:
synth_experiment(datasets, 2, '2step')

Train with r=0 redundant features


(array([[-45.6042452, -45.6042452, -45.6042452, -45.6042452, -45.6042452]]),
 [3.5810848900000565,
  3.597158725024201,
  3.631132123002317,
  3.6066061799938325,
  3.614835326996399])

In [10]:
credit_experiment(2)

: 

In [10]:
credit_experiment(2, 'two-step')



: 

# Overview TODOs
- show classification unaffected by hallucincation
  - formally
  - experiments (just synth. data + income + other tab. dataset)
- write this reconstruction impossibility in paper
- e2e training vs. 2-step training
  - synt. data + income + other tab. dataset
- classification Income, Forest Cover Type, Higgs Dataset

# Martin's Feedback
- Change Title!
- Abstract more specific (just Buzzwording til now)
- same for Intro
  - missing: more story, Inspiration TMLR paper
  - ask: what is missing in FL? What does our solution provide?
  - draw connection to PCs better
  - streamline contribution list

- Related work
  - why are the works relevant we list there?
  - can be a bit longer (maybe not necessary for WS paper)

- "A Probabilistic View on FL"
  - composability argument eher? decomposable? less probabilistic

- Differenz hybrid/vertical/horizontal visualisieren
  - Beispiel?
  - Pooja's Intro figure?

- 3.1.
    - Assumptions bisschen weicher introducen
    - connection zu Figure?
    - Notation + Intuition sollte da sein

- "Hallucination"
  - Formalie bei Einführung SPN?
  - Assunption Independence + Abmildern!

- Def. PC etwas abändern, mind. citation aber eher umschreiben
  - Fokus auf independence assumptions
  - einfach mit PCs gehen, gar nicht hardcore definitions geben
  - SPNs später

- Hallucination
  - falsche Annahme, die dazu führt
  - Assumption broken

- hybrid haben wir nicht, wir unifien
  - mehr betonen bei Intro

- Communication Costs auf jeden Fall für ICLR
  - mit oder ohne Ring Reduce

- Experiments (I)
  - non-i.i.d. anders schreiben, wir machen ja non-iid
  - "although MNIST from global joint, client only partially observes dist."
  - align with 3. marginalization

- Experiments (II)
  - binaranize MNIST generated images!!
  - show likelihoods for Q1 (FedSPN vs. SPN)
  - for ICLR: larger dataset
  - WS: Only Q1 + Q4
  - ICLR: Q2 generell raus
  - larger scale tab. data (see TabNet) instead of images

## Dev's feedback
- related work hybrid FL
- medical, credit etc.
- SVHN
- read through reviews and see what they have missed

In [2]:
import torch.nn.functional as F
import torch

a = torch.tensor([1, 1, 1, 1]).reshape(-1, 1)
F.one_hot(a)

tensor([[[0, 1]],

        [[0, 1]],

        [[0, 1]],

        [[0, 1]]])

- give name to indepence
    - see marginals, joint is max-entropy dist. then we can take prod.
    - argue: we want to maximize entropy and minimize uncertainty (over what?)
    - say: many joints lead to same marginal, but we want to have independent marginals

- However, ...
    - instead say: if we make max. entropy ass. we get correct model and might be useful if ass. not met
    - then introduce the following assumption

- go for PCs instead of SPNs
- Einet and compare