In [1]:
from transformers import GPTNeoXForCausalLM
from peft import PeftModel, LoraConfig, get_peft_model
import torch
import torch.nn as nn

import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class args:
    small_model = "EleutherAI/pythia-410m"
    large_model = "EleutherAI/pythia-1.4b"
    small_adapter = "./models/raw/pythia_410m_r=8_0.0001_gsm8k"
    rank = 8
    expanded_model = "./models/expanded/pythia_410m_1.4b_r=8_0.0001_gsm8k"
    lora_alpha = 32


In [3]:
model_small = GPTNeoXForCausalLM.from_pretrained(args.small_model)
model_large = GPTNeoXForCausalLM.from_pretrained(args.large_model)

model_small.load_adapter(args.small_adapter)

In [12]:
# assert that all the lora_ weights are the same

first_layer_lora_A = model_small.gpt_neox.layers[0].attention.query_key_value.lora_A.default.weight
first_layer_lora_B = model_small.gpt_neox.layers[0].attention.query_key_value.lora_B.default.weight

print(f"First layer A: {first_layer_lora_A}")
print(f"First layer B: {first_layer_lora_B}")

for name, param in model_small.named_parameters():
    if "lora_" in name:
        if not torch.equal(param, first_layer_lora_A) and not torch.equal(param, first_layer_lora_B):
            print(f"Weights are not tied for {name}!")
            print(param)

First layer A: Parameter containing:
tensor([[-0.0450,  0.0096, -0.0191,  ..., -0.0397,  0.0226, -0.0111],
        [-0.0042,  0.0038, -0.0143,  ...,  0.0165,  0.0095,  0.0102],
        [ 0.0273,  0.0124,  0.0265,  ...,  0.0227,  0.0082, -0.0271],
        ...,
        [ 0.0103,  0.0151,  0.0282,  ...,  0.0160, -0.0006,  0.0154],
        [ 0.0276, -0.0378,  0.0052,  ...,  0.0096,  0.0358, -0.0144],
        [ 0.0390, -0.0173,  0.0081,  ...,  0.0212,  0.0357,  0.0103]])
First layer B: Parameter containing:
tensor([[ 0.0136,  0.0021, -0.0162,  ..., -0.0092,  0.0207,  0.0147],
        [-0.0034,  0.0105,  0.0107,  ...,  0.0013, -0.0084, -0.0162],
        [-0.0090, -0.0129, -0.0078,  ...,  0.0048,  0.0092, -0.0026],
        ...,
        [-0.0036,  0.0067,  0.0023,  ..., -0.0013,  0.0042, -0.0018],
        [ 0.0036,  0.0010, -0.0139,  ..., -0.0004,  0.0097,  0.0135],
        [-0.0016,  0.0130, -0.0097,  ..., -0.0022,  0.0044, -0.0125]])
Weights are not tied for gpt_neox.layers.1.attention.query

In [4]:
# copy with same rank

config = LoraConfig(
    r=args.rank, 
    lora_alpha=args.lora_alpha,
    lora_dropout=0.05,
    target_modules=["query_key_value"],
    bias="none",
    task_type="CAUSAL_LM"
)

model_large = get_peft_model(model_large, config)

In [5]:
# new_in_a, new_out_a = 1024, args.rank   # 512, 768, 1024
# new_in_b, new_out_b = args.rank, 3072  # ? , 2304, 3072


In [6]:
# check the in_features and out_features of qkv

