In [1]:
from llmfact import LayerOutputExtractor, FBNFeatureExtractor, GroupFBNFeatureExtractor, FBNExtractor, LLMFC
from llmfact.extractor import MutiLayerAnalysis, MutiLayerAnalysis2
from llmfact.extractor import SingleLayerAnalysis
from llmfact.mask import MaskedGPT2ForSequenceClassification, MaskedGPT2AmplifiedForSequenceClassification, MaskedGPT2LMModel, MaskedModel
from transformers import GPT2Model, GPT2Config, GPT2LMHeadModel, GPT2ForSequenceClassification, Trainer, TrainingArguments
from transformers import GPT2Tokenizer
from transformers import AutoTokenizer, AutoModel, AutoConfig, AutoModelForCausalLM, AutoModelForQuestionAnswering
from datasets import load_dataset, get_dataset_config_names, get_dataset_split_names
from torch.utils.data import DataLoader
# from rouge_score import rouge_scorer
from evaluate import load

from llmfact.utils import IoU, correlation_activation, thresholding, write_layer_txt, evaluate_iou
from llmfact.pruner.llama import LayerBiasCompute
from llmfact.stat import  StatICA, StatDictionaryLearning
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.cluster import KMeans
from sklearn.linear_model import LogisticRegression
from sklearn.manifold import TSNE
from sklearn.decomposition import FastICA
import seaborn as sns
import pandas as pd
from tqdm.auto import tqdm
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '4,5,6'
os.environ["TOKENIZERS_PARALLELISM"]  = "true"

In [2]:
model_name = "lmsys/vicuna-7b-v1.5"
tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side='left')
model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
include_layers = []
for name, _ in model.named_modules():
    if "mlp.down" in name:
        include_layers.append(name)
include_layers

