# Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
import os
import pandas as pd
import torch
from dataset_preprocessing import TokenInfo
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import itertools
import pandas as pd
from tqdm import tqdm
import copy
import numpy as np

In [2]:
def get_r(prune_ratio):
    gamma=prune_ratio
    h = 8192
    i = 2048
    keep_cost = 4*(1-gamma)*h*i + 2*(1-gamma)*h + i
    reg_cost = 4*0.5*h*i + 2*0.5*h + i

    r = (reg_cost - keep_cost - (2*gamma*h + i)) / (4*(gamma*h + i))
    return r

In [3]:
PRUNE_RATIO = 0.72
RANK = int(get_r(PRUNE_RATIO))

In [4]:
PRUNE_RATIO, RANK

(0.72, 464)

# Importances

## Compute

In [2]:
def get_importances():
    # print("this is wrong")
    dir = "./new_importances_data2" # pick the right dir
    imp_files = os.listdir(dir)
    imp_files = [file for file in imp_files if file.endswith(".pkl")]
    importances = {}
    for imp_file in tqdm(imp_files):
        importances.update(pd.read_pickle(f"{dir}/{imp_file}"))
    return importances

In [22]:
def get_avg_imporances(importances):
    avg_imps = [torch.zeros_like(imp) for imp in list(importances.values())[0][0]]
    ttl = sum(token[2] for token in importances.keys())
    for token, imps in tqdm(importances.items()):
        imps = imps[0]
        for i, layer_imps in enumerate(imps):
            avg_imps[i] += layer_imps * (token[2] / ttl)
    return avg_imps

In [23]:
# avg_importances = get_avg_imporances(imps)

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 49700/49700 [00:56<00:00, 880.25it/s]


In [24]:
# pd.to_pickle(avg_importances, "./new_weighted_avg_importances.pkl")

## Read

In [None]:
avg_importances = pd.read_pickle("./avg_importances.pkl")

In [None]:
len(avg_importances)

# Model

In [None]:
model_id = "microsoft/phi-1_5"
model_revision = "349cf8b5e81fd5f791d1740da5de1313a0419bbd" # latest as of feb 1st
tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
vocab = tokenizer.get_vocab()

model = AutoModelForCausalLM.from_pretrained(
    model_id,
    revision=model_revision,
    trust_remote_code=True,
    # be careful with this?
    # torch_dtype=torch.float16,
    # attn_implementation="flash_attention_2",
)

# Low rank matrix initialization

In [None]:
@torch.no_grad()
def get_mlp_loras_simple(mlp, most_imp_cells, rank=64):
    print("get_mlp_loras_simple")
    # assumes descending order of importances in most_imp_cells
    fc1_mat = mlp.fc1.weight[most_imp_cells[:rank]]
    fc1_bias = mlp.fc1.bias[most_imp_cells] # bias is full dimension
    fc2_mat = mlp.fc2.weight[:, most_imp_cells[:rank]]
    # no need of fc2 bias
    
    fc1_A = fc1_mat
    fc1_B = torch.concat((torch.eye(rank), torch.zeros(len(most_imp_cells) - rank, rank)), dim=0)
    
    fc2_B = fc2_mat
    fc2_A = torch.concat((torch.eye(rank), torch.zeros(rank, len(most_imp_cells) - rank)), dim=1)
    # fc2_A = fc2_mat
    # fc2_B = torch.concat((torch.eye(rank), torch.zeros(mlp.fc2.out_features - rank, rank)), dim=0)
    return fc1_A.clone(), fc1_B.clone(), fc1_bias.clone(), fc2_B.clone(), fc2_A.clone()


In [None]:
from functools import lru_cache
class SVDRes():
    def __init__(self, U, S, V):
        self.U = U
        self.S = S
        self.V = V

def get_svd(tens):
    # print("running svd")
    tens = tens.cuda()
    res = torch.svd(tens)
    return SVDRes(res.U.cpu(), res.S.cpu(), res.V.cpu())

