In [1]:
import os
import sys
import socket
import torch.multiprocessing

torch.multiprocessing.set_sharing_strategy('file_system')
import warnings

warnings.filterwarnings("ignore")

conf_path = os.getcwd()
sys.path.append(conf_path)
sys.path.append(conf_path + '/datasets')
sys.path.append(conf_path + '/backbone')
sys.path.append(conf_path + '/models')
from datasets import Priv_NAMES as DATASET_NAMES
from models import get_all_models
from argparse import ArgumentParser
from utils.args import add_management_args
from datasets import get_prive_dataset
from models import get_model
from utils.training import train
from utils.best_args import best_args
from utils.conf import set_random_seed
import setproctitle

import torch
import uuid
import datetime

In [2]:
def parse_args():
    parser = ArgumentParser(description='You Only Need Me', allow_abbrev=False)
    parser.add_argument('--device_id', type=int, default=0, help='The Device Id for Experiment')

    parser.add_argument('--communication_epoch', type=int, default=2, help='The Communication Epoch in Federated Learning')
    parser.add_argument('--local_epoch', type=int, default=3, help='The Local Epoch for each Participant')
    parser.add_argument('--parti_num', type=int, default=3, help='The Number for Participants')

    parser.add_argument('--seed', type=int, default=0, help='The random seed.')
    parser.add_argument('--rand_dataset', type=dict, default={'mnist': 1, 'usps': 1, 'svhn': 1, 'syn': 2}, help='The random seed.')

    parser.add_argument('--model', type=str, default='fpl',  # moon fedinfonce
                        help='Model name.', choices=get_all_models())
    parser.add_argument('--structure', type=str, default='homogeneity')
    parser.add_argument('--dataset', type=str, default='fl_digits',  # fl_officecaltech fl_digits
                        choices=DATASET_NAMES, help='Which scenario to perform experiments on.')

    parser.add_argument('--pri_aug', type=str, default='weak',  # weak strong
                        help='Augmentation for Private Data')
    parser.add_argument('--online_ratio', type=float, default=1, help='The Ratio for Online Clients')
    parser.add_argument('--learning_decay', type=bool, default=False, help='The Option for Learning Rate Decay')
    parser.add_argument('--averaing', type=str, default='weight', help='The Option for averaging strategy')

    parser.add_argument('--infoNCET', type=float, default=0.02, help='The InfoNCE temperature')
    parser.add_argument('--T', type=float, default=0.05, help='The Knowledge distillation temperature')
    parser.add_argument('--weight', type=int, default=1, help='The Wegith for the distillation loss')

    parser.add_argument('--reserv_ratio', type=float, default=0.1, help='Reserve ratio for prototypes')

    torch.set_num_threads(4)
    def add_management_args(parser: ArgumentParser) -> None:
        parser.add_argument('--csv_log', action='store_true',
                            help='Enable csv logging',default=False)
    
    add_management_args(parser)
    
    args, unknown = parser.parse_known_args()

    best = best_args[args.dataset][args.model]

    for key, value in best.items():
        setattr(args, key, value)

    if args.seed is not None:
        set_random_seed(args.seed)
    return args

In [3]:
args = parse_args()

priv_dataset = get_prive_dataset(args)

In [4]:
priv_dataset.train_loaders

[]

In [5]:
print(vars(priv_dataset))

{'train_loaders': [], 'test_loader': [], 'args': Namespace(device_id=0, communication_epoch=2, local_epoch=3, parti_num=3, seed=0, rand_dataset={'mnist': 1, 'usps': 1, 'svhn': 1, 'syn': 2}, model='fpl', structure='homogeneity', dataset='fl_digits', pri_aug='weak', online_ratio=1, learning_decay=False, averaing='weight', infoNCET=0.02, T=0.05, weight=1, reserv_ratio=0.1, csv_log=False, local_lr=0.01, local_batch_size=64, Note='+ MSE')}


In [6]:
backbones_list = priv_dataset.get_backbone(args.parti_num, None)

In [7]:
model = get_model(backbones_list, args, priv_dataset.get_transform())

In [8]:
# print(vars(model))

In [9]:
# model.nets_list[0]

