In [1]:
from __future__ import absolute_import, division, print_function
import itertools
import argparse
import collections
import json
import logging
import math
import os
import random
import sys
from io import open
import torch.nn as nn

import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
# from pytorch_pretrained_bert.modeling import BertForQuestionAnswering, BertConfig
from modeling import BertForQuestionAnswering_Quant as BertForQuestionAnswering
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule
from pytorch_pretrained_bert.tokenization import (
    BasicTokenizer, BertTokenizer, whitespace_tokenize)

if sys.version_info[0] == 2:
    import cPickle as pickle
else:
    import pickle

In [2]:
from FIT_utils import *
from run_squad import read_squad_examples, convert_examples_to_features
import matplotlib.pyplot as plt
from quant_modules import QuantLinear, QuantLinear_Act, QuantEmbedding

In [3]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [4]:
args = type('MyClass', (object,), {'content':{}})()
args.train_file = '/home/ben/Documents/CERN/rebuttal_iclr/SQUAD/train-v1.1.json'
args.bert_model = 'bert-base-uncased'
args.do_lower_case = True
args.train_batch_size = 12
args.max_seq_length = 384
args.doc_stride = 128
args.max_query_length = 64
args.version_2_with_negative=False
args.config = None
args.config_dir = None
args.local_rank=-1
args.bit_options = [4,8,32]

In [5]:
tokenizer = BertTokenizer.from_pretrained(
        args.bert_model, do_lower_case=args.do_lower_case)

In [6]:
train_examples = read_squad_examples(
            input_file=args.train_file,
            is_training=True,
            version_2_with_negative=args.version_2_with_negative)

In [7]:
# Prepare model
cache_dir = os.path.join(
    str(PYTORCH_PRETRAINED_BERT_CACHE),
    'distributed_{}'.format(args.local_rank))
model = BertForQuestionAnswering.from_pretrained(
    args.bert_model,
    cache_dir=cache_dir,
    config_dir=args.config_dir,
    config=args.config)

