In [1]:
from Tbx_Sample import *

In [2]:
import os
import sys
import socket
import torch.multiprocessing

import copy

torch.multiprocessing.set_sharing_strategy('file_system')
import warnings

warnings.filterwarnings("ignore")

conf_path = os.getcwd()
sys.path.append(conf_path)
sys.path.append(conf_path + '/datasets')
sys.path.append(conf_path + '/backbone')
sys.path.append(conf_path + '/models')
from datasets import Priv_NAMES as DATASET_NAMES
from models import get_all_models
from argparse import ArgumentParser
from utils.args import add_management_args
from datasets import get_prive_dataset
from models import get_model
from utils.training import train
from utils.best_args import best_args
from utils.conf import set_random_seed
import setproctitle

import torch
import uuid
import datetime

In [3]:
def parse_args():
    parser = ArgumentParser(description='You Only Need Me', allow_abbrev=False)
    parser.add_argument('--device_id', type=int, default=0, help='The Device Id for Experiment')

    parser.add_argument('--communication_epoch', type=int, default=2, help='The Communication Epoch in Federated Learning')
    parser.add_argument('--local_epoch', type=int, default=3, help='The Local Epoch for each Participant')
    parser.add_argument('--parti_num', type=int, default=3, help='The Number for Participants')

    parser.add_argument('--seed', type=int, default=0, help='The random seed.')
    parser.add_argument('--rand_dataset', type=dict, default={'mnist': 1, 'usps': 1, 'svhn': 1, 'syn': 2}, help='The random seed.')

    parser.add_argument('--model', type=str, default='fpl',  # moon fedinfonce
                        help='Model name.', choices=get_all_models())
    parser.add_argument('--structure', type=str, default='homogeneity')
    parser.add_argument('--dataset', type=str, default='fl_digits',  # fl_officecaltech fl_digits
                        choices=DATASET_NAMES, help='Which scenario to perform experiments on.')

    parser.add_argument('--pri_aug', type=str, default='weak',  # weak strong
                        help='Augmentation for Private Data')
    parser.add_argument('--online_ratio', type=float, default=1, help='The Ratio for Online Clients')
    parser.add_argument('--learning_decay', type=bool, default=False, help='The Option for Learning Rate Decay')
    parser.add_argument('--averaing', type=str, default='weight', help='The Option for averaging strategy')

    parser.add_argument('--infoNCET', type=float, default=0.02, help='The InfoNCE temperature')
    parser.add_argument('--T', type=float, default=0.05, help='The Knowledge distillation temperature')
    parser.add_argument('--weight', type=int, default=1, help='The Wegith for the distillation loss')

    parser.add_argument('--reserv_ratio', type=float, default=0.1, help='Reserve ratio for prototypes')

    torch.set_num_threads(4)
    def add_management_args(parser: ArgumentParser) -> None:
        parser.add_argument('--csv_log', action='store_true',
                            help='Enable csv logging',default=False)
    
    add_management_args(parser)
    
    args, unknown = parser.parse_known_args()

    best = best_args[args.dataset][args.model]

    for key, value in best.items():
        setattr(args, key, value)

    if args.seed is not None:
        set_random_seed(args.seed)
    return args

In [4]:
args = parse_args()

priv_dataset = get_prive_dataset(args)

In [5]:
priv_dataset.train_loaders

[]

In [6]:
print(vars(priv_dataset))

{'train_loaders': [], 'test_loader': [], 'args': Namespace(device_id=0, communication_epoch=2, local_epoch=3, parti_num=3, seed=0, rand_dataset={'mnist': 1, 'usps': 1, 'svhn': 1, 'syn': 2}, model='fpl', structure='homogeneity', dataset='fl_digits', pri_aug='weak', online_ratio=1, learning_decay=False, averaing='weight', infoNCET=0.02, T=0.05, weight=1, reserv_ratio=0.1, csv_log=False, local_lr=0.01, local_batch_size=64, Note='+ MSE')}


