In [1]:
import torch
from re import L
import numpy as np
from typing import NamedTuple
from collections import defaultdict
from scipy.optimize import linear_sum_assignment

In [144]:
class PermutationSpec(NamedTuple):
    perm_to_axes: dict
    axes_to_perm: dict
    
def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec:
    perm_to_axes = defaultdict(list)
    for wk, axis_perms in axes_to_perm.items():
        for axis, perm in enumerate(axis_perms):
            if perm is not None:
                perm_to_axes[perm].append((wk, axis))
    return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm)

def mlp_permutation_spec(num_hidden_layers: int) -> PermutationSpec:
    """We assume that one permutation cannot appear in two axes of the same weight array."""
    assert num_hidden_layers >= 1
    return permutation_spec_from_axes_to_perm({
        "layer0.weight": ("P_0", None),
        **{f"layer{i}.weight": ( f"P_{i}", f"P_{i-1}")
           for i in range(1, num_hidden_layers)},
        **{f"layer{i}.bias": (f"P_{i}", )
           for i in range(num_hidden_layers)},
        f"layer{num_hidden_layers}.weight": (None, f"P_{num_hidden_layers-1}"),
        f"layer{num_hidden_layers}.bias": (None, ),
    })

In [145]:
def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None):
    """Get parameter `k` from `params`, with the permutations applied."""
    w = params[k]
    for axis, p in enumerate(ps.axes_to_perm[k]):
        # Skip the axis we're trying to permute.
        if axis == except_axis:
            continue
        # None indicates that there is no permutation relevant to that axis.
        if p is not None:
            w = torch.index_select(w, axis, perm[p].int())
    return w

In [146]:
def apply_permutation(ps: PermutationSpec, perm, params):
    """Apply a `perm` to `params`."""
    return {k: get_permuted_param(ps, perm, k, params) for k in params.keys()}

In [174]:
def weight_matching(ps: PermutationSpec, params_a, params_b, max_iter=300, init_perm=None):
    
    """Find a permutation of `params_b` to make them match `params_a`."""
    perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}
    print(perm_sizes)
    perm = {p: torch.arange(n) for p, n in perm_sizes.items()} if init_perm is None else init_perm
    
    perm_names = list(perm.keys())
    
    for iteration in range(max_iter):
        progress = False
        for p_ix in torch.randperm(len(perm_names)):
            p = perm_names[p_ix]
            n = perm_sizes[p]
            A = torch.zeros((n, n))
            for wk, axis in ps.perm_to_axes[p]:
                w_a = params_a[wk]
                w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
                w_a = torch.moveaxis(w_a, axis, 0).reshape((n, -1))
                w_b = torch.moveaxis(w_b, axis, 0).reshape((n, -1))
                A += w_a @ w_b.T

        ri, ci = linear_sum_assignment(A.detach().numpy(), maximize=True)
        assert (torch.tensor(ri) == torch.arange(len(ri))).all()
        oldL = torch.einsum('ij,ij->i', A, torch.eye(n)[perm[p].long()]).sum()
        newL = torch.einsum('ij,ij->i', A,torch.eye(n)[ci, :]).sum()
        print(f"{iteration} / {p}: {newL - oldL}")
        progress = progress or newL > oldL + 1e-12
        perm[p] = torch.Tensor(ci)

        if not progress:
            break

    return perm

In [175]:
def test_weight_matching():
    """If we just have a single hidden layer then it should converge after just one step."""
    ps = mlp_permutation_spec(num_hidden_layers=3)
    print("axes_to_perm", ps.axes_to_perm)
    print("perm_to_axes", ps.perm_to_axes)
    
    rng = torch.Generator()
    rng.manual_seed(13)
    num_hidden = 32
    
    shapes = {
        "layer0.weight": (20, num_hidden),
        "layer0.bias": (num_hidden, ),
        "layer1.weight": (num_hidden, num_hidden),
        "layer1.bias": (num_hidden, ),
        "layer2.weight": (num_hidden, num_hidden),
        "layer2.bias": (num_hidden, ),
        "layer3.weight": (num_hidden, 10),
        "layer3.bias": (10, )
    }
    
    params_a = {k: torch.randn(shape, generator=rng) for k, shape in shapes.items()}
    params_b = {k: torch.randn(shape, generator=rng) for k, shape in shapes.items()}
    
    print(params_a.keys(), params_b.keys())
    
    perm = weight_matching(ps, params_a, params_b)
    print(perm)
    
    return params_a, params_b

In [176]:
mlp_permutation_spec(num_hidden_layers=3).axes_to_perm, \
mlp_permutation_spec(num_hidden_layers=3).perm_to_axes

({'layer0.weight': ('P_0', None),
  'layer1.weight': ('P_1', 'P_0'),
  'layer2.weight': ('P_2', 'P_1'),
  'layer0.bias': ('P_0',),
  'layer1.bias': ('P_1',),
  'layer2.bias': ('P_2',),
  'layer3.weight': (None, 'P_2'),
  'layer3.bias': (None,)},
 {'P_0': [('layer0.weight', 0), ('layer1.weight', 1), ('layer0.bias', 0)],
  'P_1': [('layer1.weight', 0), ('layer2.weight', 1), ('layer1.bias', 0)],
  'P_2': [('layer2.weight', 0), ('layer2.bias', 0), ('layer3.weight', 1)]})

In [177]:
param_a, param_b = test_weight_matching()

axes_to_perm {'layer0.weight': ('P_0', None), 'layer1.weight': ('P_1', 'P_0'), 'layer2.weight': ('P_2', 'P_1'), 'layer0.bias': ('P_0',), 'layer1.bias': ('P_1',), 'layer2.bias': ('P_2',), 'layer3.weight': (None, 'P_2'), 'layer3.bias': (None,)}
perm_to_axes {'P_0': [('layer0.weight', 0), ('layer1.weight', 1), ('layer0.bias', 0)], 'P_1': [('layer1.weight', 0), ('layer2.weight', 1), ('layer1.bias', 0)], 'P_2': [('layer2.weight', 0), ('layer2.bias', 0), ('layer3.weight', 1)]}
dict_keys(['layer0.weight', 'layer0.bias', 'layer1.weight', 'layer1.bias', 'layer2.weight', 'layer2.bias', 'layer3.weight', 'layer3.bias']) dict_keys(['layer0.weight', 'layer0.bias', 'layer1.weight', 'layer1.bias', 'layer2.weight', 'layer2.bias', 'layer3.weight', 'layer3.bias'])
{'P_0': 20, 'P_1': 32, 'P_2': 32}


RuntimeError: shape '[20, -1]' is invalid for input of size 1024

In [178]:
# param_a["layer0.weight"].shape, param_a["layer0.bias"].shape, \
# param_a["layer1.weight"].shape, param_a["layer1.bias"].shape, \
# param_a["layer2.weight"].shape, param_a["layer2.bias"].shape, \
# param_a["layer3.weight"].shape, param_a["layer3.bias"].shape