In [37]:
from __future__ import absolute_import, division, print_function

import pprint
import argparse
import logging
import os
os.environ["CUDA_VISIBLE_DEVICES"]="3" # Set GPU Index to use
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
import random
import sys
import pickle
import copy
import collections
import math

import numpy as np
import numpy
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler,TensorDataset
# from torch.utils.tensorboard import SummaryWriter

from torch.nn import CrosnsEntropyLoss, MSELoss
from tqdm import tqdm
from transformer import BertForSequenceClassification,WEIGHTS_NAME, CONFIG_NAME
from transformer.modeling_quant import BertForSequenceClassification as QuantBertForSequenceClassification
from transformer import BertTokenizer
from transformer import BertAdam
from transformer import BertConfig
from transformer import QuantizeLinear, QuantizeAct, BertSelfAttention, FP_BertSelfAttention, ClipLinear
from utils_glue import *
from bertviz import model_view

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch.nn.functional as F

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0 
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def cv_initialize(model, loader, ratio, device):
    
    def initialize_hook(module, input, output):
        if isinstance(module, (QuantizeLinear, QuantizeAct, ClipLinear)):
            """KDLSQ-BERT ACT Quant init Method
            Ref: https://arxiv.org/abs/2101.05938
            """
            if not isinstance(input, torch.Tensor):
                input = input[0]
        
            n = torch.numel(input)
            input_sorted, index = torch.sort(input.reshape(-1), descending=False)
            
            index_min = torch.round(ratio * n / 2)
            index_max = n - index_min
            
            s_init = (input_sorted[int(index_min)].to(device), input_sorted[int(index_max)].to(device))
            
            # MATPLOT
            
            fig, [ax1, ax2, ax3] = plt.subplots(1,3, figsize=(16, 4))            
            
            sns.histplot(data=input.reshape(-1).detach().cpu().numpy(), kde = True, bins=100, ax=ax1)
            sns.rugplot(data=input.reshape(-1).detach().cpu().numpy(), ax=ax1)
            sns.histplot(data=module.weight.reshape(-1).detach().cpu().numpy(), kde = True, bins=100, ax=ax2)
            sns.rugplot(data=module.weight.reshape(-1).detach().cpu().numpy(), ax=ax2)
            sns.histplot(data=output.reshape(-1).detach().cpu().numpy(), kde = True, bins=100, ax=ax3)
            sns.rugplot(data=output.reshape(-1).detach().cpu().numpy(), ax=ax3)
            # fig, [ax1, ax2] = plt.subplots(1,2, figsize=(12, 4))            
            
            # sns.distplot(input.reshape(-1).detach().cpu().numpy() , hist = True, rug = True, kde = True, bins=100, norm_hist=False, kde_kws=dict(linewidth=0.5), rug_kws=dict(linewidth=0.5), ax=ax1)
            # sns.distplot(output.reshape(-1).detach().cpu().numpy() , hist = True, rug = True, kde = True, bins=100, norm_hist=False, kde_kws=dict(linewidth=0.5), rug_kws=dict(linewidth=0.5), ax=ax2)
            # # plt.axvline(x=s_init[0].detach().cpu().numpy(), color='r', linestyle='--')
            # # plt.axvline(x=s_init[1].detach().cpu().numpy(), color='r', linestyle='--')

            ax1.set_xlabel("Input Activation")
            # ax2.set_xlabel("Output Activation")
            ax2.set_xlabel("Module Weight")
            ax3.set_xlabel("Output Activation")
            
            ax1.set_ylabel("Density")
            ax2.set_ylabel("Density")
            ax3.set_ylabel("Density")

            ax1.set_title(f"{module.name} Input ACT histogram")
            # ax2.set_title(f"{module.name} Output ACT histogram")
            ax2.set_title(f"{module.name} Weight histogram")
            ax3.set_title(f"{module.name} Output ACT histogram")
            # plt.savefig(f"plt_storage/hook_inputs/sst-2-fp/{module.name}.png")
            plt.show()
            plt.close(fig)
            # module.clip_initialize(s_init)
            # logger.info(f"{module} : min {s_init[0].item()} max {s_init[1].item()}") 

    
    hooks = []

    for name, module in model.named_modules():
        hook = module.register_forward_hook(initialize_hook)
        hooks.append(hook)
    
    model.train()
    model.to(device)
    
    for step, batch in enumerate(loader):
        batch = tuple(t.to("cuda") for t in batch)
        input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch        
        with torch.no_grad():
            student_logits, student_atts, student_reps, student_probs, student_values = model(input_ids, segment_ids, input_mask, teacher_probs=None)
        break
    
    for hook in hooks:
        hook.remove()

def str2bool(v):
    if isinstance(v, bool):
        return v
    if v.lower() in ('yes', 'true', 't', 'y', '1'):
        return True
    elif v.lower() in ('no', 'false', 'f', 'n', '0'):
        return False
    else:
        raise argparse.ArgumentTypeError('Boolean value expected.')

