In [None]:
import os
import sys
sys.path.append(os.path.abspath(".."))

In [None]:
import torch
import seaborn as sns
import ristretto.activations as ra
import ristretto.models as rm
import ristretto.utils as ru
import pandas as pd
import numpy as np
# from torch.utils.tensorboard import SummaryWriter

In [None]:
# set pytorch precision
torch.set_default_tensor_type(torch.FloatTensor)
torch.set_default_dtype(torch.float32)

In [None]:
@torch.no_grad()
def get_weight_sum(model):
    return sum([x.weight.sum() for x in model.linear]).item()

In [None]:
models = [
    rm.ResNet(activation=lambda: ra.ReLU(0), seed=42),
    rm.ResNet(activation=lambda: ra.ReLU(0), seed=42),
    rm.ResNet(activation=lambda: ra.ReLU(1), seed=42)
]

In [None]:
metrics = ru.train_multiple_models(
    models,
    ru.default.DATA_LOADERS['MNIST'],
    epochs=2,
    metrics_fn=lambda m, p, y: {"weight_sum": ru.get_weight_sum(m)}
)

In [None]:
diff = pd.DataFrame({
    "0 vs 0": np.abs(metrics[0]["train"]['weight_sum'] - metrics[1]["train"]['weight_sum']),
    "0 vs 1": np.abs(metrics[0]["train"]['weight_sum'] - metrics[2]["train"]['weight_sum'])
})
sns.lineplot(data=diff, dashes=False);

### When trained with 16-bit precision the difference between the models is even greater

In [None]:
# set pytorch precision
torch.set_default_tensor_type(torch.FloatTensor)
torch.set_default_dtype(torch.bfloat16)

In [None]:
models = [
    rm.ResNet(activation=lambda: ra.ReLU(0), seed=42),
    rm.ResNet(activation=lambda: ra.ReLU(0), seed=42),
    rm.ResNet(activation=lambda: ra.ReLU(1), seed=42)
]

In [None]:
metrics = ru.train_multiple_models(
    models,
    ru.default.DATA_LOADERS['MNIST'],
    epochs=2,
    metrics_fn=lambda m, p, y: {"weight_sum": ru.get_weight_sum(m)}
)

In [None]:
diff = pd.DataFrame({
    "0 vs 0": np.abs(metrics[0]["train"]['weight_sum'] - metrics[1]["train"]['weight_sum']),
    "0 vs 1": np.abs(metrics[0]["train"]['weight_sum'] - metrics[2]["train"]['weight_sum'])
})
sns.lineplot(data=diff, dashes=False);