In [1]:
import torch
from torch import nn
import copy
import torch.nn.functional as F

class LoraLinear(nn.Module):
    def __init__(
            self,
            base_layer: nn.Linear,
            r: int = 8,
            alpha: int = 16,
            dropout: float = 0.1,
            test_mode:  bool = False
    ):
        super(LoraLinear, self).__init__()
        self.base_layer = copy.deepcopy(base_layer)
        self.r = r
        self.alpha = alpha
        self.dropout = nn.Dropout(dropout)

        self.lora_A = nn.Parameter(torch.empty((r, base_layer.in_features), dtype=base_layer.weight.dtype))
        self.loar_B = nn.Parameter(torch.empty((base_layer.out_features, r), dtype=base_layer.weight.dtype))

        nn.init.normal_(self.lora_A, mean=0.0, std=0.02)
        if test_mode:
            nn.init.normal_(self.loar_B, mean=0, std=0.02)
        else:
            nn.init.zeros_(self.loar_B)
        
        for param in self.base_layer.parameters():
            param.requires_grad = False
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        scaling = float(self.alpha) / float(self.r)
        lora_adjustment = F.linear(self.dropout(x), self.lora_A)
        lora_adjustment = F.linear(self.dropout(lora_adjustment), self.loar_B)
        return self.base_layer(x) + lora_adjustment * scaling

In [2]:
def replace_linear_with_lora(
        module: nn.Module,
        r: int = 8,
        alpha: int = 16,
        dropout: float = 0.1,
        embed_requires_grad: bool = False,
        norm_requires_grad: bool = False,
        head_requires_grad: bool = False,
        test_mode: bool = False
):
    for name, child in module.named_children():
        if any(s in name for s in ['embed', 'norm', 'lm_head']):
            requires_grad = embed_requires_grad if 'embed' in name \
                            else norm_requires_grad if 'norm' in name \
                            else head_requires_grad
            for param in child.parameters():
                param.requires_grad = requires_grad
        elif isinstance(child, nn.Linear):
            lora_linear = LoraLinear(child, r=r, alpha=alpha, dropout=dropout)
            setattr(module, name, lora_linear)
        else:
            replace_linear_with_lora(
                child, r, alpha, dropout,
                embed_requires_grad, norm_requires_grad, head_requires_grad,
                test_mode
            )

In [3]:
def print_trainable_parameters(model: nn.Module):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    trainable_percentage = 100 * trainable_params / total_params

    print(f'trainable params: {trainable_params:,} || all params: {total_params:,} || trainable%: {trainable_percentage:.4f}')

In [8]:
from transformers import AutoConfig

config = AutoConfig.for_model('llama')
config.hidden_size = 24
config.intermediate_size = config.hidden_size * 4
config.num_attention_heads = 4
config.num_hidden_layers = 4
config.num_key_value_heads = 2
config.vocab_size = 128

In [10]:
config

LlamaConfig {
  "attention_bias": false,
  "attention_dropout": 0.0,
  "bos_token_id": 1,
  "eos_token_id": 2,
  "head_dim": 128,
  "hidden_act": "silu",
  "hidden_size": 24,
  "initializer_range": 0.02,
  "intermediate_size": 96,
  "max_position_embeddings": 2048,
  "mlp_bias": false,
  "model_type": "llama",
  "num_attention_heads": 4,
  "num_hidden_layers": 4,
  "num_key_value_heads": 2,
  "pretraining_tp": 1,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000.0,
  "tie_word_embeddings": false,
  "transformers_version": "4.53.0",
  "use_cache": true,
  "vocab_size": 128
}

In [11]:
from transformers import AutoModel, AutoModelForCausalLM

raw_model = AutoModel.from_config(config)

print(raw_model)

LlamaModel(
  (embed_tokens): Embedding(128, 24)
  (layers): ModuleList(
    (0-3): 4 x LlamaDecoderLayer(
      (self_attn): LlamaAttention(
        (q_proj): Linear(in_features=24, out_features=512, bias=False)
        (k_proj): Linear(in_features=24, out_features=256, bias=False)
        (v_proj): Linear(in_features=24, out_features=256, bias=False)
        (o_proj): Linear(in_features=512, out_features=24, bias=False)
      )
      (mlp): LlamaMLP(
        (gate_proj): Linear(in_features=24, out_features=96, bias=False)
        (up_proj): Linear(in_features=24, out_features=96, bias=False)
        (down_proj): Linear(in_features=96, out_features=24, bias=False)
        (act_fn): SiLU()
      )
      (input_layernorm): LlamaRMSNorm((24,), eps=1e-06)
      (post_attention_layernorm): LlamaRMSNorm((24,), eps=1e-06)
    )
  )
  (norm): LlamaRMSNorm((24,), eps=1e-06)
  (rotary_emb): LlamaRotaryEmbedding()
)


