In [16]:
import argparse
import time
import flwr as fl
from flwr.client import ClientApp
from flwr.common import Context
from flwr.server import ServerApp, ServerConfig, ServerAppComponents
from flwr.server.strategy import FedAvg
from flwr.simulation import run_simulation
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from peft.tuners.lora import LoraConfig
from peft.utils.peft_types import TaskType
from peft.mapping import get_peft_model
import wandb
from dataset import load_validation_data
import shutil
import os
from torch.utils.data import ConcatDataset
from torch.utils.data import DataLoader

from client import GPT2FLClient
from dataset import load_data
from model import test, train

model_path = "openai-community/gpt2"

tokenizer = AutoTokenizer.from_pretrained(model_path)

model = AutoModelForSequenceClassification.from_pretrained(
        model_path,
        num_labels=3,
    )

Some weights of GPT2ForSequenceClassification were not initialized from the model checkpoint at openai-community/gpt2 and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [53]:
peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
)

model = get_peft_model(model, peft_config)

In [49]:
peft_config.target_modules

{'c_attn'}

In [50]:
peft_config.layers_to_transform

[0]

In [52]:
def print_trainable_parameters(model):
    """
    Prints the number and names of trainable parameters in the model.
    """
    trainable_params = 0
    all_param = 0
    print("Trainable layers:")
    for name, param in model.named_parameters():
        all_param += param.numel()
        if param.requires_grad:
            print(f"\n{name}")
            print(f"Shape: {param.shape}")
            trainable_params += param.numel()
    
    print(f"\nSummary:")
    print(f"trainable params: {trainable_params:,d}")
    print(f"all params: {all_param:,d}")
    print(f"trainable%: {100 * trainable_params / all_param:.2f}%")


print_trainable_parameters(model)

Trainable layers:

base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.transformer.h.0.attn.c_attn.lora_A.default.weight
Shape: torch.Size([8, 768])

base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.transformer.h.0.attn.c_attn.lora_B.default.weight
Shape: torch.Size([2304, 8])

base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.transformer.h.1.attn.c_attn.lora_A.default.weight
Shape: torch.Size([8, 768])

base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.transformer.h.1.attn.c_attn.lora_B.default.weight
Shape: torch.Size([2304, 8])

base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.base_model.model.transformer.h.2.attn.c_attn.lora_A.default.weight
Shape: torch.Size(

In [23]:
TaskType.TOKEN_CLS

<TaskType.TOKEN_CLS: 'TOKEN_CLS'>