In [7]:
backbones_list = priv_dataset.get_backbone(args.parti_num, None)

In [8]:
model = get_model(backbones_list, args, priv_dataset.get_transform())

In [9]:
# print(vars(model))

In [10]:
# model.nets_list[0]

In [None]:

class FedLeaDigits(FederatedDataset):
    NAME = 'fl_digits'
    SETTING = 'domain_skew'
    DOMAINS_LIST = ['mnist', 'usps', 'svhn', 'syn']
    percent_dict = {'mnist': 0.01, 'usps': 0.01, 'svhn': 0.01, 'syn': 0.01}

    N_SAMPLES_PER_Class = None
    N_CLASS = 10
    Nor_TRANSFORM = transforms.Compose(
        [transforms.Resize((32, 32)),
         transforms.RandomCrop(32, padding=4),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Normalize((0.485, 0.456, 0.406),
                              (0.229, 0.224, 0.225))])

    Singel_Channel_Nor_TRANSFORM = transforms.Compose(
        [transforms.Resize((32, 32)),
         transforms.RandomCrop(32, padding=4),
         transforms.RandomHorizontalFlip(),
         transforms.ToTensor(),
         transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
         transforms.Normalize((0.485, 0.456, 0.406),
                              (0.229, 0.224, 0.225))])

    def get_data_loaders(self, selected_domain_list=[]):
        # using_list = self.DOMAINS_LIST if selected_domain_list == [] else selected_domain_list
        using_list = self.DOMAINS_LIST if len(selected_domain_list) == 0 else selected_domain_list

        nor_transform = self.Nor_TRANSFORM
        sin_chan_nor_transform = self.Singel_Channel_Nor_TRANSFORM

        train_dataset_list = []
        test_dataset_list = []

        test_transform = transforms.Compose(
            [transforms.Resize((32, 32)),
             transforms.ToTensor(),
             self.get_normalization_transform()])

        sin_chan_test_transform = transforms.Compose(
            [transforms.Resize((32, 32)),
             transforms.ToTensor(),
             transforms.Lambda(lambda x: x.repeat(3, 1, 1)),
             self.get_normalization_transform()])

        for _, domain in enumerate(using_list):
            if domain == 'syn':
                train_dataset = ImageFolder_Custom(data_name=domain, root=data_path(), train=True,
                                                   transform=nor_transform)
            else:
                if domain in ['mnist', 'usps']:
                    train_dataset = MyDigits(data_path(), train=True,
                                             download=True, transform=sin_chan_nor_transform, data_name=domain)
                else:
                    train_dataset = MyDigits(data_path(), train=True,
                                             download=True, transform=nor_transform, data_name=domain)
            train_dataset_list.append(train_dataset)

        for _, domain in enumerate(self.DOMAINS_LIST):
            if domain == 'syn':
                test_dataset = ImageFolder_Custom(data_name=domain, root=data_path(), train=False,
                                                  transform=test_transform)
            else:
                if domain in ['mnist', 'usps']:
                    test_dataset = MyDigits(data_path(), train=False,
                                            download=True, transform=sin_chan_test_transform, data_name=domain)
                else:

                    test_dataset = MyDigits(data_path(), train=False,
                                            download=True, transform=test_transform, data_name=domain)

            test_dataset_list.append(test_dataset)
        traindls, testdls, label_dict = partition_digits_domain_skew_loaders(train_dataset_list, test_dataset_list, self)

        return traindls, testdls, label_dict, train_dataset_list, test_dataset_list

    @staticmethod
    def get_transform():
        transform = transforms.Compose(
            [transforms.ToPILImage(), FedLeaDigits.Nor_TRANSFORM])
        return transform

    @staticmethod
    def get_backbone(parti_num, names_list):
        nets_dict = {'resnet10': resnet10, 'resnet12': resnet12, 'efficient': EfficientNetB0, 'mobilnet': MobileNetV2}
        nets_list = []
        if names_list == None:
            for j in range(parti_num):
                nets_list.append(resnet10(FedLeaDigits.N_CLASS))
        else:
            for j in range(parti_num):
                net_name = names_list[j]
                nets_list.append(nets_dict[net_name](FedLeaDigits.N_CLASS))
        return nets_list

    @staticmethod
    def get_normalization_transform():
        transform = transforms.Normalize((0.485, 0.456, 0.406),
                                         (0.229, 0.224, 0.225))
        return transform

    @staticmethod
    def get_denormalization_transform():
        transform = DeNormalize((0.485, 0.456, 0.406),
                                (0.229, 0.224, 0.225))
        return transform


