In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys
sys.path.append("../LLM-Pruner/")

In [3]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from trl import SFTTrainer
import itertools
import pandas as pd



In [4]:
import os
from os import listdir

In [5]:
from prunning_utils import get_mlps

In [6]:
import copy
from datasets import load_dataset

In [7]:
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

## Model and Tokenizer Setup

In [8]:
model_id = "microsoft/phi-1_5"
model_revision = "349cf8b5e81fd5f791d1740da5de1313a0419bbd" # latest as of feb 1st

In [9]:
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

In [10]:
vocab = tokenizer.get_vocab()
len(vocab)

50295

In [11]:
# tokenizer.decode(token_info.get_prefixes(top_tokens[1000][0], 9, 10)[0])

In [12]:
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    revision=model_revision,
    trust_remote_code=True,
    # be careful with this?
    # torch_dtype=torch.float16,
    # attn_implementation="flash_attention_2",
)

In [13]:
# model = model.cuda()

In [14]:
model

PhiForCausalLM(
  (model): PhiModel(
    (embed_tokens): Embedding(51200, 2048)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-23): 24 x PhiDecoderLayer(
        (self_attn): PhiAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (dense): Linear(in_features=2048, out_features=2048, bias=True)
          (rotary_emb): PhiRotaryEmbedding()
        )
        (mlp): PhiMLP(
          (activation_fn): NewGELUActivation()
          (fc1): Linear(in_features=2048, out_features=8192, bias=True)
          (fc2): Linear(in_features=8192, out_features=2048, bias=True)
        )
        (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (resid_dropout): Dropout(p=0.0, inplace=False)
      )
    )
    (final_layernorm): LayerNorm((2048,), e

## Dataset

In [15]:
from LLMPruner.datasets.example_samples import get_examples

In [16]:
# examples = get_examples("bookcorpus", tokenizer, n_samples=10)
# pd.to_pickle(examples.cpu(), "./examples.pkl")
examples = pd.read_pickle("./examples.pkl")

In [17]:
# examples

In [18]:
# llm_prunner_dataset = load_dataset("yahma/alpaca-cleaned")

In [19]:
# llm_prunner_dataset["train"][6]

In [20]:
# train_val = data["train"].train_test_split(
#     test_size=2000 shuffle=True, seed=42
# )
# train_data = (
#     train_val["train"].shuffle().map(generate_and_tokenize_prompt)
# )
# val_data = {
#     args.data_path: train_val["test"].shuffle().map(generate_and_tokenize_prompt),
# }

In [1]:
import numpy as np
from datasets import Dataset
import os

  from .autonotebook import tqdm as notebook_tqdm


In [22]:
def get_baseline_dataset(filename="./baseline_dataset.pkl"):
    if os.path.isfile(filename):
        print("reading pickle")
        return pd.read_pickle(filename)
    dataset = load_dataset("nampdn-ai/tiny-textbooks")
    
    np.random.seed(123)
    data_idxs = np.random.permutation(np.arange(len(dataset["train"])))[:52000]
    
    train_data = dataset["train"]
    train_data_pd = train_data.to_pandas()
    train_data_pd = train_data_pd.iloc[data_idxs]
    train_data_pd = train_data_pd.reset_index(drop=True)
    dataset = Dataset.from_pandas(train_data_pd)
    dataset = dataset.train_test_split(test_size=2000, shuffle=True, seed=123)
    pd.to_pickle(dataset, "./baseline_dataset.pkl")
    return dataset

## Prunning

In [23]:
def get_lm_prunner_style_importances(model):
    mlps = get_mlps(model)
    imps = {}
    imps_list = pd.read_pickle("./imps_llm_prunner_style.pkl")
    for mlp, imp in zip(mlps, imps_list):
        imps[mlp] = imp
    return imps

In [24]:
@torch.no_grad()
def prune_mlp(mlp, importances, prune_ratio):
    sorted_imps_idx = torch.argsort(importances) # sorts from least to most important
    num_prune_cells = int(sorted_imps_idx.shape[0] * prune_ratio)
    keep_cells = sorted_imps_idx[num_prune_cells:]
    keep_cells = torch.sort(keep_cells).values

    fc1 = mlp.fc1
    dtype = fc1.weight.dtype
    fc1_pruned = torch.nn.Linear(fc1.weight.shape[1], keep_cells.shape[0], dtype=dtype)
    with torch.no_grad():
        fc1_pruned.weight.data = torch.clone(fc1.weight[keep_cells])
        fc1_pruned.bias.data = torch.clone(fc1.bias[keep_cells])

    fc2 = mlp.fc2
    fc2_pruned = torch.nn.Linear(keep_cells.shape[0], fc2.weight.shape[0], dtype=dtype)
    with torch.no_grad():
        fc2_pruned.weight.data = torch.clone(fc2.weight[:, keep_cells])
    
    mlp.fc1 = fc1_pruned
    mlp.fc2 = fc2_pruned

In [25]:
lm_prunner_style_imps = get_lm_prunner_style_importances(model)

In [26]:
for mlp, imp in lm_prunner_style_imps.items():
    prune_mlp(mlp, imp, 0.2)

In [27]:
model

PhiForCausalLM(
  (model): PhiModel(
    (embed_tokens): Embedding(51200, 2048)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-23): 24 x PhiDecoderLayer(
        (self_attn): PhiAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (dense): Linear(in_features=2048, out_features=2048, bias=True)
          (rotary_emb): PhiRotaryEmbedding()
        )
        (mlp): PhiMLP(
          (activation_fn): NewGELUActivation()
          (fc1): Linear(in_features=2048, out_features=6554, bias=True)
          (fc2): Linear(in_features=6554, out_features=2048, bias=True)
        )
        (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (resid_dropout): Dropout(p=0.0, inplace=False)
      )
    )
    (final_layernorm): LayerNorm((2048,), e

In [28]:
model = model.cuda()

In [29]:
with torch.no_grad():
    res = model(examples.cuda(), labels=examples.cuda())

## LoRA training

In [30]:
from peft import LoraConfig, PeftConfig
import transformers

In [31]:
lora_config = LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules=[
        'fc1', # re-train prunned layers for now
        'fc2',
    ],
    bias="none",
    lora_dropout=0.05,
    task_type="CAUSAL_LM",
)

In [32]:
model.config.use_cache = False
model.config.pretraining_tp = 1
model.gradient_checkpointing_enable()

In [33]:
# Orig params:
# batch_size = 64
# micro_batch_size = 4
batch_size = 60
micro_batch_size = 6
gradient_accumulation_steps = batch_size // micro_batch_size

training_arguments = transformers.TrainingArguments(
    per_device_train_batch_size=micro_batch_size,
    gradient_accumulation_steps=gradient_accumulation_steps,
    warmup_steps=100,
    num_train_epochs=2,
    learning_rate=1e-4,
    fp16=True,
    logging_steps=10,
    logging_first_step=True,
    optim="adamw_torch",
    evaluation_strategy="steps",
    save_strategy="steps",
    eval_steps=100,
    save_steps=200,
    output_dir="./baseline_models/lm_prunner_style/",
    save_total_limit=20,
    load_best_model_at_end=True,
    ddp_find_unused_parameters=None,
    group_by_length=False,
    # metric_for_best_model="{}_loss".format(args.data_path),
)

In [34]:
dataset = get_baseline_dataset()
train_data, eval_data = dataset["train"], dataset["test"]

reading pickle


In [35]:
tokenizer.pad_token = tokenizer.eos_token

In [36]:
trainer = SFTTrainer(
    model=model,
    train_dataset=train_data,
    eval_dataset=eval_data,
    peft_config=lora_config,
    tokenizer=tokenizer,
    args=training_arguments,
    packing=False,
    dataset_text_field="text",
    max_seq_length=1024, # tweak this
    # TODO: think harder about the datacollator
    # data_collator=transformers.DataCollatorForSeq2Seq(
    #     tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
    # ),
)

Map:   0%|          | 0/50000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.


In [37]:
trainer.model_wrapped.print_trainable_parameters()  

trainable params: 3,303,168 || all params: 1,260,512,624 || trainable%: 0.26204957706159393


In [39]:
trainer.evaluate()

You're using a CodeGenTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'eval_loss': 3.571244716644287,
 'eval_runtime': 169.9131,
 'eval_samples_per_second': 11.771,
 'eval_steps_per_second': 1.471}