In [None]:
@torch.no_grad()
def get_mlp_loras_svd(mlp, most_imp_cells, rank=64, init_bias=True):
    print("get_mlp_loras_svd")
    # assumes descending order of importances in most_imp_cells
    fc1_mat = mlp.fc1.weight[most_imp_cells]
    if init_bias:
        fc1_bias = mlp.fc1.bias[most_imp_cells]
    else:
        assert False, "not implemented yet"
    fc2_mat = mlp.fc2.weight[:, most_imp_cells]
    # no need of fc2 bias

    fc1_svd = get_svd(fc1_mat)
    fc2_svd = get_svd(fc2_mat)
    
    fc1_B = fc1_svd.U[:, :rank] @ torch.diag(fc1_svd.S[:rank])
    fc1_A = fc1_svd.V.T[:rank]

    fc2_B = fc2_svd.U[:, :rank] @ torch.diag(fc2_svd.S[:rank])
    fc2_A = fc2_svd.V.T[:rank]
    return fc1_A.clone(), fc1_B.clone(), fc1_bias.clone(), fc2_B.clone(), fc2_A.clone()

orig_model = copy.deepcopy(model)

In [None]:
@torch.no_grad()
def initialize_loras(orig_model, model, init_cells, lora_func=get_mlp_loras_simple):
    for layer_i, (orig_layer, layer) in tqdm(enumerate(zip(orig_model.model.layers, model.model.layers))):
        num_experts = len(init_cells[layer_i])
        for expert_i in range(num_experts): # 
            most_imp_cells = init_cells[layer_i][expert_i]
            fc1_A, fc1_B, fc1_bias, fc2_B, fc2_A = lora_func(orig_layer.mlp, most_imp_cells, RANK)

            # import pdb; pdb.set_trace()
            layer.mlp.experts_fc1[expert_i].orig_lora.lora_A.default.weight.data = fc1_A
            layer.mlp.experts_fc1[expert_i].orig_lora.lora_B.default.weight.data = fc1_B
            layer.mlp.experts_fc1[expert_i].lora_bias.data = fc1_bias
            
            layer.mlp.experts_fc2[expert_i].orig_lora.lora_A.default.weight.data = fc2_A
            layer.mlp.experts_fc2[expert_i].orig_lora.lora_B.default.weight.data = fc2_B
            # no bias needed for fc2

In [None]:
from prunners import prune_mlps_individually
from importances import get_mlps

# Prune Model

In [None]:
mlps = get_mlps(model)
avg_importances = dict(zip(mlps, avg_importances))
pruned_cells = prune_mlps_individually(avg_importances, PRUNE_RATIO)

# svd init cells per expert (only 1 expert in this case)
svd_init_cells = [[p] for p in pruned_cells]

# Dataset

In [None]:
from other_datasets import get_minipile, get_c4, get_wikitext2_filtered, get_bookcorpus, get_alpaca, QADataCollator, to_dataset
from dataset import get_baseline_dataset
tiny_text = get_baseline_dataset()
alpaca = get_alpaca(tokenizer, n=2000, do_split=False)

# Callbacks

In [None]:
from transformers import TrainerCallback
from evaluation import evaluate_on_nlp_tasks

In [None]:
class AccEvalCallback(TrainerCallback):
    def __init__(self):
        super().__init__()
        self.last_step=-1

    def on_evaluate(self, args, state, control, model, **kwargs):
        if state.global_step == self.last_step:
            return
        self.last_step = state.global_step
        train = model.training
        model.eval()
        with torch.no_grad():
            os.environ["TQDM_DISABLE"] = "1"
            eval_res = evaluate_on_nlp_tasks(model, tokenizer, limit=100, do_shuffle=True)["results"]
            # import pdb; pdb.set_trace()
            eval_res = {k:v["acc,none"] for k,v in eval_res.items()}
            for k, v in eval_res.items():
                state.log_history.append(
                    {
                        k:v,
                        "epoch":state.epoch,
                        "step":state.global_step,
                    }
                )
            del os.environ['TQDM_DISABLE']
            print(eval_res)
        model.train(train)

class SaveCallback(TrainerCallback):
    def __init__(self, save_path):
        super().__init__()
        self.save_path = save_path
        self.last_step=-1

    def on_evaluate(self, args, state, control, model, **kwargs):
        if state.global_step == self.last_step:
            return
        self.last_step = state.global_step
        try:
            torch.save(model.state_dict(), self.save_path)
        except Exception as e:
            print(f"error saving {e}")

class EnableMLPBias(TrainerCallback):
    def on_init_end(self, args, state, control, model, **kwargs):
        for n, p in model.named_parameters():
            if "base_layer" in n and "bias" in n:
                p.requires_grad = True


# Replace modules and init them

