In [4]:
import os
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from torchvision import transforms, datasets
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
from tqdm import tqdm
import numpy as np
import copy
from collections import defaultdict, namedtuple
from typing import NamedTuple
from scipy.optimize import linear_sum_assignment
import logging
from torch.cuda.amp import autocast, GradScaler
import torchvision
import jax.numpy as jnp
from jax import random


# Set up device and logging
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.basicConfig(stream=sys.stdout, format='%(asctime)s %(levelname)s: %(message)s', level=logging.INFO, datefmt='%I:%M:%S')

# --------------------------- RESNET MODEL DEFINITION --------------------------- #
# Basic Residual Block
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super().__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, 3, stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, 3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(planes)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

# ResNet for CIFAR-10
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super().__init__()
        self.in_planes = 16

        self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(16)
        self.layer1 = self._make_layer(block, 16, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 32, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 64, num_blocks[2], stride=2)
        self.linear = nn.Linear(64, num_classes)

    def _make_layer(self, block, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for s in strides:
            layers.append(block(self.in_planes, planes, s))
            self.in_planes = planes * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))  # 32x32
        out = self.layer1(out)  # 32x32
        out = self.layer2(out)  # 16x16
        out = self.layer3(out)  # 8x8
        out = F.avg_pool2d(out, 8)  # Global avg pool
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

def ResNet20():
    return ResNet(BasicBlock, [3,3,3])

# --------------------------- WEIGHT MATCHING CODE --------------------------- #

# PermutationSpec class for defining permutable dimensions
class PermutationSpec(NamedTuple):
    perm_to_axes: dict
    axes_to_perm: dict

def permutation_spec_from_axes_to_perm(axes_to_perm: dict) -> PermutationSpec:
    perm_to_axes = defaultdict(list)
    for wk, axis_perms in axes_to_perm.items():
        for axis, perm in enumerate(axis_perms):
            if perm is not None:
                perm_to_axes[perm].append((wk, axis))
    return PermutationSpec(perm_to_axes=dict(perm_to_axes), axes_to_perm=axes_to_perm)

def norm(name, p):
    return {
        f"{name}.weight": (p,),  # gamma
        f"{name}.bias": (p,),    # beta
        f"{name}.running_mean": (p,),  # running mean buffer
        f"{name}.running_var": (p,),   # running var buffer
    }

def resnet_permutation_spec() -> PermutationSpec:
    conv = lambda name, p_in, p_out: {f"{name}.weight": (p_out, p_in, None, None)}

    norm = lambda name, p: {
        f"{name}.weight": (p,),
        f"{name}.bias": (p,),
        f"{name}.running_mean": (p,),
        f"{name}.running_var": (p,),
    }

    dense = lambda name, p_in, p_out: {
        f"{name}.weight": (p_out, p_in),
        f"{name}.bias": (p_out,)
    }

    easyblock = lambda name, p: {
        **conv(f"{name}.conv1", p, f"P_{name}_inner"),
        **norm(f"{name}.bn1", f"P_{name}_inner"),
        **conv(f"{name}.conv2", f"P_{name}_inner", p),
        **norm(f"{name}.bn2", p),
    }

    shortcutblock = lambda name, p_in, p_out: {
        **conv(f"{name}.conv1", p_in, f"P_{name}_inner"),
        **norm(f"{name}.bn1", f"P_{name}_inner"),
        **conv(f"{name}.conv2", f"P_{name}_inner", p_out),
        **norm(f"{name}.bn2", p_out),
        **conv(f"{name}.shortcut.0", p_in, p_out),
        **norm(f"{name}.shortcut.1", p_out),
    }

    return permutation_spec_from_axes_to_perm({
        **conv("conv1", None, "P_bg0"),
        **norm("bn1", "P_bg0"),

        **easyblock("layer1.0", "P_bg0"),
        **easyblock("layer1.1", "P_bg0"),
        **easyblock("layer1.2", "P_bg0"),

        **shortcutblock("layer2.0", "P_bg0", "P_bg1"),
        **easyblock("layer2.1", "P_bg1"),
        **easyblock("layer2.2", "P_bg1"),

        **shortcutblock("layer3.0", "P_bg1", "P_bg2"),
        **easyblock("layer3.1", "P_bg2"),
        **easyblock("layer3.2", "P_bg2"),

        **dense("linear", "P_bg2", None),
    })

def get_permuted_param(ps: PermutationSpec, perm, k: str, params, except_axis=None):
  """Get parameter `k` from `params`, with the permutations applied."""
  w = params[k]
  for axis, p in enumerate(ps.axes_to_perm[k]):
    # Skip the axis we're trying to permute.
    if axis == except_axis:
      continue

    # None indicates that there is no permutation relevant to that axis.
    if p is not None:
      w = jnp.take(w, perm[p], axis=axis)

  return w

def apply_permutation(ps: PermutationSpec, perm, params):
  """Apply a `perm` to `params`."""
  return {k: get_permuted_param(ps, perm, k, params) for k in params.keys()}

