This notebook provides an example for the REPAIR ablation study and different variants, including the sequential REPAIR/RESCALE and the data-independent re-normalization.

In [1]:
import sys
sys.path.append("..")

In [7]:
import torch
from source.utils.utils import load_model
from source.utils.data_funcs import load_data
from source.utils.weight_matching import weight_matching
from source.layers.batch_norm import bn_calibration_init
from source.utils.connect import interpolate_state_dicts, repair, reset_bn_stats
from source.utils.logger import Logger
from torch.utils.data import DataLoader
from source.utils.train import validate
from copy import deepcopy

In [3]:
class config:
    model = 'cifar_vgg16'
    dataset = 'cifar10'
    special_init = 'vgg_init' # 'vgg_init' (kaiming init) or None (uniform init)
    print_freq = 100
    data_dir = '../../Linear_Mode_Connectivity/data'
    n = 3 # number of interpolation points

In [4]:
# necessary to create logger if using the train/validate/eval_line etc. functions
Logger.setup_logging()
logger = Logger()

trainset, testset = load_data(config.data_dir, config.dataset)
trainloader = DataLoader(trainset, batch_size=128, shuffle=True)
testloader = DataLoader(testset, batch_size=256, shuffle=False)

trainset_noaug, _ = load_data(config.data_dir, config.dataset, no_random_aug=True)
trainloader_noaug = DataLoader(trainset_noaug, batch_size=128, shuffle=True)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
criterion = torch.nn.CrossEntropyLoss()

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [5]:
model_1, model_2 = load_model(config).to(device), load_model(config).to(device)

sd_1 = torch.load(f'../../Linear_Mode_Connectivity/same_init_ex/{config.dataset}/{config.model}/diff_init/seed_20/model_1_160.pt', map_location=device)
sd_2 = torch.load(f'../../Linear_Mode_Connectivity/same_init_ex/{config.dataset}/{config.model}/diff_init/seed_20/model_2_160.pt', map_location=device)

model_1.load_state_dict(sd_1)
model_2.load_state_dict(sd_2)

<All keys matched successfully>

In [None]:
sd_2_wm, _ = weight_matching(config.model, sd_1, sd_2, device=device)
model_2_wm = load_model(config).to(device)
model_2_wm.load_state_dict(sd_2_wm)

sd_wm_mid = interpolate_state_dicts(sd_1, sd_2_wm, 0.5)
model_wm_mid = load_model(config).to(device)
model_wm_mid.load_state_dict(sd_wm_mid)

model_wm_mid_repaired = repair(trainloader, [model_1, model_2_wm], model_wm_mid, device, name=config.model)
model_wm_mid_rescaled = repair(trainloader, [model_1, model_2_wm], model_wm_mid, device, variant='rescale', name=config.model)
model_wm_mid_reshifted = repair(trainloader, [model_1, model_2_wm], model_wm_mid, device, variant='reshift', name=config.model)
model_wm_mid_rescaled_ave = repair(trainloader, [model_1, model_2_wm], model_wm_mid, device, variant='rescale', average=True, name=config.model)


test_loss_wm_mid, test_acc_wm_mid, _, _ = validate(testloader, model_wm_mid, criterion, device, config)
test_loss_wm_mid_repaired, test_acc_wm_mid_repaired, _, _ = validate(testloader, model_wm_mid_repaired, criterion, device, config)
test_loss_wm_mid_rescaled, test_acc_wm_mid_rescaled, _, _ = validate(testloader, model_wm_mid_rescaled, criterion, device, config)
test_loss_wm_mid_reshifted, test_acc_wm_mid_reshifted, _, _ = validate(testloader, model_wm_mid_reshifted, criterion, device, config)
test_loss_wm_mid_rescaled_ave, test_acc_wm_mid_rescaled_ave, _, _ = validate(testloader, model_wm_mid_rescaled_ave, criterion, device, config)

In [14]:
def reset_single_bn_stats(model, device, bn_loader, layer_name, batch_num=False):
    model.train()
    for name, m in model.named_modules():
        if name == layer_name:
            bn_calibration_init(m)
            continue
        m.training = False
    
    with torch.no_grad():
        for id, (data, _) in enumerate(bn_loader):
            if batch_num and id == batch_num:
                break
            data = data.to(device)
            model(data)
    model.eval()

def default_bn_stats(model):
    for m in model.modules():
        bn_calibration_init(m)

In [15]:
model_wm_mid_sequential_repaired = deepcopy(model_wm_mid_repaired)
model_wm_mid_sequential_rescaled = deepcopy(model_wm_mid_rescaled)

reset_bn_stats(model_wm_mid_repaired, device, trainloader)
reset_bn_stats(model_wm_mid_rescaled, device, trainloader)

