In [1]:
from transformers import (
    AutoModelForSeq2SeqLM,
    AutoTokenizer,
)
from peft import PeftModel
from datasets import DatasetDict, load_dataset
from data import set_seed, k_split
from tqdm import trange
import torch

In [31]:
task = 'sst2'
data_name = 'glue' if task in ['mnli','qnli','sst2','qqp'] else 'bigbench'
seed = 42
num_clients = 10
num_error_clients = 3
number = 5

In [32]:
model_name_or_path = 'google/flan-t5-base'
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
set_seed(seed)
if data_name == 'bigbench':
    dataset = load_dataset("tasksource/bigbench", task).shuffle(seed=seed)
    dataset = dataset.rename_columns({'inputs':'source','targets':'target'})
else:
    dataset = load_dataset("JsSparkYyx/NLP524", task).shuffle(seed=seed)

In [33]:
train_ds = k_split(num_clients,num_error_clients,dataset['train'])
if data_name == 'glue':
    valid_ds = k_split(num_clients,num_error_clients,dataset['valid'])
else:
    valid_ds = k_split(num_clients,num_error_clients,dataset['validation'])
dataset = DatasetDict({'train':train_ds[number],'valid':valid_ds[number]})
def tokenize_function(examples):
    # max_length=None => use the model max length (it's actually the default)
    model_inputs = tokenizer(examples['source'], truncation=True, max_length=None,padding=True,return_tensors='pt')
    if data_name == 'glue':
        model_inputs['labels'] = tokenizer(examples['target'], truncation=True, max_length=None,padding=True,return_tensors='pt')["input_ids"]
    else:
        model_inputs['labels'] = tokenizer([_[0] for _ in examples['target']], truncation=True, max_length=None,padding=True,return_tensors='pt')["input_ids"]
    return model_inputs
ds = (train_ds, valid_ds)

In [34]:
def retrive_data(ds,number):
    (train_ds, valid_ds) = ds
    return DatasetDict({'train':train_ds[number],'valid':valid_ds[number]})

def accuracy_score(outputs, ground_truths):
    correct = 0
    total = 0
    for output, truth in zip(outputs, ground_truths):
        if data_name == "bigbench":
            truth = truth[0]
        if output.strip().lower().replace(".", "") == truth.strip().lower().replace(".", ""):
            correct += 1
        total += 1
    return correct / total * 100

In [35]:
data = retrive_data(ds,number)

In [36]:
model_1 = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, return_dict=True)
lora_model = PeftModel.from_pretrained(model_1,f'JsSparkYyx/flan-t5-base-finetuned-lora-{task}-{number}')
model_2 = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, return_dict=True)
error_model = PeftModel.from_pretrained(model_2,f'JsSparkYyx/flan-t5-base-finetuned-lora-{task}-0')

In [37]:
def evaluation(data, model, tokenizer, batch_size = 128):
    example_predictions = []
    eval_set = "valid"
    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.eval()
    model.to(device)
    with torch.no_grad():
        for i in trange(0, len(data[eval_set]["source"]), batch_size):
            inputs = tokenizer(
                    data[eval_set]["source"][i : i + batch_size],
                    max_length=2048,
                    return_tensors="pt",
                    padding=True,
                ).to(device)
            outputs = model.generate(
                input_ids=inputs["input_ids"], max_new_tokens=256
            )
            outputs = tokenizer.batch_decode(
                outputs.to("cpu"), skip_special_tokens=True
            )
            example_predictions.extend(outputs)

    task_perf = accuracy_score(example_predictions, data[eval_set]["target"])
    return task_perf, example_predictions

In [38]:
task_perf, example_predictions = evaluation(data,lora_model,tokenizer, batch_size=8)
task_perf_error, example_predictions_error = evaluation(data,error_model,tokenizer, batch_size=8)
print(f"ACC of error model: {task_perf_error}, ACC of lora model: {task_perf}")

100%|██████████| 11/11 [00:02<00:00,  4.69it/s]
100%|██████████| 11/11 [00:01<00:00, 10.89it/s]

ACC of error model: 55.172413793103445, ACC of lora model: 94.25287356321839





In [39]:
from peft import get_peft_model_state_dict, set_peft_model_state_dict
lora_adaptors = []
for i in range(num_clients):
    base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, return_dict=True)
    lora_model = PeftModel.from_pretrained(base_model,f'JsSparkYyx/flan-t5-base-finetuned-lora-{task}-{i}')
    lora_adaptors.append(get_peft_model_state_dict(lora_model))

Downloading adapter_config.json:   0%|          | 0.00/497 [00:00<?, ?B/s]

Downloading (…)er_model.safetensors:   0%|          | 0.00/14.2M [00:00<?, ?B/s]

