In [1]:
import os
import sys
sys.path.append("../../../")
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
import copy
import torch
from datetime import datetime
from utils.helper import ModelConfig, color_print
from utils.dataset_utils.load_dataset import (
    load_data,
)
from utils.model_utils.load_model import load_model
from utils.model_utils.evaluate import evaluate_model, get_sparsity
from utils.dataset_utils.sampling import SamplingDataset
from utils.prune_utils.prune import LayerWrapper, find_layers

In [3]:
name = "YahooAnswersTopics"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

checkpoint = None
model_config = ModelConfig(name, device)
num_labels = model_config.config["num_labels"]

model, tokenizer, checkpoint = load_model(model_config)

Loading the model.
{'model_name': 'fabriceyhc/bert-base-uncased-yahoo_answers_topics', 'task_type': 'classification', 'architectures': 'bert', 'dataset_name': 'YahooAnswersTopics', 'num_labels': 10, 'cache_dir': 'Models'}
The model fabriceyhc/bert-base-uncased-yahoo_answers_topics is loaded.


In [4]:
train_dataloader, valid_dataloader, test_dataloader = load_data(
    name, batch_size=32, num_workers=48
)

{'dataset_name': 'YahooAnswersTopics', 'path': 'yahoo_answers_topics', 'config_name': 'yahoo_answers_topics', 'text_column': 'question_title', 'label_column': 'topic', 'cache_dir': 'Datasets/Yahoo', 'task_type': 'classification'}
Loading cached dataset YahooAnswersTopics.
The dataset YahooAnswersTopics is loaded


In [None]:
import pandas as pd
import seaborn as sns
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import copy
from tqdm import tqdm
import numpy as np
import torch
from torch.utils.data import DataLoader, SequentialSampler, TensorDataset, Subset
from torch.nn import CrossEntropyLoss, MSELoss
from functools import partial
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report

In [None]:
# batch size 1?

test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)

validation_dataset = TestDataset(validation_df, tokenizer)
validation_loader = DataLoader(validation_dataset, batch_size=1, shuffle=False)

In [None]:
class Args:
    def __init__(self):
        self.device = device
        self.local_rank = -1  # 단일 GPU 사용을 가정
        self.output_mode = "classification"  # 또는 "regression"에 따라 설정
        self.num_labels = model.config.num_labels
        self.dont_normalize_importance_by_layer = False
        self.dont_normalize_global_importance = False

args = Args()

In [None]:
def entropy(p):
    plogp = p * torch.log(p)
    plogp[p == 0] = 0
    return -plogp.sum(dim=-1)

In [None]:
per_class_importance_list = [torch.zeros(12, 12).to(args.device) for _ in range(10)]
per_class_token_list=[0.0 for _ in range(10)]
multihead_outputs_list = []  # 이 리스트에 각 layer의 출력을 저장합니다.

In [None]:
def hook_fn(module, input, output, layer_index):
    attention_value, attention_scores = output
    attention_value.requires_grad_(True)  # 그래디언트 계산을 위해 requires_grad를 True로 설정
    attention_value.retain_grad()
    multihead_outputs_list.append(attention_value)

In [None]:
def register_hooks(model):
    handles = []  # hook handles를 저장할 리스트
    for layer_index, layer in enumerate(model.bert.encoder.layer):
        handle = layer.attention.self.register_forward_hook(
            partial(hook_fn, layer_index=layer_index)
        )
        handles.append(handle)  # handle 저장
    return handles

In [None]:
def remove_hooks(handles):
    for handle in handles:
        handle.remove()

In [None]:
eval_dataloader = validation_loader  # 또는 validation_loader

In [None]:
# compute_entropy와 compute_importance는 True로 설정하여 기능을 활성화합니다.
compute_entropy = True
compute_importance = True

# 모든 헤드를 사용하여 importance score를 계산합니다.
head_mask = None