['model.layers.0.mlp.down_proj',
 'model.layers.1.mlp.down_proj',
 'model.layers.2.mlp.down_proj',
 'model.layers.3.mlp.down_proj',
 'model.layers.4.mlp.down_proj',
 'model.layers.5.mlp.down_proj',
 'model.layers.6.mlp.down_proj',
 'model.layers.7.mlp.down_proj',
 'model.layers.8.mlp.down_proj',
 'model.layers.9.mlp.down_proj',
 'model.layers.10.mlp.down_proj',
 'model.layers.11.mlp.down_proj',
 'model.layers.12.mlp.down_proj',
 'model.layers.13.mlp.down_proj',
 'model.layers.14.mlp.down_proj',
 'model.layers.15.mlp.down_proj',
 'model.layers.16.mlp.down_proj',
 'model.layers.17.mlp.down_proj',
 'model.layers.18.mlp.down_proj',
 'model.layers.19.mlp.down_proj',
 'model.layers.20.mlp.down_proj',
 'model.layers.21.mlp.down_proj',
 'model.layers.22.mlp.down_proj',
 'model.layers.23.mlp.down_proj',
 'model.layers.24.mlp.down_proj',
 'model.layers.25.mlp.down_proj',
 'model.layers.26.mlp.down_proj',
 'model.layers.27.mlp.down_proj',
 'model.layers.28.mlp.down_proj',
 'model.layers.29.mlp.do

In [4]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05

In [5]:
wiki_dataset = load_dataset("Self-GRIT/wikitext-2-raw-v1-preprocessed", split='train')
print(wiki_dataset)

Dataset({
    features: ['text'],
    num_rows: 15313
})


In [6]:
normal_components = np.load("./data/FBN/text3200-mlp.act-CanICA-SingleICA-max_iter-300_vicuna-7b-v1.5-muti-layer-wise_128_normal_mixing_std_True.npy")
normal_components.shape

(9984, 704512)

In [6]:
def cut_par_num(neuron_num_list):
    total_par = 6738415616
    print("total parameters:", total_par)
    
    total_mlp = 32 * (4096 * 11008 * 3)
    print("total mlp parameters:", total_mlp)

    total_cut = 0
    for i in neuron_num_list:
        cut_num = 4096 * 11008 * 3 - (i * 4096 * 3)
        total_cut += cut_num
    print("total cut parameters num:", total_cut)

    print(f"total cut mlp parameters: {total_cut / total_mlp:.4f}")
    print(f"total cut parameters: {total_cut / total_par:.4f}")
    print(f"parameters after cut: {total_par - total_cut:.4f}")

In [8]:
any_mask = np.abs(normal_components) > 5.98
any_mask = np.any(any_mask, axis=0).reshape(1, -1)
print(any_mask.sum())

mask = any_mask.reshape(32, 2, -1)
mask_matrix = np.ones((32, 11008))
for i in range(3, mask.shape[0] - 2):
    mask_matrix[i] = np.any(mask[i], axis=0)
print(mask_matrix.sum())
print(mask_matrix.sum(axis=1))
cut_par_num(mask_matrix.sum(axis=1))
# mask_matrix = np.repeat(mask_matrix, 2, axis=0)
mask_matrix.shape

283466
247204.0
[11008. 11008. 11008.  4913.  5599.  6116.  6687.  7346.  7303.  7405.
  7558.  7618.  7614.  7685.  7729.  7668.  8015.  7769.  7727.  7345.
  7426.  7331.  7481.  7237.  7329.  6992.  6733.  6717.  6400.  6421.
 11008. 11008.]
total parameters: 6738415616
total mlp parameters: 4328521728
total cut parameters num: 1290878976.0
total cut mlp parameters: 0.2982
total cut parameters: 0.1916
parameters after cut: 5447536640.0000


(32, 11008)

In [9]:
mask_matrix = np.array(mask_matrix, dtype=np.bool)

In [7]:
def remove_hooks(module):
    if hasattr(module, "_forward_hooks"):
        module._forward_hooks.clear()
    if hasattr(module, "_backward_hooks"):
        module._backward_hooks.clear()
    for child in module.children():
        remove_hooks(child)

# 调用此函数以清理整个模型的钩子
remove_hooks(model)

In [11]:
del normal_components
import gc
gc.collect()

69

In [7]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [15]:
mask_matrix.sum()

np.int64(247204)

In [18]:
np.save("./data/FBN/vicuna-7B-mask_0.2_wiki_2000+_threshold_5.98.npy", mask_matrix)

In [8]:
mask_matrix = np.load("./data/FBN/vicuna-7B-mask_0.2_wiki_2000+_threshold_5.98.npy")
mask_matrix = np.array(mask_matrix, dtype=np.bool)

In [9]:
remove_hooks(model)
add_bias = LayerBiasCompute(model, include_layers, tokenizer, ~mask_matrix, wiki_dataset['text'][:3000], 32)
add_bias.fit()

100%|██████████| 3000/3000 [34:31<00:00,  1.45it/s]


In [10]:
add_bias.bias_dict[6]

tensor([ 0.1001,  0.1137,  0.0005,  ..., -0.0333,  0.0772,  0.0722],
       device='cuda:0')

In [11]:
add_bias.model.model.layers[6].mlp.down_proj.bias

Parameter containing:
tensor([ 0.1001,  0.1137,  0.0005,  ..., -0.0333,  0.0772,  0.0722],
       device='cuda:0', requires_grad=True)

In [12]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 4096, padding_idx=0)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (v_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (up_proj): Linear(in_features=4096, out_features=11008, bias=False)
          (down_proj): Linear(in_features=11008, out_features=4096, bias=True)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)

In [13]:
from llmfact.pruner.pruner import PrunedLlamaModel
pruner = PrunedLlamaModel(model, mask_matrix)
model = pruner.fit()

total parameters before pruned: 6738546688
total parameters after pruned: 5447536640
total cut num: 1291010048
pruned rate: 0.1916


In [14]:
from lm_eval import evaluator
import lm_eval
wrapper_model = lm_eval.models.huggingface.HFLM(pretrained=model, trust_remote_code=True)

`pretrained` model kwarg is not of type `str`. Many other model arguments may be ignored. Please do not launch via accelerate or use `parallelize=True` if passing an existing model this way.
Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration


