In [1]:
import random
from pathlib import Path
import tarfile
from typing import Any
from logging import INFO
from collections import defaultdict, OrderedDict
from collections.abc import Sequence, Callable
import numbers
import copy

import numpy as np
import torch
from torch import nn
from torch.nn import Module
from torch.utils.data import DataLoader, Dataset
from enum import IntEnum
import flwr
from flwr.server import History, ServerConfig
from flwr.server.strategy import FedAvgM as FedAvg, Strategy
from c2m3.flower.fed_frank_wolfe_strategy import FrankWolfeSync
from flwr.common import log, NDArrays, Scalar, Parameters, ndarrays_to_parameters
from flwr.client.client import Client
from c2m3.match.utils import apply_permutation_to_statedict
from c2m3.match.frank_wolfe_sync_matching import frank_wolfe_synchronized_matching
from c2m3.match.permutation_spec import CNNPermutationSpecBuilder
from c2m3.modules.pl_module import MyLightningModule
NUM_CLASSES_FEMNIST = 62

from c2m3.common.client_utils import (
    Net,
    load_femnist_dataset,
    get_network_generator_cnn as get_network_generator,
    train_femnist,
    test_femnist,
    save_history,
    get_model_parameters,
    set_model_parameters
)


# Add new seeds here for easy autocomplete
class Seeds(IntEnum):
    """Seeds for reproducibility."""

    DEFAULT = 1337


np.random.seed(Seeds.DEFAULT)
random.seed(Seeds.DEFAULT)
torch.manual_seed(Seeds.DEFAULT)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


PathType = Path | str | None


def get_device() -> str:
    """Get the device (cuda, mps, cpu)."""
    device = "cpu"
    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
        device = "mps"
    return device

2025-03-15 16:59:26.782231: I tensorflow/tsl/cuda/cudart_stub.cc:28] Could not find cuda drivers on your machine, GPU will not be used.
2025-03-15 16:59:26.848236: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm
2025-03-15 16:59:29,300	INFO util.py:159 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.


In [2]:
home_dir = Path.cwd() / ".."
dataset_dir: Path = home_dir / "femnist"
data_dir: Path = dataset_dir / "data"
centralized_partition: Path = dataset_dir / "client_data_mappings" / "centralized"
centralized_mapping: Path = dataset_dir / "client_data_mappings" / "centralized" / "0"
federated_partition: Path = dataset_dir / "client_data_mappings" / "fed_natural"

In [3]:
# Set random seed for reproducibility
torch.manual_seed(33)

n = Net()

# Create a copy of the network
n_permuted = copy.deepcopy(n)

# Get the correct layer sizes from the network
first_layer_size = n.conv1.out_channels  # Should be 6
second_layer_size = n.conv2.out_channels  # Should be 16
third_layer_size = n.fc1.out_features  # Should be 120
fourth_layer_size = n.fc2.out_features  # Should be 84
fifth_layer_size = n.fc3.out_features  # Should be 10
print(f'{first_layer_size=}\n{second_layer_size=}\n{third_layer_size=}\n{fourth_layer_size=}\n{fifth_layer_size=}')

first_layer_size=6
second_layer_size=16
third_layer_size=120
fourth_layer_size=84
fifth_layer_size=62


In [4]:
# Random permutations
perm_matrices_old = {
    # 'P_conv1': torch.randperm(6),  # For first conv layer
    'P_conv1': torch.arange(6),  # For first conv layer
    'P_fc1': torch.randperm(120),  # For first FC layer
    #'P_fc1': torch.arange(120),  # For first FC layer
    'P_fc2': torch.randperm(84)    # For second FC layer
    #'P_fc2': torch.arange(84)    # For second FC layer
}

# Set up parameters for frank_wolfe_synchronized_matching
mlm = MyLightningModule(n, num_classes=NUM_CLASSES_FEMNIST)
ref_model = copy.deepcopy(mlm)

permutation_spec_builder = CNNPermutationSpecBuilder()
perm_spec = permutation_spec_builder.create_permutation_spec(ref_model=ref_model)

# Use the permutation spec to apply them correctly
permuted_state_dict = apply_permutation_to_statedict(perm_spec, perm_matrices_old, n_permuted.state_dict())
n_permuted.load_state_dict(permuted_state_dict)