precision_list = [[0 for _ in range(11)] for _ in range(10)]
recall_list = [[0 for _ in range(11)] for _ in range(10)]
f1score_list = [[0 for _ in range(11)] for _ in range(10)]
global_accuracy_list = [[0 for _ in range(11)] for _ in range(10)]

total_precision_list = [[0 for _ in range(11)] for _ in range(10)]
total_recall_list = [[0 for _ in range(11)] for _ in range(10)]
total_f1score_list = [[0 for _ in range(11)] for _ in range(10)]
total_global_accuracy_list = [[0 for _ in range(11)] for _ in range(10)]

In [None]:
def compute_heads_importance(args, model, eval_dataloader, compute_entropy=True, compute_importance=True,
                             head_mask=None):
    # Prepare our tensors
    handles = register_hooks(model)
    n_layers, n_heads = model.bert.config.num_hidden_layers, model.bert.config.num_attention_heads
    head_importance = torch.zeros(n_layers, n_heads).to(args.device)
    each_pred_head_importance = torch.zeros(n_layers, n_heads).to(args.device)
    attn_entropy = torch.zeros(n_layers, n_heads).to(args.device)
    preds = None
    labels = None
    tot_tokens = 0.0

    for step, batch in enumerate(tqdm(eval_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
        batch = tuple(t.to(args.device) for t in batch)
        global count
        input_ids, label_ids = batch
        input_ids = {k: v.squeeze(1).to(device) for k, v in input_ids.items()}
        label_ids = label_ids.to(device)
        actual_batch_size = input_ids['input_ids'].size(0)

        # Do a forward pass (not with torch.no_grad() since we need gradients for importance score - see below)
        outputs = model(**input_ids, output_attentions=True)
        all_attentions = outputs[1]
        logits = outputs[0]

        if compute_entropy:
            # Update head attention entropy
            for layer, attn in enumerate(all_attentions):
                masked_entropy = entropy(attn.detach())
                attn_entropy[layer] += masked_entropy.sum(-1).sum(0).detach()

        if compute_importance:
            # Update head importance scores with regards to our loss
            # First, backpropagate to populate the gradients
            if args.output_mode == "classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, args.num_labels), label_ids.view(-1))
            elif args.output_mode == "regression":
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), label_ids.view(-1))
            loss.backward()

            # Second, compute importance scores according to http://arxiv.org/abs/1905.10650
            multihead_outputs = multihead_outputs_list
            for layer, mh_layer_output in enumerate(multihead_outputs):
                # print(layer)
                mh_layer_output_store = mh_layer_output
                reshaped_mh_layer_output = mh_layer_output_store.view(actual_batch_size, 512, 12, 64)
                reshaped_mh_layer_output = reshaped_mh_layer_output.permute(0, 2, 1, 3)

                mh_layer_output_grad = mh_layer_output.grad
                reshaped_mh_layer_output_grad = mh_layer_output_grad.view(actual_batch_size, 512, 12, 64)
                reshaped_mh_layer_output_grad = reshaped_mh_layer_output_grad.permute(0, 2, 1, 3)
                dot = torch.einsum("bhli,bhli->bhl", [reshaped_mh_layer_output_grad, reshaped_mh_layer_output])
                each_head_importance = dot.abs().sum(-1).sum(0).detach()
                head_importance[layer] += each_head_importance
                each_pred_head_importance[layer] += each_head_importance
            temp_each_pred_head_importance = copy.deepcopy(each_pred_head_importance)
            each_pred_head_importance.zero_()
            multihead_outputs_list.clear()

        # Also store our logits/labels if we want to compute metrics afterwards
        if preds is None:
            preds = logits.detach().cpu().numpy()
            labels = label_ids.detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            labels = np.append(labels, label_ids.detach().cpu().numpy(), axis=0)
        prediction = np.argmax(logits.detach().cpu().numpy(), axis=1)

        per_class_importance_list[prediction.item()] += temp_each_pred_head_importance
        per_class_token = (input_ids['input_ids'] != 0).float().sum().item()
        per_class_token_list[prediction.item()] = per_class_token
        tot_tokens += per_class_token

    # Normalize
    attn_entropy /= tot_tokens
    head_importance /= tot_tokens
    for i in range(10):
        per_class_importance_list[i] /= per_class_token_list[i]

    # Layerwise importance normalization
    if not args.dont_normalize_importance_by_layer:
        exponent = 2
        norm_by_layer = torch.pow(torch.pow(head_importance, exponent).sum(-1), 1 / exponent)
        head_importance /= norm_by_layer.unsqueeze(-1) + 1e-20
        for i in range(10):
            norm_by_layer = torch.pow(torch.pow(per_class_importance_list[i], exponent).sum(-1), 1 / exponent)
            per_class_importance_list[i] /= norm_by_layer.unsqueeze(-1) + 1e-20

    if not args.dont_normalize_global_importance:
        head_importance = (head_importance - head_importance.min()) / (head_importance.max() - head_importance.min())
        for i in range(10):
            per_class_importance_list[i] = (per_class_importance_list[i] - per_class_importance_list[i].min()) / (
                        per_class_importance_list[i].max() - per_class_importance_list[i].min())
    remove_hooks(handles)
    return attn_entropy, head_importance, preds, labels

