In [1]:
import matplotlib.pyplot as plt
from trainers.trainer import standalone, centralized, fedavg, bruteforce
from utils.synthetic.data_generator import Data_Generator

In [36]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interactive_output, FloatSlider, HBox, VBox, Label, Layout, Output
from IPython.display import display, clear_output

class Analyzer:
    def __init__(self, n_client, fedavg_epochs ,centralized_epochs, bruteforce_range , bruteforce_spte):
        self.n_client = n_client
        self.fedavg_epochs =  fedavg_epochs
        self.centralized_epochs = centralized_epochs
        self.bf_range = bruteforce_range
        self.bf_step = bruteforce_spte
        self.layouts = []
        self.client_outputs = [Output() for _ in range(n_client)]
        self.global_data = [None] * n_client
        self.global_layout = Output()
        self.init_clients()
        self.update_global_graph()

    def init_clients(self):
        for client_idx in range(self.n_client):
            layout = self.create_ui(client_idx)
            self.layouts.append(layout)

    def create_ui(self, client_idx):
        sliders = {
            'N': FloatSlider(min=100, max=20000, step=100, value=10000, description='N'),
            'n_a': FloatSlider(min=0.0, max=1.0, step=0.01, value=0.5, description='n_a'),
            'alpha_a': FloatSlider(min=0.0, max=1.0, step=0.01, value=0.5, description='alpha_a'),
            'alpha_b': FloatSlider(min=0.0, max=1.0, step=0.01, value=0.5, description='alpha_b'),
            'mean_A0': FloatSlider(min=-10, max=10, step=0.1, value=-2, description='mean_A0'),
            'mean_B0': FloatSlider(min=-10, max=10, step=0.1, value=-6, description='mean_B0'),
            'mean_A1': FloatSlider(min=-10, max=10, step=0.1, value=6, description='mean_A1'),
            'mean_B1': FloatSlider(min=-10, max=10, step=0.1, value=2, description='mean_B1'),
            'std_A0': FloatSlider(min=0.1, max=3, step=0.1, value=1, description='std_A0'),
            'std_B0': FloatSlider(min=0.1, max=3, step=0.1, value=1, description='std_B0'),
            'std_A1': FloatSlider(min=0.1, max=3, step=0.1, value=1, description='std_A1'),
            'std_B1': FloatSlider(min=0.1, max=3, step=0.1, value=1, description='std_B1'),
            'epochs': FloatSlider(min=10, max=2000, step=10, value=100, description='epochs')
        }
        rows = [HBox([sliders[key] for key in ['N', 'n_a']]),
                HBox([sliders[key] for key in ['alpha_a', 'alpha_b']]),
                HBox([sliders[key] for key in ['mean_A0', 'mean_B0']]),
                HBox([sliders[key] for key in ['mean_A1', 'mean_B1']]),
                HBox([sliders[key] for key in ['std_A0', 'std_B0']]),
                HBox([sliders[key] for key in ['std_A1', 'std_B1']]),
                HBox([sliders[key] for key in ['epochs']]),
                ]
        title = Label(f'Client {client_idx} Parameters', layout=Layout(height='30px', align_self='center', justify_content='center'))
        ui = VBox([title] + rows)
        output = interactive_output(lambda **kwargs: self.update_client(client_idx, kwargs), sliders)
        layout = VBox([output, HBox([self.client_outputs[client_idx], ui])])
        return layout
    
    def create_combined_data(self, global_data):
        combined_data = {
            'A0': {'x': [], 'y': [], 's': []},
            'A1': {'x': [], 'y': [], 's': []},
            'B0': {'x': [], 'y': [], 's': []},
            'B1': {'x': [], 'y': [], 's': []}
        }
        for data in global_data:
            if data is not None:
                for group in ['A0', 'A1', 'B0', 'B1']:
                    combined_data[group]['x'].extend(data[group]['x'])
                    combined_data[group]['y'].extend(data[group]['y'])
                    combined_data[group]['s'].extend(data[group]['s'])
        return combined_data
    
    def get_xys(self, data):
        x, y, s = [], [], []
        for group in ['A0', 'A1', 'B0', 'B1']:
            x.extend(data[group]['x'])
            y.extend(data[group]['y'])
            s.extend(data[group]['s'])  
        return np.array(x), np.array(y), np.array(s)

    def update_client(self, client_idx, params):        
        means = [params['mean_A1'], params['mean_A0'], params['mean_B1'],params['mean_B0']]
        stds = [params['std_A1'], params['std_A0'], params['std_B1'],params['std_B0']]
        data = Data_Generator(params['N'], params['n_a'], params['alpha_a'], params['alpha_b'], means, stds)
        local_data = data.local_data
        x, y, s = data.get_client()
        x = x.reshape(-1,1)
        acc, s_eo, s_dp, decision_boundary = standalone(x,y,s, epochs=int(params['epochs']))
        bruteforce_acc, bruteforce_s_eo, bruteforce_s_dp, bruteforce_decision_boundary = bruteforce(x, y, s, self.bf_range, self.bf_step)
        self.global_data[client_idx] = local_data
        with self.client_outputs[client_idx]:
            self.draw_client(acc, s_eo, s_dp, decision_boundary[0], local_data, client_idx,
                             bruteforce_acc, bruteforce_s_eo, bruteforce_s_dp, bruteforce_decision_boundary)
        self.update_global_graph()
        
    def update_global_graph(self):
        with self.global_layout:
            combined_data = self.create_combined_data(self.global_data)
            centralized_acc, centralized_s_eo, centralized_s_dp, centralized_decision_boundary = centralized(combined_data, self.centralized_epochs)
            fedavg_acc, fedavg_s_eo, fedavg_s_dp, fedavg_decision_boundary = fedavg(combined_data, self.global_data, epochs=self.fedavg_epochs)
            x, y, s = self.get_xys(combined_data)
            bruteforce_acc, bruteforce_s_eo, bruteforce_s_dp, bruteforce_decision_boundary = bruteforce(x, y, s, self.bf_range, self.bf_step)
            
            self.draw_global(centralized_acc, centralized_s_eo, centralized_s_dp, centralized_decision_boundary[0], combined_data, 
                             fedavg_acc, fedavg_s_eo, fedavg_s_dp, fedavg_decision_boundary[0], 
                             bruteforce_acc, bruteforce_s_eo, bruteforce_s_dp, bruteforce_decision_boundary)
    
    
    def draw_client(self, acc, s_eo, s_dp, decision_boundary, data, client_idx, bf_acc, bf_s_eo, bf_s_dp, bf_decision_boundary):
        clear_output(wait=True) 
        fig, ax = plt.subplots(figsize=(6, 2.5))
        ax.set_title(f'Client {client_idx}')
        textstr = '\n'.join((
            r'$\mathrm{standalone}=%.3f$' % decision_boundary, 
            r'$\mathrm{Acc}=%.4f$' % acc, 
            r'$\mathrm{EO}=%.3f$' % s_eo, 
            r'$\mathrm{DP}=%.3f$' % s_dp, 
            r'$\mathrm{bruteforce}=%.3f$' % bf_decision_boundary, 
            r'$\mathrm{Acc}=%.4f$' % bf_acc, 
            r'$\mathrm{EO}=%.3f$' % bf_s_eo, 
            r'$\mathrm{DP}=%.3f$' % bf_s_dp, ))
        props = dict(boxstyle='round', facecolor='white', alpha=0.5)
        ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10, verticalalignment='top', bbox=props)
        ax.set_xlim(-10, 10)
        ax.set_ylim(0, 1)
        ax.axvline(x=decision_boundary, color='r', linestyle='--') 
        ax.axvline(x=bf_decision_boundary, color='g', linestyle='--') 
        ax.hist(data['A0']['x'], bins=100, alpha=0.5, density=True, label='Group A, Y = 0', color='red')
        ax.hist(data['B0']['x'], bins=100, alpha=0.5, density=True, label='Group B, Y = 0', color='orange')
        ax.hist(data['A1']['x'], bins=100, alpha=0.5, density=True, label='Group A, Y = 1', color='blue')
        ax.hist(data['B1']['x'], bins=100, alpha=0.5, density=True, label='Group B, Y = 1', color='green')
        ax.legend(loc='upper right')
        ax.set_ylabel('Density')
        plt.show()
        
    def draw_global(self, centralized_acc, centralized_s_eo, centralized_s_dp, centralized_decision_boundary, data, 
                    fedavg_acc, fedavg_s_eo, fedavg_s_dp, fedavg_decision_boundary, 
                    bruteforce_acc, bruteforce_s_eo, bruteforce_s_dp, bruteforce_decision_boundary):
        clear_output(wait=True) 
        fig, ax = plt.subplots(figsize=(8, 4))
        plt.title('Global Data')
        props = dict(boxstyle='round', facecolor='white', alpha=0.5)
        textstr_centralized = '\n'.join((
            r'$\mathrm{centralized}=%.3f$' % centralized_decision_boundary, 
            r'$\mathrm{Acc}=%.4f$' % centralized_acc, 
            r'$\mathrm{EO}=%.3f$' % centralized_s_eo, 
            r'$\mathrm{DP}=%.3f$' % centralized_s_dp,
            r'$\mathrm{fedavg}=%.3f$' % fedavg_decision_boundary, 
            r'$\mathrm{Acc}=%.4f$' % fedavg_acc, 
            r'$\mathrm{EO}=%.3f$' % fedavg_s_eo, 
            r'$\mathrm{DP}=%.3f$' % fedavg_s_dp, 
            r'$\mathrm{bruteforce}=%.3f$' % bruteforce_decision_boundary, 
            r'$\mathrm{Acc}=%.4f$' % bruteforce_acc, 
            r'$\mathrm{EO}=%.3f$' % bruteforce_s_eo, 
            r'$\mathrm{DP}=%.3f$' % bruteforce_s_dp))
        ax.text(0.05, 0.95, textstr_centralized, transform=ax.transAxes, fontsize=10, verticalalignment='top', bbox=props)
        ax.set_xlim(-10, 10)
        ax.set_ylim(0, 1)
        ax.axvline(x=fedavg_decision_boundary,  color='b', linestyle='--') 
        ax.axvline(x=centralized_decision_boundary, color='r', linestyle='--') 
        ax.hist(data['A0']['x'], bins=100, alpha=0.5, density=True, label='Group A, Y = 0', color='red')
        ax.hist(data['B0']['x'], bins=100, alpha=0.5, density=True, label='Group B, Y = 0', color='orange')
        ax.hist(data['A1']['x'], bins=100, alpha=0.5, density=True, label='Group A, Y = 1', color='blue')
        ax.hist(data['B1']['x'], bins=100, alpha=0.5, density=True, label='Group B, Y = 1', color='green')
        ax.legend(loc='upper right')
        ax.set_ylabel('Density')
        plt.show()
    
    def display_result(self):
        result = []
        for i in range(self.n_client):
            result.append(self.layouts[i])
        result.append(self.global_layout)
        display(VBox(result))

In [39]:
analyzer = Analyzer(n_client = 2, fedavg_epochs = 1000, centralized_epochs = 1000,  bruteforce_range = [-0.1, 0.1] , bruteforce_spte = 0.01)
analyzer.display_result()

VBox(children=(VBox(children=(Output(), HBox(children=(Output(), VBox(children=(Label(value='Client 0 Paramete…