In [1]:
from __future__ import absolute_import, division, print_function

import argparse
import csv
import logging
import os
import random
import sys
import itertools

import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from torch.nn import CrossEntropyLoss, MSELoss
from scipy.stats import pearsonr, spearmanr
from sklearn.metrics import matthews_corrcoef, f1_score

from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE, WEIGHTS_NAME, CONFIG_NAME
from pytorch_pretrained_bert.modeling import BertConfig

# from pytorch_pretrained_bert.modeling import BertForSequenceClassification
# from modeling import BertForSequenceClassification_Quant as BertForSequenceClassification

from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, WarmupLinearSchedule

FORMAT = '[%(asctime)-15s %(filename)s:%(lineno)s] %(message)s'
FORMAT_MINIMAL = '%(message)s'

logging.basicConfig(format=FORMAT)
logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)



In [36]:
from run_classifier import *
from FIT_utils import *
import matplotlib.pyplot as plt
from quant_modules import QuantLinear, QuantLinear_Act, QuantEmbedding
from transformers import BertForSequenceClassification
import datasets
from datasets import load_dataset, load_metric

In [3]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

Using cuda device


In [4]:
args = type('MyClass', (object,), {'content':{}})()
args.data_dir = '/home/ben/Documents/CERN/rebuttal_iclr/GLUE/SST-2/'
args.bert_model = 'bert-base-uncased'
args.do_lower_case = True
args.train_batch_size = 4
args.max_seq_length = 128
args.task_name = 'SST-2'
args.cache_dir = None
args.config = None
args.config_dir = None
args.local_rank=-1
args.bit_options = [2,3,4,8,32]
args.load_directory = '/home/ben/Documents/CERN/rebuttal_iclr/sst_pretrained/'

In [5]:
processors = {
    "cola": ColaProcessor,
    "mnli": MnliProcessor,
    "mnli-mm": MnliMismatchedProcessor,
    "mrpc": MrpcProcessor,
    "sst-2": Sst2Processor,
    "sts-b": StsbProcessor,
    "qqp": QqpProcessor,
    "qnli": QnliProcessor,
    "rte": RteProcessor,
    "wnli": WnliProcessor,
}

output_modes = {
    "cola": "classification",
    "mnli": "classification",
    "mrpc": "classification",
    "sst-2": "classification",
    "sts-b": "regression",
    "qqp": "classification",
    "qnli": "classification",
    "rte": "classification",
    "wnli": "classification",
}

In [6]:
task_name = args.task_name.lower()

In [7]:
processor = processors[task_name]()
output_mode = output_modes[task_name]

label_list = processor.get_labels()
num_labels = len(label_list)

In [9]:
# model = BertForSequenceClassification.from_pretrained(
#             args.load_directory, num_labels=num_labels, config=args.config)
model = BertForSequenceClassification.from_pretrained(args.bert_model)


NameError: name 'AutoModelForSequenceClassification' is not defined

In [13]:
tokenizer = BertTokenizer.from_pretrained(
    args.load_directory, do_lower_case=args.do_lower_case)
model = BertForSequenceClassification.from_pretrained("doyoungkim/bert-base-uncased-finetuned-sst2")

In [10]:
# tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

In [11]:
# # Prepare model
# cache_dir = os.path.join(
#     str(PYTORCH_PRETRAINED_BERT_CACHE),
#     'distributed_{}'.format(args.local_rank))
# model = BertForSequenceClassification.from_pretrained(
#     args.bert_model,
#     cache_dir=cache_dir,
#     config_dir=args.config_dir,
#     config=args.config,
#     num_labels=num_labels)

In [9]:
train_examples = processor.get_train_examples(args.data_dir)

In [10]:
cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(
            args.local_rank))

In [23]:
model.to(device)

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [14]:
train_features = convert_examples_to_features(
    train_examples, label_list, args.max_seq_length, tokenizer,
    output_mode)
all_input_ids = torch.tensor([f.input_ids for f in train_features],
                             dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features],
                              dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                               dtype=torch.long)

if output_mode == "classification":
    all_label_ids = torch.tensor([f.label_id for f in train_features],
                                 dtype=torch.long)
elif output_mode == "regression":
    all_label_ids = torch.tensor([f.label_id for f in train_features],
                                 dtype=torch.float)

train_data = TensorDataset(all_input_ids, all_input_mask,
                           all_segment_ids, all_label_ids)
if args.local_rank == -1:
    train_sampler = RandomSampler(train_data)
else:
    train_sampler = DistributedSampler(train_data)
train_dataloader = DataLoader(
    train_data,
    sampler=train_sampler,
    batch_size=args.train_batch_size)