def load_vocab(vocab_file):
    """Loads a vocabulary file into a dictionary."""
    vocab = collections.OrderedDict()
    index = 0
    with open(vocab_file, "r", encoding="utf-8") as reader:
        while True:
            token = reader.readline()
            if not token:
                break
            token = token.strip()
            #vocab[token] = index
            vocab[index] = token
            index += 1
    return vocab

def attention_pattern(model, loader, device):
    
    def initialize_hook(module, input, output):
        if isinstance(module, BertSelfAttention):
            
            attn_mask = input[1]
            attention_output = output[-2]["attn"]
            
            seq_length = (attn_mask == 0).sum()
            
            print(attention_output[0,:,:seq_length,seq_length-1].mean().item())
            

    hooks = []

    for name, module in model.named_modules():
        hook = module.register_forward_hook(initialize_hook)
        hooks.append(hook)
    
    model.eval()
    model.to(device)
    
    for step, batch in enumerate(loader):
        batch = tuple(t.to("cuda") for t in batch)
        input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch        
        with torch.no_grad():
            student_logits, student_atts, student_reps, student_probs, student_values = model(input_ids, segment_ids, input_mask)
        break
    
    for hook in hooks:
        hook.remove()
        
def get_tensor_data(output_mode, features):
    if output_mode == "classification":
        all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
    elif output_mode == "regression":
        all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float)


    all_seq_lengths = torch.tensor([f.seq_length for f in features], dtype=torch.long)
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    tensor_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,all_label_ids, all_seq_lengths)
    return tensor_data, all_label_ids