In [10]:
import torch
from argparse import Namespace
from models.utils.federated_model import FederatedModel
from datasets.utils.federated_dataset import FederatedDataset
from typing import Tuple
from torch.utils.data import DataLoader
import numpy as np
from utils.logger import CsvWriter
from collections import Counter


def global_evaluate(model: FederatedModel, test_dl: DataLoader, setting: str, name: str) -> Tuple[list, list]:
    accs = []
    net = model.global_net
    status = net.training
    net.eval()
    for j, dl in enumerate(test_dl):
        correct, total, top1, top5 = 0.0, 0.0, 0.0, 0.0
        for batch_idx, (images, labels) in enumerate(dl):
            with torch.no_grad():
                images, labels = images.to(model.device), labels.to(model.device)
                outputs = net(images)
                _, max5 = torch.topk(outputs, 5, dim=-1)
                labels = labels.view(-1, 1)
                top1 += (labels == max5[:, 0:1]).sum().item()
                top5 += (labels == max5).sum().item()
                total += labels.size(0)
        top1acc = round(100 * top1 / total, 2)
        top5acc = round(100 * top5 / total, 2)
        # if name in ['fl_digits','fl_officecaltech']:
        accs.append(top1acc)
        # elif name in ['fl_office31','fl_officehome']:
        #     accs.append(top5acc)
    net.train(status)
    return accs


def train(model: FederatedModel, private_dataset: FederatedDataset, args: Namespace) -> None:
    if args.csv_log:
        csv_writer = CsvWriter(args, private_dataset)

    model.N_CLASS = private_dataset.N_CLASS
    domains_list = private_dataset.DOMAINS_LIST
    domains_len = len(domains_list)

    if args.rand_dataset:
        max_num = 10
        is_ok = False

        while not is_ok:
            if model.args.dataset == 'fl_officecaltech':
                selected_domain_list = np.random.choice(domains_list, size=args.parti_num - domains_len, replace=True, p=None)
                selected_domain_list = list(selected_domain_list) + domains_list
            elif model.args.dataset == 'fl_digits':
                selected_domain_list = np.random.choice(domains_list, size=args.parti_num, replace=True, p=None)

            result = dict(Counter(selected_domain_list))

            for k in result:
                if result[k] > max_num:
                    is_ok = False
                    break
            else:
                is_ok = True
    else:
        selected_domain_dict = {'mnist': 6, 'usps': 4, 'svhn': 3, 'syn': 7}  # base
        selected_domain_dict = {'mnist': 3, 'usps': 7, 'svhn': 6, 'syn': 4}  # article

        # selected_domain_dict = {'mnist': 1, 'usps': 1, 'svhn': 9, 'syn': 9}  # 20

        # selected_domain_dict = {'mnist': 3, 'usps': 2, 'svhn': 1, 'syn': 4}  # 10

        selected_domain_list = []
        for k in selected_domain_dict:
            domain_num = selected_domain_dict[k]
            for i in range(domain_num):
                selected_domain_list.append(k)

        selected_domain_list = np.random.permutation(selected_domain_list)

        result = Counter(selected_domain_list)
    print(result)

    print(selected_domain_list)
    pri_train_loaders, test_loaders, label_dict, train_dataset_list, test_dataset_list = private_dataset.get_data_loaders(selected_domain_list)
    model.trainloaders = pri_train_loaders
    if hasattr(model, 'ini'):
        model.ini()

    accs_dict = {}
    mean_accs_list = []

    Epoch = args.communication_epoch
    for epoch_index in range(Epoch):
        model.epoch_index = epoch_index
        if hasattr(model, 'loc_update'):
            epoch_loc_loss_dict = model.loc_update(pri_train_loaders)

        accs = global_evaluate(model, test_loaders, private_dataset.SETTING, private_dataset.NAME)
        mean_acc = round(np.mean(accs, axis=0), 3)
        mean_accs_list.append(mean_acc)
        for i in range(len(accs)):
            if i in accs_dict:
                accs_dict[i].append(accs[i])
            else:
                accs_dict[i] = [accs[i]]

        print('The ' + str(epoch_index) + ' Communcation Accuracy:', str(mean_acc), 'Method:', model.args.model)
        print(accs)

    if args.csv_log:
        csv_writer.write_acc(accs_dict, mean_accs_list)
        
    return label_dict