mlm_permuted = MyLightningModule(n_permuted, num_classes=NUM_CLASSES_FEMNIST)
models = {
    'a': mlm,
    'b': mlm_permuted
}

  rank_zero_warn(


In [7]:
# This illustrates one of the dependency cycles in permutation matrix a.
# 17 -> 30 -> 79 -> 106 -> 17
# This dependency is NOT fixed by the C2M3 algorithm, see below.
# However the other dependency cycles are fixed.
print(f'{perm_matrices_old["P_fc1"][17]=}')
print(f'{perm_matrices_old["P_fc1"][30]=}')
print(f'{perm_matrices_old["P_fc1"][79]=}')
print(f'{perm_matrices_old["P_fc1"][106]=}')

perm_matrices_old["P_fc1"][17]=tensor(30)
perm_matrices_old["P_fc1"][30]=tensor(79)
perm_matrices_old["P_fc1"][79]=tensor(106)
perm_matrices_old["P_fc1"][106]=tensor(17)


In [8]:
# Get the inverse permutation matrices by reversing the indices
inverse_perm_matrices = {}
for key in ['P_conv1', 'P_fc1', 'P_fc2']:
    inverse_perm_matrices[key] = torch.argsort(perm_matrices_old[key])

# Apply inverse permutations to n_permuted
unpermuted_state_dict = apply_permutation_to_statedict(perm_spec, inverse_perm_matrices, n_permuted.state_dict())

# Create a new network to hold the unpermuted weights
n_unpermuted = copy.deepcopy(n_permuted)
n_unpermuted.load_state_dict(unpermuted_state_dict)

# Check if unpermuted network matches original
print("\nChecking if unpermuted network matches original:")
print("Conv1 weight matches? ", torch.allclose(n.conv1.weight.data, n_unpermuted.conv1.weight.data))
print("Conv1 bias matches? ", torch.allclose(n.conv1.bias.data, n_unpermuted.conv1.bias.data))
print("Conv2 weight matches? ", torch.allclose(n.conv2.weight.data, n_unpermuted.conv2.weight.data))
print("Conv2 bias matches? ", torch.allclose(n.conv2.bias.data, n_unpermuted.conv2.bias.data))
print("FC1 weight matches? ", torch.allclose(n.fc1.weight.data, n_unpermuted.fc1.weight.data))
print("FC1 bias matches? ", torch.allclose(n.fc1.bias.data, n_unpermuted.fc1.bias.data))
print("FC2 weight matches? ", torch.allclose(n.fc2.weight.data, n_unpermuted.fc2.weight.data))
print("FC2 bias matches? ", torch.allclose(n.fc2.bias.data, n_unpermuted.fc2.bias.data))
print("FC3 weight matches? ", torch.allclose(n.fc3.weight.data, n_unpermuted.fc3.weight.data))
print("FC3 bias matches? ", torch.allclose(n.fc3.bias.data, n_unpermuted.fc3.bias.data))



Checking if unpermuted network matches original:
Conv1 weight matches?  True
Conv1 bias matches?  True
Conv2 weight matches?  True
Conv2 bias matches?  True
FC1 weight matches?  True
FC1 bias matches?  True
FC2 weight matches?  True
FC2 bias matches?  True
FC3 weight matches?  True
FC3 bias matches?  True


In [9]:
# Check if layers have been permuted by comparing sorted vs unsorted values
print("\nChecking if layers were permuted:")
print("Conv1 weight permuted? ", not torch.allclose(n.conv1.weight.data, n_permuted.conv1.weight.data))
print("Conv1 bias permuted? ", not torch.allclose(n.conv1.bias.data, n_permuted.conv1.bias.data))
print("Conv2 weight permuted? ", not torch.allclose(n.conv2.weight.data, n_permuted.conv2.weight.data))
print("Conv2 bias permuted? ", not torch.allclose(n.conv2.bias.data, n_permuted.conv2.bias.data))
print("FC1 weight permuted? ", not torch.allclose(n.fc1.weight.data, n_permuted.fc1.weight.data))
print("FC1 bias permuted? ", not torch.allclose(n.fc1.bias.data, n_permuted.fc1.bias.data))
print("FC2 weight permuted? ", not torch.allclose(n.fc2.weight.data, n_permuted.fc2.weight.data))
print("FC2 bias permuted? ", not torch.allclose(n.fc2.bias.data, n_permuted.fc2.bias.data))
print("FC3 weight permuted? ", not torch.allclose(n.fc3.weight.data, n_permuted.fc3.weight.data))
print("FC3 bias permuted? ", not torch.allclose(n.fc3.bias.data, n_permuted.fc3.bias.data))



Checking if layers were permuted:
Conv1 weight permuted?  False
Conv1 bias permuted?  False
Conv2 weight permuted?  False
Conv2 bias permuted?  False
FC1 weight permuted?  True
FC1 bias permuted?  True
FC2 weight permuted?  True
FC2 bias permuted?  True
FC3 weight permuted?  True
FC3 bias permuted?  False


In [10]:
print("Parameter shapes:")
for key, value in n.state_dict().items():
    print(f"{key}: {value.shape}")

Parameter shapes:
conv1.weight: torch.Size([6, 1, 5, 5])
conv1.bias: torch.Size([6])
conv2.weight: torch.Size([16, 6, 5, 5])
conv2.bias: torch.Size([16])
fc1.weight: torch.Size([120, 256])
fc1.bias: torch.Size([120])
fc2.weight: torch.Size([84, 120])
fc2.bias: torch.Size([84])
fc3.weight: torch.Size([62, 84])
fc3.bias: torch.Size([62])


In [11]:
params = {symbol: module.model.state_dict() for symbol, module in models.items()}
models_permuted_to_universe = {symbol: copy.deepcopy(model) for symbol, model in models.items()}
symbols = ['a', 'b']
merged_model = copy.deepcopy(models[symbols[0]])

perm_matrices = frank_wolfe_synchronized_matching(
    params=params,
    perm_spec=perm_spec,  # Need to define permutation spec
    symbols=['a', 'b'],
    combinations=[('a', 'b')],
    max_iter=500,
    initialization_method='identity',
    keep_soft_perms=False,
    device='cpu',
    verbose=True
)

Weight matching:   0%|          | 0/500 [00:00<?, ?it/s]

Weight matching:   0%|          | 2/500 [00:00<00:33, 14.92it/s]

Weight matching:   1%|          | 5/500 [00:00<00:27, 17.91it/s]

Weight matching:   1%|▏         | 7/500 [00:00<00:30, 16.07it/s]

Weight matching:   2%|▏         | 9/500 [00:00<00:31, 15.74it/s]

Weight matching:   2%|▏         | 11/500 [00:00<00:31, 15.29it/s]

Weight matching:   3%|▎         | 13/500 [00:00<00:34, 13.92it/s]

Weight matching:   3%|▎         | 15/500 [00:01<00:36, 13.15it/s]

Weight matching:   3%|▎         | 17/500 [00:01<00:34, 13.88it/s]

Weight matching:   4%|▍         | 19/500 [00:01<00:34, 14.04it/s]

Weight matching:   4%|▍         | 21/500 [00:01<00:36, 13.09it/s]

Weight matching:   5%|▍         | 23/500 [00:01<00:37, 12.75it/s]

Weight matching:   5%|▌         | 25/500 [00:01<00:35, 13.20it/s]

Weight matching:   5%|▌         | 27/500 [00:01<00:34, 13.69it/s]

Weight matching:   6%|▌         | 29/500 [00:02<00:34, 13.64it/s]

Weight matching:   6%|▌         | 31/500 [00:02<00:34, 13.52it/s]

Weight matching:   7%|▋         | 33/500 [00:02<00:37, 12.57it/s]

Weight matching:   7%|▋         | 35/500 [00:02<00:34, 13.43it/s]

Weight matching:   7%|▋         | 37/500 [00:02<00:32, 14.04it/s]

Weight matching:   8%|▊         | 39/500 [00:02<00:33, 13.92it/s]

Weight matching:   8%|▊         | 41/500 [00:02<00:32, 14.08it/s]

Weight matching:   9%|▊         | 43/500 [00:03<00:31, 14.62it/s]

Weight matching:   9%|▉         | 45/500 [00:03<00:30, 15.10it/s]

Weight matching:   9%|▉         | 47/500 [00:03<00:29, 15.26it/s]

Weight matching:  10%|▉         | 49/500 [00:03<00:29, 15.40it/s]

Weight matching:  10%|█         | 51/500 [00:03<00:29, 15.15it/s]

Weight matching:  11%|█         | 53/500 [00:03<00:29, 15.06it/s]

Weight matching:  11%|█         | 55/500 [00:03<00:28, 15.43it/s]

Weight matching:  11%|█▏        | 57/500 [00:03<00:28, 15.29it/s]

Weight matching:  12%|█▏        | 59/500 [00:04<00:30, 14.25it/s]

Weight matching:  12%|█▏        | 61/500 [00:04<00:30, 14.60it/s]

Weight matching:  13%|█▎        | 63/500 [00:04<00:30, 14.39it/s]

Weight matching:  13%|█▎        | 65/500 [00:04<00:32, 13.45it/s]

Weight matching:  13%|█▎        | 67/500 [00:04<00:31, 13.82it/s]

Weight matching:  14%|█▍        | 69/500 [00:04<00:30, 14.07it/s]

Weight matching:  14%|█▍        | 71/500 [00:04<00:29, 14.38it/s]

Weight matching:  15%|█▍        | 73/500 [00:05<00:29, 14.48it/s]

Weight matching:  15%|█▌        | 75/500 [00:05<00:30, 13.86it/s]

Weight matching:  15%|█▌        | 77/500 [00:05<00:31, 13.40it/s]

Weight matching:  16%|█▌        | 79/500 [00:05<00:30, 13.58it/s]

Weight matching:  16%|█▌        | 81/500 [00:05<00:34, 12.21it/s]

Weight matching:  17%|█▋        | 83/500 [00:05<00:35, 11.91it/s]

Weight matching:  17%|█▋        | 85/500 [00:06<00:33, 12.51it/s]

Weight matching:  17%|█▋        | 87/500 [00:06<00:32, 12.84it/s]

Weight matching:  18%|█▊        | 89/500 [00:06<00:31, 13.26it/s]

Weight matching:  18%|█▊        | 91/500 [00:06<00:30, 13.37it/s]

Weight matching:  19%|█▊        | 93/500 [00:06<00:30, 13.56it/s]

Weight matching:  19%|█▉        | 95/500 [00:06<00:29, 13.50it/s]

Weight matching:  19%|█▉        | 97/500 [00:07<00:32, 12.39it/s]

Weight matching:  20%|█▉        | 99/500 [00:07<00:33, 12.14it/s]

Weight matching:  20%|██        | 101/500 [00:07<00:38, 10.37it/s]

Weight matching:  21%|██        | 103/500 [00:07<00:39,  9.94it/s]

Weight matching:  21%|██        | 105/500 [00:07<00:40,  9.87it/s]

Weight matching:  21%|██▏       | 107/500 [00:08<00:39, 10.01it/s]

Weight matching:  22%|██▏       | 109/500 [00:08<00:40,  9.60it/s]

Weight matching:  22%|██▏       | 111/500 [00:08<00:38, 10.16it/s]

Weight matching:  23%|██▎       | 113/500 [00:08<00:34, 11.16it/s]

Weight matching:  23%|██▎       | 115/500 [00:08<00:32, 11.98it/s]

Weight matching:  23%|██▎       | 117/500 [00:08<00:29, 12.86it/s]

Weight matching:  24%|██▍       | 119/500 [00:09<00:28, 13.23it/s]

Weight matching:  24%|██▍       | 121/500 [00:09<00:27, 13.80it/s]

Weight matching:  25%|██▍       | 123/500 [00:09<00:28, 13.43it/s]

Weight matching:  25%|██▌       | 125/500 [00:09<00:30, 12.18it/s]

Weight matching:  25%|██▌       | 127/500 [00:09<00:29, 12.64it/s]

Weight matching:  26%|██▌       | 129/500 [00:09<00:29, 12.65it/s]

Weight matching:  26%|██▌       | 131/500 [00:10<00:31, 11.65it/s]

Weight matching:  27%|██▋       | 133/500 [00:10<00:32, 11.18it/s]

Weight matching:  27%|██▋       | 135/500 [00:10<00:35, 10.25it/s]

Weight matching:  27%|██▋       | 137/500 [00:10<00:34, 10.58it/s]

Weight matching:  28%|██▊       | 139/500 [00:10<00:32, 11.12it/s]

Weight matching:  28%|██▊       | 141/500 [00:10<00:30, 11.77it/s]

Weight matching:  29%|██▊       | 143/500 [00:11<00:28, 12.34it/s]

Weight matching:  29%|██▉       | 145/500 [00:11<00:28, 12.30it/s]

Weight matching:  29%|██▉       | 147/500 [00:11<00:28, 12.30it/s]

Weight matching:  30%|██▉       | 149/500 [00:11<00:29, 12.05it/s]

Weight matching:  30%|███       | 151/500 [00:11<00:28, 12.31it/s]

Weight matching:  31%|███       | 153/500 [00:11<00:27, 12.45it/s]

Weight matching:  31%|███       | 155/500 [00:12<00:27, 12.61it/s]

Weight matching:  31%|███▏      | 157/500 [00:12<00:27, 12.64it/s]

Weight matching:  32%|███▏      | 159/500 [00:12<00:26, 12.71it/s]

Weight matching:  32%|███▏      | 161/500 [00:12<00:26, 12.56it/s]

Weight matching:  33%|███▎      | 163/500 [00:12<00:26, 12.62it/s]

Weight matching:  33%|███▎      | 165/500 [00:12<00:27, 12.28it/s]

Weight matching:  33%|███▎      | 167/500 [00:12<00:25, 12.82it/s]

Weight matching:  34%|███▍      | 169/500 [00:13<00:24, 13.28it/s]

Weight matching:  34%|███▍      | 171/500 [00:13<00:24, 13.49it/s]

Weight matching:  35%|███▍      | 173/500 [00:13<00:24, 13.48it/s]

Weight matching:  35%|███▌      | 175/500 [00:13<00:24, 13.54it/s]

Weight matching:  35%|███▌      | 177/500 [00:13<00:24, 13.40it/s]

Weight matching:  36%|███▌      | 179/500 [00:13<00:23, 13.45it/s]

Weight matching:  36%|███▌      | 181/500 [00:13<00:23, 13.70it/s]

Weight matching:  37%|███▋      | 183/500 [00:14<00:23, 13.40it/s]

Weight matching:  37%|███▋      | 185/500 [00:14<00:24, 13.01it/s]

Weight matching:  37%|███▋      | 187/500 [00:14<00:24, 12.75it/s]

Weight matching:  38%|███▊      | 189/500 [00:14<00:24, 12.53it/s]

Weight matching:  38%|███▊      | 191/500 [00:14<00:25, 12.30it/s]

Weight matching:  39%|███▊      | 193/500 [00:14<00:24, 12.55it/s]

Weight matching:  39%|███▉      | 195/500 [00:15<00:24, 12.49it/s]

Weight matching:  39%|███▉      | 197/500 [00:15<00:24, 12.57it/s]

Weight matching:  40%|███▉      | 199/500 [00:15<00:24, 12.47it/s]

Weight matching:  40%|████      | 201/500 [00:15<00:23, 12.59it/s]

Weight matching:  41%|████      | 203/500 [00:15<00:23, 12.57it/s]

Weight matching:  41%|████      | 205/500 [00:15<00:23, 12.37it/s]

Weight matching:  41%|████▏     | 207/500 [00:16<00:23, 12.51it/s]

Weight matching:  42%|████▏     | 209/500 [00:16<00:22, 12.76it/s]

Weight matching:  42%|████▏     | 211/500 [00:16<00:22, 12.86it/s]

Weight matching:  43%|████▎     | 213/500 [00:16<00:22, 12.85it/s]

Weight matching:  43%|████▎     | 215/500 [00:16<00:22, 12.58it/s]

Weight matching:  43%|████▎     | 217/500 [00:16<00:22, 12.74it/s]

Weight matching:  44%|████▍     | 219/500 [00:17<00:21, 12.97it/s]

Weight matching:  44%|████▍     | 221/500 [00:17<00:20, 13.38it/s]

Weight matching:  45%|████▍     | 223/500 [00:17<00:20, 13.43it/s]

Weight matching:  45%|████▌     | 225/500 [00:17<00:20, 13.18it/s]

Weight matching:  45%|████▌     | 227/500 [00:17<00:20, 13.22it/s]

Weight matching:  46%|████▌     | 229/500 [00:17<00:20, 13.28it/s]

Weight matching:  46%|████▌     | 231/500 [00:17<00:19, 13.52it/s]

Weight matching:  47%|████▋     | 233/500 [00:18<00:21, 12.53it/s]

Weight matching:  47%|████▋     | 235/500 [00:18<00:22, 11.73it/s]

Weight matching:  47%|████▋     | 237/500 [00:18<00:23, 11.25it/s]

Weight matching:  48%|████▊     | 239/500 [00:18<00:24, 10.62it/s]

Weight matching:  48%|████▊     | 241/500 [00:18<00:24, 10.41it/s]

Weight matching:  49%|████▊     | 243/500 [00:19<00:23, 10.78it/s]

Weight matching:  49%|████▉     | 245/500 [00:19<00:24, 10.62it/s]

Weight matching:  49%|████▉     | 247/500 [00:19<00:24, 10.46it/s]

Weight matching:  50%|████▉     | 249/500 [00:19<00:24, 10.32it/s]

Weight matching:  50%|█████     | 251/500 [00:19<00:22, 10.84it/s]

Weight matching:  51%|█████     | 253/500 [00:19<00:22, 11.00it/s]

Weight matching:  51%|█████     | 255/500 [00:20<00:22, 10.97it/s]

Weight matching:  51%|█████▏    | 257/500 [00:20<00:21, 11.38it/s]

Weight matching:  52%|█████▏    | 259/500 [00:20<00:20, 11.54it/s]

Weight matching:  52%|█████▏    | 261/500 [00:20<00:20, 11.61it/s]

Weight matching:  53%|█████▎    | 263/500 [00:20<00:20, 11.70it/s]

Weight matching:  53%|█████▎    | 265/500 [00:20<00:19, 11.76it/s]

Weight matching:  53%|█████▎    | 267/500 [00:21<00:20, 11.45it/s]

Weight matching:  54%|█████▍    | 269/500 [00:21<00:20, 11.14it/s]

Weight matching:  54%|█████▍    | 271/500 [00:21<00:21, 10.57it/s]

Weight matching:  55%|█████▍    | 273/500 [00:21<00:21, 10.63it/s]

Weight matching:  55%|█████▌    | 275/500 [00:21<00:20, 10.99it/s]

Weight matching:  55%|█████▌    | 277/500 [00:22<00:19, 11.19it/s]

Weight matching:  56%|█████▌    | 279/500 [00:22<00:19, 11.50it/s]

Weight matching:  56%|█████▌    | 281/500 [00:22<00:19, 11.43it/s]

Weight matching:  57%|█████▋    | 283/500 [00:22<00:18, 11.57it/s]

Weight matching:  57%|█████▋    | 285/500 [00:22<00:20, 10.45it/s]

Weight matching:  57%|█████▋    | 287/500 [00:23<00:20, 10.57it/s]

Weight matching:  58%|█████▊    | 289/500 [00:23<00:19, 10.74it/s]

Weight matching:  58%|█████▊    | 291/500 [00:23<00:22,  9.41it/s]

Weight matching:  59%|█████▊    | 293/500 [00:23<00:20, 10.08it/s]

Weight matching:  59%|█████▉    | 295/500 [00:23<00:19, 10.79it/s]

Weight matching:  59%|█████▉    | 297/500 [00:24<00:20, 10.00it/s]

Weight matching:  60%|█████▉    | 299/500 [00:24<00:18, 10.76it/s]

Weight matching:  60%|██████    | 301/500 [00:24<00:18, 11.00it/s]

Weight matching:  61%|██████    | 303/500 [00:24<00:17, 11.44it/s]

Weight matching:  61%|██████    | 305/500 [00:24<00:18, 10.43it/s]

Weight matching:  61%|██████▏   | 307/500 [00:24<00:18, 10.48it/s]

Weight matching:  62%|██████▏   | 309/500 [00:25<00:17, 10.77it/s]

Weight matching:  62%|██████▏   | 311/500 [00:25<00:18, 10.20it/s]

Weight matching:  63%|██████▎   | 313/500 [00:25<00:17, 10.52it/s]

Weight matching:  63%|██████▎   | 315/500 [00:25<00:16, 11.11it/s]

Weight matching:  63%|██████▎   | 317/500 [00:25<00:15, 11.60it/s]

Weight matching:  64%|██████▍   | 319/500 [00:25<00:15, 11.89it/s]

Weight matching:  64%|██████▍   | 321/500 [00:26<00:14, 12.16it/s]

Weight matching:  65%|██████▍   | 323/500 [00:26<00:14, 12.37it/s]

Weight matching:  65%|██████▌   | 325/500 [00:26<00:14, 12.32it/s]

Weight matching:  65%|██████▌   | 327/500 [00:26<00:14, 12.18it/s]

Weight matching:  66%|██████▌   | 329/500 [00:26<00:14, 12.12it/s]

Weight matching:  66%|██████▌   | 331/500 [00:26<00:14, 11.95it/s]

Weight matching:  67%|██████▋   | 333/500 [00:27<00:14, 11.79it/s]

Weight matching:  67%|██████▋   | 335/500 [00:27<00:14, 11.76it/s]

Weight matching:  67%|██████▋   | 337/500 [00:27<00:13, 11.70it/s]

Weight matching:  68%|██████▊   | 339/500 [00:27<00:13, 11.68it/s]

Weight matching:  68%|██████▊   | 341/500 [00:27<00:16,  9.80it/s]

Weight matching:  69%|██████▊   | 343/500 [00:28<00:15, 10.10it/s]

Weight matching:  69%|██████▉   | 345/500 [00:28<00:14, 10.50it/s]

Weight matching:  69%|██████▉   | 347/500 [00:28<00:15, 10.01it/s]

Weight matching:  70%|██████▉   | 349/500 [00:28<00:14, 10.47it/s]

Weight matching:  70%|███████   | 351/500 [00:28<00:13, 10.77it/s]

Weight matching:  71%|███████   | 353/500 [00:29<00:13, 10.77it/s]

Weight matching:  71%|███████   | 355/500 [00:29<00:13, 10.99it/s]

Weight matching:  71%|███████▏  | 357/500 [00:29<00:13, 10.35it/s]

Weight matching:  72%|███████▏  | 359/500 [00:29<00:13, 10.63it/s]

Weight matching:  72%|███████▏  | 361/500 [00:29<00:13, 10.61it/s]

Weight matching:  73%|███████▎  | 363/500 [00:29<00:12, 11.02it/s]

Weight matching:  73%|███████▎  | 365/500 [00:30<00:12, 11.09it/s]

Weight matching:  73%|███████▎  | 367/500 [00:30<00:13, 10.04it/s]

Weight matching:  74%|███████▍  | 369/500 [00:30<00:13,  9.78it/s]

Weight matching:  74%|███████▍  | 370/500 [00:30<00:13,  9.74it/s]

Weight matching:  74%|███████▍  | 372/500 [00:30<00:12, 10.29it/s]

Weight matching:  75%|███████▍  | 374/500 [00:31<00:11, 10.63it/s]

Weight matching:  75%|███████▌  | 376/500 [00:31<00:11, 10.63it/s]

Weight matching:  76%|███████▌  | 378/500 [00:31<00:11, 11.01it/s]

Weight matching:  76%|███████▌  | 380/500 [00:31<00:11, 10.62it/s]

Weight matching:  76%|███████▋  | 382/500 [00:31<00:11, 10.56it/s]

Weight matching:  77%|███████▋  | 384/500 [00:32<00:11, 10.49it/s]

Weight matching:  77%|███████▋  | 386/500 [00:32<00:10, 10.51it/s]

Weight matching:  78%|███████▊  | 388/500 [00:32<00:10, 10.46it/s]

Weight matching:  78%|███████▊  | 390/500 [00:32<00:10, 10.23it/s]

Weight matching:  78%|███████▊  | 392/500 [00:32<00:10, 10.20it/s]

Weight matching:  79%|███████▉  | 394/500 [00:32<00:10, 10.43it/s]

Weight matching:  79%|███████▉  | 396/500 [00:33<00:09, 10.40it/s]

Weight matching:  80%|███████▉  | 398/500 [00:33<00:09, 10.69it/s]

Weight matching:  80%|████████  | 400/500 [00:33<00:10,  9.42it/s]

Weight matching:  80%|████████  | 402/500 [00:33<00:10,  9.71it/s]

Weight matching:  81%|████████  | 404/500 [00:33<00:09, 10.13it/s]

Weight matching:  81%|████████  | 406/500 [00:34<00:09, 10.32it/s]

Weight matching:  82%|████████▏ | 408/500 [00:34<00:08, 10.35it/s]

Weight matching:  82%|████████▏ | 410/500 [00:34<00:08, 10.74it/s]

Weight matching:  82%|████████▏ | 412/500 [00:34<00:07, 11.00it/s]

Weight matching:  83%|████████▎ | 414/500 [00:34<00:07, 10.91it/s]

Weight matching:  83%|████████▎ | 416/500 [00:35<00:07, 10.96it/s]

Weight matching:  84%|████████▎ | 418/500 [00:35<00:07, 10.96it/s]

Weight matching:  84%|████████▍ | 420/500 [00:35<00:07, 11.08it/s]

Weight matching:  84%|████████▍ | 422/500 [00:35<00:06, 11.15it/s]

Weight matching:  85%|████████▍ | 424/500 [00:35<00:06, 11.22it/s]

Weight matching:  85%|████████▌ | 426/500 [00:35<00:06, 11.00it/s]

Weight matching:  86%|████████▌ | 428/500 [00:36<00:06, 11.26it/s]

Weight matching:  86%|████████▌ | 430/500 [00:36<00:06, 11.38it/s]

Weight matching:  86%|████████▋ | 432/500 [00:36<00:05, 11.50it/s]

Weight matching:  87%|████████▋ | 434/500 [00:36<00:06, 10.88it/s]

Weight matching:  87%|████████▋ | 436/500 [00:36<00:05, 10.76it/s]

Weight matching:  88%|████████▊ | 438/500 [00:37<00:05, 11.06it/s]

Weight matching:  88%|████████▊ | 440/500 [00:37<00:05, 11.20it/s]

Weight matching:  88%|████████▊ | 442/500 [00:37<00:05, 11.31it/s]

Weight matching:  89%|████████▉ | 444/500 [00:37<00:04, 11.44it/s]

Weight matching:  89%|████████▉ | 446/500 [00:37<00:04, 11.07it/s]

Weight matching:  90%|████████▉ | 448/500 [00:37<00:04, 11.22it/s]

Weight matching:  90%|█████████ | 450/500 [00:38<00:04, 11.40it/s]

Weight matching:  90%|█████████ | 452/500 [00:38<00:04, 11.58it/s]

Weight matching:  91%|█████████ | 454/500 [00:38<00:04, 11.36it/s]

Weight matching:  91%|█████████ | 456/500 [00:38<00:03, 11.26it/s]

Weight matching:  92%|█████████▏| 458/500 [00:38<00:03, 11.17it/s]

Weight matching:  92%|█████████▏| 460/500 [00:38<00:03, 11.19it/s]

Weight matching:  92%|█████████▏| 462/500 [00:39<00:03, 11.34it/s]

Weight matching:  93%|█████████▎| 464/500 [00:39<00:03, 11.66it/s]

Weight matching:  93%|█████████▎| 466/500 [00:39<00:02, 11.96it/s]

Weight matching:  94%|█████████▎| 468/500 [00:39<00:02, 11.85it/s]

Weight matching:  94%|█████████▍| 470/500 [00:39<00:02, 11.37it/s]

Weight matching:  94%|█████████▍| 472/500 [00:40<00:02, 11.20it/s]

Weight matching:  95%|█████████▍| 474/500 [00:40<00:02, 11.04it/s]

Weight matching:  95%|█████████▌| 476/500 [00:40<00:02, 10.86it/s]

Weight matching:  96%|█████████▌| 478/500 [00:40<00:01, 11.15it/s]

Weight matching:  96%|█████████▌| 480/500 [00:40<00:01, 11.51it/s]

Weight matching:  96%|█████████▌| 481/500 [00:40<00:01, 11.75it/s]


In [12]:
print(f'{perm_matrices[0]["a"]["P_conv1"]=}')
print(f'{perm_matrices[0]["a"]["P_fc1"]=}')
print(f'{perm_matrices[0]["a"]["P_fc2"]=}')

print(f'{perm_matrices[0]["b"]["P_conv1"]=}')
print(f'{perm_matrices[0]["b"]["P_fc1"]=}')
print(f'{perm_matrices[0]["b"]["P_fc2"]=}')

perm_matrices[0]["a"]["P_conv1"]=tensor([0, 1, 2, 3, 4, 5])
perm_matrices[0]["a"]["P_fc1"]=tensor([  0,  29,   3,  67, 107,  49,   6,   7,  68,   9,  16,  76, 103, 119,
         54,  15,  28,  17,  65,  63,  59,  98,  92,  80,  85,  37,  70,  55,
         66,  50,  30,  31, 114,  11,  41,  21,  99,  74,  38,  22,  93,  90,
         81,  43,  69,   1,  45,  48,  36,  40,  10, 101,  13,  94,  95,  52,
         19, 109,  83,  42,  44, 116, 100, 113,  64,  72,  61,  89,  34,  84,
        104,  47,  77, 112,  39,  73,  78, 108,  60,  79, 102,  96, 115,  62,
         91,   2,  53,  97,  33, 110,  46,   8, 105,  86, 118,  88,  75,  12,
         58,  57,   4,  82,  14,  51, 117,  35, 106,  32,  25,  56, 111,  18,
         24,  23,  20,   5,  87,  71,  27,  26])
perm_matrices[0]["a"]["P_fc2"]=tensor([ 0,  1, 41, 45,  4, 43,  6,  7, 16, 79, 10, 48, 11, 67, 14, 25, 73, 68,
        18, 39, 20, 21, 22, 23, 63,  5, 26, 27, 28, 29, 15, 31, 32, 33, 34, 35,
        36, 12, 38, 71, 78, 17, 42, 51, 44,  

In [13]:
# Apply permutation matrices to the parameters of n and n_permuted
from c2m3.match.utils import perm_indices_to_perm_matrix, perm_matrix_to_perm_indices

n2 = copy.deepcopy(n)
n_permuted2 = copy.deepcopy(n_permuted)

# Process both models consistently
for model_symbol, model in [('a', n2), ('b', n_permuted2)]:
    perms_to_apply = {}
    
    for perm_name in perm_matrices[0][model_symbol].keys():
        # Get original permutation indices
        perm = perm_matrices[0][model_symbol][perm_name]
        
        # Convert to matrix, transpose, then back to indices
        perm_matrix = perm_indices_to_perm_matrix(perm).T
        perm_to_apply = perm_matrix_to_perm_indices(perm_matrix)
        
        perms_to_apply[perm_name] = perm_to_apply
    
    # Apply the permutation to the model
    updated_params = apply_permutation_to_statedict(
        perm_spec, perms_to_apply, model.state_dict()
    )
    model.load_state_dict(updated_params)

# Check if both results are the same
for layer in n2.state_dict():
    if not torch.allclose(n2.state_dict()[layer], n_permuted2.state_dict()[layer], rtol=1e-4, atol=1e-4):
        print(f'{layer} not the same')
        
        # Get the tensors
        tensor1 = n2.state_dict()[layer]
        tensor2 = n_permuted2.state_dict()[layer]
        
        # Check if this is a weight matrix with rows (2D or 4D tensor)
        if len(tensor1.shape) in [2, 4]:
            # For 4D convolutional weights, reshape to 2D for comparison
            if len(tensor1.shape) == 4:
                t1_reshaped = tensor1.reshape(tensor1.shape[0], -1)
                t2_reshaped = tensor2.reshape(tensor2.shape[0], -1)
            else:
                t1_reshaped = tensor1
                t2_reshaped = tensor2
                
            # Calculate row-wise cosine similarity
            for i in range(t1_reshaped.shape[0]):
                row_sim = torch.cosine_similarity(
                    t1_reshaped[i].flatten().unsqueeze(0), 
                    t2_reshaped[i].flatten().unsqueeze(0), 
                    dim=1
                )
                
                # Find if this row is similar to any row in the other tensor
                all_sims = []
                for j in range(t2_reshaped.shape[0]):
                    sim = torch.cosine_similarity(
                        t1_reshaped[i].flatten().unsqueeze(0),
                        t2_reshaped[j].flatten().unsqueeze(0),
                        dim=1
                    )
                    all_sims.append((j, sim.item()))
                
                # Sort by similarity and check if the best match is the same index
                all_sims.sort(key=lambda x: x[1], reverse=True)
                best_match_idx, best_match_sim = all_sims[0]
                
                # Only print details if the best match is not the same index
                # or if the similarity is below threshold
                if best_match_idx != i or row_sim < 0.99:
                    if row_sim < 0.99:  # Threshold for obviously different rows
                        print(f'  Row {i} has low similarity: {row_sim.item():.5f}')
                    
                    # Only print best matches if the best match isn't the same index
                    if best_match_idx != i:
                        print(f"  Row {i} best matches in other tensor:")
                        for j, sim in all_sims[:3]:
                            print(f"    Row {j}: similarity {sim:.5f}")
            
            # Now check column-wise similarity
            if len(tensor1.shape) == 2:  # Only for 2D matrices
                # Calculate column-wise cosine similarity
                for i in range(t1_reshaped.shape[1]):
                    col_sim = torch.cosine_similarity(
                        t1_reshaped[:, i].unsqueeze(0), 
                        t2_reshaped[:, i].unsqueeze(0), 
                        dim=1
                    )
                    
                    # Find if this column is similar to any column in the other tensor
                    all_sims = []
                    for j in range(t2_reshaped.shape[1]):
                        sim = torch.cosine_similarity(
                            t1_reshaped[:, i].unsqueeze(0),
                            t2_reshaped[:, j].unsqueeze(0),
                            dim=1
                        )
                        all_sims.append((j, sim.item()))
                    
                    # Sort by similarity and check if the best match is the same index
                    all_sims.sort(key=lambda x: x[1], reverse=True)
                    best_match_idx, best_match_sim = all_sims[0]
                    
                    # Only print details if the best match is not the same index
                    # or if the similarity is below threshold
                    if best_match_idx != i or col_sim < 0.99:
                        if col_sim < 0.99:  # Threshold for obviously different columns
                            print(f'  Column {i} has low similarity: {col_sim.item():.5f}')
                        
                        # Only print best matches if the best match isn't the same index
                        if best_match_idx != i:
                            print(f"  Column {i} best matches in other tensor:")
                            for j, sim in all_sims[:3]:
                                print(f"    Column {j}: similarity {sim:.5f}")
        
        print(f'n2[{layer}]={n2.state_dict()[layer]}')
        print(f'n_permuted2[{layer}]={n_permuted2.state_dict()[layer]}')
    else:
        print(f'{layer} is the same')
    similarity = torch.cosine_similarity(n2.state_dict()[layer].flatten(), n_permuted2.state_dict()[layer].flatten(), dim=0)
    print(f'{layer} similarity: {similarity:.5f})')



conv1.weight is the same
conv1.weight similarity: 1.00000)
conv1.bias is the same
conv1.bias similarity: 1.00000)
conv2.weight is the same
conv2.weight similarity: 1.00000)
conv2.bias is the same
conv2.bias similarity: 1.00000)
fc1.weight not the same
  Row 17 has low similarity: 0.08021
  Row 17 best matches in other tensor:
    Row 106: similarity 1.00000
    Row 40: similarity 0.17030
    Row 51: similarity 0.14671
  Row 30 has low similarity: 0.10414
  Row 30 best matches in other tensor:
    Row 17: similarity 1.00000
    Row 67: similarity 0.13691
    Row 77: similarity 0.12938
  Row 79 has low similarity: -0.08946
  Row 79 best matches in other tensor:
    Row 30: similarity 1.00000
    Row 4: similarity 0.12599
    Row 50: similarity 0.12106
  Row 106 has low similarity: -0.11349
  Row 106 best matches in other tensor:
    Row 79: similarity 1.00000
    Row 119: similarity 0.15562
    Row 43: similarity 0.12833
  Column 0 has low similarity: 0.94036
  Column 1 has low similarit