In [None]:
def visualization_heatmap(file_name, array, num_layer, num_heads, label=0):
    # tensor를 CPU로 이동하여 numpy 배열로 변환
    array = array.cpu().numpy()

    df = pd.DataFrame(array)

    # 인덱스와 컬럼 이름 설정 (Layer와 Head의 인덱스를 1부터 시작하도록 조정)
    df.index = [f"Layer {i + 1}" for i in range(num_layer)]
    df.columns = [f"Head {i + 1}" for i in range(num_heads)]

    # 히트맵 생성 및 저장
    plt.figure(figsize=(12, 8))

    # Attention Score
    sns.heatmap(df, annot=True, fmt=".2f", cmap='viridis')
    if label == 0 :
        plt.title("Total head important score")
    else:
        plt.title(f'Class {label} head important score')
    plt.xlabel('Head')
    plt.ylabel('Layer')
    plt.savefig(f'./heatmap/head_importance_score/{file_name}.png')
    plt.close()

In [None]:
attn_entropy, head_importance, preds, labels = compute_heads_importance(args, model, eval_dataloader, compute_entropy, compute_importance, head_mask)

def calculate_prune_head(arr, i):
    # 2차원 배열 arr의 모든 요소와 해당 인덱스를 1차원 배열로 변환
    flattened_with_indices = [(value, index) for index, value in np.ndenumerate(arr)]

    # 값에 따라 오름차순 정렬하여 하위 12개 요소 선택
    sorted_by_value = sorted(flattened_with_indices, key=lambda x: x[0])
    bottom_12 = sorted_by_value[12 * i:12 * (i + 1)]

    # 하위 12개 요소의 인덱스만 추출
    bottom_12_indices = [index for _, index in bottom_12]

    return bottom_12_indices

for i in range(10):
    per_class_importance_list[i] = per_class_importance_list[i].cpu().numpy()

per_class_head_importance_list = copy.deepcopy(per_class_importance_list)

# layer별 max 값을 구함
def layer_max(arr):
    max_values = np.max(arr, axis=1)
    max_index = np.argmax(arr, axis=1)
    return max_index

def preprocess_prunehead(arr):
    for label in range(10):
        max_layer = layer_max(per_class_head_importance_list[label])
        for layer in range(12):
            head = max_layer[layer]
            per_class_head_importance_list[label][layer][head] = 100

def total_preprocess_prunehead(arr):
    max_layer = layer_max(arr)
    for layer in range(12):
            head = max_layer[layer]
            arr[layer][head] = 100
    return arr

def prune_head(model, prune_list):
    for layer_index, head_index in prune_list:
        model.bert.encoder.layer[layer_index].attention.prune_heads(([head_index]))
    return model

