In [1]:
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
)
from peft import PeftModel
from datasets import DatasetDict, load_dataset
from utils import set_seed, k_split
from tqdm import trange
import torch

In [2]:
task = 'sst2'
data_name = 'glue' if task in ['mnli','qnli','sst2','qqp'] else 'bigbench'
seed = 42
num_clients = 10
num_error_clients = 3
number = 5

In [3]:
model_name_or_path = 'google/flan-t5-base'
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
set_seed(seed)
if data_name == 'bigbench':
    dataset = load_dataset("tasksource/bigbench", task).shuffle(seed=seed)
    dataset = dataset.rename_columns({'inputs':'source','targets':'target'})
else:
    dataset = load_dataset("JsSparkYyx/NLP524", task).shuffle(seed=seed)

In [4]:
train_ds = k_split(num_clients,num_error_clients,dataset['train'])
if data_name == 'glue':
    valid_ds = k_split(num_clients,num_error_clients,dataset['valid'])
else:
    valid_ds = k_split(num_clients,num_error_clients,dataset['validation'])
dataset = DatasetDict({'train':train_ds[number],'valid':valid_ds[number]})
def tokenize_function(examples):
    # max_length=None => use the model max length (it's actually the default)
    model_inputs = tokenizer(examples['source'], truncation=True, max_length=None,padding=True,return_tensors='pt')
    if data_name == 'glue':
        model_inputs['labels'] = tokenizer(examples['target'], truncation=True, max_length=None,padding=True,return_tensors='pt')["input_ids"]
    else:
        model_inputs['labels'] = tokenizer([_[0] for _ in examples['target']], truncation=True, max_length=None,padding=True,return_tensors='pt')["input_ids"]
    return model_inputs
ds = (train_ds, valid_ds)

In [5]:
def retrive_data(ds,number):
    (train_ds, valid_ds) = ds
    return DatasetDict({'train':train_ds[number],'valid':valid_ds[number]})

def accuracy_score(outputs, ground_truths):
    correct = 0
    total = 0
    for output, truth in zip(outputs, ground_truths):
        if data_name == "bigbench":
            truth = truth[0]
        if output.strip().lower().replace(".", "") == truth.strip().lower().replace(".", ""):
            correct += 1
        total += 1
    return correct / total * 100

In [6]:
data = retrive_data(ds,number)

In [7]:
def evaluation(data, model, tokenizer, batch_size = 128):
    example_predictions = []
    eval_set = "valid"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.eval()
    model.to(device)
    with torch.no_grad():
        for i in trange(0, len(data[eval_set]["source"]), batch_size):
            inputs = tokenizer(
                    data[eval_set]["source"][i : i + batch_size],
                    max_length=2048,
                    return_tensors="pt",
                    padding=True,
                ).to(device)
            outputs = model.generate(
                input_ids=inputs["input_ids"], max_new_tokens=256
            )
            outputs = tokenizer.batch_decode(
                outputs.to("cpu"), skip_special_tokens=True
            )
            example_predictions.extend(outputs)

    task_perf = accuracy_score(example_predictions, data[eval_set]["target"])
    return task_perf, example_predictions

In [8]:
# model_1 = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, return_dict=True)
# lora_model = PeftModel.from_pretrained(model_1,f'JsSparkYyx/flan-t5-base-finetuned-lora-{task}-{number}')
# model_2 = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, return_dict=True)
# error_model = PeftModel.from_pretrained(model_2,f'JsSparkYyx/flan-t5-base-finetuned-lora-{task}-0')
# task_perf, example_predictions = evaluation(data,lora_model,tokenizer, batch_size=8)
# task_perf_error, example_predictions_error = evaluation(data,error_model,tokenizer, batch_size=8)
# print(f"ACC of error model: {task_perf_error}, ACC of lora model: {task_perf}")