In [None]:
from experts import Experts, EmbeddingTokenIdxTracker, mark_adapters_and_routers_as_trainable, prepare_as_if_peft_model, prepare_model_for_gradient_checkpointing
from importances import get_mlps
from post_training import get_lora_config, get_training_arguments

In [None]:
print(f"Using rank: {RANK}")
lora_config = get_lora_config(r=RANK)
training_arguments = get_training_arguments("./tmp")

training_arguments = prepare_as_if_peft_model(model, training_arguments, lora_config)

def get_layers(model):
    return model.get_submodule("model").get_submodule("layers")

In [None]:
layers = get_layers(model)

In [None]:
for i, layer in enumerate(layers):
    layer.mlp = Experts(
        model,
        layer.mlp,
        lora_config,
        i,
        layer.mlp.config,
        num_experts=1,
        cluster_init_router=False, # do not initialize mlp router
        use_improved_lora=True,
        lora_at_base_improved_lora=True,
    )

In [None]:
initialize_loras(orig_model, model, init_cells=svd_init_cells, lora_func=get_mlp_loras_svd)
mark_adapters_and_routers_as_trainable(model, lora_config)
prepare_model_for_gradient_checkpointing(model)
model.cuda();

In [None]:
model

# Train Model

In [None]:
from post_training import get_lora_config, get_training_arguments
from dataset import get_baseline_dataset
from trl import SFTTrainer
from peft import LoraConfig
import transformers
from trl import SFTTrainer
from other_datasets import SFTTrainer_

In [None]:
# Setup model for training
model.config.use_cache = False
model.config.pretraining_tp = 1
model.gradient_checkpointing_enable()

In [None]:
# train_data, eval_data = minipile["train"], minipile["test"]
train_data, eval_data = tiny_text["train"], tiny_text["test"]
eval_datasets = {
    "tiny_text":eval_data,
    "alpaca":alpaca,
    # "minipile":minipile,
    # "c4":c4,
    # "wikitext":wikitext,
    # "tiny_text":tiny_text,
    # "bookcorpus":bookcorpus,
}

In [None]:
# save_path = "./tmp/weighted_routing_better_lora_svd_init_moe50_model_state_dict"
callbacks = [AccEvalCallback(), EnableMLPBias()] #, SaveCallback(save_path)]

In [None]:
tokenizer.pad_token = tokenizer.eos_token
training_arguments.save_strategy="no"
training_arguments.eval_steps = 100

In [None]:
trainer = SFTTrainer_(
    model=model,
    train_dataset=train_data,
    eval_dataset=eval_datasets["tiny_text"],
    # peft_config=lora_config,
    tokenizer=tokenizer,
    args=training_arguments,
    packing=False,
    dataset_text_field="text",
    max_seq_length=1024, # tweak this
    # TODO: think harder about the datacollator
    # data_collator=transformers.DataCollatorForSeq2Seq(
    #     tokenizer, pad_to_multiple_of=8, return_tensors="pt", padding=True
    # ),
    callbacks=callbacks,
    data_collator=QADataCollator(tokenizer),
)

In [None]:
trainer.evaluate()

In [6]:
trainer.train()

You're using a CodeGenTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Step,Training Loss,Validation Loss
100,3.5486,3.47771
200,3.4221,3.383874
300,3.4143,3.342912
400,3.3082,3.306208
500,3.2573,3.284032
600,3.3291,3.263538
700,3.2569,3.24959
800,3.3135,3.23414
900,3.0531,3.22707
1000,3.0447,3.22073




will shuffle dataset


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
100%|█| 1000/1000 [01:04<00:00, 1


{'hellaswag': 0.32, 'piqa': 0.69, 'boolq': 0.62, 'winogrande': 0.68}


fatal: not a git repository (or any parent up to mount point /)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).


will shuffle dataset


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
100%|█| 1000/1000 [00:59<00:00, 1


{'hellaswag': 0.34, 'piqa': 0.69, 'boolq': 0.6, 'winogrande': 0.64}


fatal: not a git repository (or any parent up to mount point /)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).


will shuffle dataset


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
100%|█| 1000/1000 [01:00<00:00, 1


{'hellaswag': 0.32, 'piqa': 0.71, 'boolq': 0.62, 'winogrande': 0.63}


fatal: not a git repository (or any parent up to mount point /)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).


