In [1]:

from __future__ import absolute_import, division, print_function

import pprint
import argparse
import logging
import os
import random
import sys
import pickle
import copy
import collections
import math

import numpy as np
import numpy
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler,TensorDataset
os.environ["CUDA_VISIBLE_DEVICES"]="1" # Set GPU Index to use
from torch.nn import CrossEntropyLoss, MSELoss

from transformer import BertForSequenceClassification,WEIGHTS_NAME, CONFIG_NAME
from transformer.modeling_quant import BertForSequenceClassification as QuantBertForSequenceClassification
from transformer import BertTokenizer
from transformer import BertAdam
from transformer import BertConfig
from transformer import QuantizeLinear, QuantizeAct, BertSelfAttention, FP_BertSelfAttention, ClipLinear, BertAttention, FP_BertAttention
from utils_glue import *
from bertviz import model_view

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch.nn.functional as F

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0 
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def do_eval(model, task_name, eval_dataloader,
            device, output_mode, eval_labels, num_labels, teacher_model=None):
    eval_loss = 0
    nb_eval_steps = 0
    preds = []

    for batch_ in tqdm(eval_dataloader, desc="Inference"):
        batch_ = tuple(t.to(device) for t in batch_)
        
        with torch.no_grad():
            input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch_

            # teacher attnmap test
            if teacher_model is not None:
                
                # logits, _, teacher_reps, teacher_probs, teacher_values = teacher_model(input_ids, segment_ids, input_mask)
                
                # # logits, _, _, _, _ = model(input_ids, segment_ids, input_mask, teacher_probs=teacher_probs)
                # logits, _, _, _, _ = model(input_ids, segment_ids, input_mask, teacher_probs=(teacher_probs, teacher_values, teacher_reps))
                teacher_logits, teacher_atts, teacher_reps, teacher_probs, teacher_values = teacher_model(input_ids, segment_ids, input_mask)
                logits, student_atts, student_reps, student_probs, student_values  = model(input_ids, segment_ids, input_mask, teacher_outputs=(teacher_probs, teacher_values, teacher_reps, teacher_logits, teacher_atts))
            else:
                logits, _, _, _, _ = model(input_ids, segment_ids, input_mask)
        
        # create eval loss and other metric required by the task
        if output_mode == "classification":
            loss_fct = CrossEntropyLoss()
            tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
        elif output_mode == "regression":
            loss_fct = MSELoss()
            tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))

        eval_loss += tmp_eval_loss.mean().item()
        nb_eval_steps += 1
        if len(preds) == 0:
            preds.append(logits.detach().cpu().numpy())
        else:
            preds[0] = np.append(
                preds[0], logits.detach().cpu().numpy(), axis=0)

    eval_loss = eval_loss / nb_eval_steps

    preds = preds[0]
    if output_mode == "classification":
        preds = np.argmax(preds, axis=1)
    elif output_mode == "regression":
        preds = np.squeeze(preds)
    result = compute_metrics(task_name, preds, eval_labels.numpy())
    result['eval_loss'] = eval_loss
    return result

processors = {
    "cola": ColaProcessor,
    "mnli": MnliProcessor,
    "mnli-mm": MnliMismatchedProcessor,
    "mrpc": MrpcProcessor,
    "sst-2": Sst2Processor,
    "sts-b": StsbProcessor,
    "qqp": QqpProcessor,
    "qnli": QnliProcessor,
    "rte": RteProcessor   
}

output_modes = {
        "cola": "classification",
        "mnli": "classification",
        "mrpc": "classification",
        "sst-2": "classification",
        "sts-b": "regression",
        "qqp": "classification",
        "qnli": "classification",
        "rte": "classification"
}

default_params = {
        "cola": {"max_seq_length": 64,"batch_size":16,"eval_step": 400}, # No Aug : 50 Aug : 400
        "mnli": {"max_seq_length": 128,"batch_size":32,"eval_step":8000},
        "mrpc": {"max_seq_length": 128,"batch_size":32,"eval_step":20},
        "sst-2": {"max_seq_length": 64,"batch_size":32,"eval_step":100},
        "sts-b": {"max_seq_length": 128,"batch_size":32,"eval_step":100},
        "qqp": {"max_seq_length": 128,"batch_size":32,"eval_step":1000},
        "qnli": {"max_seq_length": 128,"batch_size":32,"eval_step":1000},
        "rte": {"max_seq_length": 128,"batch_size":32,"eval_step":100}
    }