In [9]:
from peft import get_peft_model_state_dict
lora_adaptors = []
for i in range(num_clients):
    base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, return_dict=True)
    lora_model = PeftModel.from_pretrained(base_model,f'JsSparkYyx/flan-t5-base-finetuned-lora-{task}-{i}')
    lora_adaptors.append(get_peft_model_state_dict(lora_model))

In [18]:
from algorithm import lorahub_aggregation
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, return_dict=True)
base_lora = PeftModel.from_pretrained(base_model,f'JsSparkYyx/flan-t5-base-finetuned-lora-{task}-{number}')
weights, lorahub_model = lorahub_aggregation(base_lora, lora_adaptors, data["valid"], tokenizer, batch_size = 5, sample_size = 5, seed = 42)


Running tokenizer on dataset:   0%|          | 0/5 [00:00<?, ? examples/s]

> Begin to perform gradient-free optimization ...
Launching 1 jobs with new suggestions
Updating fitness with value 0.10594421625137329
39 remaining budget and 0 running jobs
Launching 1 jobs with new suggestions
Updating fitness with value 0.0731716775894165
38 remaining budget and 0 running jobs
Launching 1 jobs with new suggestions
Updating fitness with value 0.07984843134880067
37 remaining budget and 0 running jobs
Launching 1 jobs with new suggestions
Updating fitness with value 0.08163272023200989
36 remaining budget and 0 running jobs
Launching 1 jobs with new suggestions
Updating fitness with value 0.0603573489189148
35 remaining budget and 0 running jobs
Launching 1 jobs with new suggestions
Updating fitness with value 0.0538857501745224
34 remaining budget and 0 running jobs
Launching 1 jobs with new suggestions
Updating fitness with value 0.03443073987960815
33 remaining budget and 0 running jobs
Launching 1 jobs with new suggestions
Updating fitness with value 2.0814327239

In [24]:
from algorithm import average_aggregation
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, return_dict=True)
base_lora = PeftModel.from_pretrained(base_model,f'JsSparkYyx/flan-t5-base-finetuned-lora-{task}-{number}')
avg_model = average_aggregation(base_lora,lora_adaptors)
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, return_dict=True)
base_lora = PeftModel.from_pretrained(base_model,f'JsSparkYyx/flan-t5-base-finetuned-lora-{task}-{number}')
no_noise_model = average_aggregation(base_lora,lora_adaptors[3:])
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, return_dict=True)
client_model = PeftModel.from_pretrained(base_model,f'JsSparkYyx/flan-t5-base-finetuned-lora-{task}-{number}')
task_perf_avg, example_predictions = evaluation(data,avg_model,tokenizer, batch_size=8)
task_perf_client, example_predictions_error = evaluation(data,client_model,tokenizer, batch_size=8)
task_perf_lorahub, example_predictions_error = evaluation(data,lorahub_model,tokenizer, batch_size=8)
task_perf_no_noise, example_predictions_error = evaluation(data,no_noise_model,tokenizer, batch_size=8)
print(f"ACC of client's model: {task_perf_client}, ACC of average aggregated model: {task_perf_avg}, ACC of lorahub model: {task_perf_lorahub}, ACC of no noise model: {task_perf_no_noise}")

100%|██████████| 11/11 [00:00<00:00, 26.80it/s]
100%|██████████| 11/11 [00:00<00:00, 15.01it/s]
100%|██████████| 11/11 [00:00<00:00, 28.21it/s]
100%|██████████| 11/11 [00:00<00:00, 28.57it/s]

ACC of client's model: 94.25287356321839, ACC of average aggregated model: 90.80459770114942, ACC of lorahub model: 94.25287356321839, ACC of no noise model: 91.95402298850574





In [20]:
print(weights)

[ 4.95099199e-01 -2.94326107e-04 -3.62552962e-04  5.00564884e-01
  5.00285282e-01  5.00857620e-01  2.40444463e-01 -8.58513193e-02
 -8.53954377e-02 -8.10476350e-02]
