In [25]:

from __future__ import absolute_import, division, print_function

import pprint
import argparse
import logging
import os
import random
import sys
import pickle
import copy
import collections
import math

import numpy as np
import numpy
import torch
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler,TensorDataset
os.environ["CUDA_VISIBLE_DEVICES"]="2" # Set GPU Index to use
from torch.nn import CrossEntropyLoss, MSELoss

from transformer import BertForSequenceClassification,WEIGHTS_NAME, CONFIG_NAME
from transformer.modeling_quant import BertForSequenceClassification as QuantBertForSequenceClassification
from transformer import BertTokenizer
from transformer import BertAdam
from transformer import BertConfig
from transformer import QuantizeLinear, QuantizeAct, BertSelfAttention, FP_BertSelfAttention, ClipLinear, BertAttention, FP_BertAttention
from utils_glue import *
from bertviz import model_view

from tqdm import tqdm
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

import torch.nn.functional as F

class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0 
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

def do_eval(model, task_name, eval_dataloader,
            device, output_mode, eval_labels, num_labels, teacher_model=None):
    eval_loss = 0
    nb_eval_steps = 0
    preds = []

    for batch_ in tqdm(eval_dataloader, desc="Inference"):
        batch_ = tuple(t.to(device) for t in batch_)
        
        with torch.no_grad():
            input_ids, input_mask, segment_ids, label_ids, seq_lengths = batch_

            # teacher attnmap test
            if teacher_model is not None:
                
                # logits, _, teacher_reps, teacher_probs, teacher_values = teacher_model(input_ids, segment_ids, input_mask)
                
                # # logits, _, _, _, _ = model(input_ids, segment_ids, input_mask, teacher_probs=teacher_probs)
                # logits, _, _, _, _ = model(input_ids, segment_ids, input_mask, teacher_probs=(teacher_probs, teacher_values, teacher_reps))
                teacher_logits, teacher_atts, teacher_reps, teacher_probs, teacher_values = teacher_model(input_ids, segment_ids, input_mask)
                logits, student_atts, student_reps, student_probs, student_values  = model(input_ids, segment_ids, input_mask, teacher_outputs=(teacher_probs, teacher_values, teacher_reps, teacher_logits, teacher_atts))
            else:
                logits, _, _, _, _ = model(input_ids, segment_ids, input_mask)
        
        # create eval loss and other metric required by the task
        if output_mode == "classification":
            loss_fct = CrossEntropyLoss()
            tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
        elif output_mode == "regression":
            loss_fct = MSELoss()
            tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))

        eval_loss += tmp_eval_loss.mean().item()
        nb_eval_steps += 1
        if len(preds) == 0:
            preds.append(logits.detach().cpu().numpy())
        else:
            preds[0] = np.append(
                preds[0], logits.detach().cpu().numpy(), axis=0)

    eval_loss = eval_loss / nb_eval_steps

    preds = preds[0]
    if output_mode == "classification":
        preds = np.argmax(preds, axis=1)
    elif output_mode == "regression":
        preds = np.squeeze(preds)
    result = compute_metrics(task_name, preds, eval_labels.numpy())
    result['eval_loss'] = eval_loss
    return result

processors = {
    "cola": ColaProcessor,
    "mnli": MnliProcessor,
    "mnli-mm": MnliMismatchedProcessor,
    "mrpc": MrpcProcessor,
    "sst-2": Sst2Processor,
    "sts-b": StsbProcessor,
    "qqp": QqpProcessor,
    "qnli": QnliProcessor,
    "rte": RteProcessor   
}

output_modes = {
        "cola": "classification",
        "mnli": "classification",
        "mrpc": "classification",
        "sst-2": "classification",
        "sts-b": "regression",
        "qqp": "classification",
        "qnli": "classification",
        "rte": "classification"
}