def get_tensor_data(output_mode, features):
    if output_mode == "classification":
        all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
    elif output_mode == "regression":
        all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float)


    all_seq_lengths = torch.tensor([f.seq_length for f in features], dtype=torch.long)
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    tensor_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,all_label_ids, all_seq_lengths)
    return tensor_data, all_label_ids


In [3]:
import numpy as np


def group_product(xs, ys):
    """
    the inner product of two lists of variables xs,ys
    :param xs:
    :param ys:
    :return:
    """
    return sum([torch.sum(x * y) for (x, y) in zip(xs, ys)])


def de_variable(v):
    '''
    normalize the vector and detach it from variable
    '''

    s = group_product(v, v)
    s = s**0.5
    s = s.cpu().item() + 1e-6
    v = [vi / s for vi in v]
    return v

def group_add(params, update, alpha=1):
    """
    params = params + update*alpha
    :param params: list of variable
    :param update: list of data
    :return:
    """
    for i, p in enumerate(params):
        params[i].data.add_(update[i] * alpha)
    return params


def normalization(v):
    """
    normalization of a list of vectors
    return: normalized vectors v
    """
    s = group_product(v, v)
    s = s**0.5
    s = s.cpu().item()
    v = [vi / (s + 1e-6) for vi in v]
    # v = [vi / s for vi in v]
    return v


def orthonormal(w, v_list):
    for v in v_list:
        w = group_add(w, v, alpha=-group_product(w, v))
    return normalization(w)


def total_number_parameters(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    return sum([np.prod(p.size()) for p in model_parameters])

from torch.nn import CrossEntropyLoss, MSELoss




In [4]:
task_name = "rte"
bert_size = "large"

if bert_size == "large":
    layer_num = 24
    head_num = 16
else: 
    layer_num = 12
    head_num = 12
    
teacher_model = None
# torch.cuda.empty_cache()
# !nvidia-smi

# DEVICE / DATASET

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_dir = "models"
output_dir = "output"

if bert_size == "large":
    model_dir = os.path.join(model_dir, "BERT_large")
    output_dir = os.path.join(output_dir, "BERT_large")

teacher_model_dir = os.path.join(model_dir,task_name)

# Processor & Task Info
processor = processors[task_name]()
output_mode = output_modes[task_name]
label_list = processor.get_labels()
num_labels = len(label_list)

if task_name in default_params:
    batch_size = default_params[task_name]["batch_size"]
    max_seq_length = default_params[task_name]["max_seq_length"]
    eval_step = default_params[task_name]["eval_step"]
    
# Tokenizer
tokenizer = BertTokenizer.from_pretrained(teacher_model_dir, do_lower_case=True)


# Load Dataset
data_dir = os.path.join("data",task_name)
processed_data_dir = os.path.join(data_dir,'preprocessed')

train_examples = processor.get_train_examples(data_dir)
train_features = convert_examples_to_features(train_examples, label_list,
                                max_seq_length, tokenizer, output_mode)

len_train_data = int(len(train_features) * 1)
train_features = train_features[:len_train_data]

eval_examples = processor.get_dev_examples(data_dir)
eval_features = convert_examples_to_features(eval_examples, label_list, max_seq_length, tokenizer, output_mode)
# dev_file = train_file = os.path.join(processed_data_dir,'dev.pkl') 
# eval_features = pickle.load(open(dev_file,'rb'))

train_data, train_labels = get_tensor_data(output_mode, train_features)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

eval_data, eval_labels = get_tensor_data("classification", eval_features)
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=1)
eval_data, eval_labels = get_tensor_data(output_mode, eval_features)

eval_examples = processor.get_dev_examples(data_dir)

# Sampling Sentence 
i = 0 
# num = 3
num = 1



07/14 06:04:23 PM Writing example 0 of 2490
07/14 06:04:23 PM *** Example ***
07/14 06:04:23 PM guid: train-0
07/14 06:04:23 PM tokens: [CLS] no weapons of mass destruction found in iraq yet . [SEP] weapons of mass destruction found in iraq . [SEP]
07/14 06:04:23 PM input_ids: 101 2053 4255 1997 3742 6215 2179 1999 5712 2664 1012 102 4255 1997 3742 6215 2179 1999 5712 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/14 06:04:23 PM input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/14 06:04:23 PM segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0