Downloading adapter_config.json:   0%|          | 0.00/497 [00:00<?, ?B/s]

Downloading (…)er_model.safetensors:   0%|          | 0.00/14.2M [00:00<?, ?B/s]

Downloading adapter_config.json:   0%|          | 0.00/497 [00:00<?, ?B/s]

Downloading (…)er_model.safetensors:   0%|          | 0.00/14.2M [00:00<?, ?B/s]

Downloading adapter_config.json:   0%|          | 0.00/497 [00:00<?, ?B/s]

Downloading (…)er_model.safetensors:   0%|          | 0.00/14.2M [00:00<?, ?B/s]

Downloading adapter_config.json:   0%|          | 0.00/497 [00:00<?, ?B/s]

Downloading (…)er_model.safetensors:   0%|          | 0.00/14.2M [00:00<?, ?B/s]

Downloading adapter_config.json:   0%|          | 0.00/497 [00:00<?, ?B/s]

Downloading (…)er_model.safetensors:   0%|          | 0.00/14.2M [00:00<?, ?B/s]

Downloading adapter_config.json:   0%|          | 0.00/497 [00:00<?, ?B/s]

Downloading (…)er_model.safetensors:   0%|          | 0.00/14.2M [00:00<?, ?B/s]

Downloading adapter_config.json:   0%|          | 0.00/497 [00:00<?, ?B/s]

Downloading (…)er_model.safetensors:   0%|          | 0.00/14.2M [00:00<?, ?B/s]

In [40]:
def average_aggreation(lora_adaptors):
    weight = 1/len(lora_adaptors)
    final_state_dict = {}
    keys = lora_adaptors[0].keys()
    for i, lora_adaptor in enumerate(lora_adaptors):
        if i == 0:
            for key in keys:
                final_state_dict[key] = weight * lora_adaptor[key]
        else:
            for key in keys:
                final_state_dict[key] = (
                    final_state_dict[key] + weight * lora_adaptor[key]
                )
    return final_state_dict

In [46]:
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, return_dict=True)
base_lora = PeftModel.from_pretrained(base_model,f'JsSparkYyx/flan-t5-base-finetuned-lora-{task}-0')
final_state_dict = average_aggreation(lora_adaptors[:3]+[lora_adaptors[number]])
set_peft_model_state_dict(lora_model,final_state_dict)
lora_model = lora_model.merge_and_unload()

AttributeError: 'T5ForConditionalGeneration' object has no attribute 'merge_and_unload'

In [42]:
base_model = AutoModelForSeq2SeqLM.from_pretrained(model_name_or_path, return_dict=True)
prev_lora_model = PeftModel.from_pretrained(base_model,f'JsSparkYyx/flan-t5-base-finetuned-lora-{task}-{number}')
task_perf, example_predictions = evaluation(data,lora_model,tokenizer, batch_size=8)
task_perf_error, example_predictions_error = evaluation(data,prev_lora_model,tokenizer, batch_size=8)
print(f"ACC of previous model: {task_perf_error}, ACC of final model: {task_perf}")

100%|██████████| 11/11 [00:01<00:00,  7.54it/s]
100%|██████████| 11/11 [00:02<00:00,  4.28it/s]

ACC of previous model: 91.95402298850574, ACC of final model: 90.80459770114942





In [45]:
lora_adaptors[:3]

[{'base_model.model.encoder.block.0.layer.0.SelfAttention.q.lora_A.weight': tensor([[ 0.0003, -0.0750,  0.0387,  ..., -0.0273, -0.0068, -0.0013],
          [ 0.0217, -0.0300,  0.0405,  ...,  0.0027,  0.0384,  0.0119],
          [ 0.0811,  0.0197,  0.0032,  ..., -0.0224, -0.0163, -0.0154],
          ...,
          [-0.0228,  0.0379,  0.0379,  ..., -0.0083, -0.0454,  0.0343],
          [-0.0082, -0.0822, -0.0059,  ...,  0.0045,  0.0212, -0.0318],
          [ 0.0229,  0.0021, -0.0526,  ..., -0.0116, -0.0067,  0.0390]]),
  'base_model.model.encoder.block.0.layer.0.SelfAttention.q.lora_B.weight': tensor([[-0.0586, -0.0483, -0.0100,  ...,  0.0307, -0.0094,  0.0366],
          [ 0.0653,  0.0268, -0.0435,  ..., -0.0414,  0.0170, -0.0378],
          [-0.0357,  0.0194,  0.0498,  ...,  0.0472, -0.0396,  0.0111],
          ...,
          [ 0.0417,  0.0124,  0.0321,  ...,  0.0201, -0.0312, -0.0183],
          [-0.0270, -0.0388, -0.0118,  ...,  0.0017, -0.0074,  0.0080],
          [ 0.0253,  0.0134,