def do_logging(run, student_model, teacher_model, test_dataloader, device, global_step, args, vocab):
    
    if args.bert == "large":
        layer_num = 24
        head_num = 16
    else:
        layer_num = 12
        head_num = 12
        
    nb_steps = 0
    
    kl_div_sum = [0 for i in range(layer_num)]
    st_sep_avg_sum = [0 for i in range(layer_num)]; st_cls_avg_sum = [0 for i in range(layer_num)]; tc_sep_avg_sum = [0 for i in range(layer_num)]; tc_cls_avg_sum = [0 for i in range(layer_num)]
    cover_sum = [0 for i in range(layer_num)]
    cover_teacher_sum = [0 for i in range(layer_num)]
    
    batch_num = 0
    
    for batch_ in tqdm(test_dataloader, desc="Logging Test", mininterval=0.01, ascii=True, leave=False):
        batch_ = tuple(t.to(device) for t in batch_)
        
        if batch_num >= 1: # Visualize Attention Map only First Batch 
            args.log_map = False
        
        with torch.no_grad():
            input_ids, input_mask, segment_ids, label_id, seq_length = batch_

            teacher_logits, teacher_atts, teacher_reps, teacher_probs, teacher_values = teacher_model(input_ids, segment_ids, input_mask)
            student_logits, student_atts, student_reps, student_probs, student_values = student_model(input_ids, segment_ids, input_mask, teacher_probs=teacher_probs)
            
            # Layer
            for i, (student_prob, teacher_prob) in enumerate(zip(student_probs, teacher_probs)): 

                # Head
                for head in range(head_num):
                    
                    if args.log_map:
                        
                        word_list = []
                        
                        for word in range(seq_length):
                            word_list.append(vocab[input_ids[0][word].item()])
                        
                        student_prob_map = student_prob[0][head][:seq_length,:seq_length].clone().detach().cpu().numpy()
                        teacher_prob_map = teacher_prob[0][head][:seq_length,:seq_length].clone().detach().cpu().numpy()
                        
                        fig, [ax1, ax2] = plt.subplots(1, 2, figsize=(16,8))
                        ax1.set_title(f"{i}th Layer {head}th Head Teacher")
                        heatmap = ax1.pcolor(teacher_prob_map, cmap=plt.cm.Blues)
    
                        ax1.set_xticks(numpy.arange(teacher_prob_map.shape[1]) + 0.5, minor=False)
                        ax1.set_yticks(numpy.arange(teacher_prob_map.shape[0]) + 0.5, minor=False)
                        
                        ax1.set_xlim(0, int(teacher_prob_map.shape[1]))
                        ax1.set_ylim(0, int(teacher_prob_map.shape[0]))

                        ax1.invert_yaxis()
                        ax1.xaxis.tick_top()

                        ax1.set_xticklabels(word_list, minor=False)
                        ax1.set_yticklabels(word_list, minor=False)

                        plt.xticks(rotation=45)
                        
                        ax2.set_title(f"{i}th Layer {head}th Head Student")
                        heatmap = ax2.pcolor(student_prob_map, cmap=plt.cm.Blues)

                        ax2.set_xticks(numpy.arange(student_prob_map.shape[1]) + 0.5, minor=False)
                        ax2.set_yticks(numpy.arange(student_prob_map.shape[0]) + 0.5, minor=False)

                        ax2.set_xlim(0, int(student_prob_map.shape[1]))
                        ax2.set_ylim(0, int(student_prob_map.shape[0]))

                        ax2.invert_yaxis()
                        ax2.xaxis.tick_top()

                        ax2.set_xticklabels(word_list, minor=False)
                        ax2.set_yticklabels(word_list, minor=False)

                        plt.xticks(rotation=45)
                        
                        plt_folder_name = os.path.join("plt_storage" + "/" + args.exp_name)
                        if not os.path.exists(plt_folder_name):
                            os.mkdir(plt_folder_name)          
                        plt_folder_name = os.path.join(plt_folder_name, f"step_{global_step}")
                        if not os.path.exists(plt_folder_name):
                            os.mkdir(plt_folder_name)                        
                        plt.savefig(plt_folder_name + "/" + f"L{i}_H{head}.png")
                        plt.close()
                        

                    if args.log_metric:
                        
                        student_prob = student_prob
                        teacher_prob = teacher_prob

                        # Attention Map
                        student_attn_map = student_prob[0][head][:seq_length,:seq_length].clone().detach()
                        teacher_attn_map = teacher_prob[0][head][:seq_length,:seq_length].clone().detach()

                        # KL Divergence
                        kl_div = F.kl_div(student_attn_map.log(), teacher_attn_map, reduction='batchmean')
                        kl_div_sum[i] += kl_div

                        # Special Token Prob Mean
                        st_sep_avg = student_attn_map[:,-1].mean()
                        st_cls_avg = student_attn_map[:,0].mean()
                        st_sep_avg_sum[i] += st_sep_avg
                        st_cls_avg_sum[i] += st_cls_avg
                        
                        # Ground Truth
                        tc_sep_avg = teacher_attn_map[:,-1].mean()
                        tc_cls_avg = teacher_attn_map[:,0].mean()
                        tc_sep_avg_sum[i] += tc_sep_avg
                        tc_cls_avg_sum[i] += tc_cls_avg

                        # Coverage Test
                        coverage_head_sum = 0
                        coverage_teacher_head_sum = 0
                        for k in range(student_attn_map.shape[0]):
                            st_argsort = student_attn_map[k].sort(descending=True)[1]
                            tc_argsort = teacher_attn_map[k].sort(descending=True)[1][:args.tc_top_k] # Top-5
                            
                            max_idx = 0
                            for idx in tc_argsort: # Teacher Top-5                             
                                tmp = torch.where(st_argsort == idx)
                                max_idx = max(tmp[0].item(), max_idx)
                            
                            coverage_ratio = max_idx / student_attn_map.shape[0]
                            coverage_teacher_ratio = (args.tc_top_k - 1) / student_attn_map.shape[0]
                            coverage_head_sum += coverage_ratio
                            coverage_teacher_head_sum += coverage_teacher_ratio
                        
                        coverage_head = coverage_head_sum / student_attn_map.shape[0]
                        coverage_teacher_head = coverage_teacher_head_sum / student_attn_map.shape[0]
                        
                        cover_sum[i] += coverage_head
                        cover_teacher_sum[i] += coverage_teacher_head
                        
                        nb_steps += 1
        
        batch_num = batch_num + 1
    
    if args.log_metric:
        nb_steps = nb_steps / 12
        
        for l in range(12):
            run[f"attn/L{l}_KLdiv_mean"].log(value=kl_div_sum[l] / nb_steps, step=global_step)
            run[f"attn/L{l}_st_SepProb_mean"].log(value=st_sep_avg_sum[l] / nb_steps, step=global_step)
            run[f"attn/L{l}_st_ClsProb_mean"].log(value=st_cls_avg_sum[l] / nb_steps, step=global_step)
            run[f"attn/L{l}_tc_SepProb_mean"].log(value=tc_sep_avg_sum[l] / nb_steps, step=global_step)
            run[f"attn/L{l}_tc_ClsProb_mean"].log(value=tc_cls_avg_sum[l] / nb_steps, step=global_step)
            run[f"attn/L{l}_st_cover_mean"].log(value=cover_sum[l] / nb_steps, step=global_step)
            run[f"attn/L{l}_tc_cover_mean"].log(value=cover_teacher_sum[l] / nb_steps, step=global_step)

    args.log_map = True                    


