In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
import os
import pandas as pd
import torch
from dataset_preprocessing import TokenInfo
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import itertools
import pandas as pd
from tqdm import tqdm

## Importances

In [4]:
def get_importances():
    # print("this is wrong")
    dir = "./new_importances_data"
    imp_files = os.listdir(dir)
    imp_files = [file for file in imp_files if file.endswith(".pkl")]
    importances = {}
    for imp_file in tqdm(imp_files):
        importances.update(pd.read_pickle(f"{dir}/{imp_file}"))
    return importances

In [5]:
# imps = get_importances()

In [6]:
def get_avg_imporances(importances):
    avg_imps = [torch.zeros_like(imp) for imp in list(importances.values())[0]]
    for token, imps in tqdm(importances.items()):
        for i, layer_imps in enumerate(imps):
            avg_imps[i] += layer_imps / len(importances)
    # TODO think harder about averaging method
    return avg_imps

In [7]:
# avg_importances = get_avg_imporances(imps)

In [8]:
# pd.to_pickle(avg_importances, "./avg_importances.pkl")

In [9]:
avg_importances = pd.read_pickle("./avg_importances.pkl")

In [10]:
len(avg_importances)

24

## Model

In [11]:
model_id = "microsoft/phi-1_5"
model_revision = "349cf8b5e81fd5f791d1740da5de1313a0419bbd" # latest as of feb 1st

In [12]:
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

In [13]:
vocab = tokenizer.get_vocab()
len(vocab)

50295

In [14]:
# tokenizer.decode(token_info.get_prefixes(top_tokens[1000][0], 9, 10)[0])

In [15]:
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    revision=model_revision,
    trust_remote_code=True,
    # be careful with this?
    # torch_dtype=torch.float16,
    # attn_implementation="flash_attention_2",
)

## New Dataset

In [16]:
from other_datasets import get_minipile, get_c4, get_wikitext2_filtered, get_bookcorpus, get_alpaca, QADataCollator, to_dataset
from dataset import get_baseline_dataset



In [17]:
alpaca_train, alpaca_eval = get_alpaca(tokenizer, do_split=True)
# eval datasets
# tiny_text = get_baseline_dataset()["test"]
# c4 = get_c4(n=2000, do_split=False)
minipile = get_minipile(n=2000, do_split=False)
wikitext = get_wikitext2_filtered(n=2000, do_split=False)
bookcorpus = get_bookcorpus(n=2000, do_split=False)