default_params = {
        "cola": {"max_seq_length": 64,"batch_size":16,"eval_step": 400}, # No Aug : 50 Aug : 400
        "mnli": {"max_seq_length": 128,"batch_size":32,"eval_step":8000},
        "mrpc": {"max_seq_length": 128,"batch_size":32,"eval_step":20},
        "sst-2": {"max_seq_length": 64,"batch_size":32,"eval_step":100},
        "sts-b": {"max_seq_length": 128,"batch_size":32,"eval_step":100},
        "qqp": {"max_seq_length": 128,"batch_size":32,"eval_step":1000},
        "qnli": {"max_seq_length": 128,"batch_size":32,"eval_step":1000},
        "rte": {"max_seq_length": 128,"batch_size":32,"eval_step":100}
    }

def get_tensor_data(output_mode, features):
    if output_mode == "classification":
        all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.long)
    elif output_mode == "regression":
        all_label_ids = torch.tensor([f.label_id for f in features], dtype=torch.float)


    all_seq_lengths = torch.tensor([f.seq_length for f in features], dtype=torch.long)
    all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
    tensor_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,all_label_ids, all_seq_lengths)
    return tensor_data, all_label_ids


In [13]:
import numpy as np


def group_product(xs, ys):
    """
    the inner product of two lists of variables xs,ys
    :param xs:
    :param ys:
    :return:
    """
    return sum([torch.sum(x * y) for (x, y) in zip(xs, ys)])


def de_variable(v):
    '''
    normalize the vector and detach it from variable
    '''

    s = group_product(v, v)
    s = s**0.5
    s = s.cpu().item() + 1e-6
    v = [vi / s for vi in v]
    return v

def group_add(params, update, alpha=1):
    """
    params = params + update*alpha
    :param params: list of variable
    :param update: list of data
    :return:
    """
    for i, p in enumerate(params):
        params[i].data.add_(update[i] * alpha)
    return params


def normalization(v):
    """
    normalization of a list of vectors
    return: normalized vectors v
    """
    s = group_product(v, v)
    s = s**0.5
    s = s.cpu().item()
    v = [vi / (s + 1e-6) for vi in v]
    # v = [vi / s for vi in v]
    return v


def orthonormal(w, v_list):
    for v in v_list:
        w = group_add(w, v, alpha=-group_product(w, v))
    return normalization(w)


def total_number_parameters(model):
    model_parameters = filter(lambda p: p.requires_grad, model.parameters())
    return sum([np.prod(p.size()) for p in model_parameters])

from torch.nn import CrossEntropyLoss, MSELoss




In [2]:
task_name = "cola"
bert_size = "base"

if bert_size == "large":
    layer_num = 24
    head_num = 16
else: 
    layer_num = 12
    head_num = 12
    
teacher_model = None
# torch.cuda.empty_cache()
# !nvidia-smi

# DEVICE / DATASET

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model_dir = "models"
output_dir = "output"

if bert_size == "large":
    model_dir = os.path.join(model_dir, "BERT_large")
    output_dir = os.path.join(output_dir, "BERT_large")

teacher_model_dir = os.path.join(model_dir,task_name)

# Processor & Task Info
processor = processors[task_name]()
output_mode = output_modes[task_name]
label_list = processor.get_labels()
num_labels = len(label_list)

if task_name in default_params:
    batch_size = default_params[task_name]["batch_size"]
    max_seq_length = default_params[task_name]["max_seq_length"]
    eval_step = default_params[task_name]["eval_step"]
    
# Tokenizer
tokenizer = BertTokenizer.from_pretrained(teacher_model_dir, do_lower_case=True)


# Load Dataset
data_dir = os.path.join("data",task_name)
processed_data_dir = os.path.join(data_dir,'preprocessed')

train_examples = processor.get_train_examples(data_dir)
train_features = convert_examples_to_features(train_examples, label_list,
                                max_seq_length, tokenizer, output_mode)

len_train_data = int(len(train_features) * 1)
train_features = train_features[:len_train_data]