def print_prune_head_list(prune_list, trial):
    print(f"total prune number : {len(prune_list)*trial}")
    print(f"prune head list")
    print(prune_list)


def evaluating(model, class_index, prune_num):
    preds = []
    true_labels = []
    for batch in tqdm(test_loader, desc="Evaluating"):
        inputs, labels = batch
        inputs = {k: v.squeeze(1).to(device) for k, v in inputs.items()}
        labels = labels.to(device)

        with torch.no_grad():
            outputs = model(**inputs)
        prediction = outputs.logits.argmax(dim=-1)

        preds.extend(prediction.tolist())
        true_labels.extend(labels.tolist())

    report = classification_report(true_labels, preds, output_dict=True)
    index = str(class_index)
    class_report = report[index]
    precision_list[class_index][prune_num] = class_report['precision']
    recall_list[class_index][prune_num] = class_report['recall']
    f1score_list[class_index][prune_num] = class_report['f1-score']
    global_accuracy_list[class_index][prune_num] = report['accuracy']

    print(f"Class {class_index} Precision: {class_report['precision']}")
    print(f"Class {class_index} Recall: {class_report['recall']}")
    print(f"Class {class_index} F1-Score: {class_report['f1-score']}")
    print(f"Global Accuracy: {report['accuracy']}")
    print()

In [None]:
def head_importance_prunning():
  for class_index in range(10):
      temp_model = copy.deepcopy(model)
      for num in range(11):
          print(f'Class {class_index+1} {(num+1)*12} prunning')
          prune_list = calculate_prune_head(per_class_head_importance_list[class_index], num)
          print_prune_head_list(prune_list, num+1)
          temp_model = prune_head(temp_model, prune_list)
          evaluating(temp_model, class_index, num)

head_importance_prunning()

In [None]:
def evaluating_all(model, prune_num):
    preds = []
    true_labels = []
    for batch in tqdm(test_loader, desc="Evaluating"):
        inputs, labels = batch
        inputs = {k: v.squeeze(1).to(device) for k, v in inputs.items()}
        labels = labels.to(device)

        with torch.no_grad():
            outputs = model(**inputs)
        prediction = outputs.logits.argmax(dim=-1)

        preds.extend(prediction.tolist())
        true_labels.extend(labels.tolist())

    report = classification_report(true_labels, preds, output_dict=True)
    for i in range(10):
        index = str(i)
        class_report = report[index]
        total_precision_list[i][prune_num] = class_report['precision']
        total_recall_list[i][prune_num] = class_report['recall']
        total_f1score_list[i][prune_num] = class_report['f1-score']
        total_global_accuracy_list[i][prune_num] = report['accuracy']

        print(f"Class {i} Precision: {class_report['precision']}")
        print(f"Class {i} Recall: {class_report['recall']}")
        print(f"Class {i} F1-Score: {class_report['f1-score']}")
        print(f"Global Accuracy: {report['accuracy']}")
        print()

temp_head_importance_score = copy.deepcopy(head_importance).cpu().numpy()
temp_head_importance_score = total_preprocess_prunehead(temp_head_importance_score)

In [None]:
def total_head_importance_prunning():
    temp_model = copy.deepcopy(model)
    for num in range(11):
        print(f'Total {(num+1)*12} prunning')
        prune_list = calculate_prune_head(temp_head_importance_score, num)
        print_prune_head_list(prune_list, num+1)
        temp_model = prune_head(temp_model, prune_list)
        evaluating_all(temp_model, num)

total_head_importance_prunning()

In [5]:
i = 0
color_print("Start Time:" + datetime.now().strftime("%H:%M:%S"))
color_print("#Module " + str(i) + " in progress....")
num_samples = 64

positive_samples = SamplingDataset(
    train_dataloader, i, num_samples, num_labels, True, 4, device=device
)
negative_samples = SamplingDataset(
    train_dataloader, i, num_samples, num_labels, False, 4, device=device
)
all_samples = SamplingDataset(
    train_dataloader, 200, 20, num_labels, False, 4, device=device
)