def rngmix(rng, i):
    """Mix a base RNG with an integer to generate a new RNG."""
    return random.fold_in(rng, i)


def weight_matching(rng,
                    ps: PermutationSpec,
                    params_a,
                    params_b,
                    max_iter=100,
                    init_perm=None,
                    silent=True):
    """Find a permutation of `params_b` to make them match `params_a`."""
    perm_sizes = {p: params_a[axes[0][0]].shape[axes[0][1]] for p, axes in ps.perm_to_axes.items()}
    
    perm = {p: jnp.arange(n) for p, n in perm_sizes.items()} if init_perm is None else init_perm
    perm_names = list(perm.keys())
    
    for iteration in range(max_iter):
        progress = False
        for p_ix in random.permutation(rngmix(rng, iteration), len(perm_names)):
          p = perm_names[p_ix]
          n = perm_sizes[p]
          A = jnp.zeros((n, n))
          for wk, axis in ps.perm_to_axes[p]:
            w_a = params_a[wk]
            w_b = get_permuted_param(ps, perm, wk, params_b, except_axis=axis)
            w_a = jnp.moveaxis(w_a, axis, 0).reshape((n, -1))
            w_b = jnp.moveaxis(w_b, axis, 0).reshape((n, -1))
            A += w_a @ w_b.T
        
          ri, ci = linear_sum_assignment(A, maximize=True)
          assert (ri == jnp.arange(len(ri))).all()
        
          oldL = jnp.vdot(A, jnp.eye(n)[perm[p]])
          newL = jnp.vdot(A, jnp.eye(n)[ci, :])
          if not silent: print(f"{iteration}/{p}: {newL - oldL}")
          progress = progress or newL > oldL + 1e-12
        
          perm[p] = jnp.array(ci)
        
        if not progress:
          break
        
        return perm

# --------------------------- WEIGHT MATCHING AND FLOW MATCHING PIPELINE --------------------------- #


def get_test_loader(batch_size=128):
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),  # CIFAR-10 mean
                             (0.2023, 0.1994, 0.2010))  # CIFAR-10 std
    ])

    test_set = torchvision.datasets.CIFAR10(
        root='./data', train=False, download=True, transform=transform_test
    )

    test_loader = DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=2)
    return test_loader


def evaluate(model, test_loader, device='cuda' if torch.cuda.is_available() else 'cpu'):
    model.eval()
    model.to(device)

    correct = 0
    total = 0

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

    accuracy = 100.0 * correct / total
    return accuracy

from collections import OrderedDict
import jax.numpy as jnp
from jax import random

def test_weight_matching(model_dir="imagenet_resnet_models"):
    """Tests weight matching between two ResNet20 models trained separately."""

    # Define device and instantiate models
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model_a = ResNet20().to(device)
    model_b = ResNet20().to(device)

    # Load pretrained weights into the PyTorch models first
    ref_point = 0
    perm_point = 100
    ref_model_path = f"{model_dir}/resnet_weights_{ref_point}.pt"
    perm_model_path = f"{model_dir}/resnet_weights_{perm_point}.pt"
    model_a.load_state_dict(torch.load(ref_model_path, map_location=device))
    model_b.load_state_dict(torch.load(perm_model_path, map_location=device))

    # Get permutation spec
    ps = resnet_permutation_spec()

    # Extract relevant parameters and convert them to JAX arrays
    params_a = {k: v.clone().detach() for k, v in model_a.state_dict().items() if k in ps.axes_to_perm}
    params_b = {k: v.clone().detach() for k, v in model_b.state_dict().items() if k in ps.axes_to_perm}
    params_a_jax = {k: jnp.array(v.cpu().numpy()) for k, v in params_a.items()}
    params_b_jax = {k: jnp.array(v.cpu().numpy()) for k, v in params_b.items()}

    # Evaluate before matching
    test_loader = get_test_loader()
    accuracy_before_matching = evaluate(model_b, test_loader)

    # Perform weight matching (in JAX)
    rng = random.PRNGKey(123)
    perm = weight_matching(rng, ps, params_a_jax, params_b_jax)

    # Apply permutation to params_b (in JAX)
    permuted_params_b = apply_permutation(ps, perm, params_b_jax)

    # Convert permuted parameters back to PyTorch tensors
    permuted_params_b_torch = {k: torch.tensor(np.array(v)) for k, v in permuted_params_b.items()}

    # Update model_b with permuted parameters
    state_dict_b = model_b.state_dict()
    for k in permuted_params_b_torch:
        state_dict_b[k] = permuted_params_b_torch[k]
    model_b.load_state_dict(state_dict_b)

    # Evaluate after matching
    accuracy_after_matching = evaluate(model_b, test_loader)

    # Print results
    print(f"Before matching accuracy: {accuracy_before_matching:.2f}%")
    print(f"After matching accuracy: {accuracy_after_matching:.2f}%")
    
test_weight_matching()

Before matching accuracy: 74.13%
After matching accuracy: 74.13%