eval_examples = processor.get_dev_examples(data_dir)
eval_features = convert_examples_to_features(eval_examples, label_list, max_seq_length, tokenizer, output_mode)
# dev_file = train_file = os.path.join(processed_data_dir,'dev.pkl') 
# eval_features = pickle.load(open(dev_file,'rb'))

train_data, train_labels = get_tensor_data(output_mode, train_features)
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=batch_size)

eval_data, eval_labels = get_tensor_data("classification", eval_features)
eval_sampler = SequentialSampler(eval_data)
eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=1)
eval_data, eval_labels = get_tensor_data(output_mode, eval_features)

eval_examples = processor.get_dev_examples(data_dir)

# Sampling Sentence 
i = 0 
# num = 3
num = 1



07/14 04:30:21 PM Writing example 0 of 8551
07/14 04:30:21 PM *** Example ***
07/14 04:30:21 PM guid: train-0
07/14 04:30:21 PM tokens: [CLS] our friends won ' t buy this analysis , let alone the next one we propose . [SEP]
07/14 04:30:21 PM input_ids: 101 2256 2814 2180 1005 1056 4965 2023 4106 1010 2292 2894 1996 2279 2028 2057 16599 1012 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/14 04:30:21 PM input_mask: 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/14 04:30:21 PM segment_ids: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
07/14 04:30:21 PM label: 1
07/14 04:30:21 PM label_id: 1
07/14 04:30:22 PM Writing example 0 of 1043
07/14 04:30:22 PM *** Example ***
07/14 04:30:22 PM guid: dev-0
07/14 04:30:22 PM tokens: [CLS] the sailors rode the breeze clear of the rocks . 

In [4]:
# teacher_model = BertForSequenceClassification.from_pretrained(teacher_model_dir, num_labels=num_labels)
# teacher_model.to(device)
# teacher_model.eval()

st_model_name = "1SB_M"
student_model_dir = os.path.join(output_dir, task_name, "exploration", st_model_name)   
student_config = BertConfig.from_pretrained(student_model_dir)   
student_model = QuantBertForSequenceClassification.from_pretrained(student_model_dir, config = student_config, num_labels=num_labels)
student_model.to(device)
print()

