## Step 1. 加载模型与Tokenizer

In [2]:

import os
import sys

sys.path.append(os.path.abspath('../..'))

In [4]:


from sfl.utils.training import get_best_gpu
from transformers import AutoTokenizer
from sfl.model.gpt2.gpt2_split import GPT2SplitLMHeadModel
from sfl import config

device = get_best_gpu()
tokenizer = AutoTokenizer.from_pretrained(os.path.join(config.model_download_dir, "gpt2/"), padding_side='left')
model = GPT2SplitLMHeadModel.from_pretrained(os.path.join(config.model_download_dir, "gpt2/"))
tokenizer.pad_token = tokenizer.eos_token
tokenizer.pad_token_id = 50256
model.to(device)

In [2]:
from sfl.utils.model import generate


# 测试模型输出
def get_output(text, md=model):
    t = tokenizer(text, return_tensors="pt", add_special_tokens=False)
    res = model(t['input_ids'].to(md.device), attention_mask=t['attention_mask'].to(md.device))
    r = tokenizer.decode(res.logits.argmax(dim=-1)[-1], skip_special_tokens=True)
    return r

text = "To mix food coloring with sugar, you can"
print(get_output(text))
print(generate(text, tokenizer, model))

 the the with with water, you can add
To mix food coloring with sugar, you can also use it as a sweetener.

If you want to add more sugar to the mix, add a little more water and mix well. If you add too much water, the mixture will be too thick, and you will end up with a mess. You can use a spoon to scoop out the excess water from the mixing bowl, but it's best to leave it at room temperature for at least 30 minutes before adding the rest of the water.


In [5]:
from sfl.simulator.dataset import CodeAlpacaFedDataset, PIQAFedDataset, WikiTextFedDataset

test_dataset = WikiTextFedDataset(tokenizer=tokenizer,client_ids=[])
test_loader = test_dataset.get_dataloader_unsliced(1, 'test', shrink_frac=1.0)

In [15]:
dataset = CodeAlpacaFedDataset(tokenizer=tokenizer,client_ids=[])

In [4]:
from sfl.utils.training import evaluate_perplexity
# evaluate(model, tokenizer)
from tqdm.notebook import tqdm_notebook
from sfl.utils.model import calculate_rouge
import torch

def evaluate_piqa(loader, model, tokenizer):
    model.train(False)
    rouge_1_f = 0
    rouge_2_f = 0
    rouge_l_f = 0
    len = 0
    for batch in tqdm_notebook(loader):
        q = tokenizer(batch['q_text'], return_tensors='pt', padding=True, truncation=True)
        input_ids = q['input_ids'].to(model.device)
        attention_mask = q['attention_mask'].to(model.device)
        result = model.generate(input_ids=input_ids, attention_mask=attention_mask, max_length=25,
                                pad_token_id=tokenizer.eos_token_id)
        rouge = calculate_rouge(tokenizer, result, batch['input_text'], is_tokens=True)
        rouge_1_f += rouge['rouge-1']['f']
        rouge_2_f += rouge['rouge-2']['f']
        rouge_l_f += rouge['rouge-l']['f']
        len += 1
    model.train(True)
    return rouge_1_f / len, rouge_2_f / len, rouge_l_f / len

def evaluate_loss(model, loader):
    model.train(False)
    ppl = 0
    len = 0
    for batch in loader:
        input_ids = batch['input_ids'].to(model.device)
        attention_mask = batch['input_att_mask'].to(model.device)
        with torch.no_grad():
            outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
            ppl += outputs.loss
        len += 1
    return ppl / len


ppl= evaluate_perplexity(model, test_loader)
# print('initial_test_result: ', ppl)
# ppl = evaluate_loss(model,test_loader)
print(ppl)

tensor(53.2021, device='cuda:0')


## Step 3. 设置联邦训练流程

In [5]:

from typing import Iterator
from sfl.utils.model import calculate_rouge
from sfl.simulator.simulator import SFLSimulator
from torch.utils.data import DataLoader
from tqdm.notebook import tqdm_notebook
from sfl.model.split_model import SplitModel
from sfl.simulator.strategy import FLStrategy

from torch.optim import AdamW
from sfl.config import FLConfig

client_ids = [str(i) for i in range(3)]
config = FLConfig(collect_intermediates=True,
                  global_round=4,
                  client_evaluate_freq=25,
                  client_epoch=1,  # 每轮联邦每个Client训2轮
                  split_point_1=2,
                  split_point_2=10,  # [0,1 | 2,3,.... 29| 30, 31]
                  use_lora_at_trunk=True,  # 在trunk部分使用LoRA
                  top_and_bottom_from_scratch='False',  # top和bottom都不采用预训练参数.
                  noise_mode="none",
                  noise_scale=0.0,  # 噪声大小,
                  batch_size=2,
                  dataset_type='train'
                  )

fed_dataset = WikiTextFedDataset(tokenizer=tokenizer, client_ids=client_ids, shrink_frac=0.04)


# mirror_loader = fed_dataset.get_dataloader_unsliced(2, 'validation', shrink_frac=0.01)