In [None]:
train_res = trainer.train()

Step,Training Loss,Validation Loss


In [None]:
pd.to_pickle(train_res, "./baseline_models/lm_prunner_style/training_res_lm_prunner_style.pkl")

## Misc

In [78]:
trainer.train()

ValueError: Attempting to unscale FP16 gradients.

In [68]:
mlp.fc1 = fc1_pruned

In [69]:
mlp

PhiMLP(
  (activation_fn): NewGELUActivation()
  (fc1): Linear(in_features=6144, out_features=8192, bias=True)
  (fc2): Linear(in_features=8192, out_features=2048, bias=True)
)

In [58]:
fc2.weight[:, keep_cells]

tensor([[-0.0084,  0.0056, -0.0038,  ..., -0.0163, -0.0076, -0.0128],
        [ 0.0177, -0.0159,  0.0009,  ..., -0.0150, -0.0069, -0.0016],
        [-0.0303, -0.0123,  0.0026,  ..., -0.0058,  0.0014, -0.0098],
        ...,
        [ 0.0236, -0.0162,  0.0163,  ...,  0.0057, -0.0237, -0.0002],
        [ 0.0253, -0.0157, -0.0003,  ...,  0.0082,  0.0202,  0.0056],
        [-0.0006,  0.0209, -0.0300,  ..., -0.0006, -0.0326,  0.0050]],
       dtype=torch.float16, grad_fn=<IndexBackward0>)