def do_eval(model, task_name, eval_dataloader,
            device, output_mode, eval_labels, num_labels, teacher_model=None):
    eval_loss = 0
    nb_eval_steps = 0
    preds = []

    for batch_ in tqdm(eval_dataloader, desc="Inference"):
        batch_ = tuple(t.to(device) for t in batch_)
        
        with torch.no_grad():
            input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch_

            # teacher attnmap test
            if teacher_model is not None:
                logits, teacher_atts, _, teacher_probs, _ = teacher_model(input_ids, segment_ids, input_mask)
                # teacher_probs = 0
                logits, _, _, _, _ = model(input_ids, segment_ids, input_mask, teacher_probs=teacher_probs)
            else:
                logits, _, _, _, _ = model(input_ids, segment_ids, input_mask)
        
        # create eval loss and other metric required by the task
        if output_mode == "classification":
            loss_fct = CrossEntropyLoss()
            tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
        elif output_mode == "regression":
            loss_fct = MSELoss()
            tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))

        eval_loss += tmp_eval_loss.mean().item()
        nb_eval_steps += 1
        if len(preds) == 0:
            preds.append(logits.detach().cpu().numpy())
        else:
            preds[0] = np.append(
                preds[0], logits.detach().cpu().numpy(), axis=0)

    eval_loss = eval_loss / nb_eval_steps

    preds = preds[0]
    if output_mode == "classification":
        preds = np.argmax(preds, axis=1)
    elif output_mode == "regression":
        preds = np.squeeze(preds)
    result = compute_metrics(task_name, preds, eval_labels.numpy())
    result['eval_loss'] = eval_loss
    return result

def soft_cross_entropy(predicts, targets):
    student_likelihood = torch.nn.functional.log_softmax(predicts, dim=-1)
    targets_prob = torch.nn.functional.softmax(targets, dim=-1)
    return torch.sum((- targets_prob * student_likelihood), dim=-1).mean()

processors = {
    "cola": ColaProcessor,
    "mnli": MnliProcessor,
    "mnli-mm": MnliMismatchedProcessor,
    "mrpc": MrpcProcessor,
    "sst-2": Sst2Processor,
    "sts-b": StsbProcessor,
    "qqp": QqpProcessor,
    "qnli": QnliProcessor,
    "rte": RteProcessor   
}

output_modes = {
        "cola": "classification",
        "mnli": "classification",
        "mrpc": "classification",
        "sst-2": "classification",
        "sts-b": "regression",
        "qqp": "classification",
        "qnli": "classification",
        "rte": "classification"
}

default_params = {
        "cola": {"max_seq_length": 64,"batch_size":1,"eval_step": 50}, # No Aug : 50 Aug : 400
        "mnli": {"max_seq_length": 128,"batch_size":1,"eval_step":8000},
        "mrpc": {"max_seq_length": 128,"batch_size":1,"eval_step":100},
        "sst-2": {"max_seq_length": 64,"batch_size":1,"eval_step":100},
        "sts-b": {"max_seq_length": 128,"batch_size":1,"eval_step":100},
        "qqp": {"max_seq_length": 128,"batch_size":1,"eval_step":1000},
        "qnli": {"max_seq_length": 128,"batch_size":1,"eval_step":1000},
        "rte": {"max_seq_length": 128,"batch_size":1,"eval_step": 20}
    }

from bertviz import head_view, model_view
# from bertviz.transformers_neuron_view import BertModel, BertTokenizer
from bertviz.neuron_view import show
import bertviz

# GLUE Task Selection

In [38]:
task_name = "sts-b"
bert_size = "large"

## Model Dir, Device

In [39]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_dir = "models"
output_dir = "output"

if bert_size == "large":
    model_dir = os.path.join(model_dir, "BERT_large")
    output_dir = os.path.join(output_dir, "BERT_large")

student_model_dir = os.path.join(model_dir,task_name)
student_model_dir = os.path.join(output_dir, task_name, "quant", "ternary_save")
# student_model_dir = os.path.join(output_dir, task_name, "quant", "step_2_da_10") # DA-A4W2 51.2
teacher_model_dir = os.path.join(model_dir,task_name)


## Dataset 

In [42]:
# Processor & Task Info
processor = processors[task_name]()
output_mode = output_modes[task_name]
label_list = processor.get_labels()
num_labels = len(label_list)

if task_name in default_params:
    batch_size = default_params[task_name]["batch_size"]
    max_seq_length = default_params[task_name]["max_seq_length"]
    eval_step = default_params[task_name]["eval_step"]
    
# Tokenizer
tokenizer = BertTokenizer.from_pretrained(teacher_model_dir, do_lower_case=True)


# Load Dataset
data_dir = os.path.join("data",task_name)
processed_data_dir = os.path.join(data_dir,'preprocessed')

eval_examples = processor.get_dev_examples(data_dir)
eval_features = convert_examples_to_features(eval_examples, label_list, max_seq_length, tokenizer, output_mode)
# dev_file = train_file = os.path.join(processed_data_dir,'dev.pkl') 
# eval_features = pickle.load(open(dev_file,'rb'))

eval_data, eval_labels = get_tensor_data("classification", eval_features)
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=32)
eval_data, eval_labels = get_tensor_data(output_mode, eval_features)

eval_examples = processor.get_dev_examples(data_dir)

