In [1]:
import re
def _compile_target_modules(target_modules):
    if target_modules is None:
        return None
    if isinstance(target_modules, str):
        target_modules = [target_modules]
    
    compiled_patterns = []
    for pattern in target_modules:
        if pattern.startswith("re:"):
            # If the pattern starts with 're:', compile it as a regular expression
            compiled_patterns.append(re.compile(pattern[3:], re.IGNORECASE))
        else:
            # Compile it to match the pattern anywhere in the string
            escaped_pattern = re.escape(pattern).replace(r'\*', '.*')
            compiled_patterns.append(re.compile(f".*{escaped_pattern}.*", re.IGNORECASE))
    
    return compiled_patterns

In [2]:
target_modules = ["q_proj.weight", "k_proj.weight", "v_proj.weight", "o_proj.weight", "gate_proj.weight", "up_proj.weight", "down_proj.weight"]
compiled_patterns = _compile_target_modules(target_modules)

In [3]:
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen2-0.5B-Instruct",
    torch_dtype="auto",
    device_map="auto",
    cache_dir="./cache"
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct", use_fast=True, cache_dir="./cache")

In [4]:
total_target_params = 0
for name, param in model.named_parameters():
    if any(pattern.search(name) for pattern in compiled_patterns):
        print(name)
        total_target_params += param.numel()
print(f"Total number of candidate parameters for pruning: {total_target_params:,}")

model.layers.0.self_attn.q_proj.weight
model.layers.0.self_attn.k_proj.weight
model.layers.0.self_attn.v_proj.weight
model.layers.0.self_attn.o_proj.weight
model.layers.0.mlp.gate_proj.weight
model.layers.0.mlp.up_proj.weight
model.layers.0.mlp.down_proj.weight
model.layers.1.self_attn.q_proj.weight
model.layers.1.self_attn.k_proj.weight
model.layers.1.self_attn.v_proj.weight
model.layers.1.self_attn.o_proj.weight
model.layers.1.mlp.gate_proj.weight
model.layers.1.mlp.up_proj.weight
model.layers.1.mlp.down_proj.weight
model.layers.2.self_attn.q_proj.weight
model.layers.2.self_attn.k_proj.weight
model.layers.2.self_attn.v_proj.weight
model.layers.2.self_attn.o_proj.weight
model.layers.2.mlp.gate_proj.weight
model.layers.2.mlp.up_proj.weight
model.layers.2.mlp.down_proj.weight
model.layers.3.self_attn.q_proj.weight
model.layers.3.self_attn.k_proj.weight
model.layers.3.self_attn.v_proj.weight
model.layers.3.self_attn.o_proj.weight
model.layers.3.mlp.gate_proj.weight
model.layers.3.mlp.up_

In [5]:
from datasets import load_dataset
from multiprocessing import cpu_count

raw_dataset = load_dataset(
    "HuggingFaceH4/ultrachat_200k", 
    cache_dir="./cache",
    trust_remote_code=True,
    split="train_sft"
)

In [6]:
def apply_chat_template(messages, tokenizer):
    messages["text"] = tokenizer.apply_chat_template(messages["messages"], tokenize=False)
    return messages

chat_dataset = raw_dataset.map(
    apply_chat_template, 
    fn_kwargs={"tokenizer": tokenizer},
    batched=True, 
    batch_size=1000, 
    num_proc=cpu_count(),
    desc="Applying chat template",
    remove_columns=raw_dataset.column_names
)

In [7]:
def tokenize(messages, tokenizer):
    return tokenizer(messages["text"])

tokenized_dataset = chat_dataset.map(
    tokenize,
    batched=True,
    fn_kwargs={"tokenizer": tokenizer},
    batch_size=1000,
    num_proc=cpu_count(),
    desc="Tokenizing",
    remove_columns=chat_dataset.column_names
)

In [8]:
from trl import DataCollatorForCompletionOnlyLM

# this only works for Qwen2 with system prompt
instruction_template = "\n<|im_start|>user\n"
response_template = "\n<|im_start|>assistant\n"

collator = DataCollatorForCompletionOnlyLM(
    instruction_template=instruction_template, 
    response_template=response_template, 
    tokenizer=tokenizer, 
    mlm=False
)

In [13]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    tokenized_dataset,
    collate_fn=collator,
    batch_size=8,
    shuffle=False,
    pin_memory=True
)

In [14]:
for batch in train_dataloader:
    print(batch["input_ids"].shape)
    break

torch.Size([8, 3026])