07/14 04:30:27 PM loading configuration file output/cola/exploration/1SB_M/config.json
07/14 04:30:29 PM Loading model output/cola/exploration/1SB_M/pytorch_model.bin
07/14 04:30:29 PM loading model...
07/14 04:30:29 PM done!
07/14 04:30:29 PM Weights from pretrained model not used in BertForSequenceClassification: ['bert.embeddings.word_embeddings.qweight', 'bert.encoder.layer.0.attention.self.query.qweight', 'bert.encoder.layer.0.attention.self.key.qweight', 'bert.encoder.layer.0.attention.self.value.qweight', 'bert.encoder.layer.0.attention.output.dense.qweight', 'bert.encoder.layer.0.intermediate.dense.qweight', 'bert.encoder.layer.0.output.dense.qweight', 'bert.encoder.layer.1.attention.self.query.qweight', 'bert.encoder.layer.1.attention.self.key.qweight', 'bert.encoder.layer.1.attention.self.value.qweight', 'bert.encoder.layer.1.attention.output.dense.qweight', 'bert.encoder.layer.1.intermediate.dense.qweight', 'bert.encoder.layer.1.output.dense.qweight', 'bert.encoder.layer.2.a

In [20]:
data_percentage = 0.01

percentage_index = len(train_dataloader.dataset) * data_percentage / batch_size
print(f'percentage_index: {percentage_index}')

student_model.eval()

csv_path = os.path.join(f"{task_name}-{data_percentage}-eigens.csv")
csv_file = open(csv_path, 'w', newline='')
writer = csv.writer(csv_file, delimiter=',', quotechar='|', quoting=csv.QUOTE_MINIMAL)
writer.writerow(['block', 'iters', 'max_eigenvalue'])

percentage_index: 5.344375


28

In [24]:
loss_fct = CrossEntropyLoss()

for module in student_model.modules():
    for param in module.parameters():
        param.requires_grad = True

block_id = 0

for block_id in range(layer_num):
    logger.info(f'block_id: {block_id}')
    model_block = student_model.bert.encoder.layer[block_id]
    
    v = [
            torch.randn(p.size()).to(device) for p in model_block.parameters()
        ]
    v = de_variable(v)

    lambda_old, lambdas = 0., 1.
    i = 0
    while (abs((lambdas - lambda_old) / lambdas) >= 0.01):

        lambda_old = lambdas

        acc_Hv = [
            torch.zeros(p.size()).cuda() for p in model_block.parameters()
        ]
        for step, batch in enumerate(train_dataloader):
            if step < percentage_index:
                
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids, _ = batch
                
                logits, _, _, _, _ = student_model(input_ids, segment_ids, input_mask, teacher_outputs=None)

                if output_mode == "classification":
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(
                        logits.view(-1, num_labels), label_ids.view(-1))
                elif output_mode == "regression":
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), label_ids.view(-1))
                    
                loss.backward(create_graph=True)
                grads = [param.grad for param in model_block.parameters()]
                params = model_block.parameters()

                Hv = torch.autograd.grad(
                    grads,
                    params,
                    grad_outputs=v,
                    only_inputs=True,
                    retain_graph=True)
                acc_Hv = [
                    acc_Hv_p + Hv_p for acc_Hv_p, Hv_p in zip(acc_Hv, Hv)
                ]
                student_model.zero_grad()
        # calculate raylay quotients
        lambdas = group_product(acc_Hv, v).item() / percentage_index

        v = de_variable(acc_Hv)
        logger.info(f'block_{block_id}-lambda: {lambdas}')
        writer.writerow([f'{block_id}', f'{i}', f'{lambdas}'])
        csv_file.flush()

        i += 1

07/14 04:50:40 PM block_id: 0
07/14 04:50:42 PM block_0-lambda: -3.2217998176102134e-06
07/14 04:50:44 PM block_0-lambda: 0.0024439981563399928
07/14 04:50:46 PM block_0-lambda: -0.0018867165104657804
07/14 04:50:48 PM block_0-lambda: 0.1748031934734367
07/14 04:50:49 PM block_0-lambda: 2.047457364327363
07/14 04:50:51 PM block_0-lambda: 0.32649800284203556
07/14 04:50:53 PM block_0-lambda: 0.7467861664028923
07/14 04:50:55 PM block_0-lambda: 0.5070455812345596
07/14 04:50:57 PM block_0-lambda: 0.10828334481350066
07/14 04:50:59 PM block_0-lambda: 0.2158531792999471
07/14 04:51:01 PM block_0-lambda: 0.12523935295910657
07/14 04:51:03 PM block_0-lambda: 0.2529545341119642
07/14 04:51:05 PM block_0-lambda: 0.36685967247411666
07/14 04:51:07 PM block_0-lambda: 0.28622555545917966
07/14 04:51:08 PM block_0-lambda: 0.4682571006063801
07/14 04:51:10 PM block_0-lambda: 0.6783965093715053
07/14 04:51:12 PM block_0-lambda: 0.21550128674900276
07/14 04:51:14 PM block_0-lambda: 0.1348629762883521