Start Time:14:17:58
#Module 0 in progress....


In [6]:
print("Evaluate the original model")
# result = evaluate_model(model, model_config, test_dataloader)

Evaluate the original model


In [7]:
compare = copy.deepcopy(model)

In [10]:
from utils.prune_utils.prune import prune_magnitude
prune_magnitude(
    compare, include_layers=["attention", "intermediate", "output"], sparsity_ratio=0.1
)

In [11]:
from utils.prune_utils.prune import prune_concern_identification
prune_concern_identification(model, compare, positive_samples, sparsity_ratio=0.5, include_layers=["attention", "intermediate","output"])

In [13]:
result = evaluate_model(compare, model_config, test_dataloader)

Evaluating: 100%|██████████| 1875/1875 [11:13<00:00,  2.78it/s]


Loss: 1.1742
Precision: 0.6665, Recall: 0.6354, F1-Score: 0.6406
              precision    recall  f1-score   support

           0       0.62      0.47      0.53      6000
           1       0.78      0.52      0.63      6000
           2       0.76      0.66      0.70      6000
           3       0.46      0.54      0.50      6000
           4       0.79      0.79      0.79      6000
           5       0.94      0.72      0.82      6000
           6       0.39      0.47      0.43      6000
           7       0.66      0.64      0.65      6000
           8       0.50      0.83      0.63      6000
           9       0.75      0.72      0.73      6000

    accuracy                           0.64     60000
   macro avg       0.67      0.64      0.64     60000
weighted avg       0.67      0.64      0.64     60000



In [17]:
result = evaluate_model(compare, model_config, test_dataloader)
# result = evaluate_model(module, model_config, test_dataloader)

Evaluating: 100%|██████████| 1875/1875 [11:29<00:00,  2.72it/s]


Loss: 1.0199
Precision: 0.6792, Recall: 0.6760, F1-Score: 0.6742
              precision    recall  f1-score   support

           0       0.59      0.54      0.56      6000
           1       0.74      0.64      0.68      6000
           2       0.74      0.73      0.74      6000
           3       0.52      0.52      0.52      6000
           4       0.80      0.81      0.81      6000
           5       0.90      0.82      0.86      6000
           6       0.55      0.43      0.48      6000
           7       0.64      0.71      0.67      6000
           8       0.58      0.81      0.68      6000
           9       0.74      0.75      0.74      6000

    accuracy                           0.68     60000
   macro avg       0.68      0.68      0.67     60000
weighted avg       0.68      0.68      0.67     60000



In [19]:
get_sparsity(compare)

(0.39683034509882176,
 {'bert.encoder.layer.0.attention.self.query.weight': 0.4000006781684028,
  'bert.encoder.layer.0.attention.self.query.bias': 0.0,
  'bert.encoder.layer.0.attention.self.key.weight': 0.4000006781684028,
  'bert.encoder.layer.0.attention.self.key.bias': 0.0,
  'bert.encoder.layer.0.attention.self.value.weight': 0.4001786973741319,
  'bert.encoder.layer.0.attention.self.value.bias': 0.0,
  'bert.encoder.layer.0.attention.output.dense.weight': 0.4000006781684028,
  'bert.encoder.layer.0.attention.output.dense.bias': 0.0,
  'bert.encoder.layer.0.intermediate.dense.weight': 0.40000025431315106,
  'bert.encoder.layer.0.intermediate.dense.bias': 0.0,
  'bert.encoder.layer.0.output.dense.weight': 0.40000025431315106,
  'bert.encoder.layer.0.output.dense.bias': 0.0,
  'bert.encoder.layer.1.attention.self.query.weight': 0.4000006781684028,
  'bert.encoder.layer.1.attention.self.query.bias': 0.0,
  'bert.encoder.layer.1.attention.self.key.weight': 0.4000006781684028,
  'bert