[32m2024-03-06 14:49:47.112[0m | [1mINFO    [0m | [36mother_datasets[0m:[36m__init__[0m:[36m309[0m - [1mMean length of tokens per window: 111.64232[0m
[32m2024-03-06 14:49:48.736[0m | [1mINFO    [0m | [36mother_datasets[0m:[36m__init__[0m:[36m309[0m - [1mMean length of tokens per window: 109.2675[0m


## Metric Callback

In [18]:
from transformers import TrainerCallback

In [19]:
from evaluation import evaluate_on_nlp_tasks

In [20]:
class AccEvalCallback(TrainerCallback):
    def __init__(self):
        super().__init__()
        self.last_step=-1
    
    def on_evaluate(self, args, state, control, model, **kwargs):
        if state.global_step == self.last_step:
            return
        self.last_step = state.global_step
        train = model.training
        model.eval()
        with torch.no_grad():
            os.environ["TQDM_DISABLE"] = "1"
            eval_res = evaluate_on_nlp_tasks(model, tokenizer, limit=1000)["results"]
            # import pdb; pdb.set_trace()
            eval_res = {k:v["acc,none"] for k,v in eval_res.items()}
            for k, v in eval_res.items():
                state.log_history.append(
                    {
                        k:v,
                        "epoch":state.epoch,
                        "step":state.global_step,
                    }
                )
            del os.environ['TQDM_DISABLE']
            print(eval_res)
        model.train(train)

## Train model

In [21]:
from peft import LoraConfig, PeftConfig
import transformers

In [22]:
from post_training import get_lora_config, get_training_arguments
from other_datasets import SFTTrainer_
from trl import SFTTrainer

In [23]:
lora_config = get_lora_config()
training_arguments = get_training_arguments("./tmp")

In [24]:
model.cuda();

In [25]:
model.config.use_cache = False
model.config.pretraining_tp = 1
model.gradient_checkpointing_enable()

In [26]:
# train_data, eval_data = minipile["train"], minipile["test"]
train_data, eval_data = alpaca_train, alpaca_eval
eval_datasets = {
    "alpaca":alpaca_eval,
    "minipile":minipile,
    # "c4":c4,
    "wikitext":wikitext,
    # "tiny_text":tiny_text,
    "bookcorpus":bookcorpus,
}

In [27]:
callbacks = [AccEvalCallback()]

In [28]:
tokenizer.pad_token = tokenizer.eos_token

In [29]:
training_arguments.save_strategy="no"

In [30]:
training_arguments.eval_steps = 100

In [31]:
trainer = SFTTrainer_(
    model=model,
    train_dataset=train_data,
    eval_dataset=eval_datasets,
    # eval_dataset={k:v.select(range(0, 100)) for k, v in eval_datasets.items()},
    # eval_dataset=eval_datasets["minipile"],
    peft_config=lora_config,
    tokenizer=tokenizer,
    args=training_arguments,
    packing=False,
    dataset_text_field="text",
    max_seq_length=1024, # tweak this
    # TODO: think harder about the datacollator
    # data_collator=transformers.DataCollatorForSeq2Seq(
    #     tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
    # ),
    callbacks=callbacks,
    data_collator=QADataCollator(tokenizer),
)

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]

Map:   0%|          | 0/2000 [00:00<?, ? examples/s]



In [35]:
trainer.evaluate(eval_datasets["alpaca"])

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
 77%|██████████████████████████▉        | 7697/10000 [08:14<02:09, 17.72it/s]IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [36]:
_

{'eval_loss': 1.7173835039138794,
 'eval_runtime': 49.9415,
 'eval_samples_per_second': 40.047,
 'eval_steps_per_second': 5.006}

In [38]:
trainer.state.log_history

[{'eval_loss': 1.7173835039138794,
  'eval_runtime': 49.9415,
  'eval_samples_per_second': 40.047,
  'eval_steps_per_second': 5.006,
  'step': 0},
 {'hellaswag': 0.46, 'epoch': None, 'step': 0},
 {'piqa': 0.766, 'epoch': None, 'step': 0},
 {'boolq': 0.762, 'epoch': None, 'step': 0},
 {'winogrande': 0.72, 'epoch': None, 'step': 0}]

In [None]:
model.eval();
with torch.no_grad():
    eval_res = evaluate_on_nlp_tasks(model, tokenizer, limit=1000)
model.train()

In [85]:
eval_res["results"]

{'hellaswag': {'acc,none': 0.348,
  'acc_norm,none': 0.443,
  'alias': 'hellaswag'},
 'piqa': {'acc,none': 0.671, 'acc_norm,none': 0.665, 'alias': 'piqa'},
 'boolq': {'acc,none': 0.625, 'alias': 'boolq'},
 'winogrande': {'acc,none': 0.522, 'alias': 'winogrande'}}

## Prune Model

In [16]:
from prunners import prune_mlps_individually
from importances import get_mlps

In [17]:
mlps = get_mlps(model)

In [18]:
len(mlps), len(avg_importances)

(24, 24)

In [19]:
avg_importances = dict(zip(mlps, avg_importances))

In [20]:
prune_mlps_individually(avg_importances, 0.5)

In [21]:
model

PhiForCausalLM(
  (model): PhiModel(
    (embed_tokens): Embedding(51200, 2048)
    (embed_dropout): Dropout(p=0.0, inplace=False)
    (layers): ModuleList(
      (0-23): 24 x PhiDecoderLayer(
        (self_attn): PhiAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (k_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (v_proj): Linear(in_features=2048, out_features=2048, bias=True)
          (dense): Linear(in_features=2048, out_features=2048, bias=True)
          (rotary_emb): PhiRotaryEmbedding()
        )
        (mlp): PhiMLP(
          (activation_fn): NewGELUActivation()
          (fc1): Linear(in_features=2048, out_features=4096, bias=True)
          (fc2): Linear(in_features=4096, out_features=2048, bias=True)
        )
        (input_layernorm): LayerNorm((2048,), eps=1e-05, elementwise_affine=True)
        (resid_dropout): Dropout(p=0.0, inplace=False)
      )
    )
    (final_layernorm): LayerNorm((2048,), e