In [34]:
from transformers import GPTNeoXForCausalLM
from peft import PeftModel, LoraConfig, get_peft_model
import torch
import torch.nn as nn

In [35]:
# first get the finetuned small model

small_name = "EleutherAI/pythia-70m"
small_adapter = "weight/pythia_70m_mora_r=64"

model_small_pt = GPTNeoXForCausalLM.from_pretrained(small_name)
model_small_ft = PeftModel.from_pretrained(model_small_pt, small_adapter)

In [36]:
model_small_ft

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GPTNeoXForCausalLM(
      (gpt_neox): GPTNeoXModel(
        (embed_in): Embedding(50304, 512)
        (emb_dropout): Dropout(p=0.0, inplace=False)
        (layers): ModuleList(
          (0-5): 6 x GPTNeoXLayer(
            (input_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (post_attention_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (post_attention_dropout): Dropout(p=0.0, inplace=False)
            (post_mlp_dropout): Dropout(p=0.0, inplace=False)
            (attention): GPTNeoXSdpaAttention(
              (rotary_emb): GPTNeoXRotaryEmbedding()
              (query_key_value): lora.Linear(
                (base_layer): Linear(in_features=512, out_features=1536, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default)

In [37]:
# config_mora = LoraConfig(
#     use_mora=True, 
#     mora_type=6,  # RoPE for small rank
#     r=64, 
#     target_modules=["query_key_value"], 
#     lora_dropout=0.05, 
#     task_type="CAUSAL_LM"
#     # MoRA does not use lora_alpha
# )
# model = GPTNeoXForCausalLM.from_pretrained(
#     "EleutherAI/pythia-410m"
# )
# model = get_peft_model(model, config_mora)

In [38]:
# expand module

def expand_mora_copy(old_module, new_in, new_out):
    new_module = nn.Linear(new_in, new_out, bias=old_module.bias is not None)
    nn.init.zeros_(new_module.weight)
    
    with torch.no_grad():
        new_module.weight[:old_module.in_features, :old_module.out_features].copy_(old_module.weight)
        
        for i in range(old_module.in_features, new_in):
            new_module.weight[i, :old_module.out_features].copy_(old_module.weight[i % old_module.in_features, :])
        
        for j in range(old_module.out_features, new_out):
            new_module.weight[:old_module.in_features, j].copy_(old_module.weight[:, j % old_module.out_features])
    
    if old_module.bias is not None:
        new_module.bias[:old_module.out_features].copy_(old_module.bias)
    
    return new_module


def expand_mora_padding(old_module, new_in, new_out):
    new_module = nn.Linear(new_in, new_out, bias=old_module.bias is not None)
    nn.init.zeros_(new_module.weight)

    with torch.no_grad():
        new_module.weight[:old_module.in_features, :old_module.out_features].copy_(old_module.weight)

    if old_module.bias is not None:
        new_module.bias[:old_module.out_features].copy_(old_module.bias)
    
    return new_module

In [39]:
new_in, new_out = 512, 512

for name, module in model_small_ft.named_modules():
    if name.endswith("lora_A") or name.endswith("lora_B"):
        new_module = expand_mora_padding(module.default, new_in, new_out)
        parts = name.split('.')
        parent_module = model_small_ft
        for part in parts[:-1]:
            parent_module = getattr(parent_module, part)
        
        setattr(parent_module, parts[-1], nn.ModuleDict({"default": new_module}))

In [40]:
# save the expanded model

model_small_ft.save_pretrained("weight/pythia_70m_mora_expanded_padding_r=64")

In [41]:
model_small_ft

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): GPTNeoXForCausalLM(
      (gpt_neox): GPTNeoXModel(
        (embed_in): Embedding(50304, 512)
        (emb_dropout): Dropout(p=0.0, inplace=False)
        (layers): ModuleList(
          (0-5): 6 x GPTNeoXLayer(
            (input_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (post_attention_layernorm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
            (post_attention_dropout): Dropout(p=0.0, inplace=False)
            (post_mlp_dropout): Dropout(p=0.0, inplace=False)
            (attention): GPTNeoXSdpaAttention(
              (rotary_emb): GPTNeoXRotaryEmbedding()
              (query_key_value): lora.Linear(
                (base_layer): Linear(in_features=512, out_features=1536, bias=True)
                (lora_dropout): ModuleDict(
                  (default): Dropout(p=0.05, inplace=False)
                )
                (lora_A): ModuleDict(
                  (default)