will shuffle dataset


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
 16%|▏| 162/1000 [00:09<00:50, 16IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



will shuffle dataset


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
100%|█| 1000/1000 [00:59<00:00, 1


{'hellaswag': 0.37, 'piqa': 0.71, 'boolq': 0.61, 'winogrande': 0.66}


fatal: not a git repository (or any parent up to mount point /)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).


will shuffle dataset


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
100%|█| 1000/1000 [00:59<00:00, 1


{'hellaswag': 0.36, 'piqa': 0.7, 'boolq': 0.62, 'winogrande': 0.61}


fatal: not a git repository (or any parent up to mount point /)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).


will shuffle dataset


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
100%|█| 1000/1000 [00:59<00:00, 1


{'hellaswag': 0.34, 'piqa': 0.74, 'boolq': 0.62, 'winogrande': 0.65}


fatal: not a git repository (or any parent up to mount point /)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).


will shuffle dataset


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
100%|█| 1000/1000 [00:59<00:00, 1


{'hellaswag': 0.36, 'piqa': 0.72, 'boolq': 0.62, 'winogrande': 0.65}


fatal: not a git repository (or any parent up to mount point /)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).


will shuffle dataset


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
100%|█| 1000/1000 [00:59<00:00, 1


{'hellaswag': 0.37, 'piqa': 0.72, 'boolq': 0.6, 'winogrande': 0.63}


fatal: not a git repository (or any parent up to mount point /)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).


will shuffle dataset


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
100%|█| 1000/1000 [00:59<00:00, 1


{'hellaswag': 0.36, 'piqa': 0.72, 'boolq': 0.58, 'winogrande': 0.68}


fatal: not a git repository (or any parent up to mount point /)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).


will shuffle dataset


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
100%|█| 1000/1000 [00:59<00:00, 1


{'hellaswag': 0.37, 'piqa': 0.72, 'boolq': 0.59, 'winogrande': 0.68}


fatal: not a git repository (or any parent up to mount point /)
Stopping at filesystem boundary (GIT_DISCOVERY_ACROSS_FILESYSTEM not set).


will shuffle dataset


You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this dataset from the next major release of `datasets`.
 24%|▏| 236/1000 [00:14<00:45, 16IOPub message rate exceeded.
The Jupyter server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--ServerApp.iopub_msg_rate_limit`.

Current values:
ServerApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
ServerApp.rate_limit_window=3.0 (secs)



In [7]:
trainer_state = trainer.state

In [8]:
valid_loss = pd.DataFrame(trainer_state.log_history)[["step", "eval_loss"]].set_index("step").dropna()

In [9]:
valid_loss

Unnamed: 0_level_0,eval_loss
step,Unnamed: 1_level_1
100,3.47771
200,3.383874
300,3.342912
400,3.306208
500,3.284032
600,3.263538
700,3.24959
800,3.23414
900,3.22707
1000,3.22073


# Evaluation

In [None]:
from evaluation import evaluate_on_nlp_tasks

In [None]:
model.eval();

In [None]:
with torch.no_grad():
    eval_res = evaluate_on_nlp_tasks(model, tokenizer, limit=300, do_shuffle=True)

In [None]:
eval_res["results"]

In [None]:
eval_res_orig = evaluate_on_nlp_tasks(model, tokenizer, limit=1000, bootstrap_iters=1000, do_shuffle=False)

In [None]:
eval_res_orig["results"]

In [None]:
eval_res = evaluate_on_nlp_tasks(model, tokenizer, limit=1000, bootstrap_iters=1000, do_shuffle=True)

In [None]:
eval_res["results"]

# Save

In [None]:
# model.cpu();

In [None]:
# torch.save(model.state_dict(), save_path)

# Stats

In [None]:
df = pd.DataFrame(trainer_state.log_history)

In [None]:
metrics_df = df[["step", "hellaswag", "piqa", "boolq", "winogrande"]]

In [None]:
metrics_df[["step", "boolq"]].dropna().set_index("step").plot()

In [None]:
metrics_df[["step", "hellaswag"]].dropna().set_index("step").plot()

In [None]:
metrics_df = df[["step", "piqa"]].dropna().set_index("step").plot()

In [None]:
metrics_df = df[["step", "winogrande"]].dropna().set_index("step").plot()

In [None]:
model.model.layers[0].mlp.experts_fc1[0].orig_lora.lora_B.default.weight.shape