In [11]:
import torch
from argparse import Namespace
from models.utils.federated_model import FederatedModel
from datasets.utils.federated_dataset import FederatedDataset
from typing import Tuple
from torch.utils.data import DataLoader
import numpy as np
from utils.logger import CsvWriter
from collections import Counter


def global_evaluate(model: FederatedModel, test_dl: DataLoader, setting: str, name: str) -> Tuple[list, list]:
    accs = []
    net = model.global_net
    status = net.training
    net.eval()
    for j, dl in enumerate(test_dl):
        correct, total, top1, top5 = 0.0, 0.0, 0.0, 0.0
        for batch_idx, (images, labels) in enumerate(dl):
            with torch.no_grad():
                images, labels = images.to(model.device), labels.to(model.device)
                outputs = net(images)
                _, max5 = torch.topk(outputs, 5, dim=-1)
                labels = labels.view(-1, 1)
                top1 += (labels == max5[:, 0:1]).sum().item()
                top5 += (labels == max5).sum().item()
                total += labels.size(0)
        top1acc = round(100 * top1 / total, 2)
        top5acc = round(100 * top5 / total, 2)
        # if name in ['fl_digits','fl_officecaltech']:
        accs.append(top1acc)
        # elif name in ['fl_office31','fl_officehome']:
        #     accs.append(top5acc)
    net.train(status)
    return accs


def train(model: FederatedModel, private_dataset: FederatedDataset, args: Namespace) -> None:
    if args.csv_log:
        csv_writer = CsvWriter(args, private_dataset)

    model.N_CLASS = private_dataset.N_CLASS
    domains_list = private_dataset.DOMAINS_LIST
    domains_len = len(domains_list)

    if args.rand_dataset:
        max_num = 10
        is_ok = False

        while not is_ok:
            if model.args.dataset == 'fl_officecaltech':
                selected_domain_list = np.random.choice(domains_list, size=args.parti_num - domains_len, replace=True, p=None)
                selected_domain_list = list(selected_domain_list) + domains_list
            elif model.args.dataset == 'fl_digits':
                selected_domain_list = np.random.choice(domains_list, size=args.parti_num, replace=True, p=None)

            result = dict(Counter(selected_domain_list))

            for k in result:
                if result[k] > max_num:
                    is_ok = False
                    break
            else:
                is_ok = True
    else:
        selected_domain_dict = {'mnist': 6, 'usps': 4, 'svhn': 3, 'syn': 7}  # base
        selected_domain_dict = {'mnist': 3, 'usps': 7, 'svhn': 6, 'syn': 4}  # article

        # selected_domain_dict = {'mnist': 1, 'usps': 1, 'svhn': 9, 'syn': 9}  # 20

        # selected_domain_dict = {'mnist': 3, 'usps': 2, 'svhn': 1, 'syn': 4}  # 10

        selected_domain_list = []
        for k in selected_domain_dict:
            domain_num = selected_domain_dict[k]
            for i in range(domain_num):
                selected_domain_list.append(k)

        selected_domain_list = np.random.permutation(selected_domain_list)

        result = Counter(selected_domain_list)
    print(result)

    print(selected_domain_list)
    pri_train_loaders, test_loaders, label_dict, train_dataset_list, test_dataset_list = private_dataset.get_data_loaders(selected_domain_list)
    model.trainloaders = pri_train_loaders
    if hasattr(model, 'ini'):
        model.ini()

    accs_dict = {}
    mean_accs_list = []

    Epoch = args.communication_epoch
    for epoch_index in range(Epoch):
        model.epoch_index = epoch_index
        if hasattr(model, 'loc_update'):
            epoch_loc_loss_dict = model.loc_update(pri_train_loaders)

        accs = global_evaluate(model, test_loaders, private_dataset.SETTING, private_dataset.NAME)
        mean_acc = round(np.mean(accs, axis=0), 3)
        mean_accs_list.append(mean_acc)
        for i in range(len(accs)):
            if i in accs_dict:
                accs_dict[i].append(accs[i])
            else:
                accs_dict[i] = [accs[i]]

        print('The ' + str(epoch_index) + ' Communcation Accuracy:', str(mean_acc), 'Method:', model.args.model)
        print(accs)

    if args.csv_log:
        csv_writer.write_acc(accs_dict, mean_accs_list)
        
    return label_dict, train_dataset_list, test_dataset_list