In [13]:
print_trainable_parameters(raw_model)

trainable params: 178,392 || all params: 178,392 || trainable%: 100.0000


In [7]:
lora_model = copy.deepcopy(raw_model)
replace_linear_with_lora(lora_model, r=8, alpha=16)
print_trainable_parameters(lora_model)
print(lora_model)

trainable params: 63,744 || all params: 242,136 || trainable%: 26.3257
LlamaModel(
  (embed_tokens): Embedding(128, 24)
  (layers): ModuleList(
    (0-3): 4 x LlamaDecoderLayer(
      (self_attn): LlamaAttention(
        (q_proj): LoraLinear(
          (base_layer): Linear(in_features=24, out_features=512, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (k_proj): LoraLinear(
          (base_layer): Linear(in_features=24, out_features=256, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (v_proj): LoraLinear(
          (base_layer): Linear(in_features=24, out_features=256, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (o_proj): LoraLinear(
          (base_layer): Linear(in_features=512, out_features=24, bias=False)
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (mlp): LlamaMLP(
        (gate_proj): LoraLinear(
          (base_layer): Linear(in_features=24, out_

In [15]:
def print_model_parameters(model: nn.Module):
    print('Layer Name & Parameters')
    print('------------------------------')
    for name, param in model.named_parameters():
        print(f'{name:50} | Requires_grad: {param.requires_grad}')

In [None]:
from peft import LoraConfig, get_peft_model

lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules='all-linear', # 太低版本的 peft 不支持这种做法
)
peft_lora_model = copy.deepcopy(raw_model)
peft_lora_model = get_peft_model(peft_lora_model, lora_config)
peft_lora_model.print_trainable_parameters()

trainable params: 63,744 || all params: 242,136 || trainable%: 26.3257


In [17]:
peft_lora_model

PeftModel(
  (base_model): LoraModel(
    (model): LlamaModel(
      (embed_tokens): Embedding(128, 24)
      (layers): ModuleList(
        (0-3): 4 x LlamaDecoderLayer(
          (self_attn): LlamaAttention(
            (q_proj): lora.Linear(
              (base_layer): Linear(in_features=24, out_features=512, bias=False)
              (lora_dropout): ModuleDict(
                (default): Identity()
              )
              (lora_A): ModuleDict(
                (default): Linear(in_features=24, out_features=8, bias=False)
              )
              (lora_B): ModuleDict(
                (default): Linear(in_features=8, out_features=512, bias=False)
              )
              (lora_embedding_A): ParameterDict()
              (lora_embedding_B): ParameterDict()
              (lora_magnitude_vector): ModuleDict()
            )
            (k_proj): lora.Linear(
              (base_layer): Linear(in_features=24, out_features=256, bias=False)
              (lora_dropout): Module

In [18]:
from typing import List

def upload_lora(module: nn.Module, adapter_name: str ='adapter'):
    lora_parameters = {}
    def search_lora_linear(module: nn.Module, prefix: List[str]):
        for name, child in module.named_children():
            new_prefix = prefix + [name]
            if isinstance(child, LoraLinear):
                lora_parameters['.'.join(new_prefix)] = {
                    'lora_A_weight': child.lora_A.data.cpu(),
                    'lora_B_weight': child.lora_B.data.cpu(),
                    'r': child.r,
                    'alpha': child.alpha,
                    'dropout': child.dropout
                }
                setattr(module, name, child.base_layer)
            else:
                search_lora_linear(child, new_prefix)
    search_lora_linear(module, [])
    for name, param in module.named_parameters():
        param.requires_grad = True
    torch.save(lora_parameters, f'{adapter_name}.pt')

In [19]:
def load_lora(module: nn.Module, adapter_name: str = 'adapter'):
    lora_parameters = torch.load(f'{adapter_name}.pt')
    for name, lora_params in lora_parameters.items():
        child = dict(module.named_modules())[name]
        if isinstance(child, nn.Linear):
            lora_linear = LoraLinear(child, lora_params['r'], lora_params['alpha'], lora_params['dropout'])
            lora_linear.lora_A.data = lora_params['lora_A_weight'].to(lora_linear.lora_A.device)
            lora_linear.loar_B.data = lora_params['lora_B_weight'].to(lora_linear.loar_B.device)

            parts = name.split('.')
            obj = module
            for part in parts[:-1]:
                obj = getattr(obj, part)
            setattr(obj, parts[-1], lora_linear)
    for name, param in module.named_parameters():
        if any(s in name for s in ['embed', 'norm', 'lm_head']):
            param.requires_grad = False