In [223]:
import argparse
import copy
import logging
import time

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

import pickle
from scipy.optimize import linear_sum_assignment
from sklearn.metrics import confusion_matrix


In [224]:
args_init_seed = 0                 # RNG seed for reproducible weight initialization and shuffling
args_net_config = [3072, 100, 10]  # network layer sizes: 3072-unit input → 100-unit hidden → 10-unit output
args_trials = 1                    # number of independent experimental repeats to average results over
args_epochs = 5                    # number of epochs to train each local model per communication round
args_reg = 1e-5                    # L2 weight‐decay (regularization) coefficient
args_alpha = 0.5                   # interpolation weight (e.g. for mixing new vs. old parameters)
args_communication_rounds = 5      # how many federated communication rounds to perform
args_iter_epochs = None            # optional override to train multiple epochs per iteration (None = use `args_epochs`)


In [225]:
args_pdm_sig = 1.0
args_pdm_sig0 = 1.0  
args_pdm_gamma = 7.0

In [226]:
args_model = "lenet"

In [227]:
logging.basicConfig()
logger = logging.getLogger()
logger.setLevel(logging.INFO)

In [228]:
def add_fit_args(parser):
    """
    parser : argparse.ArgumentParser
    return a parser added with args required by fit
    """
    # Training settings
    parser.add_argument('--model', type=str, default='lenet', metavar='N',
                        help='neural network used in training')
    parser.add_argument('--dataset', type=str, default='cifar10', metavar='N',
                        help='dataset used for training')
    parser.add_argument('--partition', type=str, default='homo', metavar='N',
                        help='how to partition the dataset on local workers')
    parser.add_argument('--batch-size', type=int, default=128, metavar='N',
                        help='input batch size for training (default: 64)')
    parser.add_argument('--lr', type=float, default=0.1, metavar='LR',
                        help='learning rate (default: 0.01)')
    parser.add_argument('--retrain_lr', type=float, default=0.1, metavar='RLR',
                        help='learning rate using in specific for local network retrain (default: 0.01)')
    parser.add_argument('--fine_tune_lr', type=float, default=0.1, metavar='FLR',
                        help='learning rate using in specific for fine tuning the softmax layer on the data center (default: 0.01)')
    parser.add_argument('--epochs', type=int, default=1, metavar='EP',
                        help='how many epochs will be trained in a training process')
    parser.add_argument('--retrain_epochs', type=int, default=1, metavar='REP',
                        help='how many epochs will be trained in during the locally retraining process')
    parser.add_argument('--fine_tune_epochs', type=int, default=1, metavar='FEP',
                        help='how many epochs will be trained in during the fine tuning process')
    parser.add_argument('--partition_step_size', type=int, default=6, metavar='PSS',
                        help='how many groups of partitions we will have')
    parser.add_argument('--local_points', type=int, default=5000, metavar='LP',
                        help='the approximate fixed number of data points we will have on each local worker')
    parser.add_argument('--partition_step', type=int, default=0, metavar='PS',
                        help='how many sub groups we are going to use for a particular training process')                          
    parser.add_argument('--n_nets', type=int, default=2, metavar='NN',
                        help='number of workers in a distributed cluster')
    parser.add_argument('--oneshot_matching', type=bool, default=False, metavar='OM',
                        help='if the code is going to conduct one shot matching')
    parser.add_argument('--retrain', type=bool, default=False, 
                            help='whether to retrain the model or load model locally')
    parser.add_argument('--rematching', type=bool, default=False, 
                            help='whether to recalculating the matching process (this is for speeding up the debugging process)')
    parser.add_argument('--comm_type', type=str, default='layerwise', 
                            help='which type of communication strategy is going to be used: layerwise/blockwise')    
    parser.add_argument('--comm_round', type=int, default=10, 
                            help='how many round of communications we shoud use')  
    args = parser.parse_args([])  # parse empty list = use all defaults
    return args

In [229]:
args = add_fit_args(argparse.ArgumentParser(description='Probabilistic Federated CNN Matching'))


In [230]:
input_dim = 25 * 4 * 4  # from inspecting conv outputs
hidden_dims = [120]

In [231]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LeNetContainer(nn.Module):
    def __init__(self, num_filters, kernel_size, hidden_dims, output_dim=10, input_shape=(1, 28, 28)): # kernel size = 5, hidden dimensions = [500], num_filters = [20,50]
        super(LeNetContainer, self).__init__()
        self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size, 1)
        self.conv2 = nn.Conv2d(num_filters[0], num_filters[1], kernel_size, 1)

        # Dynamically compute input_dim by forwarding a dummy input
        with torch.no_grad():
            dummy_input = torch.rand(1, *input_shape)
            x = self.conv1(dummy_input)
            x = F.max_pool2d(x, 2, 2)
            x = F.relu(x)
            x = self.conv2(x)
            x = F.max_pool2d(x, 2, 2)
            x = F.relu(x)
            input_dim = x.view(1, -1).size(1)

        self.fc1 = nn.Linear(input_dim, hidden_dims[0])
        self.fc2 = nn.Linear(hidden_dims[0], output_dim)

    def forward(self, x):
        x = self.conv1(x)
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(x)
        x = x.view(-1, x.size(1) * x.size(2) * x.size(3))
        x = self.fc1(x)
        x = self.fc2(x)
        return x


In [232]:
def record_net_data_stats(y_train, net_dataidx_map, logdir):

    net_cls_counts = {}

    for net_i, dataidx in net_dataidx_map.items():
        unq, unq_cnt = np.unique(y_train[dataidx], return_counts=True)
        tmp = {unq[i]: unq_cnt[i] for i in range(len(unq))}
        net_cls_counts[net_i] = tmp
    logging.debug('Data statistics: %s' % str(net_cls_counts))
    return net_cls_counts


In [233]:
def init_client_models(n_nets, model_type="lenet", seeds=None, use_container=True):
    """
    Initialize multiple client models with optional container and seeds.
    
    Args:
        n_nets (int): Number of client models.
        model_type (str): Type of model to initialize. Currently supports 'lenet'.
        seeds (list, optional): List of seeds to initialize models with. Must match n_nets length.
        use_container (bool): Whether to use LeNetContainer instead of base LeNet.

    Returns:
        nets_list (list): List of initialized models.
        model_meta_data (list): List of parameter shapes.
        layer_type (list): List of parameter names.
    """
    if seeds is None:
        seeds = list(range(n_nets))
    assert len(seeds) == n_nets, "Number of seeds must match number of networks."

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    logger.info(f"Using device: {device}")

    np.random.seed(0)
    torch.manual_seed(0)

    nets_list = []
    model_meta_data = []
    layer_type = []

    for i in range(n_nets):
        torch.manual_seed(seeds[i])
        if model_type == "lenet":
            if use_container:
                model = LeNetContainer(
                    num_filters=[20, 50],
                    kernel_size=5,
                    hidden_dims=[500],
                    output_dim=10
                )
        else:
            raise NotImplementedError(f"Model type '{model_type}' is not supported.")

        nets_list.append(model)
        logger.info(f"Initialized client {i} with seed {seeds[i]}")

    # Get meta data from the first model
    for k, v in nets_list[0].state_dict().items():
        model_meta_data.append(v.shape)
        layer_type.append(k)

    return nets_list, model_meta_data, layer_type


In [234]:
n_nets = 5
seeds = [42, 123, 456, 789, 999]
nets_list, model_meta_data, layer_type = init_client_models(n_nets, model_type="lenet", seeds=seeds, use_container=True)


INFO:root:Using device: cpu
INFO:root:Initialized client 0 with seed 42
INFO:root:Initialized client 1 with seed 123
INFO:root:Initialized client 2 with seed 456
INFO:root:Initialized client 3 with seed 789
INFO:root:Initialized client 4 with seed 999


In [235]:
print(nets_list)