reset_layer_s = []
for name, m in model_wm_mid_sequential_repaired.named_modules():
    if hasattr(m, 'track_running_stats'):
        reset_layer_s.append(name)

for name in reset_layer_s:
    reset_single_bn_stats(model_wm_mid_sequential_repaired, device, trainloader, name, 100)
    reset_single_bn_stats(model_wm_mid_sequential_rescaled, device, trainloader, name, 100)

In [None]:
test_loss_wm_mid_seq_repair, test_acc_wm_mid_seq_repair, _, _ = validate(testloader, model_wm_mid_sequential_repaired, criterion, device, config)
test_loss_wm_mid_seq_rescale, test_acc_wm_mid_seq_rescale, _, _ = validate(testloader, model_wm_mid_sequential_rescaled, criterion, device, config)

In [17]:
print('Test Accuracy:')
print(f'Original: {test_acc_wm_mid:.2f}%')
print(f'REPAIR: {test_acc_wm_mid_repaired:.2f}%')
print(f'RESCALE: {test_acc_wm_mid_rescaled:.2f}%')
print(f'RESHIFT: {test_acc_wm_mid_reshifted:.2f}%')
print(f'RESCALE (average): {test_acc_wm_mid_rescaled_ave:.2f}%')
print(f'SEQUENTIAL REPAIR: {test_acc_wm_mid_seq_repair:.2f}%')
print(f'SEQUENTIAL RESCALE: {test_acc_wm_mid_seq_rescale:.2f}%')

Test Accuracy:
Original: 12.29%
REPAIR: 67.43%
RESCALE: 82.67%
RESHIFT: 10.00%
RESCALE (average): 76.51%
SEQUENTIAL REPAIR: 76.53%
SEQUENTIAL RESCALE: 82.64%


Data-independent variant

In [21]:
config.model = 'cifar_vgg16_bn'

In [22]:
model_1, model_2 = load_model(config).to(device), load_model(config).to(device)

sd_1 = torch.load(f'../../Linear_Mode_Connectivity/same_init_ex/{config.dataset}/{config.model}/diff_init/seed_10/model_1_160.pt')
sd_2 = torch.load(f'../../Linear_Mode_Connectivity/same_init_ex/{config.dataset}/{config.model}/diff_init/seed_10/model_2_160.pt')

model_1.load_state_dict(sd_1)
model_2.load_state_dict(sd_2)

<All keys matched successfully>

In [None]:
sd_2_wm, _ = weight_matching(config.model, sd_1, sd_2, device=device)
model_2_wm = load_model(config).to(device)
model_2_wm.load_state_dict(sd_2_wm)

In [24]:
sd_wm_mid = interpolate_state_dicts(sd_1, sd_2_wm, 0.5)
model_wm_mid = load_model(config).to(device)
model_wm_mid.load_state_dict(sd_wm_mid)
reset_bn_stats(model_wm_mid, device, trainloader)
test_loss_reset, test_acc_reset, _, _ = validate(testloader, model_wm_mid, criterion, device, config)

[32m[2024-01-29 16:33:53][0m (20320) [1;30m{utils.py:69}INFO[0m - Test: [ 1/40]	Time  0.034 ( 0.034)	Loss 4.7535e-01 (4.7535e-01)	Acc@1  88.28 ( 88.28)	Acc@5 100.00 (100.00)
[32m[2024-01-29 16:33:55][0m (20320) [1;30m{utils.py:74}INFO[0m -  *   Acc@1 87.900 Acc@5 99.400


In [25]:
def validate_train_mode(val_loader, model, criterion, device):
    for m in model.modules():
        bn_calibration_init(m)
    model.train()

    total = 0
    correct = 0
    total_loss = 0
    with torch.no_grad():
        for i, (data, target) in enumerate(val_loader):
            data = data.to(device)
            target = target.to(device)

            # compute output
            output = model(data)
            loss = criterion(output, target)
            total_loss += loss.item()
            correct += torch.sum(output.argmax(dim=1) == target).item()
            total += len(data)
    loss_avg = total_loss / len(val_loader)
    acc = correct / total
    return loss_avg, acc

In [27]:
sd_wm_mid = interpolate_state_dicts(sd_1, sd_2_wm, 0.5)
model_wm_mid.load_state_dict(sd_wm_mid)
test_loss_reset_ind, test_acc_reset_ind = validate_train_mode(testloader, model_wm_mid, criterion, device)

In [33]:
print('Test Accuracy:')
print(f'Original RESET: {test_acc_reset:.2f}%')
print(f'Data-independent RESET: {test_acc_reset_ind*100:.2f}%')

Test Accuracy:
Original RESET: 87.90%
Data-independent RESET: 86.37%