04/01 06:45:11 PM Writing example 0 of 1500
04/01 06:45:11 PM *** Example ***
04/01 06:45:11 PM guid: dev-0
04/01 06:45:11 PM tokens: [CLS] a man with a hard hat is dancing . [SEP] a man wearing a hard hat is dancing . [SEP]
04/01 06:45:11 PM input_ids: 101 1037 2158 2007 1037 2524 6045 2003 5613 1012 102 1037 2158 4147 1037 2524 6045 2003 5613 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
04/01 06:45:11 PM input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
04/01 06:45:11 PM segment_ids: 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

  all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)


# Model Build

In [41]:
build_tc = 1
build_st = 1

if build_tc:
    # Teacher Model Build
    teacher_model = BertForSequenceClassification.from_pretrained(teacher_model_dir, num_labels=num_labels)
    teacher_model.to(device)
    teacher_model.eval()
    model = teacher_model

if build_st:
    # Student Model Build
    student_config = BertConfig.from_pretrained(student_model_dir,
                                                    quantize_act=True,
                                                    quantize_weight=True,
                                                    weight_bits = 2, # Always Ternary when "quantize_weight = True"
                                                    input_bits = 8,
                                                    clip_val = 2.5,
                                                    quantize = True,
                                                    ffn_q_1 = True,
                                                    ffn_q_2 = True,
                                                    qkv_q = True,
                                                    emb_q = True,
                                                    cls_q = True,
                                                    clipping = False,
                                                    layer_num = -1,
                                                    mean_scale = 0.7,
                                                    quantizer = "ternary",
                                                    act_quantizer = "ternary",
                                                    init_scaling = 1,
                                                    clip_ratio = 1,
                                                    gradient_scaling = False,
                                                    clip_method = "minmax",
                                                    teacher_attnmap = False,
                                                    parks = False,
                                                    stop_grad = False,
                                                    qk_FP = False,
                                                    map=False,
                                                    act_method = "clipping"
                                                    )

    student_model = QuantBertForSequenceClassification.from_pretrained(student_model_dir, config = student_config, num_labels=num_labels)
    student_model.to(device)
    model = student_model
    print()

    # Quantization Option ACT/WEIGHT
    for name, module in student_model.named_modules():
                if isinstance(module, (QuantizeLinear, QuantizeAct, ClipLinear)):    
                    module.act_flag = True
                    module.weight_flag = True

04/01 06:44:58 PM Loading model models/BERT_large/sts-b/pytorch_model.bin
04/01 06:44:59 PM loading model...
04/01 06:44:59 PM done!
04/01 06:44:59 PM loading configuration file output/BERT_large/sts-b/quant/ternary_save/config.json
04/01 06:45:05 PM Loading model output/BERT_large/sts-b/quant/ternary_save/pytorch_model.bin
04/01 06:45:06 PM loading model...
04/01 06:45:06 PM done!



## Activation Quantization Clip Value Initialization

In [1]:
# for name, module in student_model.named_modules():
#             if isinstance(module, (QuantizeLinear, QuantizeAct, ClipLinear)):    
#                 module.act_flag = False
#                 module.weight_flag = False
                
# cv_initialize(student_model, eval_dataloader, torch.Tensor([0.005]), device)

# # for name, module in student_model.named_modules():
# #             if isinstance(module, (QuantizeLinear, QuantizeAct, ClipLinear)):    
# #                 module.act_flag = True
# #                 module.weight_flag = False

## Model Evaluation

In [43]:
eval_st = 1
eval_tc = 0

if eval_st:
    print("Student Model Inferece")
    student_model.eval()
    student_result = do_eval(student_model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels, teacher_model=teacher_model)
    print(f"Student Result : {student_result}")

if eval_tc:
    print("Teacher Model Inferece")
    teacher_result = do_eval(teacher_model, task_name, eval_dataloader, device, output_mode, eval_labels, num_labels)
    print(f"Teacher Result : {teacher_result}")



Student Model Inferece


Inference: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 47/47 [00:27<00:00,  1.69it/s]

Student Result : {'pearson': 0.8887880506300738, 'spearmanr': 0.8850681732784779, 'corr': 0.8869281119542758, 'eval_loss': 0.6736781127909397}





## BERTViz Model View

In [14]:
# Sampling Sentence 
i = 0 
num = 2
for step, batch in enumerate(eval_dataloader):
    model.train()
            
    batch = tuple(t.to(device) for t in batch)
    input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch
    i = i + 1
    if i == num:
        break

seq_length = seq_lengths.item()
input_ids_sliced = input_ids[:,:seq_length]
input_id = []
for i in input_ids_sliced[0]:
    input_id.append(i.item())
tokens = tokenizer.convert_ids_to_tokens(input_id)



sample_sentence_a = str()
sample_sentence_b = str()
index = 0

for i, word in enumerate(tokens[1:-1]):
    if word == "[SEP]":
        break
    sample_sentence_a += word
    sample_sentence_a += " "
