In [1]:
from utils import build_mlp_model
import torch
model_1 = build_mlp_model()
model_2 = build_mlp_model()
from utils import initialization_with_seed
model_1 = initialization_with_seed(model_1, seed=42)
model_2 = initialization_with_seed(model_2, seed=65)
state_dict_1 = model_1.state_dict()
state_dict_2 = model_2.state_dict()
for key in state_dict_1.keys():
    print(torch.sum(state_dict_1[key] - state_dict_2[key]))
    

tensor(16.6607)
tensor(0.)
tensor(-14.4691)
tensor(0.)
tensor(-3.2008)
tensor(0.)


In [2]:
from utils import prepare_dataset
from torchmetrics import Accuracy
trainset, testset = prepare_dataset("MNIST", model_type="MLP")
accuracy = Accuracy(task="multiclass", num_classes=10)
accuracy = accuracy.to("cuda")

In [3]:
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm
from typing import Literal
def train(model: nn.Module, 
                  train_loader: DataLoader, 
                  optimizer: torch.optim.Optimizer, 
                  criterion: nn.Module,
                  epoch: int, 
                  max_epochs: int, 
                  device: Literal["cpu", "cuda"]):
            model.train()
            model.to(device)
            total_loss = 0
            pbar = tqdm(train_loader)
            for inputs, targets in pbar:
                inputs = inputs.to(device)
                targets = targets.to(device)
                optimizer.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                total_loss += loss.item()
                acc = accuracy(outputs, targets)
                loss.backward()
                optimizer.step()
                pbar.set_description(f"Epoch: {epoch+1}/{max_epochs}, Step Loss: {loss.item():.4f}, Step Acc: {acc.item():.4f}")
            return {'Train Loss': total_loss/len(train_loader), 
                    'Train Acc': accuracy.compute().item()}

In [4]:
MAX_EPOCHS = 5
TRAIN_BATCH_SIZE = 2048
TEST_BATCH_SIZE = 1000
train_loader = DataLoader(trainset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)
test_loader = DataLoader(testset, batch_size=TEST_BATCH_SIZE, shuffle=False)

In [5]:
optimizer = torch.optim.SGD(model_1.parameters(), lr=1e-1)
optimizer2 = torch.optim.SGD(model_2.parameters(), lr=1e-1)

In [6]:
criterion = nn.CrossEntropyLoss()
device = "cuda" if torch.cuda.is_available() else "cpu"
model_1_history = []
for epoch in range(MAX_EPOCHS):
    result = train(model_1, train_loader, optimizer, criterion, epoch, MAX_EPOCHS, device)
    model_1_history.append(result)
    

Epoch: 1/5, Step Loss: 2.7604, Step Acc: 0.1992:  40%|████      | 12/30 [02:48<03:43, 12.43s/it]

In [58]:
new_model = build_mlp_model()
new_model.load_state_dict(model_1.state_dict())

<All keys matched successfully>

In [59]:
def test(
    model: nn.Module,
    test_loader: DataLoader,
    criterion: nn.Module,
    device: Literal["cpu", "cuda"],
):
    model.eval()
    model.to(device)
    total_loss = 0
    pbar = tqdm(test_loader)
    with torch.no_grad():
        for inputs, targets in pbar:
            inputs = inputs.to(device)
            targets = targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            total_loss += loss.item()
            accuracy(outputs, targets)
    return {
        "Testl Loss": total_loss / len(test_loader),
        "Test Acc": accuracy.compute().item(),
    }

In [60]:
test(new_model, test_loader, criterion, "cuda")

100%|██████████| 10/10 [00:01<00:00,  5.06it/s]


{'Testl Loss': 0.2260461539030075, 'Test Acc': 0.8295442461967468}

In [61]:
for name, param in new_model.named_parameters():
    print(torch.sum(param.data - model_1.state_dict()[name].data))

tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')
tensor(0., device='cuda:0')


In [62]:
test(model_1, test_loader, criterion, device="cuda")

100%|██████████| 10/10 [00:02<00:00,  4.93it/s]


{'Testl Loss': 0.2260461539030075, 'Test Acc': 0.8312322497367859}

---

In [64]:
for epoch in range(MAX_EPOCHS):
    train(model_2, train_loader, optimizer2, criterion, epoch, MAX_EPOCHS, device="cuda")

