# MNIST MLP Classifiers

A quick start notebook for performing an alignment on MNIST MLPs.

**Make sure you change the runtime type to GPU before starting!**

In [None]:
# If not installed already, please install deep-align and all its dependencies.
! git clone https://github.com/AvivNavon/deep-align.git

In [None]:
cd deep-align

In [None]:
!pip install -e .

## Get Data
Next, we download the MNIST MLPs dataset and place it in `dataset/mnist_models`.

In [1]:
!mkdir dataset

In [None]:
!wget https://www.dropbox.com/sh/56pakaxe58z29mq/AABtWNkRYroLYe_cE3c90DXVa?dl=0 -O data_files.zip

In [2]:
!unzip -q mnist_classifiers.zip -d dataset

## Import Dependencies

In [29]:
import logging
from argparse import ArgumentParser

import numpy as np
import torch
import torch.nn.functional as F
import wandb
from sklearn.model_selection import train_test_split
from torch import nn
from torch.utils.data import DataLoader
from sklearn.metrics import f1_score
from tqdm import trange
from dataclasses import dataclass
import matplotlib.pyplot as plt

from experiments.utils.data.generate_splits import generate_splits

from deepalign.losses.mlp_losses import calc_lmc_loss, calc_recon_loss, calc_gt_perm_loss
from deepalign.utils import extract_pred
from experiments.utils import (
    common_parser, count_parameters, get_device, set_logger, set_seed, str2bool,
)
from experiments.utils.data import MultiViewMatchingBatch, MatchingModelsDataset
from deepalign.sinkhorn import matching
from deepalign import DWSMatching
from experiments.utils.data.image_data import get_mnist_dataloaders

set_logger()

## Generate Data Splits
Next, create the data split, using a subset of the extracted dataset. We will use 1000 models for train and 100 for val/test.

In [6]:
# create dataset splits (train/val/test)
generate_splits(data_root="dataset/mnist_models", save_path="dataset/splits.json", test_size=100, val_size=100, max_models=1200)

2024-05-03 22:04:51,787 - root - INFO - train size: 1000, val size: 100, test size: 100


## MLP Dataset

We create MLP Datasets and Dataloaders.



In [2]:
path = "dataset/splits.json"
batch_size = 8
num_workers = 4

In [3]:
train_set = MatchingModelsDataset(
        path=path,
        split="train",
    )
val_set = MatchingModelsDataset(
    path=path,
    split="val",
)
test_set = MatchingModelsDataset(
    path=path,
    split="test",
)

train_loader = torch.utils.data.DataLoader(
    dataset=train_set,
    batch_size=batch_size,
    shuffle=True,
    num_workers=num_workers,
    pin_memory=True,
)
val_loader = torch.utils.data.DataLoader(
    dataset=val_set,
    batch_size=batch_size,
    num_workers=num_workers,
    shuffle=False,
)
test_loader = torch.utils.data.DataLoader(
    dataset=test_set,
    batch_size=batch_size,
    shuffle=False,
    num_workers=num_workers,
    pin_memory=True,
)

In [4]:
batch = next(iter(train_loader))
batch.as_dict().keys()

dict_keys(['weights_view_0', 'biases_view_0', 'weights_view_1', 'biases_view_1', 'perm_weights_view_0', 'perm_biases_view_0', 'perm_weights_view_1', 'perm_biases_view_1', 'perms_view_0', 'perms_view_1'])

In [5]:
logging.info(
    f"train size {len(train_set)}, "
    f"val size {len(val_set)}, "
    f"test size {len(test_set)}"
)

weight_shapes, bias_shapes = batch.get_weight_shapes()

logging.info(f"weight shapes: {weight_shapes}, bias shapes: {bias_shapes}")

