## Import Package

In [1]:
from __future__ import absolute_import, division, print_function

import pprint
import argparse
import logging
import os
import random
import sys
import pickle
import copy
import collections
import math

import numpy as np
import numpy
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler,TensorDataset

from torch.nn import CrossEntropyLoss, MSELoss

from transformer import BertForSequenceClassification,WEIGHTS_NAME, CONFIG_NAME
from transformer.modeling_quant import BertForSequenceClassification as QuantBertForSequenceClassification
from transformer import BertTokenizer
from transformer import BertAdam
from transformer import BertConfig
from transformer import QuantizeLinear, BertSelfAttention, FP_BertSelfAttention, BertAttention, FP_BertAttention
from utils_glue import *

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch.nn.functional as F

mse_func = MSELoss()

processors = {
    "cola": ColaProcessor,
    "mnli": MnliProcessor,
    "mnli-mm": MnliMismatchedProcessor,
    "mrpc": MrpcProcessor,
    "sst-2": Sst2Processor,
    "sts-b": StsbProcessor,
    "qqp": QqpProcessor,
    "qnli": QnliProcessor,
    "rte": RteProcessor   
}

output_modes = {
        "cola": "classification",
        "mnli": "classification",
        "mrpc": "classification",
        "sst-2": "classification",
        "sts-b": "regression",
        "qqp": "classification",
        "qnli": "classification",
        "rte": "classification"
}

default_params = {
        "cola": {"max_seq_length": 64,"batch_size":1,"eval_step": 50}, # No Aug : 50 Aug : 400
        "mnli": {"max_seq_length": 128,"batch_size":1,"eval_step":8000},
        "mrpc": {"max_seq_length": 128,"batch_size":1,"eval_step":100},
        "sst-2": {"max_seq_length": 64,"batch_size":1,"eval_step":100},
        "sts-b": {"max_seq_length": 128,"batch_size":1,"eval_step":100},
        "qqp": {"max_seq_length": 128,"batch_size":1,"eval_step":1000},
        "qnli": {"max_seq_length": 128,"batch_size":1,"eval_step":1000},
        "rte": {"max_seq_length": 128,"batch_size":1,"eval_step": 20}
    }

def get_tensor_data(output_mode, features):
    if output_mode == "classification":
        all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
    elif output_mode == "regression":
        all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float)


    all_seq_lengths = torch.tensor([f.seq_length for f in features], dtype=torch.long)
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    tensor_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,all_label_ids, all_seq_lengths)
    return tensor_data, all_label_ids

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0 
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


11/11 08:56:01 AM Note: detected 78 virtual cores but NumExpr set to maximum of 64, check "NUMEXPR_MAX_THREADS" environment variable.
11/11 08:56:01 AM Note: NumExpr detected 78 cores but "NUMEXPR_MAX_THREADS" not set, so enforcing safe limit of 8.


Error.  nthreads cannot be larger than environment variable "NUMEXPR_MAX_THREADS" (64)

## Dataset & Model Setting

In [2]:
task_name = "sst-2"
bert_size = "large"

if bert_size == "large":
    layer_num = 24
    head_num = 16
else: 
    layer_num = 12
    head_num = 12

model_dir = "models/BERT_large/sst-2"

## Prepare Dataset

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Processor & Task Info
processor = processors[task_name]()
output_mode = output_modes[task_name]
label_list = processor.get_labels()
num_labels = len(label_list)

if task_name in default_params:
    batch_size = default_params[task_name]["batch_size"]
    max_seq_length = default_params[task_name]["max_seq_length"]
    eval_step = default_params[task_name]["eval_step"]
    
# Tokenizer
tokenizer = BertTokenizer.from_pretrained(model_dir, do_lower_case=True)

# Load Dataset
data_dir = os.path.join("data",task_name)

eval_examples = processor.get_dev_examples(data_dir)
eval_features = convert_examples_to_features(eval_examples, label_list, max_seq_length, tokenizer, output_mode)

eval_data, eval_labels = get_tensor_data("classification", eval_features)
eval_sampler = RandomSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=1)
eval_data, eval_labels = get_tensor_data(output_mode, eval_features)

# Get input batch sample
batch = next(iter(eval_dataloader))
input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch
seq_length = seq_lengths[0]