index = i

for i, word in enumerate(tokens[index+2:-1]):
    if word == "[SEP]":
        break
    sample_sentence_b += word
    sample_sentence_b += " "

sep_index = torch.where(input_ids[0] == 102)[0]

if len(sample_sentence_b) > 1:
    sample_sentence_b_start = segment_ids[0].tolist().index(1)

print(f"input_ids : {input_ids_sliced}")
print(f"tokens : {tokens}")
print(f"A : {sample_sentence_a}")
print(f"B : {sample_sentence_b}")
print(sep_index)

input_ids : tensor([[ 101, 1037, 2402, 2775, 2003, 5559, 1037, 3586, 1012,  102, 1037, 2775,
         2003, 5559, 1037, 3586, 1012,  102]], device='cuda:0')
tokens : ['[CLS]', 'a', 'young', 'child', 'is', 'riding', 'a', 'horse', '.', '[SEP]', 'a', 'child', 'is', 'riding', 'a', 'horse', '.', '[SEP]']
A : a young child is riding a horse . 
B : a child is riding a horse . 
tensor([ 9, 17], device='cuda:0')


In [2]:
from bertviz.transformers_neuron_view import BertModel, BertTokenizer
from bertviz.neuron_view import show

bertviz_neuron_tc = 0
bertviz_neuron_st = 0
bertviz_model_tc = 1
bertviz_model_st = 1

# Quantization Setting
if bertviz_neuron_st or bertviz_model_st:
    for name, module in student_model.named_modules():
            if isinstance(module, (QuantizeLinear, ClipLinear)):    
                module.act_flag = False
                module.weight_flag = False
            if isinstance(module, QuantizeAct):    
                module.act_flag = False
                module.weight_flag = False

if bertviz_neuron_tc or bertviz_neuron_st:
    if bertviz_neuron_st:
        for name, module in student_model.named_modules():
                if isinstance(module, BertSelfAttention):    
                    module.output_bertviz = True
    if bertviz_neuron_tc:
        for name, module in teacher_model.named_modules():
                if isinstance(module, FP_BertSelfAttention):    
                    module.output_bertviz = True

    model_type = 'bert'
    model_version = 'bert-base-uncased'
    
    tokenizer = BertTokenizer.from_pretrained(model_version, do_lower_case=True)
    if bertviz_neuron_tc:
        if len(sample_sentence_b) > 1:
            show(teacher_model.cpu(), model_type, tokenizer, sample_sentence_a, sample_sentence_b, display_mode="light")
        else:
            show(teacher_model.cpu(), model_type, tokenizer, sample_sentence_a,display_mode="light")
    if bertviz_neuron_st:
        if len(sample_sentence_b) > 1:
            show(student_model.cpu(), model_type, tokenizer, sample_sentence_a, sample_sentence_b, display_mode="light")
        else:
            show(student_model.cpu(), model_type, tokenizer, sample_sentence_a,display_mode="light")

if bert_size == "large":
    layer_num = 24
    head_num = 16
else:
    layer_num = 12
    head_num = 12
    
all_layers = list(range(layer_num))
layers_to_show = all_layers[18:]

if bertviz_model_tc or bertviz_model_st:
    
    if bertviz_model_tc:
        print("teacher_map")
        for name, module in teacher_model.named_modules():
                    if isinstance(module, FP_BertSelfAttention):    
                        module.output_bertviz = False
        teacher_model.eval()
        teacher_model.to(device)
        teacher_logits, teacher_atts, teacher_reps, teacher_probs, teacher_values = teacher_model(input_ids_sliced.to(device))
        model_view(teacher_probs, tokens, include_layers=layers_to_show,  display_mode="light")
        
    if bertviz_model_st:
        print("student_map")
        for name, module in student_model.named_modules():
                    if isinstance(module, BertSelfAttention):    
                        module.output_bertviz = False
        student_model.eval()
        student_model.to(device)
        student_logits, student_atts, student_reps, student_probs, student_values = student_model(input_ids_sliced.to(device), teacher_probs=teacher_probs)
        model_view(student_probs, tokens, sample_sentence_b_start,include_layers=layers_to_show, display_mode="light")# , include_layers=[0, 1])
        
    


## Forward Check

In [45]:
from torch.nn import MSELoss
mse_func = MSELoss()

if bert_size == "large":
    layer_num = 24
    head_num = 16
else:
    layer_num = 12
    head_num = 12


attention_pattern_check = 0
cover_mean_check = 0
kl_div_check = 1
mse_check = 0
attnmap_mse_check = 0

exclude_sep = 0

for name, module in student_model.named_modules():
            if isinstance(module, BertSelfAttention):    
                module.output_bertviz = False
for name, module in teacher_model.named_modules():
            if isinstance(module, FP_BertSelfAttention):    
                module.output_bertviz = False
                
