In [48]:
import sys 
sys.path.append("..")

import argparse
import time
import flwr as fl
from flwr.client import ClientApp
from flwr.common import Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft.tuners.lora import LoraConfig
from peft.utils.peft_types import TaskType
from peft.mapping import get_peft_model
import wandb
from dataset import load_validation_data
import shutil
import os
from torch.utils.data import ConcatDataset
from torch.utils.data import DataLoader

from client import GPT2FLClient
from dataset import load_data
from model import test, train

model_path = "distilbert/distilbert-base-cased"
# model_path = "gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_path)

model = AutoModelForSequenceClassification.from_pretrained(
        model_path,
        num_labels=3,
    )

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert/distilbert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [49]:
peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_lin", "v_lin"], 
)

model = get_peft_model(model, peft_config)

In [50]:
peft_config.target_modules

{'q_lin', 'v_lin'}

In [51]:
peft_config.modules_to_save

['classifier', 'score']

In [52]:
peft_config.layers_to_transform

In [53]:
def print_trainable_parameters(model):
    """
    Prints the number and names of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    print("Trainable layers:")
    for name, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            print(f"\n{name}")
            print(f"Shape: {param.shape}")
            trainable_params += param.numel()
    
    print(f"\nSummary:")
    print(f"trainable params: {trainable_params:,d}")
    print(f"all params: {all_param:,d}")
    print(f"trainable%: {100 * trainable_params / all_param:.2f}%")


print_trainable_parameters(model)

Trainable layers:

base_model.model.distilbert.transformer.layer.0.attention.q_lin.lora_A.default.weight
Shape: torch.Size([8, 768])

base_model.model.distilbert.transformer.layer.0.attention.q_lin.lora_B.default.weight
Shape: torch.Size([768, 8])

base_model.model.distilbert.transformer.layer.0.attention.v_lin.lora_A.default.weight
Shape: torch.Size([8, 768])

base_model.model.distilbert.transformer.layer.0.attention.v_lin.lora_B.default.weight
Shape: torch.Size([768, 8])

base_model.model.distilbert.transformer.layer.1.attention.q_lin.lora_A.default.weight
Shape: torch.Size([8, 768])

base_model.model.distilbert.transformer.layer.1.attention.q_lin.lora_B.default.weight
Shape: torch.Size([768, 8])

base_model.model.distilbert.transformer.layer.1.attention.v_lin.lora_A.default.weight
Shape: torch.Size([8, 768])

base_model.model.distilbert.transformer.layer.1.attention.v_lin.lora_B.default.weight
Shape: torch.Size([768, 8])

base_model.model.distilbert.transformer.layer.2.attention.q_l

In [55]:
model

PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): DistilBertForSequenceClassification(
      (distilbert): DistilBertModel(
        (embeddings): Embeddings(
          (word_embeddings): Embedding(28996, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (transformer): Transformer(
          (layer): ModuleList(
            (0-5): 6 x TransformerBlock(
              (attention): DistilBertSdpaAttention(
                (dropout): Dropout(p=0.1, inplace=False)
                (q_lin): lora.Linear(
                  (base_layer): Linear(in_features=768, out_features=768, bias=True)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.1, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=76