In [1]:
import os 
import torch
print(torch.__version__)
# os.environ["CUDA_VISIBLE_DEVICES"] = '1'
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'  # For synchronous execution of CPU and GPU
client_device = torch.device("cuda:2")
server_device = torch.device("cuda:1")
model_device = torch.device("cuda:0")

  from .autonotebook import tqdm as notebook_tqdm


1.13.1+cu116


In [14]:
from transformers import OPTForCausalLM
import torch
from typing import Optional, List
from torch.nn import CrossEntropyLoss
from transformers.modeling_outputs import CausalLMOutputWithPast


class ClientSideOPTForCausalLM(OPTForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        split_outputs = self.model.decoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        # print(hidden_outputs)
        return split_outputs, attention_mask, head_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict

class ServerSideOPTForCausalLM(OPTForCausalLM):
    def __init__(self, config):
        super().__init__(config)
        split_layer = 2
        self.model.decoder.layers = self.model.decoder.layers[split_layer:]
        self.model.decoder.embed_positions = None
        self.model.decoder.embed_tokens = None
        self.model.decoder.final_layer_norm = None
        

    def forward(
        self,
        split_outputs: torch.Tensor,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        
        outputs = self.model.decoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            head_mask=head_mask,
            past_key_values=past_key_values,
            inputs_embeds=split_outputs,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        logits = self.lm_head(outputs[0]).contiguous()

        loss = None
        if labels is not None:
            # move labels to correct device to enable model parallelism
            labels = labels.to(logits.device)
            # Shift so that tokens < n predict n
            shift_logits = logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()
            # Flatten the tokens
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
        

In [3]:
from transformers import OPTConfig
config = OPTConfig(num_hidden_layers=1)
client_model = ClientSideOPTForCausalLM(config)
num_hidden_layers=11
config = OPTConfig()
server_model = ServerSideOPTForCausalLM(config)
    

In [4]:
print(client_model)

ClientSideOPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): Embedding(50272, 768, padding_idx=1)
      (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
      (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (layers): ModuleList(
        (0): OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), 

In [5]:
print(server_model)

ServerSideOPTForCausalLM(
  (model): OPTModel(
    (decoder): OPTDecoder(
      (embed_tokens): None
      (embed_positions): None
      (final_layer_norm): None
      (layers): ModuleList(
        (0): OPTDecoderLayer(
          (self_attn): OPTAttention(
            (k_proj): Linear(in_features=768, out_features=768, bias=True)
            (v_proj): Linear(in_features=768, out_features=768, bias=True)
            (q_proj): Linear(in_features=768, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=768, bias=True)
          )
          (activation_fn): ReLU()
          (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (fc1): Linear(in_features=768, out_features=3072, bias=True)
          (fc2): Linear(in_features=3072, out_features=768, bias=True)
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        )
        (1): OPTDecoderLayer(
          (self_attn): OPTAttention(
    

In [6]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import LoraConfig, get_peft_model
import torch

MODEL_NAME = "facebook/opt-125m"
tokenizer = AutoTokenizer.from_pretrained(
    MODEL_NAME, 
    trust_remote_code=True,
    cache_dir="/app/.huggingface_cache/model/"
)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    cache_dir="/app/.huggingface_cache/model/",
    trust_remote_code=True,
    torch_dtype=torch.float16
)


lora_modules=["q_proj", "k_proj", "v_proj", "out_proj", "fc1", "fc2", "lm_head"] 
lora_config = LoraConfig(
    r=1,  # dimension of the updated matrices
    lora_alpha=64,  # parameter for scaling
    target_modules=lora_modules,
    lora_dropout=0.1,  # dropout probability for layers
    bias="none",
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
client_model = get_peft_model(client_model, lora_config)
server_model = get_peft_model(server_model, lora_config)

In [7]:
print(model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): OPTForCausalLM(
      (model): OPTModel(
        (decoder): OPTDecoder(
          (embed_tokens): Embedding(50272, 768, padding_idx=1)
          (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (layers): ModuleList(
            (0): OPTDecoderLayer(
              (self_attn): OPTAttention(
                (k_proj): Linear(
                  in_features=768, out_features=768, bias=True
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.1, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=768, out_features=1, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=1, out_features=768, bias=False)
                  )
                  (lora_embeddin

In [8]:
print(server_model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): ServerSideOPTForCausalLM(
      (model): OPTModel(
        (decoder): OPTDecoder(
          (embed_tokens): None
          (embed_positions): None
          (final_layer_norm): None
          (layers): ModuleList(
            (0): OPTDecoderLayer(
              (self_attn): OPTAttention(
                (k_proj): Linear(
                  in_features=768, out_features=768, bias=True
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.1, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=768, out_features=1, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=1, out_features=768, bias=False)
                  )
                  (lora_embedding_A): ParameterDict()
                  (lora_embedding_B): ParameterDict()
                )
             

In [9]:
print(client_model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): ClientSideOPTForCausalLM(
      (model): OPTModel(
        (decoder): OPTDecoder(
          (embed_tokens): Embedding(50272, 768, padding_idx=1)
          (embed_positions): OPTLearnedPositionalEmbedding(2050, 768)
          (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
          (layers): ModuleList(
            (0): OPTDecoderLayer(
              (self_attn): OPTAttention(
                (k_proj): Linear(
                  in_features=768, out_features=768, bias=True
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.1, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=768, out_features=1, bias=False)
                  )
                  (lora_B): ModuleDict(
                    (default): Linear(in_features=1, out_features=768, bias=False)
                  )
                  (lor

In [10]:
from datasets import load_dataset
from transformers import default_data_collator
from data_utils import transform_data_to_fedml_format, group_texts, tokenize_function
from functools import partial
from torch.utils.data import DataLoader
import torch

tokenizer = AutoTokenizer.from_pretrained(
    "facebook/opt-125m", 
    trust_remote_code=True,
    cache_dir="/app/.huggingface_cache/model/"
)
inputs = tokenizer("Hello world!",return_tensors="pt")
print(inputs)

# resize embeddings
embedding_size = client_model.get_input_embeddings().weight.shape[0]
if len(tokenizer) > embedding_size:
    client_model.resize_token_embeddings(len(tokenizer))


# load data
print(tokenizer)
block_size = 1024
raw_datasets = load_dataset(
    "wikitext",
    "wikitext-2-raw-v1",
    cache_dir="/app/.huggingface_cache/dataset/",
    streaming=False
)
print(raw_datasets)
column_names = list(raw_datasets["train"].features)
print(column_names)

# data preprocessing 
__tokenize_function = partial(tokenize_function, text_column_name='text', tokenizer=tokenizer)
tokenized_datasets = raw_datasets.map(
                    __tokenize_function,
                    batched=True,
                    remove_columns=column_names,
                    desc="Running tokenizer on dataset",
                )
print(tokenized_datasets)

__group_texts = partial(group_texts, block_size=block_size)
lm_datasets = tokenized_datasets.map(
                __group_texts,
                batched=True,
                # num_proc=1,
                # load_from_cache_file=not data_args.overwrite_cache,
                # desc=f"Grouping texts in chunks of {block_size}",
            )

print(lm_datasets['train'])
lm_datasets['train'].set_format("torch", device=client_device)
train_dataloader = DataLoader(lm_datasets['train'], shuffle=True, collate_fn=default_data_collator, batch_size=1)
print(train_dataloader)
no_decay = ["bias", "layer_norm.weight"]
client_grouped_parameters = [
    {
        "params": [p for n, p in client_model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.01,
    },
    {
        "params": [p for n, p in client_model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
server_grouped_parameters = [
    {
        "params": [p for n, p in server_model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.01,
    },
    {
        "params": [p for n, p in server_model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
model_grouped_paramters = [
    {
        "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        "weight_decay": 0.01,
    },
    {
        "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        "weight_decay": 0.0,
    },
]
server_optimizer = torch.optim.AdamW(server_grouped_parameters, lr=0.001)
client_optimizer = torch.optim.AdamW(client_grouped_parameters, lr=0.001)
model_optimizer = torch.optim.AdamW(model_grouped_paramters, lr=0.001)

{'input_ids': tensor([[    2, 31414,   232,   328]]), 'attention_mask': tensor([[1, 1, 1, 1]])}
GPT2TokenizerFast(name_or_path='facebook/opt-125m', vocab_size=50265, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'eos_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'unk_token': AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=True), 'pad_token': AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=True)}, clean_up_tokenization_spaces=True)
DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})
['text']
DatasetDict({


In [11]:
def unfreeze_params(model):
    for param in model.parameters():
        param.requires_grad = True
        
def print_trainable_params(model):
    for name, param in model.named_parameters():
        if(param.requires_grad):
            print(name, param.requires_grad)
        
def print_grad(model):
    for name, param in model.named_parameters():
        if(param.requires_grad):
            print(name, param.grad)
            
# unfreeze_params(client_model)
# unfreeze_params(server_model)
print("\n\n---- trainable params of server_model ----")
print_trainable_params(server_model)
print("\n\n---- trainable params of client_model ----")
print_trainable_params(client_model)
print(f"\n {len(train_dataloader)}")



---- trainable params of server_model ----
base_model.model.model.decoder.layers.0.self_attn.k_proj.lora_A.default.weight True
base_model.model.model.decoder.layers.0.self_attn.k_proj.lora_B.default.weight True
base_model.model.model.decoder.layers.0.self_attn.v_proj.lora_A.default.weight True
base_model.model.model.decoder.layers.0.self_attn.v_proj.lora_B.default.weight True
base_model.model.model.decoder.layers.0.self_attn.q_proj.lora_A.default.weight True
base_model.model.model.decoder.layers.0.self_attn.q_proj.lora_B.default.weight True
base_model.model.model.decoder.layers.0.self_attn.out_proj.lora_A.default.weight True
base_model.model.model.decoder.layers.0.self_attn.out_proj.lora_B.default.weight True
base_model.model.model.decoder.layers.0.fc1.lora_A.default.weight True
base_model.model.model.decoder.layers.0.fc1.lora_B.default.weight True
base_model.model.model.decoder.layers.0.fc2.lora_A.default.weight True
base_model.model.model.decoder.layers.0.fc2.lora_B.default.weight 

In [16]:
import math
import time
import numpy as np

print_grad = False
model.train()
client_model.train()
server_model.train()
print(len(train_dataloader))

latency = {
    "client": [],
    "server": [],
    "end-to-end": [],
    "model": []
}


scaler = torch.cuda.amp.GradScaler()
for step, batch in enumerate(train_dataloader):
    print(f"step = {step} ")
    
    for key in batch.keys():
        batch[key] = batch[key].to(device=client_device)
    
    client_model, server_model = client_model.to(dtype=torch.float16, device=client_device), server_model.to(dtype=torch.float16, device=server_device)
    # Split training
    # with torch.autocast(server_model.device.type, dtype=torch.float16, enabled=True):
    client_start_time = time.perf_counter()
    # input_ids, past_key_values, attention_mask, labels, acts = client_model(**batch)
    split_output, attention_mask, head_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict = client_model(**batch)
    client_end_time = time.perf_counter()
    split_output, attention_mask, head_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict
    
    split_output = split_output.to(server_device)
    # acts.retain_grad()
    # input_ids, attention_mask, labels, acts = input_ids.to(device=server_device), attention_mask.to(device=server_device), labels.to(device=server_device), acts.to(device=server_device)
    # split_output, attention_mask, head_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict = split_output, attention_mask, head_mask, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict.to(server_device)
    
    server_start_time = time.perf_counter()
    # split_outputs = server_model(input_ids=input_ids, past_key_values=past_key_values, attention_mask=attention_mask, labels=labels, hidden_layer_input=acts)
    split_outputs = server_model()
    server_end_time = time.perf_counter()
        
    split_loss = split_outputs.loss
    # scaler.scale(split_loss).backward()
    split_loss.backward()
    split_perplexity = math.exp(split_loss)
    
    # scaler.step(server_optimizer)
    # scaler.step(client_optimizer)
    server_optimizer.step()
    client_optimizer.step()
    
    # scaler.update()
    if(print_grad):
        print("\n\n---- grad on server_model ----")
        print_grad(server_model)    
        print("\n\n---- grad on client_model ----")
        print_grad(client_model)
    
    latency["client"].append(client_end_time-client_start_time)
    latency["server"].append(server_end_time-server_start_time)
    latency["end-to-end"].append(server_end_time-client_start_time)
    print(f"  - split_loss = {split_loss}")
    print(f"  - split_perplexity = {split_perplexity}")
    print(f"  - Latency (sec): Client = {np.average(latency['client'])}  ||  Server = {np.average(latency['server'])}  || End-to-end = {np.average(latency['end-to-end'])}")
    
    server_optimizer.zero_grad()
    client_optimizer.zero_grad()

2355
step = 0 
BaseModelOutputWithPast(last_hidden_state=tensor([[[ 0.4541,  0.0764,  0.3213,  ...,  0.9365,  0.0616, -0.5015],
         [-0.8564,  0.6431, -0.2213,  ...,  0.7637,  0.3423, -2.0781],
         [-0.2100,  0.0800, -0.9648,  ...,  0.0854,  2.5332, -2.0527],
         ...,
         [ 0.3835,  0.0983,  0.2330,  ..., -0.0782,  0.1295,  0.1847],
         [ 0.8315, -1.9922,  0.0691,  ...,  0.6812, -1.2275,  0.4658],
         [-0.4324, -0.7661, -1.2197,  ..., -1.1074, -0.6133, -1.2627]]],
       device='cuda:2', dtype=torch.float16, grad_fn=<NativeLayerNormBackward0>), past_key_values=((tensor([[[[ 7.9102e-01,  3.6670e-01, -8.7402e-01,  ...,  2.0557e-01,
            9.8206e-02, -7.5439e-01],
          [-5.7666e-01,  4.5068e-01, -4.8364e-01,  ..., -2.6514e-01,
           -8.4326e-01, -1.0675e-01],
          [-1.5625e-01,  7.1094e-01,  4.7913e-02,  ..., -1.1719e-01,
           -3.4473e-01, -7.5537e-01],
          ...,
          [-2.8027e-01, -2.0544e-01,  1.4807e-01,  ...,  9.0674e-

AttributeError: 'BaseModelOutputWithPast' object has no attribute 'to'