In [12]:
# The 49 Communcation Accuracy: 74.34 Method: fpl
# [97.63, 88.94, 68.75, 42.04]

# batch 256

# The 49 Communcation Accuracy: 80.12 Method: fpl
# [98.08, 90.33, 79.53, 52.54]
# batch 64

In [13]:
label_dict, train_dataset_list, test_dataset_list = train(model, priv_dataset, args)

{'mnist': 1, 'syn': 1, 'usps': 1}
['mnist' 'syn' 'usps']


Local Pariticipant 2 CE = 2.308,InfoNCE = 0.000: 100%|██████████| 3/3 [00:01<00:00,  2.31it/s]
Local Pariticipant 1 CE = 2.129,InfoNCE = 0.000: 100%|██████████| 3/3 [00:01<00:00,  1.91it/s]
Local Pariticipant 0 CE = 1.131,InfoNCE = 0.000: 100%|██████████| 3/3 [00:06<00:00,  2.18s/it]


Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
The 0 Communcation Accuracy: 13.51 Method: fpl
[11.35, 13.15, 19.59, 9.95]


Local Pariticipant 1 CE = 2.195,InfoNCE = 2.302: 100%|██████████| 3/3 [00:04<00:00,  1.54s/it]
Local Pariticipant 0 CE = 0.735,InfoNCE = 1.106: 100%|██████████| 3/3 [00:24<00:00,  8.13s/it]
Local Pariticipant 2 CE = 1.984,InfoNCE = 2.124: 100%|██████████| 3/3 [00:05<00:00,  1.85s/it]


Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
Partition 0: 1 clusters
The 1 Communcation Accuracy: 18.002 Method: fpl
[35.75, 16.44, 10.42, 9.4]


In [14]:
label_dict.keys()

dict_keys(['mnist', 'syn', 'usps'])

In [15]:
for d in label_dict.values():
    print(len(d))

60000
10000
7291


In [17]:
avaliable_indexes = get_sorted_label_indices(label_dict)

In [126]:
original_data = copy.deepcopy(avaliable_indexes[domain])

teste = copy.deepcopy(avaliable_indexes[domain])

In [18]:
# Define federation list
federation = []

# Client A
init_pop = 50
init_steps = 20
init_period = int(0.2 * init_pop)
init_freq = 1
max_comb = 6
init_rate = None
init_angular = None
domain = 'mnist'

add_participant_to_federation(
    federation=federation,
    available_indexes=avaliable_indexes[domain],
    init_pop=init_pop,
    init_steps=init_steps,
    init_freq=init_freq,
    init_period=init_period,
    init_rate = init_rate,
    init_angular = init_angular,
    max_comb=max_comb
)

Added new participant with 6 labels to the federation.


In [19]:
# Client B
init_pop = 50
init_steps = 20
init_period = int(0.2 * init_pop)
init_freq = None
max_comb = 5
init_rate = None
init_angular = None
domain = 'mnist'