In [6]:
# teacher_model = BertForSequenceClassification.from_pretrained(teacher_model_dir, num_labels=num_labels)
# teacher_model.to(device)
# teacher_model.eval()

st_model_name = "1SB_M"
student_model_dir = os.path.join(output_dir, task_name, "exploration", st_model_name)   
student_config = BertConfig.from_pretrained(student_model_dir)   
student_model = QuantBertForSequenceClassification.from_pretrained(student_model_dir, config = student_config, num_labels=num_labels)
student_model.to(device)
print()

07/14 06:04:25 PM loading configuration file output/BERT_large/rte/exploration/1SB_M/config.json
07/14 06:04:30 PM Loading model output/BERT_large/rte/exploration/1SB_M/pytorch_model.bin
07/14 06:04:32 PM loading model...
07/14 06:04:32 PM done!
07/14 06:04:32 PM Weights from pretrained model not used in BertForSequenceClassification: ['bert.embeddings.word_embeddings.qweight', 'bert.encoder.layer.0.attention.self.query.qweight', 'bert.encoder.layer.0.attention.self.key.qweight', 'bert.encoder.layer.0.attention.self.value.qweight', 'bert.encoder.layer.0.attention.output.dense.qweight', 'bert.encoder.layer.0.intermediate.dense.qweight', 'bert.encoder.layer.0.output.dense.qweight', 'bert.encoder.layer.1.attention.self.query.qweight', 'bert.encoder.layer.1.attention.self.key.qweight', 'bert.encoder.layer.1.attention.self.value.qweight', 'bert.encoder.layer.1.attention.output.dense.qweight', 'bert.encoder.layer.1.intermediate.dense.qweight', 'bert.encoder.layer.1.output.dense.qweight', 'be

In [7]:
data_percentage = 0.05

percentage_index = len(train_dataloader.dataset) * data_percentage / batch_size
print(f'percentage_index: {percentage_index}')

student_model.eval()

csv_path = os.path.join("layer_hessian_results", f"{task_name}-{data_percentage}-eigens.csv")
csv_file = open(csv_path, 'w', newline='')
writer = csv.writer(csv_file, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL)
writer.writerow(['block', 'iters', 'max_eigenvalue'])

percentage_index: 3.890625


28

In [8]:
loss_fct = CrossEntropyLoss()

for module in student_model.modules():
    for param in module.parameters():
        param.requires_grad = True

block_id = 0

for block_id in range(layer_num):
    logger.info(f'block_id: {block_id}')
    model_block = student_model.bert.encoder.layer[block_id]
    
    v = [
            torch.randn(p.size()).to(device) for p in model_block.parameters()
        ]
    v = de_variable(v)

    lambda_old, lambdas = 0., 1.
    i = 0
    while (abs((lambdas - lambda_old) / lambdas) >= 0.005):

        lambda_old = lambdas

        acc_Hv = [
            torch.zeros(p.size()).cuda() for p in model_block.parameters()
        ]
        for step, batch in enumerate(train_dataloader):
            if step < percentage_index:
                
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids, _ = batch
                
                logits, _, _, _, _ = student_model(input_ids, segment_ids, input_mask, teacher_outputs=None)

                if output_mode == "classification":
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(
                        logits.view(-1, num_labels), label_ids.view(-1))
                elif output_mode == "regression":
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), label_ids.view(-1))
                    
                loss.backward(create_graph=True)
                grads = [param.grad for param in model_block.parameters()]
                params = model_block.parameters()

                Hv = torch.autograd.grad(
                    grads,
                    params,
                    grad_outputs=v,
                    only_inputs=True,
                    retain_graph=True)
                acc_Hv = [
                    acc_Hv_p + Hv_p for acc_Hv_p, Hv_p in zip(acc_Hv, Hv)
                ]
                student_model.zero_grad()
        # calculate raylay quotients
        lambdas = group_product(acc_Hv, v).item() / percentage_index

        v = de_variable(acc_Hv)
        logger.info(f'block_{block_id}-lambda: {lambdas}')
        if abs((lambdas - lambda_old) / lambdas) < 0.005:
            writer.writerow([f'{block_id}', f'{i}', f'{lambdas}'])
            csv_file.flush()

        i += 1