In [11]:
The 49 Communcation Accuracy: 74.34 Method: fpl
[97.63, 88.94, 68.75, 42.04]

batch 256

The 49 Communcation Accuracy: 80.12 Method: fpl
[98.08, 90.33, 79.53, 52.54]
batch 64

SyntaxError: invalid syntax (1885791128.py, line 1)

In [12]:
label_dict = train(model, priv_dataset, args)

{'mnist': 1, 'syn': 1, 'usps': 1}
['mnist' 'syn' 'usps']


Local Pariticipant 1 CE = 2.064,InfoNCE = 0.000: 100%|██████████| 3/3 [00:05<00:00,  1.88s/it]
Local Pariticipant 0 CE = 1.500,InfoNCE = 0.000: 100%|██████████| 3/3 [00:40<00:00, 13.39s/it]
Local Pariticipant 2 CE = 2.202,InfoNCE = 0.000: 100%|██████████| 3/3 [00:05<00:00,  1.71s/it]


Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
The 0 Communcation Accuracy: 13.385 Method: fpl
[11.35, 13.15, 19.04, 10.0]


Local Pariticipant 0 CE = 0.738,InfoNCE = 1.237: 100%|██████████| 3/3 [00:45<00:00, 15.25s/it]
Local Pariticipant 2 CE = 1.789,InfoNCE = 2.047: 100%|██████████| 3/3 [00:05<00:00,  1.97s/it]
Local Pariticipant 1 CE = 2.244,InfoNCE = 2.405: 100%|██████████| 3/3 [00:08<00:00,  2.68s/it]


Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
The 1 Communcation Accuracy: 19.06 Method: fpl
[35.89, 13.25, 16.75, 10.35]


In [47]:
label_dict.keys()

dict_keys(['mnist', 'syn', 'usps'])

In [48]:
for d in label_dict.values():
    print(len(d))

60000
10000
7291


In [74]:
import torch

def get_sorted_label_indices(label_dict):
    """
    Given a dictionary where each key is a dataset name and the value is a tensor or list of labels,
    returns a new dictionary mapping each label to its corresponding indices, sorted by label.
    Indices are stored in sets instead of lists for efficient lookups.

    Args:
        label_dict (dict): A dictionary with dataset names as keys and tensors/lists of labels as values.

    Returns:
        dict: A dictionary with dataset names as keys and sorted mappings of labels to index sets.
    """
    index_map = {}

    for dataset_name, y_values in label_dict.items():
        index_map[dataset_name] = {}  # Initialize the dataset key

        # If y_values is a tensor, convert it to a list
        if isinstance(y_values, torch.Tensor):
            y_values = y_values.tolist()

        # Loop through all labels in the list
        for index, label in enumerate(y_values):  
            if label not in index_map[dataset_name]:
                index_map[dataset_name][label] = set()  # Initialize set if not present
            
            index_map[dataset_name][label].add(index)  # Store index
        
        # Sort the dictionary keys for each dataset
        index_map[dataset_name] = dict(sorted(index_map[dataset_name].items()))
    
    return index_map

In [75]:
avaliable_indexes = get_sorted_label_indices(label_dict)

# Printing results
for dataset, mapping in avaliable_indexes.items():
    print(f"\nDataset: {dataset}")
    aux = 0
    for label, indices in mapping.items():
        print(f"  Label {label}: {len(indices)} occurrences")
        aux+=len(indices)
    print(aux)


Dataset: mnist
  Label 0: 5923 occurrences
  Label 1: 6742 occurrences
  Label 2: 5958 occurrences
  Label 3: 6131 occurrences
  Label 4: 5842 occurrences
  Label 5: 5421 occurrences
  Label 6: 5918 occurrences
  Label 7: 6265 occurrences
  Label 8: 5851 occurrences
  Label 9: 5949 occurrences
60000

Dataset: syn
  Label 0: 1000 occurrences
  Label 1: 1000 occurrences
  Label 2: 1000 occurrences
  Label 3: 1000 occurrences
  Label 4: 1000 occurrences
  Label 5: 1000 occurrences
  Label 6: 1000 occurrences
  Label 7: 1000 occurrences
  Label 8: 1000 occurrences
  Label 9: 1000 occurrences