In [None]:
fc2_pruned = torch.nn.Linear(fc2.weight.shape[1], fc2.weight.shape[0])

In [35]:
fc1_pruned.weight.shape

torch.Size([8192, 2048])

In [44]:
with torch.no_grad():
    fc1_pruned.bias = torch.nn.Parameter(torch.clone(fc1_pruned.bias[keep_cells]))

In [45]:
with torch.no_grad():
    fc1_pruned.weight = torch.nn.Parameter(torch.clone(fc1_pruned.weight[keep_cells]))

In [51]:
fc1_pruned = torch.nn.Linear(fc1.weight.shape[1], fc1.weight.shape[0])
with torch.no_grad():
    fc1_pruned.weight.data = torch.clone(fc1.weight)
    fc1_pruned.bias.data = torch.clone(fc1.bias)

In [None]:
fc2_pruned = torch.nn.Linear(fc2.weight.shape[1], fc2.weight.shape[0])
with torch.no_grad():
    fc1_pruned.weight.data = torch.clone(fc1.weight)
    fc1_pruned.bias.data = torch.clone(fc1.bias)

In [54]:
with torch.no_grad():
    fc1_pruned.weight.data = torch.clone(fc1.weight)
    fc1_pruned.bias.data = torch.clone(fc1.bias)

In [None]:
fc1P

In [None]:
with torch.no_grad():
    fc1_pruned.bieas.data = torch.clone(fc1.weight)

In [38]:
fc1_pruned.weight = copy.deepcopy(fc1_pruned.weight[keep_cells])

torch.Size([6144, 2048])

In [95]:
mlps[0]??

[0;31mSignature:[0m       [0mmlps[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m([0m[0;34m*[0m[0margs[0m[0;34m,[0m [0;34m**[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0;31mType:[0m            PhiMLP
[0;31mString form:[0m    
PhiMLP(
  (activation_fn): NewGELUActivation()
  (fc1): Linear(in_features=2048, out_features=8192, bias=True)
  (fc2): Linear(in_features=8192, out_features=2048, bias=True)
)
[0;31mFile:[0m            ~/.cache/huggingface/modules/transformers_modules/microsoft/phi-1_5/349cf8b5e81fd5f791d1740da5de1313a0419bbd/modeling_phi.py
[0;31mSource:[0m         
[0;32mclass[0m [0mPhiMLP[0m[0;34m([0m[0mnn[0m[0;34m.[0m[0mModule[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m    [0;32mdef[0m [0m__init__[0m[0;34m([0m[0mself[0m[0;34m,[0m [0mconfig[0m[0;34m)[0m[0;34m:[0m[0;34m[0m
[0;34m[0m        [0msuper[0m[0;34m([0m[0;34m)[0m[0;34m.[0m[0m__init__[0m[0;34m([0m[0;34m)[0m[0;34m[0m
[0;34m[0m        

In [89]:
fc1_pruned.weight.shape

torch.Size([8192, 2048])

In [None]:
for 

In [None]:
def prune_mlp(mlp):
    # modification happens inplace
    

In [76]:
mlps[0].fc1 = 

Linear(in_features=2048, out_features=8192, bias=True)

In [24]:
imps = get_avg_computed_importances("./importances")

In [25]:
len(imps)

10000

In [32]:
list(imps.values())[0][0]

torch.Size([8192])