In [14]:
# Apply permutation matrices to the parameters of n and n_permuted
from c2m3.match.utils import perm_indices_to_perm_matrix, perm_matrix_to_perm_indices

n3 = copy.deepcopy(n)
n_permuted3 = copy.deepcopy(n_permuted)

# First transform n_permuted3 to universe space using transpose of its perm matrix
perms_to_apply = {}
for perm_name in perm_matrices[0]['b'].keys():
    # Get original permutation indices
    perm = perm_matrices[0]['b'][perm_name]
    
    # Convert to matrix, transpose, then back to indices
    perm_matrix = perm_indices_to_perm_matrix(perm).T
    perm_to_apply = perm_matrix_to_perm_indices(perm_matrix)
    
    perms_to_apply[perm_name] = perm_to_apply

# Apply permutation to get n_permuted3 in universe space
universe_params = apply_permutation_to_statedict(
    perm_spec, perms_to_apply, n_permuted3.state_dict()
)

# Now transform from universe space to n3's space using n3's perm matrix (not transposed)
perms_to_apply = {}
for perm_name in perm_matrices[0]['a'].keys():
    # Get original permutation indices
    perm = perm_matrices[0]['a'][perm_name]
    
    # Convert to matrix (no transpose) then to indices
    perm_matrix = perm_indices_to_perm_matrix(perm)
    perm_to_apply = perm_matrix_to_perm_indices(perm_matrix)
    
    perms_to_apply[perm_name] = perm_to_apply