Epoch: 1/10, Step Loss: 1.1503, Step Acc: 0.6086: 100%|██████████| 30/30 [00:30<00:00,  1.02s/it]
Epoch: 2/10, Step Loss: 0.5910, Step Acc: 0.8010: 100%|██████████| 30/30 [00:30<00:00,  1.02s/it]
Epoch: 3/10, Step Loss: 0.6078, Step Acc: 0.7780: 100%|██████████| 30/30 [00:30<00:00,  1.01s/it]
Epoch: 4/10, Step Loss: 0.4363, Step Acc: 0.8651: 100%|██████████| 30/30 [00:29<00:00,  1.00it/s]
Epoch: 5/10, Step Loss: 0.3407, Step Acc: 0.9013: 100%|██████████| 30/30 [00:29<00:00,  1.03it/s]
Epoch: 6/10, Step Loss: 0.3180, Step Acc: 0.9030: 100%|██████████| 30/30 [00:29<00:00,  1.02it/s]
Epoch: 7/10, Step Loss: 0.3054, Step Acc: 0.9243: 100%|██████████| 30/30 [00:29<00:00,  1.01it/s]
Epoch: 8/10, Step Loss: 0.3605, Step Acc: 0.8947: 100%|██████████| 30/30 [00:28<00:00,  1.05it/s]
Epoch: 9/10, Step Loss: 0.3017, Step Acc: 0.9030: 100%|██████████| 30/30 [00:28<00:00,  1.05it/s]
Epoch: 10/10, Step Loss: 0.2532, Step Acc: 0.9211: 100%|██████████| 30/30 [00:28<00:00,  1.04it/s]


In [70]:
test(model_2, train_loader, criterion, device="cuda")

100%|██████████| 30/30 [04:36<00:00,  9.22s/it]


{'Testl Loss': 2.302582621574402, 'Test Acc': 0.09847778081893921}

In [13]:
state_dict_1 = model_1.state_dict()
state_dict_2 = model_2.state_dict()
for key in state_dict_1.keys():
    diff = torch.sum(torch.abs(state_dict_1[key] - state_dict_2[key]))
    print(f"{key.upper()}: Diff: {diff}")

0.WEIGHT: Diff: 4709.8525390625
0.BIAS: Diff: 3.2348155975341797
2.WEIGHT: Diff: 1394.478515625
2.BIAS: Diff: 8.667121887207031
4.WEIGHT: Diff: 157.3886260986328
4.BIAS: Diff: 0.6120636463165283


In [18]:
alpha = 0
from interpolate import interpolate_weights
new_state_dict = interpolate_weights(state_dict_1, state_dict_2, alpha)
model_1.load_state_dict(new_state_dict)
test(model_1, test_loader, criterion, device="cuda")

100%|██████████| 10/10 [00:01<00:00,  5.05it/s]


{'Testl Loss': 0.07838473487645388, 'Test Acc': 0.947195291519165}

In [71]:
from permute import match_and_permute
valset = torch.utils.data.Subset(trainset, range(500))
val_loader = DataLoader(valset, batch_size=100, shuffle=False)
new_state_dict = match_and_permute(build_mlp_model, state_dict_1, state_dict_2, 
                                   val_loader, device="cuda")

model_1.load_state_dict(new_state_dict)
test(model_1, test_loader, criterion, device="cuda")

Layer 0: Cost before: 2543.345947265625, Cost after: 1425.186279296875
Layer 1: Cost before: 2463.74755859375, Cost after: 1404.76025390625
[array([ 17,  88, 119,  92, 127, 116, 122,  54,  84,   4,  24, 102,   9,
        81,  19,  60,  27,  61, 104,  63,   1,  21,  75,  78,  51,  85,
       120,  10,  98,  13,  44,  66,  36,  12, 114, 113,  64,  86,   6,
        45,  42, 109,  76,  79,  57,  28,  90,   5,  11,  25, 100, 111,
        16,  22,  18,  68,  94, 101, 103,  74,  73,  56,  83, 105,  95,
       107,  47,  39,  38, 121,  30, 117,  53,  29, 110, 125,  80,  37,
        32,  20,  58,  97,  77,  15,  91,  72,  48,  52,  26, 126,  41,
        67,  33,   8,  70,  50,   7,   2,  35,  40,  69,  46,  89, 106,
        71,  93,  31,  59,  82,  23, 112,  65,  14,  96, 123,  62, 118,
       124,  49,   0,  55,  34, 108,  99,  43,   3,  87, 115])]


IndexError: list index out of range