In [1]:
import numpy as np
import matplotlib.pyplot as plt
from ipywidgets import interactive_output, FloatSlider, HBox, VBox, Label, Layout, Output
from IPython.display import display, clear_output

from trainers.trainer import standalone, centralized, fedavg, bruteforce
from utils.synthetic.data_generator import Data_Generator

class Analyzer:
    def __init__(self, n_client, fedavg_epochs, centralized_epochs, bf_range, bf_step):
        self.n_client = n_client
        self.centralized_lr = 0.03
        self.fedavg_lr = 0.03
        self.fedavg_epochs =  fedavg_epochs
        self.centralized_epochs = centralized_epochs
        self.bf_range = bf_range
        self.bf_step = bf_step
        self.sliders_config = {
            'N': {'min': 100, 'max': 20000, 'step': 100, 'value': 10000, 'description': 'N'},
            'n_a': {'min': 0.0, 'max': 1.0, 'step': 0.01, 'value': 0.5, 'description': 'n_a'},
            'alpha_a': {'min': 0.0, 'max': 1.0, 'step': 0.01, 'value': 0.5, 'description': 'alpha_a'},
            'alpha_b': {'min': 0.0, 'max': 1.0, 'step': 0.01, 'value': 0.5, 'description': 'alpha_b'},
            'mean_A0': {'min': -10, 'max': 10, 'step': 0.1, 'value': -2, 'description': 'mean_A0'},
            'mean_B0': {'min': -10, 'max': 10, 'step': 0.1, 'value': -6, 'description': 'mean_B0'},
            'mean_A1': {'min': -10, 'max': 10, 'step': 0.1, 'value': 6, 'description': 'mean_A1'},
            'mean_B1': {'min': -10, 'max': 10, 'step': 0.1, 'value': 2, 'description': 'mean_B1'},
            'std_A0': {'min': 0.1, 'max': 3, 'step': 0.1, 'value': 1, 'description': 'std_A0'},
            'std_B0': {'min': 0.1, 'max': 3, 'step': 0.1, 'value': 1, 'description': 'std_B0'},
            'std_A1': {'min': 0.1, 'max': 3, 'step': 0.1, 'value': 1, 'description': 'std_A1'},
            'std_B1': {'min': 0.1, 'max': 3, 'step': 0.1, 'value': 1, 'description': 'std_B1'},
            'epochs': {'min': 10, 'max': 2000, 'step': 10, 'value': 200, 'description': 'epochs'},
            'lr': {'min': 0.01, 'max': 1.0, 'step': 0.01, 'value': 0.03, 'description': 'lr'}
        }
        self.layout_groups = [
            ['N', 'n_a'],
            ['alpha_a', 'alpha_b'],
            ['mean_A0', 'mean_B0'],
            ['mean_A1', 'mean_B1'],
            ['std_A0', 'std_B0'],
            ['std_A1', 'std_B1'],
            ['epochs', 'lr']]
        self.global_data = [None] * n_client    
        self.client_outputs = [Output() for _ in range(self.n_client)]
        self.client_layouts = []
        self.global_layout = Output()
        
        # initialize client and global model.
        self.init_clients()
        self.update_global()
        
        
    ##### UI setting ###########################################################
    def create_client_ui(self, client_idx):
        # create dicts of sliders  
        # eg. 'N': FloatSlider(min=100, max=20000, step=100, value=10000, description='N'),
        sliders = {key: FloatSlider(**self.sliders_config[key]) for key in self.sliders_config}
        # set the sliders' placements.
        slider_placements = [HBox([sliders[key] for key in group]) for group in self.layout_groups]
        slider_titles = Label(f'Client {client_idx} Parameters', layout=Layout(height='30px', 
                                align_self='center', justify_content='center'))
        slider_ui = VBox([slider_titles] + slider_placements)
        # set interactive graph and slider ui.
        output_graph = interactive_output(lambda **kwargs: self.update_client(client_idx, kwargs), sliders)
        client_layout = VBox([output_graph, HBox([self.client_outputs[client_idx], slider_ui])])
        return client_layout
        
    def init_clients(self):
        for client_idx in range(self.n_client):
            client_layout = self.create_client_ui(client_idx)
            self.client_layouts.append(client_layout)
        
    def draw_client(self, sa_result, bf_result, client_data, client_idx):
        clear_output(wait=True) 
        fig, ax = plt.subplots(figsize=(6, 2.5))
        ax.set_title(f'Client {client_idx}')
        textstr = '\n'.join((
            f'$standalone={sa_result[3]:.2f}$', 
            f'$Acc={sa_result[0]:.4f}$',
            f'$EO={sa_result[1]:.3f}$', 
            f'$DP={sa_result[2]:.3f}$', 
            f'$bruteforce={bf_result[3]:.2f}$', 
            f'$Acc={bf_result[0]:.4f}$',
            f'$EO={bf_result[1]:.3f}$', 
            f'$DP={bf_result[2]:.3f}$'))
        props = dict(boxstyle='round', facecolor='white', alpha=0.5)
        ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10, verticalalignment='top', bbox=props)
        ax.set_xlim(-10, 10)
        ax.set_ylim(0, 1)
        ax.axvline(x=sa_result[3], color='r', linestyle='--') 
        ax.axvline(x=bf_result[3], color='g', linestyle='--') 
        ax.hist(client_data['A0']['x'], bins=100, alpha=0.5, density=True, label='Group A, Y = 0', color='red')
        ax.hist(client_data['B0']['x'], bins=100, alpha=0.5, density=True, label='Group B, Y = 0', color='orange')
        ax.hist(client_data['A1']['x'], bins=100, alpha=0.5, density=True, label='Group A, Y = 1', color='blue')
        ax.hist(client_data['B1']['x'], bins=100, alpha=0.5, density=True, label='Group B, Y = 1', color='green')
        ax.legend(loc='upper right')
        ax.set_ylabel('Density')
        plt.show()
        
    def draw_global(self, centralized_result, fedavg_result, bf_result, data):
        clear_output(wait=True) 
        fig, ax = plt.subplots(figsize=(8, 4))
        plt.title('Global Data')
        props = dict(boxstyle='round', facecolor='white', alpha=0.5)
        textstr= '\n'.join([
            f'$centralized={centralized_result[3]:.2f}$', 
            f'$Acc={centralized_result[0]:.4f}$', 
            f'$EO={centralized_result[1]:.3f}$', 
            f'$DP={centralized_result[2]:.3f}$',
            f'$fedavg={fedavg_result[3]:.2f}$', 
            f'$Acc={fedavg_result[0]:.4f}$', 
            f'$EO={fedavg_result[1]:.3f}$', 
            f'$DP={fedavg_result[2]:.3f}$',
            f'$bruteforce={bf_result[3]:.2f}$', 
            f'$Acc={bf_result[0]:.4f}$',
            f'$EO={bf_result[1]:.3f}$',
            f'$DP={bf_result[2]:.3f}$'
        ])
        ax.text(0.05, 0.95, textstr, transform=ax.transAxes, fontsize=10, verticalalignment='top', bbox=props)
        ax.set_xlim(-10, 10)
        ax.set_ylim(0, 1)
        ax.axvline(x=fedavg_result[3],  color='b', linestyle='--') 
        ax.axvline(x=centralized_result[3], color='r', linestyle='--') 
        ax.axvline(x=bf_result[3], color='g', linestyle='--') 
        ax.hist(data['A0']['x'], bins=100, alpha=0.5, density=True, label='Group A, Y = 0', color='red')
        ax.hist(data['B0']['x'], bins=100, alpha=0.5, density=True, label='Group B, Y = 0', color='orange')
        ax.hist(data['A1']['x'], bins=100, alpha=0.5, density=True, label='Group A, Y = 1', color='blue')
        ax.hist(data['B1']['x'], bins=100, alpha=0.5, density=True, label='Group B, Y = 1', color='green')
        ax.legend(loc='upper right')
        ax.set_ylabel('Density')
        plt.show()
        
        
    #### Helper Functions for update_global() ##################################   
    def create_combined_data(self, global_data):
        combined_data = {
            'A0': {'x': [], 'y': [], 's': []},
            'A1': {'x': [], 'y': [], 's': []},
            'B0': {'x': [], 'y': [], 's': []},
            'B1': {'x': [], 'y': [], 's': []}}
        for data in global_data:
            if data is not None:
                for group in ['A0', 'A1', 'B0', 'B1']:
                    combined_data[group]['x'].extend(data[group]['x'])
                    combined_data[group]['y'].extend(data[group]['y'])
                    combined_data[group]['s'].extend(data[group]['s'])
        return combined_data
    
    def convert_xys(self, data):
        x, y, s = [], [], []
        for group in ['A0', 'A1', 'B0', 'B1']:
            x.extend(data[group]['x'])
            y.extend(data[group]['y'])
            s.extend(data[group]['s'])  
        return np.array(x), np.array(y), np.array(s)
    
    
    #### Model training & evaluation ###########################################
    def update_client(self, client_idx, params):
        # generate training data based on the given parameters.
        means = [params['mean_A1'], params['mean_A0'], params['mean_B1'],params['mean_B0']]
        stds = [params['std_A1'], params['std_A0'], params['std_B1'],params['std_B0']]
        data = Data_Generator(params['N'], params['n_a'], params['alpha_a'], params['alpha_b'], means, stds)
        lr = params['lr']
        # adjust data format
        client_data = data.get_client()
        x, y, s = data.get_xys()
        # train and evaluate standalone and bruteforce models.
        sa_result = standalone(x, y, s, lr, epochs=int(params['epochs']))
        bf_result = bruteforce(x, y, s, self.bf_range, self.bf_step, warm_start=sa_result[3])
        # update global data.
        self.global_data[client_idx] = client_data
        # update specified clident model.
        with self.client_outputs[client_idx]:
            self.draw_client(sa_result, bf_result, client_data, client_idx)
        # update global model corresponding to the change of global data.
        self.update_global()
    
    def update_global(self):
        # adjust data format
        combined_data = self.create_combined_data(self.global_data)
        x, y, s = self.convert_xys(combined_data)
        # train and eval centralized, fedavg, and bruteforce models.
        centralized_result = centralized(combined_data, self.centralized_lr, self.centralized_epochs)
        fedavg_result = fedavg(combined_data, self.global_data, self.fedavg_lr, epochs=self.fedavg_epochs)
        bf_result = bruteforce(x, y, s, self.bf_range, self.bf_step, warm_start=centralized_result[3])
        # update the result.
        with self.global_layout:
            self.draw_global(centralized_result, fedavg_result, bf_result, combined_data)
    
    #### Display function ######################################################
    def display_result(self):
        layouts = []
        # append all client layouts.
        for client_idx in range(self.n_client):
            layouts.append(self.client_layouts[client_idx])
        # append global layout.
        layouts.append(self.global_layout)
        # display the layouts vertically. 
        display(VBox(layouts))

In [2]:
#n_client, fedavg_epochs, centralized_epochs, sliders_config, 
analyzer = Analyzer(n_client = 2, fedavg_epochs = 200, centralized_epochs = 200, bf_range = 0.2, bf_step = 0.01)
analyzer.display_result()

VBox(children=(VBox(children=(Output(), HBox(children=(Output(), VBox(children=(Label(value='Client 0 Paramete…