In [1]:
from transformers import GPTNeoXForCausalLM
from peft import PeftModel, LoraConfig, get_peft_model
import torch
import torch.nn as nn
import math

  from .autonotebook import tqdm as notebook_tqdm


In [36]:
# first get the finetuned small model

# small_name = "EleutherAI/pythia-70m"
small_name = "EleutherAI/pythia-160m"
# large_name = "EleutherAI/pythia-160m"
large_name = "EleutherAI/pythia-410m"
small_adapter = "weight/pythia_160m_r=64_0.0001_schedule"

expanded_model = "weight/pythia_160m_410m_padding_r=64_schedule/"


new_in_a, new_out_a = 1024, 64   # 512, 768, 1024
new_in_b, new_out_b = 64, 3072  # ? , 2304, 3072

model_small_pt = GPTNeoXForCausalLM.from_pretrained(small_name)

In [37]:
model_small_pt

GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 768)
    (emb_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-11): 12 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (post_attention_dropout): Dropout(p=0.0, inplace=False)
        (post_mlp_dropout): Dropout(p=0.0, inplace=False)
        (attention): GPTNeoXAttention(
          (rotary_emb): GPTNeoXRotaryEmbedding()
          (query_key_value): Linear(in_features=768, out_features=2304, bias=True)
          (dense): Linear(in_features=768, out_features=768, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=768, out_features=3072, bias=True)
          (dense_4h_to_h): Linear(in_features=3072, out_features=768, bias=True)
          

In [38]:
model_small_ft = PeftModel.from_pretrained(model_small_pt, small_adapter)

In [39]:
model_small_ft

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GPTNeoXForCausalLM(
      (gpt_neox): GPTNeoXModel(
        (embed_in): Embedding(50304, 768)
        (emb_dropout): Dropout(p=0.0, inplace=False)
        (layers): ModuleList(
          (0-11): 12 x GPTNeoXLayer(
            (input_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (post_attention_layernorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
            (post_attention_dropout): Dropout(p=0.0, inplace=False)
            (post_mlp_dropout): Dropout(p=0.0, inplace=False)
            (attention): GPTNeoXAttention(
              (rotary_emb): GPTNeoXRotaryEmbedding()
              (query_key_value): lora.Linear(
                (base_layer): Linear(in_features=768, out_features=2304, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default): 

In [40]:
# check the in_features and out_features of qkv

model_large_pt = GPTNeoXForCausalLM.from_pretrained(large_name)
model_large_pt

GPTNeoXForCausalLM(
  (gpt_neox): GPTNeoXModel(
    (embed_in): Embedding(50304, 1024)
    (emb_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-23): 24 x GPTNeoXLayer(
        (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
        (post_attention_dropout): Dropout(p=0.0, inplace=False)
        (post_mlp_dropout): Dropout(p=0.0, inplace=False)
        (attention): GPTNeoXAttention(
          (rotary_emb): GPTNeoXRotaryEmbedding()
          (query_key_value): Linear(in_features=1024, out_features=3072, bias=True)
          (dense): Linear(in_features=1024, out_features=1024, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (mlp): GPTNeoXMLP(
          (dense_h_to_4h): Linear(in_features=1024, out_features=4096, bias=True)
          (dense_4h_to_h): Linear(in_features=4096, out_features=1024, bias=True)
  

In [41]:
# expand module

def expand_lora_copy(old_module, new_in, new_out):
    new_module = nn.Linear(new_in, new_out, bias=old_module.bias is not None)
    nn.init.zeros_(new_module.weight)
    
    with torch.no_grad():
        new_module.weight[:old_module.out_features, :old_module.in_features].copy_(old_module.weight)
        
        for i in range(old_module.in_features, new_in):
            new_module.weight[:old_module.out_features, i].copy_(old_module.weight[:, i % old_module.in_features])
        
        for j in range(old_module.out_features, new_out):
            new_module.weight[j, :old_module.in_features].copy_(old_module.weight[j % old_module.out_features, :])
    
    if old_module.bias is not None:
        new_module.bias[:old_module.out_features].copy_(old_module.bias)
    
    return new_module


def expand_lora_padding(old_module, new_in, new_out):
    new_module = nn.Linear(new_in, new_out, bias=old_module.bias is not None)
    nn.init.zeros_(new_module.weight)

    with torch.no_grad():
        new_module.weight[:old_module.out_features, :old_module.in_features].copy_(old_module.weight)

    if old_module.bias is not None:
        new_module.bias[:old_module.out_features].copy_(old_module.bias)
    
    return new_module

def expand_lora_normal(old_module, new_in, new_out):
    new_module = nn.Linear(new_in, new_out, bias=old_module.bias is not None)
    
    # Initialize with Kaiming-uniform initialization
    nn.init.kaiming_uniform_(new_module.weight, a=math.sqrt(5))

    with torch.no_grad():
        new_module.weight[:old_module.out_features, :old_module.in_features].copy_(old_module.weight)

    if old_module.bias is not None:
        new_module.bias[:old_module.out_features].copy_(old_module.bias)

    return new_module

def expand_lora_noop_normal(old_module, new_in, new_out):
    new_module = nn.Linear(new_in, new_out, bias=old_module.bias is not None)
    
    # Initialize with Kaiming-uniform initialization
    nn.init.kaiming_uniform_(new_module.weight, a=math.sqrt(5))

    return new_module

def expand_lora_noop_zero(old_module, new_in, new_out):
    new_module = nn.Linear(new_in, new_out, bias=old_module.bias is not None)
    
    # Initialize with zeros
    nn.init.zeros_(new_module.weight)

    return new_module

In [42]:
# Get the LoRA adapter for larger model

config_lora = LoraConfig(
    r=64, 
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["query_key_value"],
    bias="none",
    task_type="CAUSAL_LM"
)
model_large_ft = get_peft_model(model_large_pt, config_lora)

In [43]:
model_large_ft

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GPTNeoXForCausalLM(
      (gpt_neox): GPTNeoXModel(
        (embed_in): Embedding(50304, 1024)
        (emb_dropout): Dropout(p=0.0, inplace=False)
        (layers): ModuleList(
          (0-23): 24 x GPTNeoXLayer(
            (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (post_attention_dropout): Dropout(p=0.0, inplace=False)
            (post_mlp_dropout): Dropout(p=0.0, inplace=False)
            (attention): GPTNeoXAttention(
              (rotary_emb): GPTNeoXRotaryEmbedding()
              (query_key_value): lora.Linear(
                (base_layer): Linear(in_features=1024, out_features=3072, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (defaul

In [44]:
# specify the feature sizes and expansion method here


for name, module in model_small_ft.named_modules():
    if name.endswith("lora_A"):
        new_module = expand_lora_padding(module.default, new_in_a, new_out_a)
        parts = name.split('.')
        parent_module = model_large_ft
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)
        
        setattr(parent_module, parts[-1], nn.ModuleDict({"default": new_module}))

    elif name.endswith("lora_B"):
        new_module = expand_lora_padding(module.default, new_in_b, new_out_b)
        parts = name.split('.')
        parent_module = model_large_ft
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)
        
        setattr(parent_module, parts[-1], nn.ModuleDict({"default": new_module}))

In [45]:
model_large_ft

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GPTNeoXForCausalLM(
      (gpt_neox): GPTNeoXModel(
        (embed_in): Embedding(50304, 1024)
        (emb_dropout): Dropout(p=0.0, inplace=False)
        (layers): ModuleList(
          (0-23): 24 x GPTNeoXLayer(
            (input_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (post_attention_layernorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
            (post_attention_dropout): Dropout(p=0.0, inplace=False)
            (post_mlp_dropout): Dropout(p=0.0, inplace=False)
            (attention): GPTNeoXAttention(
              (rotary_emb): GPTNeoXRotaryEmbedding()
              (query_key_value): lora.Linear(
                (base_layer): Linear(in_features=1024, out_features=3072, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (defaul

In [46]:
# save the expanded model

model_large_ft.save_pretrained(expanded_model)

In [24]:
# f"./weight/pythia_{args.small_model}_{args.large_model}_{args.expand_method}_r=64_schedule/"