for name, module in student_model.named_modules():
            if isinstance(module, (QuantizeLinear, ClipLinear, QuantizeAct)):    
                module.act_flag = True
                module.weight_flag = True

seed=42
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)

student_model.eval()
teacher_model.eval()
student_model.to(device)
teacher_model.to(device)
teacher_logits, teacher_atts, teacher_reps, teacher_probs, teacher_values = teacher_model(input_ids_sliced.to(device))
student_logits, student_atts, student_reps, student_probs, student_values = student_model(input_ids_sliced.to(device), teacher_probs=teacher_probs)

probs = teacher_probs
if attention_pattern_check:
    print("Attention mean CHECK")
    for i in range(layer_num):
        if len(sep_index) == 2:
            print((probs[i][0,:,:,sep_index[0]].mean() + probs[i][0,:,:,sep_index[1]].mean()).item())
        else:
            print(probs[i][0,:,:,sep_index[0]].mean().item())
    
if cover_mean_check:
    print("COVER MEAN CHECK")
    top_k = 5

    for i in range(layer_num):
        teacher = teacher_probs[i][0]
        student = student_probs[i][0]

        head_sum = 0
        for h in range(head_num):
            coverage_head_sum = 0
            for row in range(seq_length-1):
                if exclude_sep:
                    tc_argsort = teacher[h][:seq_length-1,:seq_length-1].sort(descending=True)[1][row][:top_k] # top-k
                    st_argsort = student[h][:seq_length-1,:seq_length-1].sort(descending=True)[1][row]
                tc_argsort = teacher[h].sort(descending=True)[1][row][:top_k] # top-k
                st_argsort = student[h].sort(descending=True)[1][row]

                max_idx = 0
                for idx in tc_argsort:
                    tmp = torch.where(st_argsort == idx)
                    max_idx = max(tmp[0].item(), max_idx)

                coverage_ratio = max_idx / student.shape[1]
                coverage_head_sum += coverage_ratio

                # print(f"H{h} : {coverage_head_sum/seq_length}")

            head_sum += coverage_head_sum / seq_length
        print(head_sum / head_num)

if kl_div_check:
    print("KL DIV CHECK")
    for i in range(layer_num):
        if exclude_sep:
            if len(sep_index) == 2:
                teacher_atts[i][:,:,:,sep_index[0]] = -100000; teacher_atts[i][:,:,:,sep_index[1]] = -100000
                student_atts[i][:,:,:,sep_index[0]] = -100000; student_atts[i][:,:,:,sep_index[1]] = -100000
            else:
                teacher_atts[i][:,:,:,sep_index[0]] = -100000
                student_atts[i][:,:,:,sep_index[0]] = -100000
                
            teacher = torch.nn.Softmax(dim=-1)(teacher_atts[i])
            student = torch.nn.Softmax(dim=-1)(student_atts[i])
            
            student = torch.clamp_min(student, 1e-8)
            teacher = torch.clamp_min(teacher, 1e-8)
        else:    
            teacher = teacher_probs[i]
            student = student_probs[i]
        
        neg_cross_entropy = teacher * torch.log(student) 
        neg_cross_entropy = torch.sum(neg_cross_entropy, dim=-1)  # (b, h, s, s) -> (b, h, s)
        neg_cross_entropy = torch.sum(neg_cross_entropy, dim=-1) / seq_lengths.view(-1, 1)  # (b, h, s) -> (b, h)

        # p(t) log p(t) = negative entropy
        neg_entropy = teacher * torch.log(teacher) 
        neg_entropy = torch.sum(neg_entropy, dim=-1)  # (b, h, s, s) -> (b, h, s)
        neg_entropy = torch.sum(neg_entropy, dim=-1) / seq_lengths.view(-1, 1)  # (b, h, s) -> (b, h)

        kld_loss = neg_entropy - neg_cross_entropy

        kld_loss_sum = torch.sum(kld_loss)
        print(kld_loss_sum.item())

if mse_check:
    for i in range(layer_num):
        print(mse_func(teacher_atts[i], student_atts[i]).item())
        
if attnmap_mse_check:
    for i in range(layer_num):
        if exclude_sep:
            if len(sep_index) == 2:
                teacher_atts[i][:,:,:,sep_index[0]] = -100000; teacher_atts[i][:,:,:,sep_index[1]] = -100000
                student_atts[i][:,:,:,sep_index[0]] = -100000; student_atts[i][:,:,:,sep_index[1]] = -100000
            else:
                teacher_atts[i][:,:,:,sep_index[0]] = -100000
                student_atts[i][:,:,:,sep_index[0]] = -100000
                
            teacher = torch.nn.Softmax(dim=-1)(teacher_atts[i])
            student = torch.nn.Softmax(dim=-1)(student_atts[i])
            print(mse_func(teacher, student).item())
        else:    
            print(mse_func(teacher_probs[i], student_probs[i]).item())
    