07/14 06:04:35 PM block_id: 0


  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass


07/14 06:04:42 PM block_0-lambda: 2.3985539663986987e-06
07/14 06:04:47 PM block_0-lambda: 0.002090829203885243
07/14 06:04:53 PM block_0-lambda: 0.07367013065690496
07/14 06:04:59 PM block_0-lambda: 0.5844492471840487
07/14 06:05:04 PM block_0-lambda: 0.5747494372019327
07/14 06:05:10 PM block_0-lambda: 0.787686819053558
07/14 06:05:16 PM block_0-lambda: 0.7871970670769014
07/14 06:05:16 PM block_id: 1
07/14 06:05:22 PM block_1-lambda: -7.455763767050752e-07
07/14 06:05:27 PM block_1-lambda: 0.025249699512159968
07/14 06:05:33 PM block_1-lambda: 0.7545009750917734
07/14 06:05:38 PM block_1-lambda: 0.8833756657489332
07/14 06:05:44 PM block_1-lambda: 0.7087361898766943
07/14 06:05:50 PM block_1-lambda: 0.783837682272057
07/14 06:05:55 PM block_1-lambda: 1.0360744736759537
07/14 06:06:01 PM block_1-lambda: 0.8320200357092432
07/14 06:06:06 PM block_1-lambda: 0.4347141679511013
07/14 06:06:12 PM block_1-lambda: 1.1705257308530999
07/14 06:06:18 PM block_1-lambda: 1.154293810986132
07/14 

07/14 06:18:22 PM block_5-lambda: 1.8589051028332078
07/14 06:18:28 PM block_5-lambda: 1.9751552597107178
07/14 06:18:33 PM block_5-lambda: 2.140727460623745
07/14 06:18:38 PM block_5-lambda: 1.3707215335953187
07/14 06:18:43 PM block_5-lambda: 1.9941282463839733
07/14 06:18:48 PM block_5-lambda: 2.3416007781124497
07/14 06:18:53 PM block_5-lambda: 1.3609137477644955
07/14 06:18:58 PM block_5-lambda: 1.3422852788105548
07/14 06:19:03 PM block_5-lambda: 1.8450848070014432
07/14 06:19:08 PM block_5-lambda: 1.7365223834792294
07/14 06:19:13 PM block_5-lambda: 1.2735859208317646
07/14 06:19:18 PM block_5-lambda: 1.5211440243395458
07/14 06:19:23 PM block_5-lambda: 2.2248782728570533
07/14 06:19:28 PM block_5-lambda: 2.2595700183546685
07/14 06:19:33 PM block_5-lambda: 2.3595867846385543
07/14 06:19:38 PM block_5-lambda: 1.9759598697524472
07/14 06:19:43 PM block_5-lambda: 1.3626204034889557
07/14 06:19:49 PM block_5-lambda: 1.3521136548145707
07/14 06:19:49 PM block_id: 6
07/14 06:19:54 PM

07/14 06:30:40 PM block_9-lambda: 7.166424487010542
07/14 06:30:45 PM block_9-lambda: 8.325072751945282
07/14 06:30:49 PM block_9-lambda: 5.396272100119227
07/14 06:30:54 PM block_9-lambda: 5.445457121454568
07/14 06:30:54 PM block_id: 10
07/14 06:30:58 PM block_10-lambda: -6.256748484559806e-06
07/14 06:31:03 PM block_10-lambda: -0.05772609787293706
07/14 06:31:07 PM block_10-lambda: 0.4737388273798318
07/14 06:31:12 PM block_10-lambda: 12.289848848519076
07/14 06:31:16 PM block_10-lambda: 10.691254274912149
07/14 06:31:20 PM block_10-lambda: 11.076077748493976
07/14 06:31:25 PM block_10-lambda: 8.378625831450803
07/14 06:31:29 PM block_10-lambda: 11.694433005459338
07/14 06:31:34 PM block_10-lambda: 10.024838612261545
07/14 06:31:38 PM block_10-lambda: 10.255416196034137
07/14 06:31:43 PM block_10-lambda: 12.171640664219378
07/14 06:31:47 PM block_10-lambda: 12.773654187060743
07/14 06:31:51 PM block_10-lambda: 5.282108414125251
07/14 06:31:56 PM block_10-lambda: 6.903242363987199
07