07/14 04:55:28 PM block_2-lambda: 0.840804066912861
07/14 04:55:30 PM block_2-lambda: 0.7001804784088438
07/14 04:55:31 PM block_2-lambda: 0.8907316776900636
07/14 04:55:33 PM block_2-lambda: 0.49316915709043074
07/14 04:55:35 PM block_2-lambda: 1.0273064943576133
07/14 04:55:37 PM block_2-lambda: 0.19774853507059664
07/14 04:55:38 PM block_2-lambda: 0.15136005864228688
07/14 04:55:40 PM block_2-lambda: 0.6295960432687652
07/14 04:55:42 PM block_2-lambda: 0.29299549436920524
07/14 04:55:44 PM block_2-lambda: 0.6498577915120718
07/14 04:55:46 PM block_2-lambda: 0.06491550993827362
07/14 04:55:47 PM block_2-lambda: 0.10134486507468858
07/14 04:55:49 PM block_2-lambda: 0.42918239500268607
07/14 04:55:51 PM block_2-lambda: 0.8522931267233456
07/14 04:55:53 PM block_2-lambda: 0.10508586476245353
07/14 04:55:54 PM block_2-lambda: 0.6211146654796633
07/14 04:55:56 PM block_2-lambda: 0.5331880571465508
07/14 04:55:58 PM block_2-lambda: 0.7981291543306794
07/14 04:56:00 PM block_2-lambda: 1.877

07/14 04:59:52 PM block_4-lambda: 2.484659768868407
07/14 04:59:54 PM block_4-lambda: 2.0688036090100024
07/14 04:59:55 PM block_4-lambda: 3.3194995068650592
07/14 04:59:57 PM block_4-lambda: 3.1239853645818103
07/14 04:59:59 PM block_4-lambda: 3.032537912460896
07/14 05:00:00 PM block_4-lambda: 2.9746640789454446
07/14 05:00:02 PM block_4-lambda: 2.797765980518543
07/14 05:00:03 PM block_4-lambda: 3.29172497757024
07/14 05:00:05 PM block_4-lambda: 3.2816714145275405
07/14 05:00:05 PM block_id: 5
07/14 05:00:07 PM block_5-lambda: -3.770498815697824e-05
07/14 05:00:08 PM block_5-lambda: 0.37657312030276974
07/14 05:00:10 PM block_5-lambda: 4.594871959077337
07/14 05:00:11 PM block_5-lambda: 2.691821162010017
07/14 05:00:13 PM block_5-lambda: 2.6442498216126036
07/14 05:00:15 PM block_5-lambda: 3.7548384155087855
07/14 05:00:16 PM block_5-lambda: 3.487251254079384
07/14 05:00:18 PM block_5-lambda: 7.107801262279265
07/14 05:00:19 PM block_5-lambda: 1.6374033273360753
07/14 05:00:21 PM bl

07/14 05:03:46 PM block_8-lambda: 4.797620241327769
07/14 05:03:47 PM block_8-lambda: 7.375163940529616
07/14 05:03:49 PM block_8-lambda: 3.6267834314197533
07/14 05:03:50 PM block_8-lambda: 4.537779355576102
07/14 05:03:52 PM block_8-lambda: 4.096660272738569
07/14 05:03:53 PM block_8-lambda: 1.4401225845181667
07/14 05:03:54 PM block_8-lambda: 3.7027868690887757
07/14 05:03:56 PM block_8-lambda: 4.286588388442324
07/14 05:03:57 PM block_8-lambda: 2.95772452115204
07/14 05:03:59 PM block_8-lambda: 3.337845028861756
07/14 05:04:00 PM block_8-lambda: 4.58247684732597
07/14 05:04:01 PM block_8-lambda: 4.364462281344104
07/14 05:04:03 PM block_8-lambda: 5.475176857056192
07/14 05:04:04 PM block_8-lambda: 3.665441290675067
07/14 05:04:06 PM block_8-lambda: 4.0032313729068605
07/14 05:04:07 PM block_8-lambda: 5.58618396754126
07/14 05:04:08 PM block_8-lambda: 7.347249510748012
07/14 05:04:10 PM block_8-lambda: 4.925340463537707
07/14 05:04:11 PM block_8-lambda: 3.0802479042336057
07/14 05:0

07/14 05:07:02 PM block_11-lambda: -1.1126199168442927