11/11 08:56:01 AM Writing example 0 of 872
11/11 08:56:01 AM *** Example ***
11/11 08:56:01 AM guid: dev-1
11/11 08:56:01 AM tokens: [CLS] it ' s a charming and often affecting journey . [SEP]
11/11 08:56:01 AM input_ids: 101 2009 1005 1055 1037 11951 1998 2411 12473 4990 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
11/11 08:56:01 AM input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
11/11 08:56:01 AM segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
11/11 08:56:01 AM label: 1
11/11 08:56:01 AM label_id: 1


## Build Model

In [4]:
# Build Model
student_model_M_dir = "output/BERT_large/sst-2/exploration/sst-2_large_map_large_M_42"
student_config = BertConfig.from_pretrained(student_model_M_dir)             
student_model_M = QuantBertForSequenceClassification.from_pretrained(student_model_M_dir, config = student_config, num_labels=num_labels)
print("Student Model - Attention Map KD-QAT")

Student Model - Attention Map KD-QAT


## Hessian Analysis

In [6]:
from hessian import hessian

if output_mode == "classification":
    loss_fct = CrossEntropyLoss()
elif output_mode == "regression":
    loss_fct = MSELoss()

tc_max_eigens = []
model = student_model_M.to(device)
teacher_model = None

for batch in tqdm(eval_dataloader):
    input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch
    hessian_comp = hessian(model, data=(input_ids, label_ids), criterion=loss_fct, cuda=True, input_zip = (input_ids, segment_ids, input_mask), teacher_model=teacher_model, kd_type=None)
    top_eigenvalues, top_eigenvector = hessian_comp.eigenvalues(top_n=3)
    tc_max_eigens = tc_max_eigens + top_eigenvalues

## Save Hessian Max Eigen Values

In [None]:
pt_folder_name = "hessian_value_pts"
file_name = f"{bert_size}_{task_name}_{hessian}.pt"}

if not os.path.exists(pt_folder_name):
    os.mkdir(pt_folder_name)
    
file_dir = os.path.join(pt_folder_name, file_name)

print(file_dir)
try:
    torch.save(tc_max_eigens, file_dir)
except:
    import pdb; pdb.set_trace()
print("==> Model Eigen Value DONE!")

## Plot Hessian Eigenvalue Spectra (Figure. 3)

In [None]:
import torch
import seaborn as sns
import matplotlib.pyplot as plt

eigens_1 = torch.load("hessian_value_pts/cola_sarq_init.pt")
eigens_2 =torch.load("hessian_value_pts/cola_ternary_init.pt")

eigens_1_pos = []
eigens_1_neg = []
for eigen in eigens_1:
    if eigen > 0:
        eigens_1_pos.append(eigen)
    else:
        eigens_1_neg.append(eigen)
    
eigens_2_pos = []
eigens_2_neg = []
for eigen in eigens_2:
    if eigen > 0:
        eigens_2_pos.append(eigen)
    else:
        eigens_2_neg.append(eigen)
        
fs = 13
lw = 2.5

fig, axes = plt.subplots(1, 2, figsize=(10,3.4), dpi=200)

color_1 = "tab:blue"
color_2 = "navy"
color_3 = "darkblue"
color_4 = "tab:red"

# pos = [pos_1, pos_2]
# Plot
pos_1 = sns.kdeplot(eigens_1_pos, color=color_1, label=eigens_1_name, linewidth=lw, ax=axes[1])
pos_2 = sns.kdeplot(eigens_2_pos, color=color_2, label=eigens_2_name, linewidth=lw, ax=axes[1])
neg_1 = sns.kdeplot(eigens_1_neg, color=color_1, label=eigens_1_name_2, linewidth=lw, ax=axes[0])
neg_2 = sns.kdeplot(eigens_2_neg, color=color_2, label=eigens_2_name_2, linewidth=lw, ax=axes[0])

# Font Size

pos_1.get_yaxis().set_visible(False)
pos_1.tick_params(axis='x', labelsize=fs)
neg_1.tick_params(axis='x', labelsize=fs)
neg_1.tick_params(axis='y', labelsize=fs)
neg_1.set_ylabel(ylabel="Density", fontsize = fs+2)
pos_1.set_xlabel(xlabel="Positive Max Eigenvalue", fontsize = fs+2)
neg_1.set_xlabel(xlabel="Negative Max Eigenvalue", fontsize = fs+2)
axes[1].legend(fontsize = fs, loc=1)
axes[0].legend(fontsize = fs, loc=2)
axes[1].set_ylim(0, 0.025)
fig.tight_layout()