[2022-11-10 20:25:15,700 run_classifier.py:431] Writing example 0 of 67349
[2022-11-10 20:25:15,702 run_classifier.py:496] *** Example ***
[2022-11-10 20:25:15,702 run_classifier.py:497] guid: train-1
[2022-11-10 20:25:15,703 run_classifier.py:498] tokens: [CLS] hide new secret ##ions from the parental units [SEP]
[2022-11-10 20:25:15,703 run_classifier.py:499] input_ids: 101 5342 2047 3595 8496 2013 1996 18643 3197 102 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
[2022-11-10 20:25:15,703 run_classifier.py:501] input_mask: 1 1 1 1 1 1 1 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
[2022-11-10 20:25:15,704

In [15]:
model.eval()

BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, element

In [16]:
## Define useful layer hooks:
def linear_flops_counter_hook(module, input, output):
    input = input[0]
    # pytorch checks dimensions, so here we don't care much
    output_last_dim = output.shape[-1]
    bias_flops = output_last_dim if module.bias is not None else 0
    module.__flops__ += int(np.prod(input.shape) * output_last_dim + bias_flops)

In [17]:
MODULES_MAPPING = {
    nn.Linear: linear_flops_counter_hook,
    QuantLinear: linear_flops_counter_hook,
}

In [18]:
layers = []
names = []
for name, module in model.named_modules():
    if type(module) in MODULES_MAPPING:
        names.append(name)
        layers.append(module)

In [19]:
param_nums = []
params = []
names = []
for name, module in model.named_modules():
    if (isinstance(module, nn.Linear) or (isinstance(module, nn.MultiheadAttention))):
        for n, p in list(module.named_parameters()):
            if n.endswith('weight'):
                names.append(name)
                p.collect = True
                layers.append(module)
                param_nums.append(p.numel())
                params.append(p)
            else:
                p.collect = False
        continue
    for p in list(module.parameters()):
        if p.requires_grad:
            p.collect = False
for i, n in enumerate(names):
    print(i, n)

0 bert.encoder.layer.0.attention.self.query
1 bert.encoder.layer.0.attention.self.key
2 bert.encoder.layer.0.attention.self.value
3 bert.encoder.layer.0.attention.output.dense
4 bert.encoder.layer.0.intermediate.dense
5 bert.encoder.layer.0.output.dense
6 bert.encoder.layer.1.attention.self.query
7 bert.encoder.layer.1.attention.self.key
8 bert.encoder.layer.1.attention.self.value
9 bert.encoder.layer.1.attention.output.dense
10 bert.encoder.layer.1.intermediate.dense
11 bert.encoder.layer.1.output.dense
12 bert.encoder.layer.2.attention.self.query
13 bert.encoder.layer.2.attention.self.key
14 bert.encoder.layer.2.attention.self.value
15 bert.encoder.layer.2.attention.output.dense
16 bert.encoder.layer.2.intermediate.dense
17 bert.encoder.layer.2.output.dense
18 bert.encoder.layer.3.attention.self.query
19 bert.encoder.layer.3.attention.self.key
20 bert.encoder.layer.3.attention.self.value
21 bert.encoder.layer.3.attention.output.dense
22 bert.encoder.layer.3.intermediate.dense
23 bert

In [55]:
def benchmarking_Fish(model, device, params, metric, data_loader,iterations):
    ''' Used to generate the convergence statistics for the Empirical Fisher
    Args:
        model
        device
        params - list of accumulated model parameters to generate placeholders
        criterion - loss function used to compute the EF
        data_loader
        iterations - total number of Estimator iterations to sum over
    Returns:
        F_average - computed EF trace
        estimator_accumulation - accumulation of each individual EF estimator
        estimator_mean_accumulation - accumulation of the EF trace estimator over iterations
    '''
    model.eval()
    
    # accumulate hutchinson estimator
    estimator_accumulation = []

    estimator_mean_accumulation = []
    
    iteration = 0
    total_processed = 0
    batches = 0
    
    while(iteration < iterations):

        TFv = [torch.zeros(p.size()).to(device) for p in params]  # accumulate iteration up to datapoints_per_iteration
        
        for i, data in enumerate(data_loader, 1):
            batch = tuple(t.to(device) for t in data)
            input_ids, input_mask, segment_ids, label_ids = batch

            # define a new function to compute loss values for both output_modes
            logits = model(input_ids, segment_ids, input_mask, labels=None)
            loss = criterion(logits.view(-1, num_labels), label_ids.view(-1))
            
            loss.backward(create_graph=True)
            
            loss.backward()

            paramsH = []
            gradsH = []
            for paramH in model.parameters():
                if not paramH.collect:
                    continue
                paramsH.append(paramH)
                gradsH.append(0. if paramH.grad is None else paramH.grad + 0.)
            
            # Fisher Accumulation
            G2 = []
            for g in gradsH:
                G2.append(batch_size*g*g)
                
            TFv = [TFv_ + G2_ + 0. for TFv_, G2_ in zip(TFv, G2)]
            
            total_processed += 1
            
            TFv_normed = [TFv_ / float(total_processed) for TFv_ in TFv]

            vFv = [torch.sum(x) for x in TFv_normed]
            
            indiv = np.array([torch.sum(x).detach().cpu().numpy() for x in G2])
            estimator_accumulation.append(indiv)

            vFv_c = np.array([i.detach().cpu().numpy() for i in vFv])
            
            F_average = vFv_c

            print(f'Iteration {iteration}')
                    
            estimator_mean_accumulation.append(F_average)
            
            iteration += 1
            
            if iteration >= iterations:
                break
                
#         eval_metric = metric.compute()
#         print(f"Sanity check: {task_name}, {eval_metric}")

    return F_average, estimator_accumulation, estimator_mean_accumulation

In [66]:
def benchmarking_Hess(model, device, params, metric, data_loader, iterations, datapoints_per_iteration):
    ''' Used to generate the convergence statistics for the Hessian
    Args:
        model
        device
        params - list of accumulated model parameters to generate placeholders
        criterion - loss function used to compute the Hessian
        data_loader
        iterations - total number of Estimator iterations to sum over
        datapoints_per_iteration - min(batch size, datapoints_per_iteration) used to compute each estimate
    Returns:
        H_average - computed Hessian trace
        estimator_accumulation - accumulation of each individual hutchinson estimator
        estimator_mean_accumulation - accumulation of the Hutchinson trace estimator over iterations
    '''
    
    ## Defines the Rademacher generation
    def rademacher():
        v = [torch.randint_like(p, high=2, device=device) for p in params]
        for v_i in v:
            v_i[v_i == 0] = -1
        return v
    
    model.eval()
    
    # accumulate hutchinson estimator
    estimator_accumulation = []
    estimator_mean_accumulation = []

    iteration = 0
    iteration_batch = 0
    
    v = rademacher()
    
    while(iteration < iterations):

        THv = [torch.zeros(p.size()).to(device) for p in params]
        
        for i, data in enumerate(data_loader, 1):
            batch = tuple(t.to(device) for t in data)
            input_ids, input_mask, segment_ids, label_ids = batch

            # define a new function to compute loss values for both output_modes
            logits = model(input_ids, segment_ids, input_mask, labels=None).logits
#             loss = logits.loss
            loss = criterion(logits.view(-1, num_labels), label_ids.view(-1))
            
            loss.backward(create_graph=True, retain_graph=True)

            paramsH = []
            gradsH = []
            for paramH in model.parameters():
                if not paramH.collect:
                    continue
                paramsH.append(paramH)
                gradsH.append(0. if paramH.grad is None else paramH.grad + 0.)
            
            Hv = torch.autograd.grad(gradsH, paramsH, grad_outputs=v,only_inputs=False,retain_graph=True)
            
            THv = [THv_ + Hv_ + 0. for THv_, Hv_ in zip(THv, Hv)]
            
            iteration_batch += 1
            
            if iteration_batch*batch_size >= datapoints_per_iteration:
                
                THv = [THv_ / float(iteration_batch) for THv_ in THv] # normalise to the number of batches
                
                vHv = [torch.sum(x * y) for (x, y) in zip(THv, v)] # compute the Hutchinson estimator
                
                vHv_c = np.array([i.cpu().numpy() for i in vHv])
                
                estimator_accumulation.append(vHv_c) # accumulate the estimator
                
                H_average = np.mean(estimator_accumulation, axis=0)

                estimator_mean_accumulation.append(H_average)
                
                print(f'Iteration {iteration}')
                
                # Reset the hutchinson estimator variables
                v = rademacher()
                THv = [torch.zeros(p.size()).to(device) for p in params]  # accumulate result
                iteration_batch = 0
                iteration += 1
                
                if iteration >= iterations:
                    break
                    
    return H_average, estimator_accumulation, estimator_mean_accumulation

In [47]:
metric = load_metric("glue", 'sst2')

In [24]:
batch = next(iter(train_dataloader))
batch = tuple(t.to(device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch

# define a new function to compute loss values for both output_modes
_ = model(input_ids, segment_ids, input_mask, labels=None)

In [25]:
criterion = CrossEntropyLoss().to(device)

In [57]:
F, Fa, Fma = benchmarking_Fish(model, device, params, criterion, train_dataloader, iterations = 200)

AttributeError: 'SequenceClassifierOutput' object has no attribute 'view'

In [67]:
H, Ha, Hma = benchmarking_Hess(model, device, params, metric, train_dataloader, 
                      iterations = 200,  
                      datapoints_per_iteration = args.train_batch_size)

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

In [None]:
mean_normed = (Fa - np.mean(Fa, axis=0))/np.mean(Fa, axis=0)
# plt.yscale('log')
print(f'EF variance: {np.var(mean_normed)}')

In [None]:
mean_normed = (Ha - np.mean(Ha, axis=0))/np.mean(Ha, axis=0)
# plt.yscale('log')
print(f'Hessian variance: {np.var(mean_normed)}')