model_large

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GPTNeoXForCausalLM(
      (gpt_neox): GPTNeoXModel(
        (embed_in): Embedding(50304, 2048)
        (emb_dropout): Dropout(p=0.0, inplace=False)
        (layers): ModuleList(
          (0-23): 24 x GPTNeoXLayer(
            (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
            (post_attention_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
            (post_attention_dropout): Dropout(p=0.0, inplace=False)
            (post_mlp_dropout): Dropout(p=0.0, inplace=False)
            (attention): GPTNeoXAttention(
              (rotary_emb): GPTNeoXRotaryEmbedding()
              (query_key_value): lora.Linear(
                (base_layer): Linear(in_features=2048, out_features=6144, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (defaul

In [7]:
# expand module

def expand_lora_copy(old_module, new_in, new_out):
    new_module = nn.Linear(new_in, new_out, bias=old_module.bias is not None)
    nn.init.zeros_(new_module.weight)
    
    with torch.no_grad():
        new_module.weight[:old_module.out_features, :old_module.in_features].copy_(old_module.weight)
        
        for i in range(old_module.in_features, new_in):
            new_module.weight[:old_module.out_features, i].copy_(old_module.weight[:, i % old_module.in_features])
        
        for j in range(old_module.out_features, new_out):
            new_module.weight[j, :old_module.in_features].copy_(old_module.weight[j % old_module.out_features, :])
    
    if old_module.bias is not None:
        new_module.bias[:old_module.out_features].copy_(old_module.bias)
    
    return new_module


def expand_lora_padding(old_module, new_in, new_out):
    new_module = nn.Linear(new_in, new_out, bias=old_module.bias is not None)
    nn.init.zeros_(new_module.weight)

    with torch.no_grad():
        new_module.weight[:old_module.out_features, :old_module.in_features].copy_(old_module.weight)

    if old_module.bias is not None:
        new_module.bias[:old_module.out_features].copy_(old_module.bias)
    
    return new_module

def expand_lora_normal(old_module, new_in, new_out):
    new_module = nn.Linear(new_in, new_out, bias=old_module.bias is not None)
    
    # Initialize with Kaiming-uniform initialization
    nn.init.kaiming_uniform_(new_module.weight, a=math.sqrt(5))

    with torch.no_grad():
        new_module.weight[:old_module.out_features, :old_module.in_features].copy_(old_module.weight)

    if old_module.bias is not None:
        new_module.bias[:old_module.out_features].copy_(old_module.bias)

    return new_module

def expand_lora_noop_normal(old_module, new_in, new_out):
    new_module = nn.Linear(new_in, new_out, bias=old_module.bias is not None)
    
    # Initialize with Kaiming-uniform initialization
    nn.init.kaiming_uniform_(new_module.weight, a=math.sqrt(5))

    return new_module

def expand_lora_noop_zero(old_module, new_in, new_out):
    new_module = nn.Linear(new_in, new_out, bias=old_module.bias is not None)
    
    # Initialize with zeros
    nn.init.zeros_(new_module.weight)

    return new_module

In [8]:
# # Get the LoRA adapter for larger model

# config_lora = LoraConfig(
#     r=64, 
#     lora_alpha=32,
#     lora_dropout=0.05,
#     target_modules=["query_key_value"],
#     bias="none",
#     task_type="CAUSAL_LM"
# )
# model_large_ft = get_peft_model(model_large_pt, config_lora)

In [9]:
# new_in_a = model_large.config.hidden_size

In [10]:
# specify the feature sizes and expansion method here


for name, module in model_small.named_modules():
    if name.endswith("lora_A"):
        new_module = expand_lora_padding(module.default, model_large.config.hidden_size, args.rank)
        parts = name.split('.')
        parent_module = model_large
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)
        
        setattr(parent_module, parts[-1], nn.ModuleDict({"default": new_module}))

    elif name.endswith("lora_B"):
        new_module = expand_lora_padding(module.default, args.rank, model_large.config.hidden_size * 3)
        parts = name.split('.')
        parent_module = model_large
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)
        
        setattr(parent_module, parts[-1], nn.ModuleDict({"default": new_module}))

In [None]:
# unwrapped_model = accelerator.unwrap_model(model)
#             unwrapped_model.save_pretrained(
#                 output,
#                 is_main_process=accelerator.is_main_process,
#                 save_function=accelerator.save,
#             )

In [14]:
# save the expanded model

model_large.save_pretrained(args.expanded_model)

In [24]:
# f"./weight/pythia_{args.small_model}_{args.large_model}_{args.expand_method}_r=64_schedule/"