10000

Dataset: usps
  Label 0: 1194 occurrences
  Label 1: 1005 occurrences
  Label 2: 731 occurrences
  Label 3: 658 occurrences
  Label 4: 652 occurrences
  Label 5: 556 occurrences
  Label 6: 664 occurrences
  Label 7: 645 occurrences
  Label 8: 542 occurrences
  Label 9: 644 occurrences
7291


In [62]:
from Tbx_Sample import *

In [63]:
class Participant:
    def __init__(self, labels, stochastic_params, mnist_population, available_indexes):
        """
        Initialize the Participant with labels and their stochastic process parameters.
        
        labels: List of labels to sample from (e.g., [0, 1, 2, 3, ...])
        stochastic_params: A list of dictionaries containing the stochastic process parameters for each label
        mnist_population: The indexed MNIST population
        available_indexes: The available indexes for sampling from the MNIST population
        """
        self.labels = labels
        self.stochastic_params = stochastic_params
        self.mnist_population = mnist_population
        self.available_indexes = available_indexes
        self.Samples = self._generate_Samples()
    
    def _generate_Samples(self):
        """
        Generate samples for each label based on their respective stochastic process parameters.
        """
        samples = {}
        for label, params in zip(self.labels, self.stochastic_params):
            try:
                # Create the stochastic process for this label
                stochastic_process = StochasticProcess(
                    initial_population=params['initial_population'],
                    steps=params['steps'],
                    process_type=params['process_type'],
                    combine_with=params['combine_with'],
                    rate=params['rate'],
                    angular_coef=params['angular_coef'],
                    period=params['period'],
                    frequency=params['frequency']
                )
                # Create a sample instance for this label
                sample = Sample(
                    label=label,
                    stochastic_process=stochastic_process,
                    mnist_population=self.mnist_population,
                    available_indexes=self.available_indexes
                )
                # Store the samples for this label
                samples[label] = sample
            except KeyError as e:
                print(f"Missing parameter {e} for label {label}")
            except Exception as e:
                print(f"An error occurred for label {label}: {e}")
        return samples
    
    def get_sample_label(self, label, time_step):
        """
        Get the samples for a specific label at a given time step.
        
        label: The label for which samples are needed (e.g., 3)
        time_step: The time step at which to fetch the samples
        """
        try:
            sample = self.Samples[label]
            return sample.get_samples_at(time_step)
        except KeyError:
            print(f"Sample for label {label} not found.")
        except Exception as e:
            print(f"An error occurred: {e}")

    def report_data(self):
        aux_data = {}
        for label in self.labels:
            aux_data[label] = self.Samples[label].get_process_data()
        
        df = pd.DataFrame.from_dict(aux_data, orient='index')
        df.reset_index(inplace=True)
        df.rename(columns={'index': 'id'}, inplace=True)
        df.columns = ['id'] + [f'value_{i+1}' for i in range(df.shape[1] - 1)]

        return df

In [117]:
import itertools
import random

def generate_limited_values(start, stop, step, max_values=None):
    values = list(frange(start, stop, step))
    if max_values and len(values) > max_values:
        values = random.sample(values, max_values)
    return values

def frange(start, stop, step):
    while start < stop:
        yield round(start, 10)  # Avoid floating-point precision errors
        start += step