07/14 06:41:40 PM block_12-lambda: 12.976513475778113
07/14 06:41:40 PM block_id: 13
07/14 06:41:45 PM block_13-lambda: 6.401476573692747e-06
07/14 06:41:49 PM block_13-lambda: 5.047662819245733
07/14 06:41:53 PM block_13-lambda: 8.585212922000501
07/14 06:41:57 PM block_13-lambda: 12.37183891817269
07/14 06:42:01 PM block_13-lambda: 11.573087270958835
07/14 06:42:05 PM block_13-lambda: 11.988152806538654
07/14 06:42:09 PM block_13-lambda: 10.941034646398093
07/14 06:42:13 PM block_13-lambda: 12.708130373054718
07/14 06:42:17 PM block_13-lambda: 10.657008894954819
07/14 06:42:21 PM block_13-lambda: 15.299167372615463
07/14 06:42:25 PM block_13-lambda: 14.961269884224398
07/14 06:42:29 PM block_13-lambda: 14.037855523657129
07/14 06:42:33 PM block_13-lambda: 13.241739418611948
07/14 06:42:37 PM block_13-lambda: 8.738068484877008
07/14 06:42:41 PM block_13-lambda: 14.585494321034137
07/14 06:42:45 PM block_13-lambda: 12.88169180628765
07/14 06:42:49 PM block_13-lambda: 10.489403904681225

07/14 06:51:06 PM block_19-lambda: 7.783722291509789
07/14 06:51:09 PM block_19-lambda: 7.617292902077058
07/14 06:51:13 PM block_19-lambda: 6.032031446096887
07/14 06:51:16 PM block_19-lambda: 6.420045416039157
07/14 06:51:19 PM block_19-lambda: 7.616696277296687
07/14 06:51:23 PM block_19-lambda: 8.277602597891565
07/14 06:51:26 PM block_19-lambda: 5.749285717087099
07/14 06:51:29 PM block_19-lambda: 7.004094012769829
07/14 06:51:33 PM block_19-lambda: 6.461101240901105
07/14 06:51:36 PM block_19-lambda: 5.430116461941516
07/14 06:51:39 PM block_19-lambda: 6.824895284262048
07/14 06:51:42 PM block_19-lambda: 7.588298996768323
07/14 06:51:46 PM block_19-lambda: 5.7515893652735945
07/14 06:51:49 PM block_19-lambda: 8.520639197414658
07/14 06:51:52 PM block_19-lambda: 9.591476256588855
07/14 06:51:56 PM block_19-lambda: 6.8037398617909135
07/14 06:51:59 PM block_19-lambda: 7.596023743411145
07/14 06:52:02 PM block_19-lambda: 7.695510557856426
07/14 06:52:06 PM block_19-lambda: 6.2467050

07/14 06:59:12 PM block_23-lambda: 2.502325463965236
07/14 06:59:15 PM block_23-lambda: 4.352340024159137
07/14 06:59:18 PM block_23-lambda: 3.8900766640781876
07/14 06:59:21 PM block_23-lambda: 3.817349230908007
07/14 06:59:23 PM block_23-lambda: 2.9209151057354417
07/14 06:59:26 PM block_23-lambda: 3.3446802346103164
07/14 06:59:29 PM block_23-lambda: 4.210061437154869
07/14 06:59:32 PM block_23-lambda: 4.557550514558233
07/14 06:59:35 PM block_23-lambda: 3.5562499019515563
07/14 06:59:37 PM block_23-lambda: 2.4880540227315513
07/14 06:59:40 PM block_23-lambda: 3.9166384777390815
07/14 06:59:43 PM block_23-lambda: 4.068452275900477
07/14 06:59:46 PM block_23-lambda: 4.257618854323544
07/14 06:59:49 PM block_23-lambda: 4.538678150100401
07/14 06:59:52 PM block_23-lambda: 3.552490234375
07/14 06:59:54 PM block_23-lambda: 3.2225981562970634
07/14 06:59:57 PM block_23-lambda: 3.9367879231770835
07/14 07:00:00 PM block_23-lambda: 4.412604029398845
07/14 07:00:03 PM block_23-lambda: 3.2166