add_participant_to_federation(
    federation=federation,
    available_indexes=avaliable_indexes[domain],
    init_pop=init_pop,
    init_steps=init_steps,
    init_freq=init_freq,
    init_period=init_period,
    init_rate = init_rate,
    init_angular = init_angular,
    max_comb=max_comb
)

Added new participant with 5 labels to the federation.


In [20]:
# Client C
init_pop = 50
init_steps = 20
init_period = None
init_freq = 1
max_comb = 8
init_rate = None
init_angular = None
domain = 'mnist'

add_participant_to_federation(
    federation=federation,
    available_indexes=avaliable_indexes[domain],
    init_pop=init_pop,
    init_steps=init_steps,
    init_freq=init_freq,
    init_period=init_period,
    init_rate = init_rate,
    init_angular = init_angular,
    max_comb=max_comb
)

Added new participant with 8 labels to the federation.


In [21]:
# Client D
init_pop = 50
init_steps = 20
init_period = None
init_freq = None
max_comb = 8
init_rate = None
init_angular = None
domain = 'mnist'

add_participant_to_federation(
    federation=federation,
    available_indexes=avaliable_indexes[domain],
    init_pop=init_pop,
    init_steps=init_steps,
    init_freq=init_freq,
    init_period=init_period,
    init_rate = init_rate,
    init_angular = init_angular,
    max_comb=max_comb
)

Added new participant with 8 labels to the federation.


In [22]:
# Client E
init_pop = 600
init_steps = 20
init_period = 0.0000001
init_freq = 0.0000001
max_comb = 8
init_rate = 0.00000001
init_angular = 0.00000001
domain = 'mnist'

add_participant_to_federation(
    federation=federation,
    available_indexes=avaliable_indexes[domain],
    init_pop=init_pop,
    init_steps=init_steps,
    init_freq=init_freq,
    init_period=init_period,
    init_rate = init_rate,
    init_angular = init_angular,
    max_comb=max_comb
)

Added new participant with 8 labels to the federation.


In [23]:
aux_plot = []
for i, part in enumerate(federation):
    aux_part = part.report_data()
    aux_part['Participant'] = f'Part_{i}'
    aux_part.rename(columns = {'id' : 'Digit'}, inplace = True)
    aux_plot.append(aux_part)

df_plot = pd.concat(aux_plot)

# Melt the DataFrame
df_plot = df_plot.melt(id_vars=['Digit', 'Participant'], 
                     var_name='Time', 
                     value_name='Count')

# Convert 'Time' to numerical values (e.g., value_1 → 1, value_2 → 2, etc.)
df_plot['Time'] = df_plot['Time'].str.extract('(\d+)').astype(int) - 1

df_plot.shape

(735, 4)

In [24]:
import altair as alt