# 定义Client本地学习策略
class QAFLStrategy(FLStrategy):

    def client_evaluate(self, global_round, client_id, log):
        ppl = evaluate_perplexity(self.simulator.llm, test_loader)
        log['test-ppl'] = ppl

    def __init__(self):
        super().__init__()
        self.attacker_rouge_b2tr = []

    def client_step(self, client_id: str, global_round, local_epoch, llm: SplitModel, iterator: Iterator,
                    cfg: FLConfig):
        optimizer = AdamW(llm.parameters(), lr=1e-5)
        avg_rouge = 0
        avg_rouge_pt = 0
        avg_loss = 0
        step_num = 0
        with tqdm_notebook(total=cfg.client_steps) as pbar:
            for step, batch in enumerate(iterator):
                optimizer.zero_grad()
                input_ids = batch['input_ids'].to(llm.device)
                attention_mask = batch['input_att_mask'].to(llm.device)
                outputs = llm(input_ids=input_ids, labels=input_ids, attention_mask=attention_mask)
                avg_rouge += calculate_rouge(tokenizer, outputs.logits, batch['input_text'])['rouge-l']['f']
                avg_loss += outputs.loss.detach().cpu().item()
                step_num += 1
                outputs.loss.backward()
                optimizer.step()
                pbar.set_description(
                    f'Client {client_id} Epoch {local_epoch} Step{self.simulator.get_current_step(client_id, step)}, Loss {outputs.loss.item():.3f}')
                pbar.update(1)
                self.step_done(client_id, step, batch, logs={"self": avg_rouge / step_num,
                                                             "loss": float(
                                                                 avg_loss) / step_num,
                                                             "self-pt": avg_rouge_pt / step_num})  # Collect gradients
            # real_fx = deepcopy(llm.get_bottom_to_trunk_fx())
            # real_dfx = deepcopy(llm.get_top_to_trunk_grad())

            #
            # # def ccl(lm_logits, labels):
            # #     seq_len = labels.size(-1)
            # #     half_len = seq_len // 2
            # #     odd = seq_len % 2
            # #     shift_logits = lm_logits.contiguous()
            # #     shift_labels = torch.concat([labels[..., 1:half_len],  # size = half_len - 1
            # #                                  labels[..., :half_len + 1 + odd]],  # size = half_len+1+odd
            # #                                 dim=-1).contiguous()
            # #     return CrossEntropyLoss()(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            #
            # # def mirror_attack(**kwargs):
            # #     # freeze trunk
            # #     for nm, param in llm.get_trunk_params():
            # #         param.requires_grad = False
            # #     opt_mrr = AdamW(llm.parameters(), lr=1e-5)
            # #     with tqdm_notebook(total=500 * len(mirror_loader)) as pbar:
            # #         for i in range(500):
            # #             for step, size in enumerate(mirror_loader):
            # #                 opt_mrr.zero_grad()
            # #                 ii = size['input_ids'].to(llm.device)
            # #                 am = size['input_att_mask'].to(llm.device)
            # #                 outputs = llm(input_ids=ii, attention_mask=am)
            # #                 loss = calc_shifted_loss(outputs.logits, ii)
            # #                 loss.backward()
            # #                 opt_mrr.step()
            # #                 o2 = llm(input_ids=input_ids,attention_mask=attention_mask)
            # #                 rouge = calculate_rouge(tokenizer,o2.logits,batch['input_text'])
            # #                 pbar.set_description(f'MIRROR {loss.item():.3f}')
            # #                 pbar.set_postfix(rouge=rouge['rouge-l']['f'])
            # #                 pbar.update(1)
            #
            # def mirror_attack(**kwargs):
            #     # freeze trunk
            #     for nm, param in llm.get_trunk_params():
            #         param.requires_grad = False
            #     # opt_mrr = AdamW(llm.parameters(), lr=1e-5)
            #     batch_size, seq_len = input_ids.shape[:2]
            #     vocab_size = llm.config.vocab_size
            #     gt = torch.softmax(torch.randn((batch_size, seq_len, vocab_size)).to(llm.device), dim=-1)
            #     gt.requires_grad = True
            #     tuning_params = [param for nm, param in llm.get_top_params() if param.requires_grad]
            #     opt_mrr = torch.optim.AdamW([gt], lr=0.05, betas=(0.9, 0.999), eps=1e-6, weight_decay=0.01)
            #     opt_tmp = AdamW(tuning_params)
            #     with tqdm_notebook(total=1000) as pbar2:
            #         for i in range(1000):
            #             opt_mrr.zero_grad()
            #             opt_tmp.zero_grad()
            #             outputs = llm(input_ids=input_ids, attention_mask=attention_mask)
            #             loss = calc_shifted_loss_logits(torch.softmax(outputs.logits,dim=-1), torch.softmax(gt,dim=-1))
            #             # loss.backward()
            #             fx = llm.get_bottom_to_trunk_fx(detach=False)
            #             fx2 = llm.get_trunk_to_top_fx(detach=False)
            #             fx_diff = 0
            #             for f1,f2 in zip(real_fx.to(fx.device),fx):
            #                 fx_diff += ((f1-f2)**2).sum()
            #             dfx = torch.autograd.grad(loss, fx2, create_graph=True)
            #             dfx_diff = 0
            #             for f1,f2 in zip(real_dfx,dfx):
            #                 f2 = f2.to(f1.device)
            #                 dfx_diff += ((f1-f2)**2).sum() + torch.abs((f1 - f2)).sum()
            #             tag_loss = dfx_diff
            #             tag_loss.backward()
            #             opt_mrr.step()
            #             rouge = calculate_rouge(tokenizer, gt, batch['input_text'])
            #             sent = tokenizer.decode(torch.argmax(gt[0], dim=-1))
            #             pbar2.set_description(f'MIRROR {dfx_diff.item():.3f}')
            #             pbar2.set_postfix(rouge=rouge['rouge-l']['f'],sent=sent)
            #             pbar2.update(1)
            #     for nm, param in llm.get_trunk_params():
            #         param.requires_grad = True
            #     x = generate("To mix food coloring with sugar, you can", tokenizer, llm)
            #     print(x)
            #
            # simulator.restored_run(mirror_attack, ['top', 'bottom'], 'mirror')
            # outputs_pt = self.simulator.restored_forward('top', input_ids=input_ids, labels=input_ids,
            #                                               attention_mask=attention_mask)
            # avg_rouge_pt += calculate_rouge(tokenizer, outputs_pt.logits, batch['input_text'])['rouge-l']['f']
    def callback_intermediate_result(self, global_round, client_id, local_epoch, local_step,
                                     b2tr_fx, tr2b_grad,
                                     tr2t_fx, t2tr_grad,
                                     batch, logs):
        pass