KL DIV CHECK
2.9676687717437744
5.659456253051758
3.098494052886963
2.3057239055633545
2.587374210357666
1.6189770698547363
1.2007269859313965
1.662299633026123
2.1250391006469727
2.056593418121338
3.305050849914551
4.6170477867126465
4.530083656311035
4.024961948394775
4.323973178863525
3.5094804763793945
3.292818784713745
2.7897772789001465
4.441211700439453
8.378562927246094
20.496295928955078
37.17641830444336
40.135982513427734
76.01957702636719


In [35]:

i = 9
teacher = teacher_probs[1][:,i,:,:]
student = student_probs[1][:,i,:,:]



neg_cross_entropy = teacher * torch.log(student) 
neg_cross_entropy = torch.sum(neg_cross_entropy, dim=-1)  # (b, h, s, s) -> (b, h, s)
neg_cross_entropy = torch.sum(neg_cross_entropy, dim=-1) / seq_lengths.view(-1, 1)  # (b, h, s) -> (b, h)

# p(t) log p(t) = negative entropy
neg_entropy = teacher * torch.log(teacher) 
neg_entropy = torch.sum(neg_entropy, dim=-1)  # (b, h, s, s) -> (b, h, s)
neg_entropy = torch.sum(neg_entropy, dim=-1) / seq_lengths.view(-1, 1)  # (b, h, s) -> (b, h)

kld_loss = neg_entropy - neg_cross_entropy

kld_loss_sum = torch.sum(kld_loss)
print(kld_loss_sum.item())

1.155358076095581


In [43]:

head = 7
for head in range(16):
    teacher = teacher_probs[23][:,head,:,:]
    student = student_probs[23][:,head,:,:]
    neg_cross_entropy = teacher * torch.log(student) 
    neg_cross_entropy = torch.sum(neg_cross_entropy, dim=-1)  # (b, h, s, s) -> (b, h, s)
    neg_cross_entropy = torch.sum(neg_cross_entropy, dim=-1) / seq_lengths.view(-1, 1)  # (b, h, s) -> (b, h)

    # p(t) log p(t) = negative entropy
    neg_entropy = teacher * torch.log(teacher) 
    neg_entropy = torch.sum(neg_entropy, dim=-1)  # (b, h, s, s) -> (b, h, s)
    neg_entropy = torch.sum(neg_entropy, dim=-1) / seq_lengths.view(-1, 1)  # (b, h, s) -> (b, h)

    kld_loss = neg_entropy - neg_cross_entropy

    kld_loss_sum = torch.sum(kld_loss)
    print(kld_loss_sum.item())

0.4594290256500244
0.5127482414245605
0.5085060596466064
0.44210290908813477
0.3044252395629883
0.4317352771759033
0.43944454193115234
0.34624576568603516
0.45526742935180664
0.4619622230529785
0.2999594211578369
0.5414872169494629
0.4033381938934326
0.3882777690887451
0.3862929344177246
0.29909467697143555


In [39]:
student_probs[23][:,7,:,:]

tensor([[[2.1772e-02, 1.2781e-02, 4.9408e-03,  ..., 2.0158e-02,
          6.9626e-02, 1.4999e-02],
         [1.0919e-03, 2.1085e-03, 1.0577e-03,  ..., 1.6647e-04,
          4.3606e-04, 1.8829e-01],
         [2.0435e-04, 2.2089e-04, 1.0611e-04,  ..., 4.0684e-05,
          4.6589e-05, 1.9174e-01],
         ...,
         [1.2242e-03, 9.6587e-05, 5.1341e-05,  ..., 2.4791e-03,
          8.4426e-04, 2.3197e-01],
         [3.1989e-05, 1.6445e-05, 4.2517e-06,  ..., 5.4571e-05,
          1.5931e-05, 1.9737e-01],
         [2.7168e-02, 1.6892e-02, 1.5799e-02,  ..., 1.5508e-02,
          2.3137e-02, 4.5687e-02]]], device='cuda:0', grad_fn=<SliceBackward0>)

In [37]:
teacher_probs[1][:,9,:,0]

tensor([[0.8724, 0.4891, 0.6094, 0.6159, 0.4543, 0.4242, 0.5964, 0.1814, 0.7470,
         0.3960, 0.4766, 0.3705, 0.6179, 0.2615, 0.6612, 0.4897, 0.6067, 0.3899,
         0.4389, 0.7189, 0.3123, 0.2026, 0.3890, 0.2045, 0.5183, 0.4821, 0.4648,
         0.2569, 0.2587, 0.4544, 0.2224, 0.6254, 0.7420, 0.1829, 0.5220, 0.2288,
         0.6594, 0.6613, 0.4734, 0.7735, 0.7583, 0.2122, 0.6728, 0.2626, 0.7567,
         0.2797, 0.7548, 0.4164, 0.7275]], device='cuda:0',
       grad_fn=<SelectBackward0>)