def generate_stochastic_combinations(init_pop=200, init_steps=20, init_freq=None, init_period=None, init_rate = None, init_angular = None,  max_combinations=None):
    if init_freq is None:
        init_freq = generate_limited_values(0.1 * init_pop, 0.5 * init_pop, 0.05 * init_pop, max_combinations)
    else:
        init_freq = generate_limited_values(init_freq * init_pop, 5 * init_freq * init_pop, 0.1 * init_freq * init_pop, max_combinations) if not isinstance(init_period, list) else generate_limited_values(init_freq[0] * init_pop, init_freq[1] * init_pop, (init_freq[1] - init_freq[0]) * 0.1 * init_pop, max_combinations)

    if init_period is None:
        init_period = generate_limited_values(0, 0.5, 0.07, max_combinations)
    else:
        init_period = generate_limited_values(0, init_period, 0.07*init_period, max_combinations)

    if init_rate is None:
        init_rate = generate_limited_values(-0.1, 0.1, 0.025, max_combinations)
    else:
        init_rate = generate_limited_values(-init_rate, init_rate, 0.05*init_rate, max_combinations)

    if init_angular is None:
        init_angular = generate_limited_values(-4, 4, 0.5, max_combinations)
    else:
        init_angular = generate_limited_values(-init_angular, init_angular, 0.05*init_angular, max_combinations)


    # Define possible values for 'combine_with' and 'process_type'
    combine_with_values = ['exponential', 'sine', 'exp_decay']
    process_type_values = ['combined']

    # Generate valid combinations of 'combine_with'
    combine_with_combinations = [
        comb for r in range(1, 3)  # Generate combinations of size 1 and 2
        for comb in itertools.combinations(combine_with_values, r)
        if not ('exponential' in comb and 'exp_decay' in comb)
    ]

    # Generate all combinations
    param_combinations = list(itertools.product(init_rate, init_angular, combine_with_combinations, process_type_values))

    # If max_combinations is specified, sample from it
    if max_combinations and len(param_combinations) > max_combinations:
        param_combinations = random.sample(param_combinations, max_combinations)

    # Generate the parameter dictionaries
    parameter_dicts = []
    for rate, angular, combine, process in param_combinations:
        param_dict = {
            'initial_population': init_pop,
            'steps': init_steps,
            'process_type': process,
            'combine_with': list(combine),
            'rate': rate,
            'angular_coef': angular,
            'period': random.choice(init_period) if isinstance(init_period, list) else init_period,
            'frequency': random.choice(init_freq) if isinstance(init_freq, list) else init_freq,
        }
        parameter_dicts.append(param_dict)

    return parameter_dicts

In [118]:
import random

# Function to randomly select labels based on the length of stochastic_params
def select_random_labels(stochastic_params, label_range=(0, 10)):
    # Generate random labels based on the length of stochastic_params
    num_labels = len(stochastic_params)  # The number of labels to select
    random_labels = random.sample(range(label_range[0], label_range[1]), num_labels)
    
    return random_labels

In [119]:
def add_participant_to_federation(
    federation, available_indexes, 
    init_pop, init_steps, init_freq, init_period, init_rate, init_angular, max_comb
):
    """
    Generates stochastic parameter combinations, selects random labels, and adds a new Participant 
    instance to the federation.

    Args:
        federation (list): The list where the new Participant will be appended.
        available_indexes (dict): Dictionary of available indexes per dataset.
        init_pop (list): Initial population parameters.
        init_steps (list): Initial steps parameters.
        init_freq (list): Initial frequency parameters.
        init_period (list): Initial period parameters.
        max_comb (int): Maximum number of stochastic parameter combinations to generate.

    Returns:
        None: The function modifies the `federation` list in place.
    """
    # Generate a subset of combinations
    param_combinations = generate_stochastic_combinations(
        init_pop=init_pop, init_steps=init_steps, 
        init_freq=init_freq, init_period=init_period, 
        init_rate = init_rate, init_angular = init_angular,
        max_combinations=max_comb
    )

    # # Print the generated combinations
    # for idx, params in enumerate(param_combinations):
    #     print(f"Combination {idx + 1}: {params}")

    # Generate a list of random labels for the given stochastic parameters
    random_labels = select_random_labels(param_combinations)

    # Create a Participant instance
    participant = Participant(
        labels=random_labels, 
        stochastic_params=param_combinations, 
        mnist_population=available_indexes, 
        available_indexes=available_indexes
    )

    # Append the new participant to the federation
    federation.append(participant)

    print(f"Added new participant with {len(random_labels)} labels to the federation.")

In [120]:
avaliable_indexes.keys()

dict_keys(['mnist', 'syn', 'usps'])

In [152]:
# Define federation list
federation = []

# Client A
init_pop = 50
init_steps = 20
init_period = int(0.2 * init_pop)
init_freq = 1
max_comb = 6
init_rate = None
init_angular = None
domain = 'mnist'