In [15]:
results = evaluator.simple_evaluate( 
    model=wrapper_model,
    model_args="lmsys/vicuna-7b-v1.5",
    tasks=["wikitext"],
    num_fewshot=0,
    task_manager=lm_eval.tasks.TaskManager(),
    batch_size=1)
results['results']

[Task: wikitext] metric word_perplexity is defined, but aggregation is not. using default aggregation=weighted_perplexity
[Task: wikitext] metric word_perplexity is defined, but higher_is_better is not. using default higher_is_better=False
[Task: wikitext] metric byte_perplexity is defined, but aggregation is not. using default aggregation=weighted_perplexity
[Task: wikitext] metric byte_perplexity is defined, but higher_is_better is not. using default higher_is_better=False
[Task: wikitext] metric bits_per_byte is defined, but aggregation is not. using default aggregation=bits_per_byte
[Task: wikitext] metric bits_per_byte is defined, but higher_is_better is not. using default higher_is_better=False
Overwriting default num_fewshot of wikitext from None to 0
100%|██████████| 62/62 [00:00<00:00, 399.36it/s]
  0%|          | 0/62 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (5943 > 4096). Running this sequence through t

{'wikitext': {'alias': 'wikitext',
  'word_perplexity,none': 16.366720485858323,
  'word_perplexity_stderr,none': 'N/A',
  'byte_perplexity,none': 1.6866186555790748,
  'byte_perplexity_stderr,none': 'N/A',
  'bits_per_byte,none': 0.7541338171961662,
  'bits_per_byte_stderr,none': 'N/A'}}

In [16]:
results = evaluator.simple_evaluate( 
    model=wrapper_model,
    tasks=["piqa", "hellaswag", "winogrande", "openbookqa", "arc_easy", "arc_challenge"],
    num_fewshot=0,
    task_manager=lm_eval.tasks.TaskManager(),
    batch_size=1)
results['results']

Overwriting default num_fewshot of arc_challenge from None to 0
Overwriting default num_fewshot of arc_easy from None to 0
Overwriting default num_fewshot of openbookqa from None to 0
Overwriting default num_fewshot of winogrande from None to 0
Overwriting default num_fewshot of hellaswag from None to 0
Overwriting default num_fewshot of piqa from None to 0
100%|██████████| 1172/1172 [00:01<00:00, 829.03it/s]
100%|██████████| 2376/2376 [00:02<00:00, 817.85it/s]
100%|██████████| 500/500 [00:00<00:00, 1662.19it/s]
100%|██████████| 1267/1267 [00:00<00:00, 85286.20it/s]
100%|██████████| 10042/10042 [00:07<00:00, 1355.82it/s]
100%|██████████| 1838/1838 [00:03<00:00, 471.59it/s]
Running loglikelihood requests: 100%|██████████| 62566/62566 [4:02:00<00:00,  4.31it/s]  


{'arc_challenge': {'alias': 'arc_challenge',
  'acc,none': 0.36689419795221845,
  'acc_stderr,none': 0.014084133118104419,
  'acc_norm,none': 0.39334470989761094,
  'acc_norm_stderr,none': 0.014275101465692932},
 'arc_easy': {'alias': 'arc_easy',
  'acc,none': 0.6839225589225589,
  'acc_stderr,none': 0.009540440071928223,
  'acc_norm,none': 0.6342592592592593,
  'acc_norm_stderr,none': 0.009882988069418874},
 'hellaswag': {'alias': 'hellaswag',
  'acc,none': 0.49741087432782316,
  'acc_stderr,none': 0.004989714512282005,
  'acc_norm,none': 0.6650069707229636,
  'acc_norm_stderr,none': 0.004710234188047076},
 'openbookqa': {'alias': 'openbookqa',
  'acc,none': 0.306,
  'acc_stderr,none': 0.020629569998345414,
  'acc_norm,none': 0.392,
  'acc_norm_stderr,none': 0.02185468495561119},
 'piqa': {'alias': 'piqa',
  'acc,none': 0.7301414581066377,
  'acc_stderr,none': 0.010356595421852079,
  'acc_norm,none': 0.735038084874864,
  'acc_norm_stderr,none': 0.010296557993316033},
 'winogrande': {'