In [1]:
import os 
import torch
# os.environ["CUDA_VISIBLE_DEVICES"] = '1'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'  # For synchronous execution of CPU and GPU
device = torch.device("cuda:1")


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from modeling_mixformer_sequential import MixFormerSequentialForCausalLM, InferenceParams
from configuration_mixformer_sequential import MixFormerSequentialConfig

from transformers.modeling_outputs import CausalLMOutputWithPast
import torch
from typing import Any, Dict, Optional, Tuple, Union

class ClientSideMixFormerSequentialForCausalLM(MixFormerSequentialForCausalLM):
    
    def __init__(self, config):
        super().__init__(config)
        self.split_layer=2
        self.num_layers=20
        for i in range(self.split_layer, self.num_layers+1):
            self.layers[i] = None
            
    def forward(
        self,
        input_ids: torch.LongTensor,
        past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
        attention_mask: Optional[torch.BoolTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        **kwargs,
    ):
        if attention_mask is not None and self.training:
            print("`attention_mask` is not supported during training. Using it might lead to unexpected results.")

        if past_key_values is None and attention_mask is None:
            print("[Client] past_key_values & attention_mask is None!")
            lm_logits = self.layers(input_ids)
            return lm_logits
        else:
            print("[Client] forward with past_key_values or attention_mask!")
            hidden_layer = self.layers[0](input_ids)
            for module in self.layers[1: self.split_layer]:  # return intermediate tensor
                hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
            return input_ids, past_key_values, attention_mask, labels, hidden_layer



class ServerSideMixFormerSequentialForCausalLM(MixFormerSequentialForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        self.split_layer=2
        for i in range(0, self.split_layer):
            self.layers[i] = None
        
    def forward(
        self,
        input_ids: torch.LongTensor,
        past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
        attention_mask: Optional[torch.BoolTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        hidden_layer_input: torch.Tensor = None,
        **kwargs,
    ) -> CausalLMOutputWithPast:
        if attention_mask is not None and self.training:
            print("`attention_mask` is not supported during training. Using it might lead to unexpected results.")

        if past_key_values is None and attention_mask is None:
            print("[Server] past_key_values & attention_mask is None!")
            lm_logits = self.layers(input_ids)
        else:
            print("[Server] forward with past_key_values or attention_mask!")
            hidden_layer = hidden_layer_input
            for module in self.layers[self.split_layer:-1]:  # Compute the remaining block 
                hidden_layer = module(hidden_layer, past_key_values=past_key_values, attention_mask=attention_mask)
            lm_logits = self.layers[-1](hidden_layer)
            
        loss = None
        if labels is not None:
            loss = self.loss(lm_logits, labels)

        return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)



[2023-10-13 15:54:18,968] [INFO] [real_accelerator.py:133:get_accelerator] Setting ds_accelerator to cuda (auto detect)


In [3]:
config = MixFormerSequentialConfig()
client_model = ClientSideMixFormerSequentialForCausalLM(config)
server_model = ServerSideMixFormerSequentialForCausalLM(config)

In [4]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
model_name = ("gpt2", "microsoft/phi-1_5")
tokenizer = AutoTokenizer.from_pretrained(
    model_name[1], 
    trust_remote_code=True,
    cache_dir="/app/.huggingface_cache/model/"
)
model = AutoModelForCausalLM.from_pretrained(
    model_name[1],
    cache_dir="/app/.huggingface_cache/model/",
    trust_remote_code=True,
)
print(model)
lora_modules=["Wqkv"] 
lora_config = LoraConfig(
    r=2,  # dimension of the updated matrices
    lora_alpha=64,  # parameter for scaling
    target_modules=lora_modules,
    lora_dropout=0.1,  # dropout probability for layers
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
client_model = get_peft_model(client_model, lora_config)
server_model = get_peft_model(server_model, lora_config)

MixFormerSequentialForCausalLM(
  (layers): Sequential(
    (0): Embedding(
      (wte): Embedding(51200, 2048)
      (drop): Dropout(p=0.0, inplace=False)
    )
    (1): ParallelBlock(
      (ln): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
      (resid_dropout): Dropout(p=0.0, inplace=False)
      (mixer): MHA(
        (rotary_emb): RotaryEmbedding()
        (Wqkv): Linear(in_features=2048, out_features=6144, bias=True)
        (out_proj): Linear(in_features=2048, out_features=2048, bias=True)
        (inner_attn): SelfAttention(
          (drop): Dropout(p=0.0, inplace=False)
        )
        (inner_cross_attn): CrossAttention(
          (drop): Dropout(p=0.0, inplace=False)
        )
      )
      (mlp): MLP(
        (fc1): Linear(in_features=2048, out_features=8192, bias=True)
        (fc2): Linear(in_features=8192, out_features=2048, bias=True)
        (act): NewGELUActivation()
      )
    )
    (2): ParallelBlock(
      (ln): LayerNorm((2048,), eps=1e-05, elementwis

In [5]:
from datasets import load_dataset
from transformers import default_data_collator
from data_utils import transform_data_to_fedml_format, group_texts, tokenize_function
from functools import partial
from torch.utils.data import DataLoader
import torch

# resize embeddings
embedding_size = client_model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
    client_model.resize_token_embeddings(len(tokenizer))

# load data

block_size = tokenizer.model_max_length
raw_datasets = load_dataset(
    "wikitext",
    "wikitext-2-raw-v1",
    cache_dir="/app/.huggingface_cache/dataset/",
    streaming=False
)
column_names = list(raw_datasets["train"].features)

# data preprocessing 
__tokenize_function = partial(tokenize_function, text_column_name="text", tokenizer=tokenizer)
tokenized_datasets = raw_datasets.map(
                    __tokenize_function,
                    batched=True,
                    remove_columns=column_names,
                    desc="Running tokenizer on dataset",
                )

__group_texts = partial(group_texts, block_size=block_size)
lm_datasets = tokenized_datasets.map(
                __group_texts,
                batched=True,
                # num_proc=1,
                # load_from_cache_file=not data_args.overwrite_cache,
                # desc=f"Grouping texts in chunks of {block_size}",
            )
lm_datasets['train'].set_format("torch", device=device)
train_dataloader = DataLoader(lm_datasets['train'], shuffle=True, collate_fn=default_data_collator, batch_size=1)
print(lm_datasets['train'])
print(train_dataloader)
no_decay = ["bias", "layer_norm.weight"]
client_grouped_parameters = [
    {
        "params": [p for n, p in client_model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.01,
    },
    {
        "params": [p for n, p in client_model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
server_grouped_parameters = [
    {
        "params": [p for n, p in server_model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.01,
    },
    {
        "params": [p for n, p in server_model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
server_optimizer = torch.optim.AdamW(server_grouped_parameters, lr=0.001)
client_optimizer = torch.optim.AdamW(client_grouped_parameters, lr=0.001)

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 1150
})
<torch.utils.data.dataloader.DataLoader object at 0x7f954406da30>


In [6]:
def unfreeze_params(model):
    for param in model.parameters():
        param.requires_grad = True
        
def print_trainable_params(model):
    for name, param in model.named_parameters():
        if(param.requires_grad):
            print(name, param.requires_grad)
        
def print_grad(model):
    for name, param in model.named_parameters():
        if(param.requires_grad):
            print(name, param.grad)
            
# unfreeze_params(client_model)
# unfreeze_params(server_model)
print("\n\n---- trainable params of server_model ----")
print_trainable_params(server_model)
print("\n\n---- trainable params of client_model ----")
print_trainable_params(client_model)



---- trainable params of server_model ----
base_model.model.layers.2.mixer.Wqkv.lora_A.default.weight True
base_model.model.layers.2.mixer.Wqkv.lora_B.default.weight True
base_model.model.layers.3.mixer.Wqkv.lora_A.default.weight True
base_model.model.layers.3.mixer.Wqkv.lora_B.default.weight True
base_model.model.layers.4.mixer.Wqkv.lora_A.default.weight True
base_model.model.layers.4.mixer.Wqkv.lora_B.default.weight True
base_model.model.layers.5.mixer.Wqkv.lora_A.default.weight True
base_model.model.layers.5.mixer.Wqkv.lora_B.default.weight True
base_model.model.layers.6.mixer.Wqkv.lora_A.default.weight True
base_model.model.layers.6.mixer.Wqkv.lora_B.default.weight True
base_model.model.layers.7.mixer.Wqkv.lora_A.default.weight True
base_model.model.layers.7.mixer.Wqkv.lora_B.default.weight True
base_model.model.layers.8.mixer.Wqkv.lora_A.default.weight True
base_model.model.layers.8.mixer.Wqkv.lora_B.default.weight True
base_model.model.layers.9.mixer.Wqkv.lora_A.default.weight 

In [7]:
import math
import time

        
model.train()
counter = 0
for step, batch in enumerate(train_dataloader):
    print(f"step = {step} ")
    if(counter==2):
        exit()
    for key in batch.keys():
        batch[key] = batch[key].to(device)
    # batch = batch.to(device)
    print(f"batch: {batch}")
    client_model, server_model = client_model.to(device), server_model.to(device)
    
    client_start_time = time.perf_counter()
    input_ids, past_key_values, attention_mask, labels, acts = client_model(**batch)
    client_end_time = time.perf_counter()
    # acts.retain_grad()
    
    server_start_time = time.perf_counter()
    outputs = server_model(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, labels=labels, hidden_layer_input=acts)
    server_end_time = time.perf_counter()
    loss = outputs.loss
    perplexity = math.exp(loss)
    loss.backward()
    
    server_optimizer.step()
    client_optimizer.step()
    
    print("\n\n---- grad on server_model ----")
    print_grad(server_model)    
    print("\n\n---- grad on client_model ----")
    print_grad(client_model)
    
    server_optimizer.zero_grad()
    client_optimizer.zero_grad()
    print(f"  - loss = {loss}")
    print(f"  - perplexity = {perplexity}")
    print(f"  - Client = {client_end_time-client_start_time}  ||  Server = {server_end_time-server_start_time}")
    counter = counter + 1


step = 0 
batch: {'input_ids': tensor([[ 373, 6928,  416,  ...,   67,  446,  550]], device='cuda:1'), 'attention_mask': tensor([[1, 1, 1,  ..., 1, 1, 1]], device='cuda:1'), 'labels': tensor([[ 373, 6928,  416,  ...,   67,  446,  550]], device='cuda:1')}
`attention_mask` is not supported during training. Using it might lead to unexpected results.
[Client] forward with past_key_values or attention_mask!
`attention_mask` is not supported during training. Using it might lead to unexpected results.
[Server] forward with past_key_values or attention_mask!


---- grad on server_model ----
base_model.model.layers.2.mixer.Wqkv.lora_A.default.weight tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], device='cuda:1')
base_model.model.layers.2.mixer.Wqkv.lora_B.default.weight tensor([[-3.1633e-05, 

KeyboardInterrupt: 

: 