add_participant_to_federation(
    federation=federation,
    available_indexes=avaliable_indexes[domain],
    init_pop=init_pop,
    init_steps=init_steps,
    init_freq=init_freq,
    init_period=init_period,
    init_rate = init_rate,
    init_angular = init_angular,
    max_comb=max_comb
)

Added new participant with 6 labels to the federation.


In [153]:
# Client B
init_pop = 50
init_steps = 20
init_period = int(0.2 * init_pop)
init_freq = None
max_comb = 5
init_rate = None
init_angular = None
domain = 'mnist'

add_participant_to_federation(
    federation=federation,
    available_indexes=avaliable_indexes[domain],
    init_pop=init_pop,
    init_steps=init_steps,
    init_freq=init_freq,
    init_period=init_period,
    init_rate = init_rate,
    init_angular = init_angular,
    max_comb=max_comb
)

Added new participant with 5 labels to the federation.


In [154]:
# Client C
init_pop = 50
init_steps = 20
init_period = None
init_freq = 1
max_comb = 8
init_rate = None
init_angular = None
domain = 'mnist'

add_participant_to_federation(
    federation=federation,
    available_indexes=avaliable_indexes[domain],
    init_pop=init_pop,
    init_steps=init_steps,
    init_freq=init_freq,
    init_period=init_period,
    init_rate = init_rate,
    init_angular = init_angular,
    max_comb=max_comb
)

Added new participant with 8 labels to the federation.


In [155]:
# Client D
init_pop = 50
init_steps = 20
init_period = None
init_freq = None
max_comb = 8
init_rate = None
init_angular = None
domain = 'mnist'

add_participant_to_federation(
    federation=federation,
    available_indexes=avaliable_indexes[domain],
    init_pop=init_pop,
    init_steps=init_steps,
    init_freq=init_freq,
    init_period=init_period,
    init_rate = init_rate,
    init_angular = init_angular,
    max_comb=max_comb
)

Added new participant with 8 labels to the federation.


In [166]:
# Client E
init_pop = 600
init_steps = 20
init_period = 0.0000001
init_freq = 0.0000001
max_comb = 8
init_rate = 0.00000001
init_angular = 0.00000001
domain = 'mnist'

add_participant_to_federation(
    federation=federation,
    available_indexes=avaliable_indexes[domain],
    init_pop=init_pop,
    init_steps=init_steps,
    init_freq=init_freq,
    init_period=init_period,
    init_rate = init_rate,
    init_angular = init_angular,
    max_comb=max_comb
)

Added new participant with 8 labels to the federation.


In [167]:
aux_plot = []
for i, part in enumerate(federation):
    aux_part = part.report_data()
    aux_part['Participant'] = f'Part_{i}'
    aux_part.rename(columns = {'id' : 'Digit'}, inplace = True)
    aux_plot.append(aux_part)

df_plot = pd.concat(aux_plot)

# Melt the DataFrame
df_plot = df_plot.melt(id_vars=['Digit', 'Participant'], 
                     var_name='Time', 
                     value_name='Count')

# Convert 'Time' to numerical values (e.g., value_1 → 1, value_2 → 2, etc.)
df_plot['Time'] = df_plot['Time'].str.extract('(\d+)').astype(int) - 1

df_plot.shape

(1239, 4)

In [168]:
import altair as alt