# Apply permutation to transform from universe to n3's space
transformed_params = apply_permutation_to_statedict(
    perm_spec, perms_to_apply, universe_params
)

# Load the transformed parameters
n_permuted3.load_state_dict(transformed_params)

# Check if results match n3
for layer in n3.state_dict():
    if not torch.allclose(n3.state_dict()[layer], n_permuted3.state_dict()[layer], rtol=1e-4, atol=1e-4):
        print(f'{layer} not the same')
        
        # Get the tensors
        tensor1 = n2.state_dict()[layer]
        tensor2 = n_permuted2.state_dict()[layer]
        
        # Check if this is a weight matrix with rows (2D or 4D tensor)
        if len(tensor1.shape) in [2, 4]:
            # For 4D convolutional weights, reshape to 2D for comparison
            if len(tensor1.shape) == 4:
                t1_reshaped = tensor1.reshape(tensor1.shape[0], -1)
                t2_reshaped = tensor2.reshape(tensor2.shape[0], -1)
            else:
                t1_reshaped = tensor1
                t2_reshaped = tensor2
                
            # Calculate row-wise cosine similarity
            for i in range(t1_reshaped.shape[0]):
                row_sim = torch.cosine_similarity(
                    t1_reshaped[i].flatten().unsqueeze(0), 
                    t2_reshaped[i].flatten().unsqueeze(0), 
                    dim=1
                )
                
                # Find if this row is similar to any row in the other tensor
                all_sims = []
                for j in range(t2_reshaped.shape[0]):
                    sim = torch.cosine_similarity(
                        t1_reshaped[i].flatten().unsqueeze(0),
                        t2_reshaped[j].flatten().unsqueeze(0),
                        dim=1
                    )
                    all_sims.append((j, sim.item()))
                
                # Sort by similarity and check if the best match is the same index
                all_sims.sort(key=lambda x: x[1], reverse=True)
                best_match_idx, best_match_sim = all_sims[0]
                
                # Only print details if the best match is not the same index
                # or if the similarity is below threshold
                if best_match_idx != i or row_sim < 0.99:
                    if row_sim < 0.99:  # Threshold for obviously different rows
                        print(f'  Row {i} has low similarity: {row_sim.item():.5f}')
                    
                    # Only print best matches if the best match isn't the same index
                    if best_match_idx != i:
                        print(f"  Row {i} best matches in other tensor:")
                        for j, sim in all_sims[:3]:
                            print(f"    Row {j}: similarity {sim:.5f}")
            
            # Now check column-wise similarity
            if len(tensor1.shape) == 2:  # Only for 2D matrices
                # Calculate column-wise cosine similarity
                for i in range(t1_reshaped.shape[1]):
                    col_sim = torch.cosine_similarity(
                        t1_reshaped[:, i].unsqueeze(0), 
                        t2_reshaped[:, i].unsqueeze(0), 
                        dim=1
                    )
                    
                    # Find if this column is similar to any column in the other tensor
                    all_sims = []
                    for j in range(t2_reshaped.shape[1]):
                        sim = torch.cosine_similarity(
                            t1_reshaped[:, i].unsqueeze(0),
                            t2_reshaped[:, j].unsqueeze(0),
                            dim=1
                        )
                        all_sims.append((j, sim.item()))
                    
                    # Sort by similarity and check if the best match is the same index
                    all_sims.sort(key=lambda x: x[1], reverse=True)
                    best_match_idx, best_match_sim = all_sims[0]
                    
                    # Only print details if the best match is not the same index
                    # or if the similarity is below threshold
                    if best_match_idx != i or col_sim < 0.99:
                        if col_sim < 0.99:  # Threshold for obviously different columns
                            print(f'  Column {i} has low similarity: {col_sim.item():.5f}')
                        
                        # Only print best matches if the best match isn't the same index
                        if best_match_idx != i:
                            print(f"  Column {i} best matches in other tensor:")
                            for j, sim in all_sims[:3]:
                                print(f"    Column {j}: similarity {sim:.5f}")
        
        print(f'n2[{layer}]={n2.state_dict()[layer]}')
        print(f'n_permuted2[{layer}]={n_permuted2.state_dict()[layer]}')
    else:
        print(f'{layer} is the same')
    similarity = torch.cosine_similarity(n3.state_dict()[layer].flatten(), n_permuted3.state_dict()[layer].flatten(), dim=0)
    print(f'{layer} similarity: {similarity:.5f})')


conv1.weight is the same
conv1.weight similarity: 1.00000)
conv1.bias is the same
conv1.bias similarity: 1.00000)
conv2.weight is the same
conv2.weight similarity: 1.00000)
conv2.bias is the same
conv2.bias similarity: 1.00000)
fc1.weight not the same
  Row 17 has low similarity: 0.08021
  Row 17 best matches in other tensor:
    Row 106: similarity 1.00000
    Row 40: similarity 0.17030
    Row 51: similarity 0.14671
  Row 30 has low similarity: 0.10414
  Row 30 best matches in other tensor:
    Row 17: similarity 1.00000
    Row 67: similarity 0.13691
    Row 77: similarity 0.12938
  Row 79 has low similarity: -0.08946
  Row 79 best matches in other tensor:
    Row 30: similarity 1.00000
    Row 4: similarity 0.12599
    Row 50: similarity 0.12106
  Row 106 has low similarity: -0.11349
  Row 106 best matches in other tensor:
    Row 79: similarity 1.00000
    Row 119: similarity 0.15562
    Row 43: similarity 0.12833
  Column 0 has low similarity: 0.94036
  Column 1 has low similarit