def plot_mnist_distribution(plot_data, plot_width=800, plot_height=400):
    """
    Plot the MNIST distribution using a bar chart and heatmap, distributing the plot size
    between the bar chart and heatmap.

    Parameters:
    - plot_data: DataFrame with the MNIST data
    - plot_width: Total width of the final plot (default: 800)
    - plot_height: Total height of the final plot (default: 400)
    """

    # Create a slider for selecting 'Time'
    time_slider = alt.binding_range(
        min=plot_data['Time'].min(), 
        max=plot_data['Time'].max(), 
        step=1, 
        name="Select Time: "
    )
    time_param = alt.selection_point(bind=time_slider, fields=['Time'], value = plot_data['Time'].min())

    # Altair selection object for interactive filtering
    bar_point = alt.selection_point(encodings=['x'])
    client_selection = alt.selection_point(encodings=['y'])
    
    # Calculate the split dimensions for the two charts
    bar_width = plot_width * 0.4   # 40% of the width for the bar chart
    heatmap_width = plot_width * 0.6  # 60% of the width for the heatmap
    bar_height = plot_height       # Full height for the bar chart
    heatmap_height = plot_height   # Full height for the heatmap
    
    # Heatmap (Digit vs. Participant) filtered by selected Time
    heatmap = alt.Chart(plot_data).mark_rect().encode(
        x=alt.X('Digit:O', title="MNIST Digit (0-9)"),
        y=alt.Y('Participant:N', title="Participants"),
        color = alt.condition(
                bar_point, 
                alt.Color('Count:Q', scale=alt.Scale(scheme='greenblue'), title="Total Count", legend=alt.Legend(orient="top")),  
                alt.ColorValue("grey")  # Non-selected values will be grey
            ),
        tooltip=[alt.Tooltip('Digit:O', title="Digit"),
                 alt.Tooltip('Participant:N', title="Participant"),
                 alt.Tooltip('Count:Q', title="Total Count")]
    ).transform_filter(
        time_param
    ).add_params(bar_point, client_selection, time_param).properties(width=heatmap_width, height=heatmap_height)
    
    # Overlay points for highlighting
    highlight_points = heatmap.mark_point().encode(
        color=alt.ColorValue('black'),
        size=alt.Size('Count:Q', title="Selected Images")
    ).transform_filter(
        bar_point
    ).transform_filter(
        client_selection
    )
    
    
    # Bar chart (Total images per client_selectionicipant) filtered by selected Time
    bar_chart = alt.Chart(plot_data).mark_bar().encode(
        x="Participant:N",
        y="sum(Count):Q",
        color=alt.condition(bar_point, alt.ColorValue("steelblue"), alt.ColorValue("grey"))
    ).transform_filter(
        time_param
    ).add_params(bar_point).properties(width=bar_width, height=bar_height)


    # Create an Altair line chart with markers
    line_chart = alt.Chart(plot_data).mark_line().encode(
        x='Time:O',  # Ordinal axis for time
        y='Count:Q',  # Quantitative axis for process value
        color=alt.Color('Digit:O', title="Stoc.Process"),
        tooltip=['Digit', 'Time', 'Count']
    ).properties(
        title="Process over Time",
        width=300,
        height=plot_height
    ).transform_filter(
        bar_point
    ).transform_filter(
        client_selection
    )
    
    # Add markers (circles) at data points
    marker_chart = alt.Chart(plot_data).mark_circle(size=50).encode(
        x='Time:O',
        y='Count:Q',
        color=alt.Color('Digit:O', title="Stoc.Process"),
        tooltip=['Digit', 'Time', 'Count']
    ).transform_filter(
        bar_point
    ).transform_filter(
        client_selection
    )
    
    # Overlay the markers on top of the line chart
    chart = line_chart + marker_chart
    
    # Combine the charts horizontally
    final_chart = alt.hconcat(
        bar_chart,
        heatmap + highlight_points,
        chart
    ).resolve_legend(
        color="independent",
        size="independent"
    )

    return final_chart

# Example of using the function with custom plot size
final_chart = plot_mnist_distribution(plot_data=df_plot, plot_width=350, plot_height=250)
final_chart

In [73]:
train_dataset_list[0].dataset

Dataset MNIST
    Number of datapoints: 60000
    Root location: C:/Users/arthu/USPy/0_BEPE/1_FPL/datasets/
    Split: Train
    StandardTransform
Transform: Compose(
               Resize(size=(32, 32), interpolation=bilinear, max_size=None, antialias=True)
               RandomCrop(size=(32, 32), padding=4)
               RandomHorizontalFlip(p=0.5)
               ToTensor()
               Lambda()
               Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
           )

In [74]:
train_dataset_list[0].dataset.targets

tensor([5, 0, 4,  ..., 5, 6, 8])

In [83]:
client_tgt = federation[0]

idxs_tgt = client_tgt.get_sample_step(0)
idxs_tgt[:5]

[1409, 2820, 58759, 49287, 17810]

In [85]:
from torch.utils.data import DataLoader, SubsetRandomSampler


train_sampler = SubsetRandomSampler(idxs_tgt)

tr_dl = DataLoader(train_dataset_list[0].dataset, batch_size=64, sampler=train_sampler)