def plot_mnist_distribution(plot_data, plot_width=800, plot_height=400):
    """
    Plot the MNIST distribution using a bar chart and heatmap, distributing the plot size
    between the bar chart and heatmap.

    Parameters:
    - plot_data: DataFrame with the MNIST data
    - plot_width: Total width of the final plot (default: 800)
    - plot_height: Total height of the final plot (default: 400)
    """

    # Create a slider for selecting 'Time'
    time_slider = alt.binding_range(
        min=plot_data['Time'].min(), 
        max=plot_data['Time'].max(), 
        step=1, 
        name="Select Time: "
    )
    time_param = alt.selection_point(bind=time_slider, fields=['Time'], value = plot_data['Time'].min())

    # Altair selection object for interactive filtering
    bar_point = alt.selection_point(encodings=['x'])
    client_selection = alt.selection_point(encodings=['y'])
    
    # Calculate the split dimensions for the two charts
    bar_width = plot_width * 0.4   # 40% of the width for the bar chart
    heatmap_width = plot_width * 0.6  # 60% of the width for the heatmap
    bar_height = plot_height       # Full height for the bar chart
    heatmap_height = plot_height   # Full height for the heatmap
    
    # Heatmap (Digit vs. Participant) filtered by selected Time
    heatmap = alt.Chart(plot_data).mark_rect().encode(
        x=alt.X('Digit:O', title="MNIST Digit (0-9)"),
        y=alt.Y('Participant:N', title="Participants"),
        color = alt.condition(
                bar_point, 
                alt.Color('Count:Q', scale=alt.Scale(scheme='greenblue'), title="Total Count", legend=alt.Legend(orient="top")),  
                alt.ColorValue("grey")  # Non-selected values will be grey
            ),
        tooltip=[alt.Tooltip('Digit:O', title="Digit"),
                 alt.Tooltip('Participant:N', title="Participant"),
                 alt.Tooltip('Count:Q', title="Total Count")]
    ).transform_filter(
        time_param
    ).add_params(bar_point, client_selection, time_param).properties(width=heatmap_width, height=heatmap_height)
    
    # Overlay points for highlighting
    highlight_points = heatmap.mark_point().encode(
        color=alt.ColorValue('black'),
        size=alt.Size('Count:Q', title="Selected Images")
    ).transform_filter(
        bar_point
    ).transform_filter(
        client_selection
    )
    
    
    # Bar chart (Total images per client_selectionicipant) filtered by selected Time
    bar_chart = alt.Chart(plot_data).mark_bar().encode(
        x="Participant:N",
        y="sum(Count):Q",
        color=alt.condition(bar_point, alt.ColorValue("steelblue"), alt.ColorValue("grey"))
    ).transform_filter(
        time_param
    ).add_params(bar_point).properties(width=bar_width, height=bar_height)


    # Create an Altair line chart with markers
    line_chart = alt.Chart(plot_data).mark_line().encode(
        x='Time:O',  # Ordinal axis for time
        y='Count:Q',  # Quantitative axis for process value
        color=alt.Color('Digit:O', title="Stoc.Process"),
        tooltip=['Digit', 'Time', 'Count']
    ).properties(
        title="Process over Time",
        width=300,
        height=plot_height
    ).transform_filter(
        bar_point
    ).transform_filter(
        client_selection
    )
    
    # Add markers (circles) at data points
    marker_chart = alt.Chart(plot_data).mark_circle(size=50).encode(
        x='Time:O',
        y='Count:Q',
        color=alt.Color('Digit:O', title="Stoc.Process"),
        tooltip=['Digit', 'Time', 'Count']
    ).transform_filter(
        bar_point
    ).transform_filter(
        client_selection
    )
    
    # Overlay the markers on top of the line chart
    chart = line_chart + marker_chart
    
    # Combine the charts horizontally
    final_chart = alt.hconcat(
        bar_chart,
        heatmap + highlight_points,
        chart
    ).resolve_legend(
        color="independent",
        size="independent"
    )

    return final_chart

# Example of using the function with custom plot size
final_chart = plot_mnist_distribution(plot_data=df_plot, plot_width=350, plot_height=250)
final_chart

In [170]:
from torch.utils.data import DataLoader, SubsetRandomSampler


train_sampler = SubsetRandomSampler(federation[-2].get_sample_label(0, 0))

DataLoader(train_dataset, batch_size=64, sampler=train_sampler)

{1152,
 1625,
 2539,
 5978,
 6207,
 6619,
 7506,
 9196,
 11091,
 11605,
 12457,
 14258,
 14472,
 14649,
 14722,
 16043,
 17738,
 18858,
 19166,
 19550,
 19822,
 21273,
 21658,
 22857,
 23589,
 24952,
 26183,
 26298,
 28131,
 29250,
 30088,
 34623,
 39710,
 40390,
 40418,
 40633,
 42217,
 42764,
 44121,
 47000,
 47166,
 47599,
 51348,
 53344,
 54368,
 54783,
 57202,
 57843,
 58819,
 59337}