2024-05-03 22:46:53,094 - root - INFO - train size 1000, val size 100, test size 100
2024-05-03 22:46:53,097 - root - INFO - weight shapes: (torch.Size([784, 128]), torch.Size([128, 128]), torch.Size([128, 128]), torch.Size([128, 10])), bias shapes: (torch.Size([128]), torch.Size([128]), torch.Size([128]), torch.Size([10]))



## Image Dataloaders

In [17]:
image_data_path = "datasets/MNIST"
image_batch_size = 12
allow_download = True  # allow downloading MNIST
image_flatten_size = 28 * 28

In [7]:
train_image_loader, val_image_loader, test_image_loader = get_mnist_dataloaders(
    image_data_path, batch_size=image_batch_size, allow_download=allow_download
)

## Initialize DWSNet

In [8]:
# get device
device = get_device()
logging.info(f"device = {device}")

2024-05-03 22:46:53,149 - root - INFO - device = cpu


In [15]:
@dataclass
class Args:
    hidden_dim=64
    n_hidden=4
    output_features=128
    input_dim_downsample=8
    add_bn=True
    diagonal=True
    # loss weights
    supervised_loss_weight=1 
    recon_loss_weight=1
    add_task_loss=True
    add_l2_loss=True

args = Args()

In [10]:
model = DWSMatching(
        weight_shapes=weight_shapes,
        bias_shapes=bias_shapes,
        input_features=1,
        hidden_dim=args.hidden_dim,
        n_hidden=args.n_hidden,
        output_features=args.output_features,
        input_dim_downsample=args.input_dim_downsample,
        bn=args.add_bn,
        diagonal=args.diagonal,
    ).to(device)

logging.info(f"number of parameters: {count_parameters(model)}")

2024-05-03 22:46:53,722 - root - INFO - number of parameters: 31873905


## Eval Function

In [19]:
@torch.no_grad()
def evaluate(model, loader, image_loader, add_task_loss=True, add_l2_loss=True):
    model.eval()

    perm_loss = 0.0
    recon_loss = 0.
    correct = 0.0
    total = 0.0
    predicted, gt = [], []
    recon_losses, baseline_losses, hard_recon_losses, sink_ours_losses, sink_random_losses = [], [], [], [], []
    for j, batch in enumerate(loader):
        image_batch = next(iter(image_loader))
        image_batch = tuple(t.to(device) for t in image_batch)
        batch: MultiViewMatchingBatch = batch.to(device)

        input_0 = (batch.weights_view_0, batch.biases_view_0)
        input_1 = (batch.weights_view_1, batch.biases_view_1)
        perm_input_0 = (batch.perm_weights_view_0, batch.perm_biases_view_0)

        out_0 = model(input_0)
        out_1 = model(input_1)
        perm_out_0 = model(perm_input_0)

        pred_matrices_perm_0 = extract_pred(
            out_0,
            perm_out_0,
        )

        pred_matrices = extract_pred(
            out_0,
            out_1,
        )

        # loss from GT permutations
        curr_gt_loss = calc_gt_perm_loss(
            pred_matrices_perm_0, batch.perms_view_0, device=device
        )

        # reconstruction loss
        curr_recon_loss = calc_recon_loss(
            pred_matrices,
            input_0,
            input_1,
            image_batch=image_batch,
            sinkhorn_project=True,
            add_task_loss=add_task_loss,
            add_l2_loss=add_l2_loss,
            alpha=0.5,
            eval_mode=True,
            device=device,
            image_flatten_size=image_flatten_size,
        )

        # reconstruction loss and images
        results = calc_lmc_loss(
            pred_matrices,
            input_0,
            input_1,
            image_batch=image_batch,
            sinkhorn_project=True,
            device=device,
            image_flatten_size=image_flatten_size,
        )

        recon_losses.append(results["soft"]["losses"])
        hard_recon_losses.append(results["hard"]["losses"])
        baseline_losses.append(results["no_alignment"]["losses"])

        perm_loss += curr_gt_loss.item()
        recon_loss += curr_recon_loss.item()

        curr_correct = 0.
        curr_gts = []
        curr_preds = []

        for pred, gt_perm in zip(pred_matrices_perm_0, batch.perms_view_0):
            pred = matching(pred).to(device)
            curr_correct += ((pred.argmax(1)).eq(gt_perm) * 1.0).mean().item()
            curr_preds.append(pred.argmax(1).reshape(-1))
            curr_gts.append(gt_perm.reshape(-1))

        total += 1
        correct += (curr_correct / len(pred_matrices_perm_0))
        predicted.extend(curr_preds)
        gt.extend(curr_gts)

    predicted = torch.cat(predicted).int()
    gt = torch.cat(gt).int()

    avg_loss = perm_loss / total
    avg_acc = correct / total
    recon_loss = recon_loss / total

    f1 = f1_score(predicted.cpu().detach().numpy(), gt.cpu().detach().numpy(), average="macro")

    # LMC losses
    lmc_losses = dict(
        soft_alignment=np.stack(recon_losses).mean(0),  # NOTE: this is the soft alignment loss.
        no_alignment=np.stack(baseline_losses).mean(0),
        alignment=np.stack(hard_recon_losses).mean(0),
    )

    return dict(
        avg_loss=avg_loss,
        avg_acc=avg_acc,
        recon_loss=recon_loss,
        predicted=predicted,
        gt=gt,
        f1=f1,
        lmc_losses=lmc_losses,
    )