simulator = SFLSimulator(client_ids=client_ids, strategy=QAFLStrategy(), llm=model, tokenizer=tokenizer,
                         dataset=fed_dataset, config=config)
# model.print_split_model()
# simulator.simulate()



## Step 3. 开始联邦模拟

In [6]:
import wandb

wandb.init(
    project="sfl-eval",
    name="gpt2-large-wikitext-ppl",
    # track hyperparameters and run metadata
    config={
        "dataset": 'code',
        "noise": "0.0"
    }
)

# ppl = evaluate_perplexity(simulator.llm, test_loader)
# report = {}
# report['test-ppl'] = ppl
# wandb.log(report)
simulator.simulate()

[34m[1mwandb[0m: Currently logged in as: [33mstupidtree[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011113008360068004, max=1.0…



  0%|          | 0/50 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:75.00 MB, downlink:75.00 MB


  0%|          | 0/50 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:75.00 MB, downlink:75.00 MB


  0%|          | 0/50 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:75.00 MB, downlink:75.00 MB
SERVER: AGGREGATION


  0%|          | 0/50 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:150.00 MB, downlink:150.00 MB


  0%|          | 0/50 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:150.00 MB, downlink:150.00 MB


  0%|          | 0/50 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:150.00 MB, downlink:150.00 MB
SERVER: AGGREGATION
Global Round 0 communication overhead: uplink=450.00 MB, downlink=450.00 MB
SERVER: AGGREGATION


  0%|          | 0/50 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:75.00 MB, downlink:75.00 MB


  0%|          | 0/50 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:75.00 MB, downlink:75.00 MB


  0%|          | 0/50 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:75.00 MB, downlink:75.00 MB
SERVER: AGGREGATION


  0%|          | 0/50 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:150.00 MB, downlink:150.00 MB


  0%|          | 0/50 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:150.00 MB, downlink:150.00 MB


  0%|          | 0/50 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:150.00 MB, downlink:150.00 MB
SERVER: AGGREGATION
Global Round 1 communication overhead: uplink=450.00 MB, downlink=450.00 MB
SERVER: AGGREGATION


  0%|          | 0/50 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:75.00 MB, downlink:75.00 MB


  0%|          | 0/50 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:75.00 MB, downlink:75.00 MB


  0%|          | 0/50 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:75.00 MB, downlink:75.00 MB
SERVER: AGGREGATION


  0%|          | 0/50 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:150.00 MB, downlink:150.00 MB


  0%|          | 0/50 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:150.00 MB, downlink:150.00 MB
SERVER: AGGREGATION
Global Round 2 communication overhead: uplink=375.00 MB, downlink=375.00 MB
SERVER: AGGREGATION


  0%|          | 0/50 [00:00<?, ?it/s]

Client 0 communication overhead: uplink:75.00 MB, downlink:75.00 MB


  0%|          | 0/50 [00:00<?, ?it/s]

Client 2 communication overhead: uplink:75.00 MB, downlink:75.00 MB


  0%|          | 0/50 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:75.00 MB, downlink:75.00 MB
SERVER: AGGREGATION


  0%|          | 0/50 [00:00<?, ?it/s]

Client 1 communication overhead: uplink:150.00 MB, downlink:150.00 MB
SERVER: AGGREGATION
Global Round 3 communication overhead: uplink=300.00 MB, downlink=300.00 MB
SERVER: AGGREGATION
FL communication overhead: uplink=1.54 GB, downlink=1.54 GB