[LeNetContainer(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
), LeNetContainer(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
), LeNetContainer(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Linear(in_features=500, out_features=10, bias=True)
), LeNetContainer(
  (conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=800, out_features=500, bias=True)
  (fc2): Li

In [236]:
print(model_meta_data)

[torch.Size([20, 1, 5, 5]), torch.Size([20]), torch.Size([50, 20, 5, 5]), torch.Size([50]), torch.Size([500, 800]), torch.Size([500]), torch.Size([10, 500]), torch.Size([10])]


In [237]:
print(layer_type)

['conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias', 'fc1.weight', 'fc1.bias', 'fc2.weight', 'fc2.bias']


[torch.Size([20, 1, 5, 5]), torch.Size([20]),
 torch.Size([50, 20, 5, 5]), torch.Size([50]),
 torch.Size([500, 800]), torch.Size([500]),
 torch.Size([10, 500]), torch.Size([10])]

 ## the learnable paramter size so 20 filters of 5x5 

In [238]:
device = "cpu"

In [281]:
def pdm_prepare_full_weights_cnn(nets, device="cpu"):
    """
    we extract all weights of the conv nets out here:
    """
    weights = [] # this is where all the clients net_weights go 
    for net_i, net in enumerate(nets):
        # outer loop is for each client 
        net_weights = [] # this is for one client datam later appended to weights 
        statedict = net.state_dict()
        # print("Statedict: ", statedict)
        for param_id, (k, v) in enumerate(statedict.items()):
            print("------Printing K:------ ",k)
            # print("------Printing V:------ ",v)
            if device == "cpu":
                if 'fc' in k or 'classifier' in k:
                    print("-------------------------Inside fc check !----------------------------")
                    if 'weight' in k:
                        net_weights.append(v.numpy().T)
                        print("V.T:\n")
                        print(v.shape)
                    else:
                        net_weights.append(v.numpy())
                        print("V:\n")
                        print(v.shape)
                elif 'conv' in k or 'features' in k:
                    print("-------------------------Inside conv check !----------------------------")
                    if 'weight' in k: # this checks if its a conv layer and does something
                        _weight_shape = v.size() 
                        print("SIZE: ", _weight_shape)
                        if len(_weight_shape) == 4:
                            net_weights.append(v.numpy().reshape(_weight_shape[0], _weight_shape[1]*_weight_shape[2]*_weight_shape[3]))
                            print("V:")
                            print(v.numpy().reshape(_weight_shape[0], _weight_shape[1]*_weight_shape[2]*_weight_shape[3]).shape)
                        else:
                            pass
                    else:
                        net_weights.append(v.numpy()) # biases are directly appended 
            print("Net_Wights\n")
            print(len(net_weights), len(net_weights[0])) 
        weights.append(net_weights)
    return weights


In [282]:
a = pdm_prepare_full_weights_cnn(nets=nets_list) # net_weights goes till 8 bcs there are 4 layers and each has a bias 

------Printing K:------  conv1.weight
-------------------------Inside conv check !----------------------------
SIZE:  torch.Size([20, 1, 5, 5])
V:
(20, 25)
Net_Wights

1 20
------Printing K:------  conv1.bias
-------------------------Inside conv check !----------------------------
Net_Wights

2 20
------Printing K:------  conv2.weight
-------------------------Inside conv check !----------------------------
SIZE:  torch.Size([50, 20, 5, 5])
V:
(50, 500)
Net_Wights

3 20
------Printing K:------  conv2.bias
-------------------------Inside conv check !----------------------------
Net_Wights

4 20
------Printing K:------  fc1.weight
-------------------------Inside fc check !----------------------------
V.T:

torch.Size([500, 800])
Net_Wights

5 20
------Printing K:------  fc1.bias
-------------------------Inside fc check !----------------------------
V:

torch.Size([500])
Net_Wights

6 20
------Printing K:------  fc2.weight
-------------------------Inside fc check !-------------------------

In [None]:
# print(len(a[0][5]))

500


In [240]:
def block_patching(w_j, L_next, assignment_j_c, layer_index, model_meta_data, 
                   matching_shapes=None, 
                   layer_type="fc", 
                   network_name="lenet"):
    """
    Applies neuron/block reordering to the weight matrix w_j based on assignment_j_c.
    """
    
    print(f"\n=== block_patching for layer {layer_index} ({layer_type}) ===")
    print(f"Original w_j shape: {w_j.shape}")
    
    if assignment_j_c is None:
        print("No assignment provided — returning weights unchanged.")
        return w_j

    layer_meta_data = model_meta_data[2 * layer_index - 2]
    prev_layer_meta_data = model_meta_data[2 * layer_index - 2 - 2]
    
    print(f"Current layer meta shape: {layer_meta_data}")
    print(f"Previous layer meta shape: {prev_layer_meta_data}")

    if layer_type == "conv":    
        # Compute new weight matrix shape for conv: flatten input filter blocks
        new_w_j = np.zeros((w_j.shape[0], L_next * (layer_meta_data[-1] ** 2)))
        print(f"New conv w_j shape after patching: {new_w_j.shape}")
        
        block_indices = [
            np.arange(i * layer_meta_data[-1] ** 2, (i + 1) * layer_meta_data[-1] ** 2)
            for i in range(L_next)
        ]
        ori_block_indices = [
            np.arange(i * layer_meta_data[-1] ** 2, (i + 1) * layer_meta_data[-1] ** 2)
            for i in range(layer_meta_data[1])
        ]

        print(f"Block indices count: {len(block_indices)}")
        print(f"Original block indices count: {len(ori_block_indices)}")

        for ori_id in range(layer_meta_data[1]):
            if assignment_j_c[ori_id] < len(block_indices):  # Add bounds check
                print(f"Mapping conv block {ori_id} → {assignment_j_c[ori_id]}")
                new_w_j[:, block_indices[assignment_j_c[ori_id]]] = w_j[:, ori_block_indices[ori_id]]

    elif layer_type == "fc":
        if network_name == "lenet" and matching_shapes is not None:
            prev_layer_output_size = prev_layer_meta_data[0] if len(prev_layer_meta_data) > 0 else w_j.shape[1]
            new_w_j = np.zeros((w_j.shape[0], L_next))
            print(f"New fc w_j shape after patching: {new_w_j.shape}")
            
            for ori_id in range(min(len(assignment_j_c), w_j.shape[1])):
                if assignment_j_c[ori_id] < L_next:
                    print(f"Mapping neuron {ori_id} → {assignment_j_c[ori_id]}")
                    new_w_j[:, assignment_j_c[ori_id]] = w_j[:, ori_id]
        else:
            new_w_j = np.zeros((w_j.shape[0], L_next))
            print(f"New fallback fc w_j shape: {new_w_j.shape}")
            
            for ori_id in range(min(len(assignment_j_c), w_j.shape[1])):
                if assignment_j_c[ori_id] < L_next:
                    print(f"Mapping neuron {ori_id} → {assignment_j_c[ori_id]}")
                    new_w_j[:, assignment_j_c[ori_id]] = w_j[:, ori_id]
                    
    return new_w_j


In [241]:
def process_softmax_bias(batch_weights, last_layer_const, sigma, sigma0):
    J = len(batch_weights)
    sigma_bias = sigma
    sigma0_bias = sigma0
    mu0_bias = 0.1
    softmax_bias = [batch_weights[j][-1] for j in range(J)]
    softmax_inv_sigma = [s / sigma_bias for s in last_layer_const]
    softmax_bias = sum([b * s for b, s in zip(softmax_bias, softmax_inv_sigma)]) + mu0_bias / sigma0_bias
    softmax_inv_sigma = 1 / sigma0_bias + sum(softmax_inv_sigma)
    return softmax_bias, softmax_inv_sigma


In [242]:
def row_param_cost(global_weights, weights_j_l, global_sigmas, sigma_inv_j):

    match_norms = ((weights_j_l + global_weights) ** 2 / (sigma_inv_j + global_sigmas)).sum(axis=1) - (
                global_weights ** 2 / global_sigmas).sum(axis=1)

    return match_norms


def row_param_cost_simplified(global_weights, weights_j_l, sij_p_gs, red_term):
    match_norms = ((weights_j_l + global_weights) ** 2 / sij_p_gs).sum(axis=1) - red_term
    return match_norms

In [243]:
def compute_cost(global_weights, weights_j, global_sigmas, sigma_inv_j, prior_mean_norm, prior_inv_sigma,
                 popularity_counts, gamma, J):

    param_cost_start = time.time()
    Lj = weights_j.shape[0]
    counts = np.minimum(np.array(popularity_counts, dtype=np.float32), 10)

    sij_p_gs = sigma_inv_j + global_sigmas
    red_term = (global_weights ** 2 / global_sigmas).sum(axis=1)
    stupid_line_start = time.time()

    param_cost = np.array([row_param_cost_simplified(global_weights, weights_j[l], sij_p_gs, red_term) for l in range(Lj)], dtype=np.float32)
    stupid_line_dur = time.time() - stupid_line_start

    param_cost += np.log(counts / (J - counts))
    param_cost_dur = time.time() - param_cost_start

    nonparam_start = time.time()
    L = global_weights.shape[0]
    max_added = min(Lj, max(700 - L, 1))
    nonparam_cost = np.outer((((weights_j + prior_mean_norm) ** 2 / (prior_inv_sigma + sigma_inv_j)).sum(axis=1) - (
                prior_mean_norm ** 2 / prior_inv_sigma).sum()), np.ones(max_added, dtype=np.float32))
    cost_pois = 2 * np.log(np.arange(1, max_added + 1))
    nonparam_cost -= cost_pois
    nonparam_cost += 2 * np.log(gamma / J)

    nonparam_dur = time.time() - nonparam_start

    full_cost = np.hstack((param_cost, nonparam_cost)).astype(np.float32)
    return full_cost


In [244]:
from scipy.optimize import linear_sum_assignment
import numpy as np
import time

def matching_upd_j(weights_j, global_weights, sigma_inv_j, global_sigmas, prior_mean_norm, prior_inv_sigma,
                   popularity_counts, gamma, J):

    L = global_weights.shape[0]
    print(f"\n=== matching_upd_j ===")
    print(f"weights_j shape          : {weights_j.shape}")
    print(f"global_weights shape     : {global_weights.shape}")
    print(f"global_sigmas shape      : {global_sigmas.shape}")
    print(f"sigma_inv_j shape        : {sigma_inv_j.shape}")
    print(f"prior_mean_norm shape    : {prior_mean_norm.shape}")
    print(f"prior_inv_sigma shape    : {prior_inv_sigma.shape}")
    print(f"popularity_counts (start): {popularity_counts}")
    print(f"L (global comps)         : {L}")

    compute_cost_start = time.time()
    full_cost = compute_cost(
        global_weights.astype(np.float32),
        weights_j.astype(np.float32),
        global_sigmas.astype(np.float32),
        sigma_inv_j.astype(np.float32),
        prior_mean_norm.astype(np.float32),
        prior_inv_sigma.astype(np.float32),
        popularity_counts,
        gamma,
        J
    )
    compute_cost_dur = time.time() - compute_cost_start
    print(f"full_cost shape          : {full_cost.shape}")
    print(f"compute_cost duration    : {compute_cost_dur:.4f}s")

    start_time = time.time()
    row_ind, col_ind = linear_sum_assignment(-full_cost)
    solve_dur = time.time() - start_time
    print(f"linear_sum_assignment time: {solve_dur:.4f}s")
    print(f"row_ind: {row_ind}")
    print(f"col_ind: {col_ind}")

    assignment_j = []
    new_L = L

    for l, i in zip(row_ind, col_ind):
        print(f"Assigning local {l} → global {i}")
        if i < L:
            popularity_counts[i] += 1
            assignment_j.append(i)
            global_weights[i] += weights_j[l]
            global_sigmas[i] += sigma_inv_j
        else:
            popularity_counts += [1]
            assignment_j.append(new_L)
            print(f"New global component added: {new_L}")
            new_L += 1
            global_weights = np.vstack((global_weights, prior_mean_norm + weights_j[l]))
            global_sigmas = np.vstack((global_sigmas, prior_inv_sigma + sigma_inv_j))

    print(f"Updated global_weights shape: {global_weights.shape}")
    print(f"Updated global_sigmas shape : {global_sigmas.shape}")
    print(f"Updated popularity_counts    : {popularity_counts}")
    print(f"assignment_j                 : {assignment_j}")

    return global_weights, global_sigmas, popularity_counts, assignment_j


def objective(global_weights, global_sigmas):
    print("\n=== objective ===")
    print(f"global_weights shape : {global_weights.shape}")
    print(f"global_sigmas shape  : {global_sigmas.shape}")
    obj = ((global_weights) ** 2 / global_sigmas).sum()
    print(f"Objective value      : {obj:.4f}")
    return obj


def patch_weights(w_j, L_next, assignment_j_c):
    print("\n=== patch_weights ===")
    print(f"w_j shape            : {w_j.shape}")
    print(f"L_next               : {L_next}")
    print(f"assignment_j_c       : {assignment_j_c}")

    if assignment_j_c is None:
        return w_j

    new_w_j = np.zeros((w_j.shape[0], L_next))
    try:
        new_w_j[:, assignment_j_c] = w_j
        print(f"patched w_j shape     : {new_w_j.shape}")
    except Exception as e:
        print(f"Error during patching: {e}")
    
    return new_w_j


In [245]:
def match_layer(weights_bias, sigma_inv_layer, mean_prior, sigma_inv_prior, gamma, it):
    J = len(weights_bias)

    group_order = sorted(range(J), key=lambda x: -weights_bias[x].shape[0])

    batch_weights_norm = [w * s for w, s in zip(weights_bias, sigma_inv_layer)]
    prior_mean_norm = mean_prior * sigma_inv_prior

    global_weights = prior_mean_norm + batch_weights_norm[group_order[0]]
    global_sigmas = np.outer(np.ones(global_weights.shape[0]), sigma_inv_prior + sigma_inv_layer[group_order[0]])

    popularity_counts = [1] * global_weights.shape[0]

    assignment = [[] for _ in range(J)]

    assignment[group_order[0]] = list(range(global_weights.shape[0]))

    ## Initialize
    for j in group_order[1:]:
        global_weights, global_sigmas, popularity_counts, assignment_j = matching_upd_j(batch_weights_norm[j],
                                                                                        global_weights,
                                                                                        sigma_inv_layer[j],
                                                                                        global_sigmas, prior_mean_norm,
                                                                                        sigma_inv_prior,
                                                                                        popularity_counts, gamma, J)
        assignment[j] = assignment_j

    ## Iterate over groups
    for iteration in range(it):
        random_order = np.random.permutation(J)
        for j in random_order:  # random_order:
            to_delete = []
            ## Remove j
            Lj = len(assignment[j])
            for l, i in sorted(zip(range(Lj), assignment[j]), key=lambda x: -x[1]):
                popularity_counts[i] -= 1
                if popularity_counts[i] == 0:
                    del popularity_counts[i]
                    to_delete.append(i)
                    for j_clean in range(J):
                        for idx, l_ind in enumerate(assignment[j_clean]):
                            if i < l_ind and j_clean != j:
                                assignment[j_clean][idx] -= 1
                            elif i == l_ind and j_clean != j:
                                logger.info('Warning - weird unmatching')
                else:
                    global_weights[i] = global_weights[i] - batch_weights_norm[j][l]
                    global_sigmas[i] -= sigma_inv_layer[j]

            global_weights = np.delete(global_weights, to_delete, axis=0)
            global_sigmas = np.delete(global_sigmas, to_delete, axis=0)

            ## Match j
            global_weights, global_sigmas, popularity_counts, assignment_j = matching_upd_j(batch_weights_norm[j],
                                                                                            global_weights,
                                                                                            sigma_inv_layer[j],
                                                                                            global_sigmas,
                                                                                            prior_mean_norm,
                                                                                            sigma_inv_prior,
                                                                                            popularity_counts, gamma, J)
            assignment[j] = assignment_j

    logger.info('Number of global neurons is %d, gamma %f' % (global_weights.shape[0], gamma))
    logger.info("***************Shape of global weights after match: {} ******************".format(global_weights.shape))
    return assignment, global_weights, global_sigmas


In [246]:
def dlayer_wise_group_descent(batch_weights, layer_index, batch_frequencies, sigma_layers, 
                                sigma0_layers, gamma_layers, it, 
                                model_layer_type,
                                n_layers,):
    print("\n=== layer_wise_group_descent ===")
    print(f"Layer index: {layer_index}, Iteration: {it}, Total layers: {n_layers}")
    print(f"Batch size (J): {len(batch_weights)}")
    
    if type(sigma_layers) is not list:
        sigma_layers = (n_layers - 1) * [sigma_layers]
    if type(sigma0_layers) is not list:
        sigma0_layers = (n_layers - 1) * [sigma0_layers]
    if type(gamma_layers) is not list:
        gamma_layers = (n_layers - 1) * [gamma_layers]

    last_layer_const = []
    total_freq = sum(batch_frequencies)
    for f in batch_frequencies:
        last_layer_const.append(f / total_freq)

    J = len(batch_weights)
    print(f"Total clients (J): {J}")
    
    init_num_kernel = batch_weights[0][0].shape[0]
    print(f"Initial num kernel: {init_num_kernel}")

    init_channel_kernel_dims = []
    for bw in batch_weights[0]:
        if len(bw.shape) > 1:
            init_channel_kernel_dims.append(bw.shape[1])
    print(f"init_channel_kernel_dims: {init_channel_kernel_dims}")

    sigma_bias_layers = sigma_layers
    sigma0_bias_layers = sigma0_layers
    mu0 = 0.
    mu0_bias = 0.1
    assignment_c = [None for j in range(J)]
    L_next = None

    sigma = sigma_layers[layer_index - 1]
    sigma_bias = sigma_bias_layers[layer_index - 1]
    gamma = gamma_layers[layer_index - 1]
    sigma0 = sigma0_layers[layer_index - 1]
    sigma0_bias = sigma0_bias_layers[layer_index - 1]

    print(f"σ, σ0, γ: {sigma}, {sigma0}, {gamma}")

    # Layer-specific logic
    if layer_index <= 1:
        print("Branch A: First layer or shallow network (n_layers == 2)")
        weights_bias = [np.hstack((batch_weights[j][0], batch_weights[j][layer_index * 2 - 1].reshape(-1, 1))) for j in range(J)]
        sigma_inv_prior = np.array(init_channel_kernel_dims[layer_index - 1] * [1 / sigma0] + [1 / sigma0_bias])
        mean_prior = np.array(init_channel_kernel_dims[layer_index - 1] * [mu0] + [mu0_bias])

        if n_layers == 2:
            sigma_inv_layer = [np.array(D * [1 / sigma] + [1 / sigma_bias] + [y / sigma for y in last_layer_const[j]]) for j in range(J)]
        else:
            sigma_inv_layer = [np.array(init_channel_kernel_dims[layer_index - 1] * [1 / sigma] + [1 / sigma_bias]) for j in range(J)]

    elif layer_index == (n_layers - 1) and n_layers > 2:
        print("Branch B: Final layer of deeper network")
        layer_type = model_layer_type[2 * layer_index - 2]
        prev_layer_type = model_layer_type[2 * layer_index - 4]
        first_fc_identifier = (('fc' in layer_type or 'classifier' in layer_type) and ('conv' in prev_layer_type or 'features' in layer_type))
        print(f"first_fc_identifier: {first_fc_identifier}")

        weights_bias = [np.hstack((batch_weights[j][2 * layer_index - 2].T, 
                                   batch_weights[j][2 * layer_index - 1].reshape(-1, 1))) for j in range(J)]

        sigma_inv_prior = np.array([1 / sigma0_bias] + (weights_bias[0].shape[1] - 1) * [1 / sigma0])
        mean_prior = np.array([mu0_bias] + (weights_bias[0].shape[1] - 1) * [mu0])
        sigma_inv_layer = [np.array([1 / sigma_bias] + (weights_bias[j].shape[1] - 1) * [1 / sigma]) for j in range(J)]

    elif (layer_index > 1 and layer_index < (n_layers - 1)):
        print("Branch C: Intermediate hidden layers")
        layer_type = model_layer_type[2 * layer_index - 2]
        prev_layer_type = model_layer_type[2 * layer_index - 4]

        if 'conv' in layer_type or 'features' in layer_type:
            weights_bias = [np.hstack((batch_weights[j][2 * layer_index - 2], 
                                       batch_weights[j][2 * layer_index - 1].reshape(-1, 1))) for j in range(J)]
        elif 'fc' in layer_type or 'classifier' in layer_type:
            weights_bias = [np.hstack((batch_weights[j][2 * layer_index - 2].T, 
                                       batch_weights[j][2 * layer_index - 1].reshape(-1, 1))) for j in range(J)]

        sigma_inv_prior = np.array([1 / sigma0_bias] + (weights_bias[0].shape[1] - 1) * [1 / sigma0])
        mean_prior = np.array([mu0_bias] + (weights_bias[0].shape[1] - 1) * [mu0])
        sigma_inv_layer = [np.array([1 / sigma_bias] + (weights_bias[j].shape[1] - 1) * [1 / sigma]) for j in range(J)]

    print(f"weights_bias[0] shape: {weights_bias[0].shape}")
    print(f"sigma_inv_prior shape: {sigma_inv_prior.shape}")
    print(f"mean_prior shape: {mean_prior.shape}")
    print(f"sigma_inv_layer[0] shape: {sigma_inv_layer[0].shape}")

    # Run matching
    assignment_c, global_weights_c, global_sigmas_c = match_layer(
        weights_bias, sigma_inv_layer, mean_prior, sigma_inv_prior, gamma, it
    )

    L_next = global_weights_c.shape[0]
    print(f"Matched global component count (L_next): {L_next}")

    # Reconstruct weight structure
    if layer_index <= 1:
        if n_layers == 2:
            softmax_bias, softmax_inv_sigma = process_softmax_bias(batch_weights, last_layer_const, sigma, sigma0)
            global_weights_out = [softmax_bias]
            global_inv_sigmas_out = [softmax_inv_sigma]
        else:
            global_weights_out = [
                global_weights_c[:, :init_channel_kernel_dims[int(layer_index / 2)]],
                global_weights_c[:, init_channel_kernel_dims[int(layer_index / 2)]]
            ]
            global_inv_sigmas_out = [
                global_sigmas_c[:, :init_channel_kernel_dims[int(layer_index / 2)]],
                global_sigmas_c[:, init_channel_kernel_dims[int(layer_index / 2)]]
            ]
        print("Branch A Output Shapes:", [g.shape for g in global_weights_out])

    elif layer_index == (n_layers - 1) and n_layers > 2:
        softmax_bias, softmax_inv_sigma = process_softmax_bias(batch_weights, last_layer_const, sigma, sigma0)
        layer_type = model_layer_type[2 * layer_index - 2]
        gwc_shape = global_weights_c.shape

        if "conv" in layer_type or 'features' in layer_type:
            global_weights_out = [global_weights_c[:, :-1], global_weights_c[:, -1]]
            global_inv_sigmas_out = [global_sigmas_c[:, :-1], global_sigmas_c[:, -1]]
        elif "fc" in layer_type or 'classifier' in layer_type:
            global_weights_out = [global_weights_c[:, :-1].T, global_weights_c[:, -1]]
            global_inv_sigmas_out = [global_sigmas_c[:, :-1].T, global_sigmas_c[:, -1]]
        print("Branch B Output Shapes:", [g.shape for g in global_weights_out])

    elif (layer_index > 1 and layer_index < (n_layers - 1)):
        layer_type = model_layer_type[2 * layer_index - 2]
        gwc_shape = global_weights_c.shape
        if "conv" in layer_type or 'features' in layer_type:
            global_weights_out = [global_weights_c[:, :-1], global_weights_c[:, -1]]
            global_inv_sigmas_out = [global_sigmas_c[:, :-1], global_sigmas_c[:, -1]]
        elif "fc" in layer_type or 'classifier' in layer_type:
            global_weights_out = [global_weights_c[:, :-1].T, global_weights_c[:, -1]]
            global_inv_sigmas_out = [global_sigmas_c[:, :-1].T, global_sigmas_c[:, -1]]
        print("Branch C Output Shapes:", [g.shape for g in global_weights_out])

    print("global_inv_sigmas_out shapes:", [g.shape for g in global_inv_sigmas_out])
    map_out = [g_w / g_s for g_w, g_s in zip(global_weights_out, global_inv_sigmas_out)]
    print("Final map_out shapes:", [m.shape for m in map_out])
    
    return map_out, assignment_c, L_next


In [247]:
num_workers = len(nets_list)
n_classes = args_net_config[-1]
averaging_weights = np.full((num_workers, n_classes), 1.0 / num_workers)
batch_freqs = np.full((len(nets_list), n_classes), 1.0 / n_classes)

In [248]:
print(batch_freqs)

[[0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
 [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
 [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
 [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]
 [0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1 0.1]]


In [249]:
print(averaging_weights)

[[0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2]
 [0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2]
 [0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2]
 [0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2]
 [0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2 0.2]]


In [None]:
import numpy as np

def layer_wise_group_descent(batch_weights, layer_index, batch_frequencies, sigma_layers, 
                             sigma0_layers, gamma_layers, it, 
                             model_layer_type,
                             n_layers,):
    """
    We implement a layer-wise matching here:
    """
    print(f"===> Starting layer_wise_group_descent for layer_index={layer_index}")

    if type(sigma_layers) is not list:
        sigma_layers = (n_layers - 1) * [sigma_layers]
    if type(sigma0_layers) is not list:
        sigma0_layers = (n_layers - 1) * [sigma0_layers]
    if type(gamma_layers) is not list:
        gamma_layers = (n_layers - 1) * [gamma_layers]

    print(f"Total layers: {n_layers}, Iterations: {it}")
    print(f"Sigma for layer {layer_index}: {sigma_layers[layer_index - 1]}")
    print(f"Gamma for layer {layer_index}: {gamma_layers[layer_index - 1]}")

    total_freq = sum(batch_frequencies)
    last_layer_const = [f / total_freq for f in batch_frequencies]
    print(f"Batch frequencies: {batch_frequencies}")
    print(f"Normalized frequencies: {last_layer_const}")

    J = len(batch_weights)
    print(f"Number of clients (J): {J}")
    print(f"Type of batch_weights[0]: {type(batch_weights[0])}, length: {len(batch_weights[0])}")

    if len(batch_weights[0]) > 0:
        print(f"Shape of batch_weights[0][0]: {batch_weights[0][0].shape}")
    init_num_kernel = batch_weights[0][0].shape[0]

    init_channel_kernel_dims = []
    for bw in batch_weights[0]:
        if len(bw.shape) > 1:
            init_channel_kernel_dims.append(bw.shape[1])
    print("init_channel_kernel_dims: %s" % init_channel_kernel_dims)

    sigma = sigma_layers[layer_index - 1]
    sigma_bias = sigma_layers[layer_index - 1]
    gamma = gamma_layers[layer_index - 1]
    sigma0 = sigma0_layers[layer_index - 1]
    sigma0_bias = sigma0_layers[layer_index - 1]
    mu0 = 0.
    mu0_bias = 0.1

    if layer_index <= 1:
        print("Layer Type: First Layer")

        weights_bias = [np.hstack((batch_weights[j][0], batch_weights[j][layer_index * 2 - 1].reshape(-1, 1))) for j in range(J)]

        sigma_inv_prior = np.array(init_channel_kernel_dims[layer_index - 1] * [1 / sigma0] + [1 / sigma0_bias])
        mean_prior = np.array(init_channel_kernel_dims[layer_index - 1] * [mu0] + [mu0_bias])

        if n_layers == 2:
            # Assuming 'D' is defined or accessible in this context, otherwise this line will cause an error.
            # For the purpose of converting logger to print, 'D' is left as is.
            sigma_inv_layer = [
                np.array(D * [1 / sigma] + [1 / sigma_bias] + [y / sigma for y in last_layer_const[j]]) for j in range(J)]
        else:
            sigma_inv_layer = [np.array(init_channel_kernel_dims[layer_index - 1] * [1 / sigma] + [1 / sigma_bias]) for j in range(J)]

    elif layer_index == (n_layers - 1) and n_layers > 2:
        print("Layer Type: Final Layer")

        layer_type = model_layer_type[2 * layer_index - 2]
        prev_layer_type = model_layer_type[2 * layer_index - 4]
        first_fc_identifier = (('fc' in layer_type or 'classifier' in layer_type) and ('conv' in prev_layer_type or 'features' in layer_type))

        weights_bias = [np.hstack((batch_weights[j][2 * layer_index - 2].T, 
                                   batch_weights[j][2 * layer_index - 1].reshape(-1, 1))) for j in range(J)]

        sigma_inv_prior = np.array([1 / sigma0_bias] + (weights_bias[0].shape[1] - 1) * [1 / sigma0])
        mean_prior = np.array([mu0_bias] + (weights_bias[0].shape[1] - 1) * [mu0])
        sigma_inv_layer = [np.array([1 / sigma_bias] + (weights_bias[j].shape[1] - 1) * [1 / sigma]) for j in range(J)]

    elif (layer_index > 1 and layer_index < (n_layers - 1)):
        print("Layer Type: Middle Layer")

        layer_type = model_layer_type[2 * layer_index - 2]
        prev_layer_type = model_layer_type[2 * layer_index - 4]

        if 'conv' in layer_type or 'features' in layer_type:
            print("SubType: Convolutional")
            weights_bias = [np.hstack((batch_weights[j][2 * layer_index - 2], 
                                       batch_weights[j][2 * layer_index - 1].reshape(-1, 1))) for j in range(J)]
        elif 'fc' in layer_type or 'classifier' in layer_type:
            print("SubType: Fully Connected")
            first_fc_identifier = (('fc' in layer_type or 'classifier' in layer_type) and ('conv' in prev_layer_type or 'features' in layer_type))
            weights_bias = [np.hstack((batch_weights[j][2 * layer_index - 2].T, 
                                       batch_weights[j][2 * layer_index - 1].reshape(-1, 1))) for j in range(J)]

        sigma_inv_prior = np.array([1 / sigma0_bias] + (weights_bias[0].shape[1] - 1) * [1 / sigma0])
        mean_prior = np.array([mu0_bias] + (weights_bias[0].shape[1] - 1) * [mu0])
        sigma_inv_layer = [np.array([1 / sigma_bias] + (weights_bias[j].shape[1] - 1) * [1 / sigma]) for j in range(J)]

    print(f"weights_bias[0] shape: {weights_bias[0].shape}")
    print(f"sigma_inv_prior shape: {sigma_inv_prior.shape}")
    print(f"mean_prior shape: {mean_prior.shape}")
    print(f"sigma_inv_layer[0] shape: {sigma_inv_layer[0].shape}")

    # Matching
    # You'll need to ensure `match_layer` and `process_softmax_bias` are defined elsewhere
    # or replace them with your actual implementation.
    # For demonstration, I'm assuming they exist and work as intended.
    assignment_c, global_weights_c, global_sigmas_c = match_layer(weights_bias, sigma_inv_layer, mean_prior,
                                                                   sigma_inv_prior, gamma, it)
    L_next = global_weights_c.shape[0]

    # Build outputs
    if layer_index <= 1:
        if n_layers == 2:
            print("Handling softmax for 2-layer network")
            softmax_bias, softmax_inv_sigma = process_softmax_bias(batch_weights, last_layer_const, sigma, sigma0)
            global_weights_out = [softmax_bias]
            global_inv_sigmas_out = [softmax_inv_sigma]
        else:
            global_weights_out = [
                global_weights_c[:, :init_channel_kernel_dims[int(layer_index / 2)]],
                global_weights_c[:, init_channel_kernel_dims[int(layer_index / 2)]]
            ]
            global_inv_sigmas_out = [
                global_sigmas_c[:, :init_channel_kernel_dims[int(layer_index / 2)]],
                global_sigmas_c[:, init_channel_kernel_dims[int(layer_index / 2)]]
            ]
        print(f"Branch A: global_weights_out shapes: {[gwo.shape for gwo in global_weights_out]}")

    elif layer_index == (n_layers - 1) and n_layers > 2:
        print("Handling final softmax layer")
        softmax_bias, softmax_inv_sigma = process_softmax_bias(batch_weights, last_layer_const, sigma, sigma0)

        layer_type = model_layer_type[2 * layer_index - 2]
        gwc_shape = global_weights_c.shape
        if "conv" in layer_type or 'features' in layer_type:
            global_weights_out = [global_weights_c[:, :-1], global_weights_c[:, -1]]
            global_inv_sigmas_out = [global_sigmas_c[:, :-1], global_sigmas_c[:, -1]]
        elif "fc" in layer_type or 'classifier' in layer_type:
            global_weights_out = [global_weights_c[:, :-1].T, global_weights_c[:, -1]]
            global_inv_sigmas_out = [global_sigmas_c[:, :-1].T, global_sigmas_c[:, -1]]
        print(f"Branch B: global_weights_out shapes: {[gwo.shape for gwo in global_weights_out]}")

    elif (layer_index > 1 and layer_index < (n_layers - 1)):
        print("Handling middle layer output assembly")
        layer_type = model_layer_type[2 * layer_index - 2]
        gwc_shape = global_weights_c.shape

        if "conv" in layer_type or 'features' in layer_type:
            global_weights_out = [global_weights_c[:, :-1], global_weights_c[:, -1]]
            global_inv_sigmas_out = [global_sigmas_c[:, :-1], global_sigmas_c[:, -1]]
        elif "fc" in layer_type or 'classifier' in layer_type:
            global_weights_out = [global_weights_c[:, :-1].T, global_weights_c[:, -1]]
            global_inv_sigmas_out = [global_sigmas_c[:, :-1].T, global_sigmas_c[:, -1]]
        print(f"Branch Mid: global_weights_out shapes: {[gwo.shape for gwo in global_weights_out]}")

    print(f"Final global_inv_sigmas_out shapes: {[giso.shape for giso in global_inv_sigmas_out]}")

    map_out = [g_w / g_s for g_w, g_s in zip(global_weights_out, global_inv_sigmas_out)]
    print(f"map_out shapes: {[mo.shape for mo in map_out]}")

    return map_out, assignment_c, L_next

## Corrected BBP_MAP

In [250]:
def BBP_MAP(nets_list, model_meta_data, layer_type, averaging_weights, batch_freqs, args, device="cpu"):
    """
    Fixed version of BBP_MAP function
    """
    models = nets_list
    n_classes = args_net_config[-1]
    it = 5
    
    # Fix: Define missing variables with default values THESE werent defined
    sigma = getattr(args, 'pdm_sig', 1.0) # these are all hyperparameters floating point values
    sigma0 = getattr(args, 'pdm_sig0', 1.0) # these are all hyperparameters floating point values
    gamma = getattr(args, 'pdm_gamma', 7.0) # these are all hyperparameters floating point values
    
    assignments_list = []
    
    batch_weights = pdm_prepare_full_weights_cnn(models, device=device)
    
    logging.info("=="*15)
    logging.info("Weights shapes: {}".format([bw.shape for bw in batch_weights[0]]))

    n_layers = int(len(batch_weights[0]) / 2)
    num_workers = len(nets_list)
    matching_shapes = []
    first_fc_index = None

    for layer_index in range(1, n_layers):
        try:
            print("Layer Wise Group Descent for Layer Index:", layer_index)
            layer_hungarian_weights, assignment, L_next = layer_wise_group_descent(
                batch_weights=batch_weights, 
                layer_index=layer_index,
                sigma0_layers=sigma0, 
                sigma_layers=sigma, 
                batch_frequencies=batch_freqs, 
                it=it, 
                gamma_layers=gamma, 
                model_layer_type=layer_type,
                n_layers=n_layers,
            )
            assignments_list.append(assignment)
            
            # Check if we have valid assignments
            if not assignment or len(assignment) != num_workers:
                logging.warning(f"Invalid assignment at layer {layer_index}")
                continue
            
            type_of_patched_layer = layer_type[2 * (layer_index + 1) - 2]
            if 'conv' in type_of_patched_layer or 'features' in type_of_patched_layer:
                l_type = "conv"
            elif 'fc' in type_of_patched_layer or 'classifier' in type_of_patched_layer:
                l_type = "fc"

            type_of_this_layer = layer_type[2 * layer_index - 2]
            type_of_prev_layer = layer_type[2 * layer_index - 2 - 2]
            first_fc_identifier = (('fc' in type_of_this_layer or 'classifier' in type_of_this_layer) and 
                                 ('conv' in type_of_prev_layer or 'features' in type_of_this_layer))
            
            if first_fc_identifier:
                first_fc_index = layer_index
            
            matching_shapes.append(L_next)
            tempt_weights = [([batch_weights[w][i] for i in range(2 * layer_index - 2)] + 
                            copy.deepcopy(layer_hungarian_weights)) for w in range(num_workers)]

            for worker_index in range(num_workers):
                try:
                    # Add bounds checking for layer access
                    if 2 * (layer_index + 1) - 2 >= len(batch_weights[worker_index]):
                        logging.warning(f"Layer index out of bounds for worker {worker_index}")
                        continue
                        
                    if first_fc_index is None:
                        if l_type == "conv":
                            patched_weight = block_patching(
                                batch_weights[worker_index][2 * (layer_index + 1) - 2], 
                                L_next, assignment[worker_index], 
                                layer_index+1, model_meta_data,
                                matching_shapes=matching_shapes, layer_type=l_type,
                                network_name=args.model)
                        elif l_type == "fc":
                            patched_weight = block_patching(
                                batch_weights[worker_index][2 * (layer_index + 1) - 2].T, 
                                L_next, assignment[worker_index], 
                                layer_index+1, model_meta_data,
                                matching_shapes=matching_shapes, layer_type=l_type,
                                network_name=args.model).T
                    elif layer_index >= first_fc_index:
                        patched_weight = patch_weights(
                            batch_weights[worker_index][2 * (layer_index + 1) - 2].T, 
                            L_next, assignment[worker_index]).T

                    tempt_weights[worker_index].append(patched_weight)
                    
                except Exception as e:
                    logging.error(f"Error patching weights for worker {worker_index}, layer {layer_index}: {e}")
                    continue

            # Update batch_weights for next iteration
            for worker_index in range(num_workers):
                for lid in range(2 * (layer_index + 1) - 1, len(batch_weights[0])):
                    if lid < len(batch_weights[worker_index]):
                        tempt_weights[worker_index].append(batch_weights[worker_index][lid])
            
            batch_weights = tempt_weights
            
        except Exception as e:
            logging.error(f"Error in layer {layer_index}: {e}")
            continue

    # Continue with the rest of the function...
    matched_weights = []
    num_layers = len(batch_weights[0]) if batch_weights and len(batch_weights) > 0 else 0
    
    if num_layers == 0:
        logging.error("No valid batch weights found")
        return [], assignments_list

    # Create cache directory if it doesn't exist
    import os
    os.makedirs('./matching_weights_cache', exist_ok=True)
    
    try:
        with open('./matching_weights_cache/matched_layerwise_weights', 'wb') as weights_file:
            pickle.dump(batch_weights, weights_file)
    except Exception as e:
        logging.warning(f"Could not save weights cache: {e}")

    last_layer_weights_collector = []

    for i in range(num_workers):
        if len(batch_weights[i]) >= 2:
            bias_shape = batch_weights[i][-1].shape
            last_layer_bias = batch_weights[i][-1].reshape((1, bias_shape[0]))
            last_layer_weights = np.concatenate((batch_weights[i][-2], last_layer_bias), axis=0)
            last_layer_weights_collector.append(last_layer_weights)

    if not last_layer_weights_collector:
        logging.error("No valid last layer weights found")
        return [], assignments_list

    last_layer_weights_collector = np.array(last_layer_weights_collector)
    avg_last_layer_weight = np.zeros(last_layer_weights_collector[0].shape, dtype=np.float32)

    for i in range(min(n_classes, avg_last_layer_weight.shape[1])):
        avg_weight_collector = np.zeros(last_layer_weights_collector[0][:, 0].shape, dtype=np.float32)
        for j in range(num_workers):
            if j < len(averaging_weights) and i < len(averaging_weights[j]):
                avg_weight_collector += averaging_weights[j][i] * last_layer_weights_collector[j][:, i]
        avg_last_layer_weight[:, i] = avg_weight_collector

    for i in range(num_layers - 2):
        if i < len(batch_weights[0]):
            matched_weights.append(batch_weights[0][i])

    matched_weights.append(avg_last_layer_weight[0:-1, :])
    matched_weights.append(avg_last_layer_weight[-1, :])
    
    return matched_weights, assignments_list

In [251]:
hungarian_weights, assignments_list = BBP_MAP(
        nets_list, model_meta_data, layer_type, 
        averaging_weights, batch_freqs, args, 
        device=device
    )
print("\n✅ Success! Matching completed.\n")

Statedict:  OrderedDict({'conv1.weight': tensor([[[[ 0.1529,  0.1660, -0.0469,  0.1837, -0.0438],
          [ 0.0404, -0.0974,  0.1175,  0.1763, -0.1467],
          [ 0.1738,  0.0374,  0.1478,  0.0271,  0.0964],
          [-0.0282,  0.1542,  0.0296, -0.0934,  0.0510],
          [-0.0921, -0.0235, -0.0812,  0.1327, -0.1579]]],


        [[[-0.0922, -0.0565, -0.1203,  0.0189, -0.1975],
          [ 0.1806, -0.1699,  0.1544,  0.0333, -0.0649],
          [ 0.1236,  0.0312,  0.1616,  0.0219, -0.0631],
          [ 0.0537, -0.0542,  0.0842,  0.1786,  0.1156],
          [-0.0874,  0.1155,  0.0358,  0.1016, -0.1219]]],


        [[[-0.1980, -0.0773, -0.1534,  0.1641,  0.0576],
          [ 0.0828,  0.0633, -0.0035,  0.1565, -0.1421],
          [ 0.0126, -0.1365,  0.0617, -0.0689,  0.0613],
          [-0.0417,  0.1659, -0.1185, -0.1193, -0.1193],
          [ 0.1799,  0.0667,  0.1925, -0.1651, -0.1984]]],


        [[[-0.1565, -0.1345,  0.0810,  0.0716,  0.1662],
          [-0.1033, -0.1363,  0.106

INFO:root:Weights shapes: [(20, 25), (20,), (50, 500), (50,), (800, 500), (500,), (500, 10), (10,)]


OrderedDict({'conv1.weight': tensor([[[[ 0.0710,  0.0612, -0.1817,  0.1770, -0.0030],
          [ 0.1994,  0.1034, -0.1873, -0.0893,  0.1836],
          [-0.1801, -0.1286, -0.0362,  0.0209,  0.1691],
          [-0.1807,  0.1202, -0.0677,  0.0130, -0.1172],
          [-0.1092,  0.1389,  0.1056,  0.0888,  0.0997]]],


        [[[-0.0476,  0.0397, -0.1289, -0.1700,  0.0622],
          [ 0.0201, -0.1864,  0.1072, -0.0138,  0.0885],
          [-0.0538,  0.0686,  0.0423, -0.0871,  0.1258],
          [-0.0400,  0.1944, -0.0676,  0.1185,  0.1420],
          [ 0.0326,  0.0254, -0.0362, -0.1602,  0.1134]]],


        [[[ 0.1644,  0.1622, -0.1984, -0.1506,  0.1195],
          [ 0.0748,  0.1322,  0.1766, -0.1325, -0.1591],
          [-0.1824, -0.1459, -0.1373,  0.1573, -0.0853],
          [-0.1099, -0.1544,  0.1438,  0.1502, -0.0869],
          [-0.0852,  0.1792, -0.1000, -0.0173,  0.1843]]],


        [[[ 0.0865, -0.0596, -0.1200, -0.0519,  0.1367],
          [-0.1583, -0.1621,  0.1944, -0.0006, 

INFO:root:Number of global neurons is 28, gamma 7.000000
INFO:root:***************Shape of global weights after match: (28, 26) ******************


full_cost shape          : (20, 46)
compute_cost duration    : 0.0065s
linear_sum_assignment time: 0.0004s
row_ind: [ 0  1  2  3  4  5  6  7  8  9 10 11 12 13 14 15 16 17 18 19]
col_ind: [ 1 12 14  2  7 17  4  8 16  6  3  5  9 10 27 15  0 26 11 13]
Assigning local 0 → global 1
Assigning local 1 → global 12
Assigning local 2 → global 14
Assigning local 3 → global 2
Assigning local 4 → global 7
Assigning local 5 → global 17
Assigning local 6 → global 4
Assigning local 7 → global 8
Assigning local 8 → global 16
Assigning local 9 → global 6
Assigning local 10 → global 3
Assigning local 11 → global 5
Assigning local 12 → global 9
Assigning local 13 → global 10
Assigning local 14 → global 27
New global component added: 26
Assigning local 15 → global 15
Assigning local 16 → global 0
Assigning local 17 → global 26
New global component added: 27
Assigning local 18 → global 11
Assigning local 19 → global 13
Updated global_weights shape: (28, 26)
Updated global_sigmas shape : (28, 26)
Updated pop

INFO:root:Number of global neurons is 58, gamma 7.000000
INFO:root:***************Shape of global weights after match: (58, 701) ******************


Assigning local 34 → global 28
Assigning local 35 → global 29
Assigning local 36 → global 45
Assigning local 37 → global 47
Assigning local 38 → global 35
Assigning local 39 → global 43
Assigning local 40 → global 27
Assigning local 41 → global 13
Assigning local 42 → global 15
Assigning local 43 → global 18
Assigning local 44 → global 22
Assigning local 45 → global 57
New global component added: 57
Assigning local 46 → global 4
Assigning local 47 → global 2
Assigning local 48 → global 34
Assigning local 49 → global 3
Updated global_weights shape: (58, 701)
Updated global_sigmas shape : (58, 701)
Updated popularity_counts    : [5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
assignment_j                 : [np.int64(5), np.int64(23), np.int64(16), np.int64(37), np.int64(32), np.int64(26), np.int64(38), np.int64(14), np.int64(11), np.int64(20), np.int64(19), np.in

INFO:root:Number of global neurons is 508, gamma 7.000000
INFO:root:***************Shape of global weights after match: (508, 59) ******************


linear_sum_assignment time: 0.0293s
row_ind: [  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71
  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89
  90  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107
 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143
 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161
 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179
 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197
 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215
 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
 234 2

In [252]:
print("🔹 Matched Weights Summary:")
for i, w in enumerate(hungarian_weights):
    print(f"Layer {i}: shape = {w.shape}, type = {type(w)}")

print("\n🔹 Assignments List:")
for i, assignment in enumerate(assignments_list):
    print(f"Layer {i+1} assignment: {assignment}")

🔹 Matched Weights Summary:
Layer 0: shape = (28, 25), type = <class 'numpy.ndarray'>
Layer 1: shape = (28,), type = <class 'numpy.ndarray'>
Layer 2: shape = (58, 700), type = <class 'numpy.ndarray'>
Layer 3: shape = (58,), type = <class 'numpy.ndarray'>
Layer 4: shape = (58, 508), type = <class 'numpy.ndarray'>
Layer 5: shape = (508,), type = <class 'numpy.ndarray'>
Layer 6: shape = (508, 10), type = <class 'numpy.ndarray'>
Layer 7: shape = (10,), type = <class 'numpy.ndarray'>

🔹 Assignments List:
Layer 1 assignment: [[np.int64(15), np.int64(5), 26, np.int64(2), np.int64(3), np.int64(4), np.int64(7), np.int64(6), 27, np.int64(8), np.int64(9), np.int64(10), np.int64(11), np.int64(1), np.int64(12), np.int64(13), np.int64(14), np.int64(0), np.int64(16), np.int64(17)], [np.int64(4), np.int64(6), np.int64(12), np.int64(1), np.int64(5), np.int64(7), np.int64(0), np.int64(14), np.int64(8), np.int64(16), np.int64(17), 24, np.int64(13), np.int64(9), np.int64(11), np.int64(15), 25, np.int64(3),

## Checking the results of BBP_MAP

In [253]:
print(assignments_list)

[[[np.int64(15), np.int64(5), 26, np.int64(2), np.int64(3), np.int64(4), np.int64(7), np.int64(6), 27, np.int64(8), np.int64(9), np.int64(10), np.int64(11), np.int64(1), np.int64(12), np.int64(13), np.int64(14), np.int64(0), np.int64(16), np.int64(17)], [np.int64(4), np.int64(6), np.int64(12), np.int64(1), np.int64(5), np.int64(7), np.int64(0), np.int64(14), np.int64(8), np.int64(16), np.int64(17), 24, np.int64(13), np.int64(9), np.int64(11), np.int64(15), 25, np.int64(3), np.int64(2), np.int64(10)], [np.int64(3), np.int64(9), 20, np.int64(4), np.int64(13), np.int64(12), np.int64(1), np.int64(16), np.int64(0), np.int64(14), np.int64(10), np.int64(8), np.int64(2), np.int64(17), 21, np.int64(6), np.int64(15), np.int64(5), np.int64(11), np.int64(7)], [np.int64(8), np.int64(6), np.int64(12), np.int64(10), np.int64(11), np.int64(4), np.int64(0), np.int64(7), np.int64(16), 22, np.int64(15), np.int64(5), np.int64(14), np.int64(17), np.int64(1), np.int64(13), 23, np.int64(3), np.int64(9), np.i

In [254]:
for layer_idx, layer_assignment in enumerate(assignments_list, start=1):
    print(f"Layer {layer_idx} assignments:")
    for client_idx, assign in enumerate(layer_assignment):
        # assign might be a NumPy array or list of ints
        print(f"  Client {client_idx:>2}: {list(assign)}")
    print()  # blank line between layers


Layer 1 assignments:
  Client  0: [np.int64(15), np.int64(5), 26, np.int64(2), np.int64(3), np.int64(4), np.int64(7), np.int64(6), 27, np.int64(8), np.int64(9), np.int64(10), np.int64(11), np.int64(1), np.int64(12), np.int64(13), np.int64(14), np.int64(0), np.int64(16), np.int64(17)]
  Client  1: [np.int64(4), np.int64(6), np.int64(12), np.int64(1), np.int64(5), np.int64(7), np.int64(0), np.int64(14), np.int64(8), np.int64(16), np.int64(17), 24, np.int64(13), np.int64(9), np.int64(11), np.int64(15), 25, np.int64(3), np.int64(2), np.int64(10)]
  Client  2: [np.int64(3), np.int64(9), 20, np.int64(4), np.int64(13), np.int64(12), np.int64(1), np.int64(16), np.int64(0), np.int64(14), np.int64(10), np.int64(8), np.int64(2), np.int64(17), 21, np.int64(6), np.int64(15), np.int64(5), np.int64(11), np.int64(7)]
  Client  3: [np.int64(8), np.int64(6), np.int64(12), np.int64(10), np.int64(11), np.int64(4), np.int64(0), np.int64(7), np.int64(16), 22, np.int64(15), np.int64(5), np.int64(14), np.int6

In [255]:
#per‐layer permutation vectors that BBP‐MAP has found for each client. 
"""
Layer 1 assignments:
  Client  0: [15, 5, 26, 2, 3, 4, …, 0, 16, 17]
               ^  ^   ^  ^  ^  ^        ^
               |  |   |  |  |  |        +-- original block 19 → global block 17
               |  |   |  |  |  +----------- original block 0  → global block 15
               |  |   |  |  +-------------- original block 3  → global block 2
               ·  ·   ·  ·  
"""

'\nLayer 1 assignments:\n  Client  0: [15, 5, 26, 2, 3, 4, …, 0, 16, 17]\n               ^  ^   ^  ^  ^  ^        ^\n               |  |   |  |  |  |        +-- original block 19 → global block 17\n               |  |   |  |  |  +----------- original block 0  → global block 15\n               |  |   |  |  +-------------- original block 3  → global block 2\n               ·  ·   ·  ·  \n'

In [256]:
for idx, client in enumerate(nets_list):
    print(f"\n=== Client {idx} parameters ===")
    for name, param in client.named_parameters():
        print(f"{name:30}  {tuple(param.shape)}  mean={param.data.mean():.4f}")


=== Client 0 parameters ===
conv1.weight                    (20, 1, 5, 5)  mean=-0.0091
conv1.bias                      (20,)  mean=-0.0235
conv2.weight                    (50, 20, 5, 5)  mean=0.0000
conv2.bias                      (50,)  mean=0.0056
fc1.weight                      (500, 800)  mean=-0.0000
fc1.bias                        (500,)  mean=-0.0011
fc2.weight                      (10, 500)  mean=0.0002
fc2.bias                        (10,)  mean=0.0061

=== Client 1 parameters ===
conv1.weight                    (20, 1, 5, 5)  mean=0.0012
conv1.bias                      (20,)  mean=-0.0136
conv2.weight                    (50, 20, 5, 5)  mean=-0.0000
conv2.bias                      (50,)  mean=-0.0041
fc1.weight                      (500, 800)  mean=0.0000
fc1.bias                        (500,)  mean=-0.0005
fc2.weight                      (10, 500)  mean=0.0005
fc2.bias                        (10,)  mean=0.0114

=== Client 2 parameters ===
conv1.weight                    (20

In [257]:
print(hungarian_weights)

[array([[-6.49053110e-02,  1.16629486e-02, -8.81744910e-02,
        -2.17946184e-02,  5.93166242e-02,  7.70698432e-02,
        -4.15283367e-02,  3.73556279e-03,  2.95193767e-02,
         3.19954948e-02,  3.62161975e-02,  2.53420640e-02,
        -7.05319547e-02, -4.84311115e-02,  4.81213753e-03,
         8.44100962e-02,  5.88592341e-02,  1.31092152e-02,
         4.95940751e-02,  1.36419202e-01, -4.80760516e-02,
        -8.98706758e-02,  4.93932504e-03,  8.50024155e-02,
        -1.07170306e-01],
       [-7.20851444e-02, -3.51366817e-02, -8.21623479e-02,
        -1.26698202e-02, -7.01947200e-03,  5.37607688e-02,
         7.99585208e-02, -5.81740987e-02, -6.76749231e-02,
         1.15144953e-02, -5.88538895e-02, -7.16517912e-02,
         8.91824532e-02, -2.38491068e-02,  6.42496274e-02,
        -9.29510904e-02,  7.40384627e-02,  2.19847721e-02,
        -4.43810344e-02, -1.00660952e-01, -4.20958642e-02,
         1.34098784e-01,  2.09462745e-02,  2.35479918e-02,
         2.42167517e-02],
   

In [258]:
len(hungarian_weights)

8

In [259]:
len(hungarian_weights[0])

28

# ppt for every block with pictorial with n_lets + code snippet 
# if you change no of clients how does the assignent list and hungarian weights chnage, input size etc for lenet
# check how its used for fedma itslef output of bbp_map
# kernel to kernel matching check how they are appended 
# also reverse bbp map order 
# clean the file make it super nice know kernel and channel and enuron difference 