## Model training

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=5e-4)

In [13]:
epochs = 5  # doing just 5 epochs here, in the paper we do 100 on a larger train set

In [None]:
epoch_iter = trange(epochs)
for epoch in epoch_iter:
    for i, batch in enumerate(train_loader):
        model.train()
        optimizer.zero_grad()

        batch: MultiViewMatchingBatch = batch.to(device)
        image_batch = next(iter(train_image_loader))
        image_batch = tuple(t.to(device) for t in image_batch)

        input_0 = (batch.weights_view_0, batch.biases_view_0)
        input_1 = (batch.weights_view_1, batch.biases_view_1)
        perm_input_0 = (batch.perm_weights_view_0, batch.perm_biases_view_0)

        out_0 = model(input_0)
        out_1 = model(input_1)
        perm_out_0 = model(perm_input_0)

        pred_matrices_perm_0 = extract_pred(
            out_0,
            perm_out_0,
        )

        pred_matrices = extract_pred(
            out_0,
            out_1,
        )

        # loss from GT permutations
        gt_perm_loss = calc_gt_perm_loss(
            pred_matrices_perm_0, batch.perms_view_0, device=device
        )

        # reconstruction loss
        recon_loss = calc_recon_loss(
            pred_matrices,
            input_0,
            input_1,
            image_batch=image_batch,
            sinkhorn_project=True,   # if we perms are already bi-stochastic we don't need to do anything
            add_task_loss=args.add_task_loss,
            add_l2_loss=args.add_l2_loss,
            device=device,
            image_flatten_size=image_flatten_size,
        )

        loss = gt_perm_loss * args.supervised_loss_weight + recon_loss * args.recon_loss_weight
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        epoch_iter.set_description(
            f"[{epoch} {i+1}], train loss: {loss.item():.3f}, recon loss: {recon_loss:.3f}, supervised loss: {gt_perm_loss:.3f}"
        )

## Evaluate

In [27]:
test_loss_dict = evaluate(
    model, test_loader, image_loader=test_image_loader,
    add_task_loss=args.add_task_loss, add_l2_loss=args.add_l2_loss,
)

## Plot LMC Results

In [None]:
x = torch.linspace(0.0, 1.0, len(test_loss_dict["lmc_losses"]["alignment"])).numpy().tolist()
for k, v in test_loss_dict["lmc_losses"].items():
    plt.plot(x, v, label=k)
plt.legend()
plt.show()