[2022-11-09 02:36:04,273 modeling.py:372] loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at /home/ben/.pytorch_pretrained_bert/distributed_-1/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba
[2022-11-09 02:36:04,274 modeling.py:380] extracting archive file /home/ben/.pytorch_pretrained_bert/distributed_-1/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /tmp/tmps2916wsy
[2022-11-09 02:36:07,691 modeling.py:452] Weights of BertForQuestionAnswering_Quant not initialized from pretrained model: ['bert.embeddings.word_embeddings.x_min', 'bert.embeddings.word_embeddings.x_max', 'bert.embeddings.position_embeddings.x_min', 'bert.embeddings.position_embeddings.x_max', 'bert.embeddings.token_type_embeddings.x_min', 'bert.embeddings.token_type_embeddings.x_max', '

In [8]:
model.to(device)

BertForQuestionAnswering_Quant(
  (bert): BertModel_Quant(
    (embeddings): BertEmbeddings_Quant(
      (word_embeddings): QuantEmbedding(30522, 768, padding_idx=0)
      (position_embeddings): QuantEmbedding(512, 768)
      (token_type_embeddings): QuantEmbedding(2, 768)
      (LayerNorm): BertLayerNorm()
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder_Quant(
      (layer): ModuleList(
        (0): BertLayer_Quant(
          (attention): BertAttention_Quant(
            (self): BertSelfAttention_Quant(
              (query): QuantLinear(in_features=768, out_features=768, bias=True)
              (key): QuantLinear(in_features=768, out_features=768, bias=True)
              (value): QuantLinear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput_Quant(
              (dense): QuantLinear(in_features=768, out_features=768, bias=True)
              (LayerNor

In [9]:
cached_train_features_file = args.train_file + '_{0}_{1}_{2}_{3}'.format(
    list(filter(None, args.bert_model.split('/'))).pop(),
    str(args.max_seq_length), str(args.doc_stride),
    str(args.max_query_length))
train_features = None
try:
    with open(cached_train_features_file, "rb") as reader:
        train_features = pickle.load(reader)
except BaseException:
    train_features = convert_examples_to_features(
        examples=train_examples,
        tokenizer=tokenizer,
        max_seq_length=args.max_seq_length,
        doc_stride=args.doc_stride,
        max_query_length=args.max_query_length,
        is_training=True)
    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        with open(cached_train_features_file, "wb") as writer:
            pickle.dump(train_features, writer)

all_input_ids = torch.tensor([f.input_ids for f in train_features],
                             dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features],
                              dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                               dtype=torch.long)
all_start_positions = torch.tensor(
    [f.start_position for f in train_features], dtype=torch.long)
all_end_positions = torch.tensor(
    [f.end_position for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask,
                           all_segment_ids, all_start_positions,
                           all_end_positions)
if args.local_rank == -1:
    train_sampler = RandomSampler(train_data)
else:
    train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(
    train_data,
    sampler=train_sampler,
    batch_size=args.train_batch_size)

In [10]:
fit_computerw = FIT(model, device, train_dataloader, ['pooler', 'qa'])

72
84934656
0 bert.encoder.layer.0.attention.self.query 589824
1 bert.encoder.layer.0.attention.self.key 589824
2 bert.encoder.layer.0.attention.self.value 589824
3 bert.encoder.layer.0.attention.output.dense 589824
4 bert.encoder.layer.0.intermediate.dense 2359296
5 bert.encoder.layer.0.output.dense 2359296
6 bert.encoder.layer.1.attention.self.query 589824
7 bert.encoder.layer.1.attention.self.key 589824
8 bert.encoder.layer.1.attention.self.value 589824
9 bert.encoder.layer.1.attention.output.dense 589824
10 bert.encoder.layer.1.intermediate.dense 2359296
11 bert.encoder.layer.1.output.dense 2359296
12 bert.encoder.layer.2.attention.self.query 589824
13 bert.encoder.layer.2.attention.self.key 589824
14 bert.encoder.layer.2.attention.self.value 589824
15 bert.encoder.layer.2.attention.output.dense 589824
16 bert.encoder.layer.2.intermediate.dense 2359296
17 bert.encoder.layer.2.output.dense 2359296
18 bert.encoder.layer.3.attention.self.query 589824
19 bert.encoder.layer.3.attention.

ValueError: too many values to unpack (expected 4)

In [None]:

EFw, EFa, fap, faa, param_ranges, act_ranges = fit_computerw.EF(model, train_dataloader, 
                                                               None, 
                                                               tol=1e-2, 
                                                               min_iterations=200,
                                                               max_iterations=200)

In [None]:
# Prepare model
cache_dir = os.path.join(
    str(PYTORCH_PRETRAINED_BERT_CACHE),
    'distributed_{}'.format(args.local_rank))
model = BertForQuestionAnswering.from_pretrained(
    args.bert_model,
    cache_dir=cache_dir,
    config_dir=args.config_dir,
    config=args.config)

In [None]:
model.to(device)

In [None]:
fit_computera = FIT(model, device, train_dataloader, ['output', 'pooler'])

In [None]:
_, EFa, _, faa, _, act_ranges = fit_computera.EF(model, train_dataloader, 
                                                               None, 
                                                               tol=1e-2, 
                                                               min_iterations=200,
                                                               max_iterations=200)

In [None]:
plt.title('W Trace')
plt.plot(EFw/fit_computerw.param_nums,'o-', label='EF')
plt.grid(True, which='both')
plt.legend()
plt.yscale('log')

In [None]:
plt.title('A Trace')
plt.plot(EFa/fit_computera.act_nums,'o-', label='EF')
plt.grid(True, which='both')
plt.legend()
plt.yscale('log')

In [None]:
plt.plot(fap)
plt.yscale('log')

In [None]:
plt.plot(faa)
plt.yscale('log')

In [None]:
fit_computerw.Ra = act_ranges
fit_computerw.EFa = EFa

In [None]:
print(len(EFa))

In [None]:
## Define useful layer hooks:
def linear_flops_counter_hook(module, input, output):
    input = input[0]
    # pytorch checks dimensions, so here we don't care much
    output_last_dim = output.shape[-1]
    bias_flops = output_last_dim if module.bias is not None else 0
    module.__flops__ += int(np.prod(input.shape) * output_last_dim + bias_flops)
    
def multihead_attention_counter_hook(multihead_attention_module, input, output):
    flops = 0

    q, k, v = input

    batch_first = multihead_attention_module.batch_first \
        if hasattr(multihead_attention_module, 'batch_first') else False
    if batch_first:
        batch_size = q.shape[0]
        len_idx = 1
    else:
        batch_size = q.shape[1]
        len_idx = 0

    dim_idx = 2

    qdim = q.shape[dim_idx]
    kdim = k.shape[dim_idx]
    vdim = v.shape[dim_idx]

    qlen = q.shape[len_idx]
    klen = k.shape[len_idx]
    vlen = v.shape[len_idx]

    num_heads = multihead_attention_module.num_heads
    assert qdim == multihead_attention_module.embed_dim

    if multihead_attention_module.kdim is None:
        assert kdim == qdim
    if multihead_attention_module.vdim is None:
        assert vdim == qdim

    flops = 0

    # Q scaling
    flops += qlen * qdim

    # Initial projections
    flops += (
        (qlen * qdim * qdim)  # QW
        + (klen * kdim * kdim)  # KW
        + (vlen * vdim * vdim)  # VW
    )

    if multihead_attention_module.in_proj_bias is not None:
        flops += (qlen + klen + vlen) * qdim

    # attention heads: scale, matmul, softmax, matmul
    qk_head_dim = qdim // num_heads
    v_head_dim = vdim // num_heads

    head_flops = (
        (qlen * klen * qk_head_dim)  # QK^T
        + (qlen * klen)  # softmax
        + (qlen * klen * v_head_dim)  # AV
    )

    flops += num_heads * head_flops

    # final projection, bias is always enabled
    flops += qlen * vdim * (vdim + 1)

    flops *= batch_size
    multihead_attention_module.__flops__ += int(flops)

In [None]:
MODULES_MAPPING = {
    nn.Linear: linear_flops_counter_hook,
    QuantLinear: linear_flops_counter_hook,
    nn.MultiheadAttention: multihead_attention_counter_hook
}

In [None]:
def bert_input_constructor(input_shape, tokenizer):
    inp_seq = ""
    for _ in range(input_shape[1] - 2):  # there are two special tokens [CLS] and [SEP]
        inp_seq += tokenizer.pad_token  # let's use pad token to form a fake
    # sequence for subsequent flops calculation

    inputs = tokenizer([inp_seq] * input_shape[0], padding=True, truncation=True,
                       return_tensors="pt")
    labels = torch.tensor([1] * input_shape[0])
    # Batch size input_shape[0], sequence length input_shape[128]
    inputs = dict(inputs)
    inputs.update({"labels": labels})
    return inputs

def remove(layers):
    for l in layers:
        l.__flops_handle__.remove()
        del l.__flops_handle__
        del l.__flops__

In [None]:
layers = []
names = []
for name, module in model.named_modules():
    if type(module) in MODULES_MAPPING and 'pooler' not in name:
        names.append(name)
        layers.append(module)

In [None]:
for l in layers:
    l.__flops__ = 0
    l.__flops_handle__ = l.register_forward_hook(MODULES_MAPPING[type(l)])

In [None]:
data = next(iter(train_dataloader))
batch = tuple(t.to(device) for t in data)
batch_size= len(batch)
input_ids, input_mask, segment_ids, start_positions, end_positions = batch
_ = model(input_ids, segment_ids, input_mask, start_positions, end_positions)

In [None]:
bert_fp32_bops = [l.__flops__ for l in layers]

In [None]:
remove(layers)

In [None]:
for l, b in zip(names, bert_fp32_bops):
    print(l, b)

In [None]:
print(len(bert_fp32_bops))

In [None]:
total_bops = np.sum(bert_fp32_bops)
print(total_bops)
print(total_bops * (8/32)**2)
print(total_bops * (4/32)**2)
print(total_bops * (8/32)*(4/32))

In [None]:
# free up a bit of space to do the analysis
del model
del train_dataloader
torch.cuda.empty_cache()

In [None]:
total_sensitivity = np.concatenate((EFw/fit_computerw.param_nums, EFa/fit_computera.act_nums), axis=0)

In [None]:
configs = [i for i in itertools.product(np.arange(len(total_sensitivity)+1), repeat=len(args.bit_options)) if sum(i)==len(total_sensitivity)]

In [None]:
print(len(configs))

In [None]:
order = np.argsort(total_sensitivity)
bops_acc = []
fit_values_acc = []
w_configs_acc = []
a_configs_acc = []
min_FIT = np.inf
# min_criterion = 10.5e10
best = None
# for _, c in enumerate(configs):
#     # Expand c to get bit allocations matching the amount of layers to quantize
#     bit_allocations = [np.repeat(args.bit_options, c)[np.where(order==i)[0][0]] for i in range(len(total_sensitivity))]
    
    
#     w_config = bit_allocations[0:72]
#     a_config = bit_allocations[72:]
    
#     fit_value = fit_computerw.FIT(np.array(w_config), np.array(a_config))
    
#     bops = 0
#     a_config.insert(0, 32)
    
#     a_indx = 0
#     for i, (bo, name, wb) in enumerate(zip(bert_fp32_bops[:-1], fit_computerw.names, w_config)):
#         if name in fit_computera.names:
#             bops += bo*a_config[a_indx]*wb*(1/32)**2
#             a_indx += 1
#         else:
#             bops += bo*wb*8*(1/32)**2
    
    
# #     for i, (wb, ab) in enumerate(zip(w_config, a_config)):
# #         bops += bert_fp32_bops[i]*wb*ab*(1/32)**2
        
# #     possible_configs.append((bops, fit_value, w_config, a_config[1:]))
#     bops_acc.append(bops)
#     fit_values_acc.append(fit_value)
#     w_configs_acc.append(w_config)
#     a_configs_acc.append(a_config)
    
#     if bops < min_criterion and fit_value < min_FIT:
#         min_FIT = fit_value
#         best = (bops, fit_value, w_config, a_config[1:])
        
# print(best)

In [None]:
# generate additional random configurations
# min_FIT = np.inf
min_criterion = 13030593057.0
for i in range(2000):
    w_config = list(np.random.choice(args.bit_options[:-1], 72, p=[0.9,0.1]))
    a_config = list(np.random.choice(args.bit_options[:-1], 48, p=[0.7,0.3]))
    
    fit_value = fit_computerw.FIT(np.array(w_config), np.array(a_config))
    
    bops = 0
    
    a_indx = 0
    for i, (bo, name, wb) in enumerate(zip(bert_fp32_bops[:-1], fit_computerw.names, w_config)):
        if name in fit_computera.names:
            bops += bo*a_config[a_indx]*wb*(1/32)**2
            a_indx += 1
        else:
            bops += bo*wb*8*(1/32)**2
        
#     possible_configs.append((bops, fit_value, w_config, a_config[1:]))
    bops_acc.append(bops)
    fit_values_acc.append(fit_value)
    w_configs_acc.append(w_config)
    a_configs_acc.append(a_config)
    
    if bops < min_criterion and fit_value < min_FIT:
        min_FIT = fit_value
        best = (bops, fit_value, w_config, a_config)
        print('updated')
        
print(best)

In [None]:
print(bops_acc[-1])

In [None]:
fig = plt.figure()
plt.scatter(bops_acc, fit_values_acc, s=2,)
plt.xscale('log')
# plt.yscale('log')
# plt.xlim(total_bops * (7.6/32)**2, total_bops * (10/32)**2)
# plt.ylim(6,10)
plt.scatter(filtered_info[idx][0], filtered_info[idx][1], s=20, marker='v', c='black')

In [None]:
fig = plt.figure()
plt.plot(np.arange(len(best[2])), best[2])


In [None]:
plt.plot(np.arange(len(best[3])), best[3])

In [None]:
a_config = [8 for i in range(48)]
test_fit = fit_computerw.FIT(np.array(w_config), np.array(a_config))
w_config = [4 for i in range(72)]
test_bops = 0
a_config.insert(0, 32)

a_indx = 0
for i, (bo, name, wb) in enumerate(zip(bert_fp32_bops[:-1], fit_computerw.names, w_config)):
    if name in fit_computera.names:
        test_bops += bo*a_config[a_indx]*wb*(1/32)**2
        a_indx += 1
    else:
        test_bops += bo*wb*8*(1/32)**2



In [None]:
print(test_bops)

In [1]:
criterion = 10030593057.0
filtered = []
filtered_info = []
for bps, fit, wconf, aconf in zip(bops_acc, fit_values_acc, w_configs_acc, a_configs_acc):
    if bps < criterion:
        filtered.append(fit)
        filtered_info.append((bps, fit, wconf, aconf))
idx = np.argmin(filtered)
print(filtered_info[idx])

NameError: name 'bops_acc' is not defined

In [None]:
import json

In [None]:
layer_bits = {}
for l, b in zip(names[:-1], best[2]):
    layer_bits[l[13:]] = b

In [None]:
print(layer_bits)

In [None]:
activation_bits = {}
for l, b in zip(fit_computera.names, best[3]):
    activation_bits[l[13:]] = b

In [None]:
print(activation_bits)