# Study the simple setting

finetune two models and merge; can also do it in the FL setting

In [1]:
from merging import MergingFactory
from tqdm import tqdm
from copy import deepcopy
import data_utils
import torch
from train_utils import get_metric
from copy import deepcopy
from torch import nn

  from .autonotebook import tqdm as notebook_tqdm
2025-04-04 12:21:24.495088: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-04-04 12:21:24.514730: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
class args:
    dataset="20newsgroups"
    client_batch=16
    clients=50  # 
    iid_alpha=0.1  # the larger, the more non-iid
    seed=0
    eval_frac=1.0
    lora_rank=16
    lora_alpha=16
    freeze_a = "false"
    merging_strategy = "average"
    
# set seed for debugging
import random
import numpy as np
import torch
from transformers import set_seed
set_seed(args.seed)
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed_all(args.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

## prepare models and dataloaders

In [3]:
clients, valloader, testloader, test_batch = data_utils.build_dataset(args.dataset,
                                                                      args.client_batch,
                                                                      args.clients,
                                                                      args.iid_alpha, args.seed, args.eval_frac)
clients = [clients[0], clients[1]]  # we take out two loaders to use

import models
model = models.build_model(args.dataset)
total = sum(p.numel() for p in model.parameters() if p.requires_grad)

models.add_adapters_dataset(args.dataset, model, args.lora_rank, args.lora_alpha)

def str2bool(s):
    return s.lower() == 'true'

if str2bool(args.freeze_a):
    for n,p in model.named_parameters():
        if "lora_A" in n:
            p.requires_grad = False

trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Training {trainable} parameters ({100*trainable/total:.2f}% of original {total})")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

Found 20 classes


Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Training 2374656 parameters (1.91% of original 124455168)




## training

In [4]:
server_freeze = False
server_lr = 1 # for the original fedavg, this is one; for others like fedadam, we can set it to 1e-3

In [5]:
server_model = model
orig_server_model = deepcopy(server_model)  # backup if need

server_params = {n:p for n,p in server_model.named_parameters() if p.requires_grad}
server_mask = {n:torch.ones_like(p) for n,p in server_params.items()}
server_freeze = False
if server_freeze:
    for p in server_params.values():
        p.requires_grad = False

server_opt = torch.optim.SGD(server_params.values(), lr=server_lr)

In [6]:
merging_strategy =  args.merging_strategy
merger = MergingFactory.get_merging_strategy(merging_strategy, server_model, args=args)
scaling_coefficient = 1.0  # no need to change

eval_accu = 0
def eval_loop(model, loader):
    model.eval()
    stats_acc = {}
    for x,y in loader:
        with torch.no_grad():
            _, stats = test_batch(model, x, y)
        for k,v in stats.items():
            stats_acc[k] = stats_acc.get(k, 0) + v
    stats_acc['loss'] /= stats_acc['count']
    return stats_acc

In [7]:
server_batch = 2  # how many clients we use
client_lr = 1e-3

clients = [clients[0], clients[1]]  # the client loaders
client_epochs = 6  # local training epochs; in standard FL setting, this is 1
l2_clip_norm = 0.

The following cell only finetunes two models, without merging. A simple FL setup can be recovered by increasing the global rounds and adding the merging code in the training loop.

In [8]:
rounds = 1

pbar = tqdm(range(rounds))

client_models = [deepcopy(server_model) for _ in range(2)]
client_ids = [0, 1]
# client_loaders=[clients[0]]
client_loaders = clients


for round_idx, rnd in enumerate(pbar):
    neg_client_deltas = []
    stats_acc = {}
    sample_num_list = torch.Tensor([len(client_loader) for client_loader in client_loaders])  # record the number of batches for each client
    nums_fisher_examples = torch.Tensor([(len(client_loader)-1)*client_loader.batch_size for client_loader in client_loaders])
    nums_regmean_examples = nums_fisher_examples.clone()
    clients_this_round = server_batch  # record the number of clients for this round

    for i, client_id in enumerate(client_ids):
        # Download Model
        client_model = client_models[i]
        client_model.to(device)

        # Local Training
        # client_opt = torch.optim.SGD(client_model.parameters(), lr=client_lr, momentum=0.9)
        client_opt = torch.optim.Adam(client_model.parameters(), lr=client_lr)
        client_loader = clients[client_id]
        client_acc = {}
        
        for epoch in range(client_epochs):
            for x,y in client_loader:
                loss, stats = test_batch(client_model, x, y)
                
                client_opt.zero_grad()
                loss.backward()

                if l2_clip_norm > 0:
                    torch.nn.utils.clip_grad_norm_(client_model.parameters(), l2_clip_norm)

                client_opt.step()

                for k,v in stats.items():
                    client_acc[k] = client_acc.get(k, 0) + v
                pbar.set_description(f"eval: {eval_accu} | client {i}, epoch {epoch} | loss {loss:.4f}")
                
            if epoch % 1 == 0:
                eval_model = deepcopy(client_model)
                eval_model.to(device)

                eval_results = eval_loop(eval_model, client_loader)
                print("Accuracy is {}".format(get_metric(eval_results, "accu")))

        # This is our delta parameter
        client_model.to("cpu")  # move to cpu to save memory
        neg_client_delta = {
            n: server_params[n].data - cp.data for n,cp 
                            in client_model.named_parameters() if cp.requires_grad
        }
        neg_client_deltas.append(neg_client_delta)

        # Log last iteration
        client_acc['norm'] = 0
        for k,v in client_acc.items():
            stats_acc[k] = stats_acc.get(k, 0) + v
    

eval: 0 | client 0, epoch 0 | loss 6.5722:   0%|          | 0/1 [00:00<?, ?it/s] We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.
eval: 0 | client 0, epoch 1 | loss 1.0522:   0%|          | 0/1 [00:02<?, ?it/s]

Accuracy is 0.5316455696202531


eval: 0 | client 0, epoch 2 | loss 0.7329:   0%|          | 0/1 [00:04<?, ?it/s]

Accuracy is 0.6455696202531646


eval: 0 | client 0, epoch 3 | loss 0.6980:   0%|          | 0/1 [00:06<?, ?it/s]

Accuracy is 0.890295358649789


eval: 0 | client 0, epoch 4 | loss 0.0898:   0%|          | 0/1 [00:08<?, ?it/s]

Accuracy is 0.9620253164556962


eval: 0 | client 0, epoch 5 | loss 0.0607:   0%|          | 0/1 [00:10<?, ?it/s]

Accuracy is 0.9915611814345991


eval: 0 | client 0, epoch 5 | loss 0.0023:   0%|          | 0/1 [00:12<?, ?it/s]

Accuracy is 1.0


eval: 0 | client 1, epoch 1 | loss 0.7112:   0%|          | 0/1 [00:14<?, ?it/s]

Accuracy is 0.48255813953488375


eval: 0 | client 1, epoch 2 | loss 0.7028:   0%|          | 0/1 [00:16<?, ?it/s]

Accuracy is 0.8197674418604651


eval: 0 | client 1, epoch 3 | loss 0.1618:   0%|          | 0/1 [00:17<?, ?it/s]

Accuracy is 0.936046511627907


eval: 0 | client 1, epoch 4 | loss 0.0761:   0%|          | 0/1 [00:19<?, ?it/s]

Accuracy is 0.9883720930232558


eval: 0 | client 1, epoch 5 | loss 0.0131:   0%|          | 0/1 [00:20<?, ?it/s]

Accuracy is 0.9941860465116279


eval: 0 | client 1, epoch 5 | loss 0.0016: 100%|██████████| 1/1 [00:21<00:00, 21.98s/it]

Accuracy is 1.0





check the performances of the finetuned models

In [11]:
eval_model = deepcopy(client_models[1])
eval_model.to(device)

eval_results = eval_loop(eval_model, clients[1])

get_metric(eval_results, "accu")

1.0

In [10]:
eval_model = deepcopy(client_models[0])
eval_model.to(device)

eval_results = eval_loop(eval_model, clients[0])

get_metric(eval_results, "accu")

1.0

## merging

need to run them all everytime we try a new merging method, because we edit the `cur_server_model` in-place

In [30]:
# copy the pretraiend model
cur_server_model = deepcopy(orig_server_model)
cur_server_params = {n:p for n,p in cur_server_model.named_parameters() if p.requires_grad}
cur_server_opt = torch.optim.SGD(cur_server_params.values(), lr=server_lr)

In [None]:
# prepare for merging
merger = MergingFactory.get_merging_strategy("average", cur_server_model, args=args)  # average, task_arithmetic, fisher_merging, ties_merging, regmean_merging
scaling_coefficient = 1.  # controls how far we want to go in the direction of the aggregated task vector, only used in task arithmetic
param_value_mask_rate = 0.5  # controls how much parameters we want to prune/drop in ties-merging (and dare)

In [32]:
# merging
aggregated_update = merger.aggregate_updates(neg_client_deltas,
                                                sample_num_list=sample_num_list,
                                                scaling_coefficient=scaling_coefficient,
                                                client_loaders=client_loaders,
                                                test_batch=test_batch,
                                                nums_fisher_examples=nums_fisher_examples,
                                                nums_regmean_examples=nums_regmean_examples,
                                                device=device,
                                                normalize_fisher_weight=True,
                                                minimal_fisher_weight = 1e-6,
                                                param_value_mask_rate= param_value_mask_rate)
merger.update_server_model(aggregated_update, cur_server_opt)

In [33]:
# evaluation
eval_model = deepcopy(cur_server_model)
eval_model.to(device)

eval_results = eval_loop(eval_model, clients[0])
print(get_metric(eval_results, "accu"))

eval_results = eval_loop(eval_model, clients[1])
print(get_metric(eval_results, "accu"))

0.379746835443038
0.005813953488372093
