# Head Selection
**_BERT_** is a **_Multi-layer_ _Multi-Head_** Transformer architecture. As discuss in many of the current reseachers, different Attention heads captures different lingustic patterns. For a better deletion of words using Attention mechanism we need to choose a head which **captures pattern useful for classification.**

To do this we are using a Brute force mechanism to seach through all the possible heads. We are deleting TopK words attended by different heads from the sentence and measuring the new classification score. In case of sentiments, removing sentiments related words makes the sentence neutral. The heads are sorted by the amount to which it is able to make the sentences from dev set to Neutral.

In [1]:
import csv
import logging
import os
import random
import sys
import numpy as np
import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
                              TensorDataset)
from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange

from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForSequenceClassification, BertConfig, WEIGHTS_NAME, CONFIG_NAME
#from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear

from bertviz.bertviz import attention, visualization
from bertviz.bertviz.pytorch_pretrained_bert import BertModel, BertTokenizer

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)

In [4]:
logger = logging.getLogger(__name__)
bert_classifier_model_dir = "/zhangpai25/wyc/drg/model_saved/" ## Path of BERT classifier model path
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = torch.cuda.device_count()
logger.info("device: {}, n_gpu {}".format(device, n_gpu))

01/12/2023 19:25:19 - INFO - __main__ -   device: cuda, n_gpu 8


In [17]:
## Model for performing Classification
model_cls = BertForSequenceClassification.from_pretrained(bert_classifier_model_dir, num_labels=2)
tokenizer = BertTokenizer.from_pretrained('/zhangpai25/wyc/drg/web_data', do_lower_case=True)
model_cls.to(device)
model_cls.eval()

01/12/2023 19:28:08 - INFO - pytorch_pretrained_bert.modeling -   loading archive file /zhangpai25/wyc/drg/model_saved/
01/12/2023 19:28:08 - INFO - pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "type_vocab_size": 2,
  "vocab_size": 21128
}

01/12/2023 19:28:10 - INFO - bertviz.bertviz.pytorch_pretrained_bert.tokenization -   loading vocabulary file /zhangpai25/wyc/drg/web_data/vocab.txt


BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(21128, 768)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): FusedLayerNorm(torch.Size([768]), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): FusedLayerNorm(torch.Size([768]), e

In [6]:
## Model to get the attention weights of all the heads
model = BertModel.from_pretrained(bert_classifier_model_dir)
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)
model.to(device)
model.eval()

01/12/2023 19:25:42 - INFO - bertviz.bertviz.pytorch_pretrained_bert.modeling -   loading archive file /zhangpai25/wyc/drg/model_saved/
01/12/2023 19:25:42 - INFO - bertviz.bertviz.pytorch_pretrained_bert.modeling -   Model config {
  "attention_probs_dropout_prob": 0.1,
  "directionality": "bidi",
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "max_position_embeddings": 512,
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pooler_fc_size": 768,
  "pooler_num_attention_heads": 12,
  "pooler_num_fc_layers": 3,
  "pooler_size_per_head": 128,
  "pooler_type": "first_token_transform",
  "type_vocab_size": 2,
  "vocab_size": 21128
}

01/12/2023 19:25:47 - ERROR - bertviz.bertviz.pytorch_pretrained_bert.tokenization -   Model name 'bert-base-uncased' was not found in model name list (bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-b

BertModel(
  (embeddings): BertEmbeddings(
    (word_embeddings): Embedding(21128, 768)
    (position_embeddings): Embedding(512, 768)
    (token_type_embeddings): Embedding(2, 768)
    (LayerNorm): FusedLayerNorm(torch.Size([768]), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): BertEncoder(
    (layer): ModuleList(
      (0): BertLayer(
        (attention): BertAttention(
          (self): BertSelfAttention(
            (query): Linear(in_features=768, out_features=768, bias=True)
            (key): Linear(in_features=768, out_features=768, bias=True)
            (value): Linear(in_features=768, out_features=768, bias=True)
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (output): BertSelfOutput(
            (dense): Linear(in_features=768, out_features=768, bias=True)
            (LayerNorm): FusedLayerNorm(torch.Size([768]), eps=1e-12, elementwise_affine=True)
            (dropout): Dropout(p=0.1, inplace=

In [7]:
max_seq_len=70 # Maximum sequence length 
sm = torch.nn.Softmax(dim=-1) ## Softmax over the batch

In [8]:
def run_multiple_examples(input_sentences, bs=32):
    """
    This fucntion returns classification predictions for batch of sentences.
    input_sentences: list of strings
    bs : batch_size : int
    """
    
    ## Prepare data for classification
    ids = []
    segment_ids = []
    input_masks = []
    pred_lt = []
    for sen in input_sentences:
        text_tokens = tokenizer.tokenize(sen)
        tokens = ["[CLS]"] + text_tokens + ["[SEP]"]
        temp_ids = tokenizer.convert_tokens_to_ids(tokens)
        input_mask = [1] * len(temp_ids)
        segment_id = [0] * len(temp_ids)
        padding = [0] * (max_seq_len - len(temp_ids))

        temp_ids += padding
        input_mask += padding
        segment_id += padding
        
        ids.append(temp_ids)
        input_masks.append(input_mask)
        segment_ids.append(segment_id)
    
    ## Convert input lists to Torch Tensors
    ids = torch.tensor(ids)
    segment_ids = torch.tensor(segment_ids)
    input_masks = torch.tensor(input_masks)
    
    steps = len(ids) // bs
    
    for i in range(steps+1):
        if i == steps:
            temp_ids = ids[i * bs : len(ids)]
            temp_segment_ids = segment_ids[i * bs: len(ids)]
            temp_input_masks = input_masks[i * bs: len(ids)]
        else:
            temp_ids = ids[i * bs : i * bs + bs]
            temp_segment_ids = segment_ids[i * bs: i * bs + bs]
            temp_input_masks = input_masks[i * bs: i * bs + bs]
        
        temp_ids = temp_ids.to(device)
        temp_segment_ids = temp_segment_ids.to(device)
        temp_input_masks = temp_input_masks.to(device)
        
        with torch.no_grad():
            preds = sm(model_cls(temp_ids, temp_segment_ids, temp_input_masks))
        pred_lt.extend(preds.tolist())
    
    return pred_lt

In [9]:
def read_file(path,size):
    with open(path) as fp:
        data = fp.read().splitlines()[:size]
    return data

In [10]:
def get_attention_for_batch(input_sentences, bs=32):
    """
    This function calculates attention weights of all the heads and
    returns it along with the encoded sentence for further processing.
    
    input sentence: list of strings
    bs : batch_size
    """
    
    ## Preprocessing for BERT 
    ids = []
    segment_ids = []
    input_masks = []
    pred_lt = []
    ids_for_decoding = []
    for sen in input_sentences:
        text_tokens = tokenizer.tokenize(sen)
        tokens = ["[CLS]"] + text_tokens + ["[SEP]"]
        temp_ids = tokenizer.convert_tokens_to_ids(tokens)
        input_mask = [1] * len(temp_ids)
        segment_id = [0] * len(temp_ids)
        padding = [0] * (max_seq_len - len(temp_ids))
        
        ids_for_decoding.append(tokenizer.convert_tokens_to_ids(tokens))
        temp_ids += padding
        input_mask += padding
        segment_id += padding
        
        ids.append(temp_ids)
        input_masks.append(input_mask)
        segment_ids.append(segment_id)
    ## Convert the list of int ids to Torch Tensors
    ids = torch.tensor(ids)
    segment_ids = torch.tensor(segment_ids)
    input_masks = torch.tensor(input_masks)
    
    steps = len(ids) // bs
    
    for i in trange(steps+1):
        if i == steps:
            temp_ids = ids[i * bs : len(ids)]
            temp_segment_ids = segment_ids[i * bs: len(ids)]
            temp_input_masks = input_masks[i * bs: len(ids)]
        else:
            temp_ids = ids[i * bs : i * bs + bs]
            temp_segment_ids = segment_ids[i * bs: i * bs + bs]
            temp_input_masks = input_masks[i * bs: i * bs + bs]
        
        temp_ids = temp_ids.to(device)
        temp_segment_ids = temp_segment_ids.to(device)
        temp_input_masks = temp_input_masks.to(device)
        
        with torch.no_grad():
            _, _, attn = model(temp_ids, temp_segment_ids, temp_input_masks)
        
        # Add all the Attention Weights to CPU memory
        # Attention weights for each layer is stored in a dict 'attn_prob'
        for k in range(12):
            attn[k]['attn_probs'] = attn[k]['attn_probs'].to('cpu')
        
        '''
        attention weights are stored in this way:
        att_lt[layer]['attn_probs']['input_sentence']['head']['length_of_sentence']
        '''
        # Concate Attention weights for all the examples in the list att_lt[layer_no]['attn_probs']
        
        if i == 0:
            att_lt = attn
            heads = len(att_lt)
        else:
            for j in range(heads):
                att_lt[j]['attn_probs'] = torch.cat((att_lt[j]['attn_probs'],attn[j]['attn_probs']),0)
        
    
    return att_lt, ids_for_decoding

In [11]:
def process_sentences(input_sentences, att, decoding_ids, threshold=0.25):
    """
    This function processes each input sentence by removing the top tokens defined threshold value.
    Each sentence is processed for each head.
    
    input_ids: list of strings
    decoding_ids: indexed input_sentnces thus len(input_sentences) == len(decoding_ids)
    threshold: Percentage of the top indexes to be removed
    """
    # List of None of num_of_layers * num_of_heads to save the results of each head for input_sentences
    
    lt = [None for x in range(len(att) * len(att[0]['attn_probs'][0]))]
    #print(len(lt))
    
    inx = 0
    for i in trange(len(att)): #  For all the layers
        for j in range(len(att[i]['attn_probs'][0])): # For all the heads in the ith Layer
            processed_sen = [None for q in decoding_ids] # List of len(decoding_ids)
            for k in range(len(input_sentences)): # For all the senteces 
                _, topi = att[i]['attn_probs'][k][j][0].topk(len(decoding_ids[k])) # Get top attended ids
                topi = topi.tolist()
                topi = topi[:int(len(topi) * threshold)] 
                ## Decode the sentece after removing the topk indexes
                final_indexes = []
                count = 0
                count1 = 0
                tokens = ["[CLS]"] + tokenizer.tokenize(input_sentences[k]) + ["[SEP]"]
                while count < len(decoding_ids[k]):
                    if count in topi: # Remove index if present in topk
                        while (count + count1 + 1) < len(decoding_ids[k]):
                            if "##" in tokens[count + count1 + 1]:
                                count1 += 1
                            else:
                                break
                        count += count1
                        count1 = 0
                    else: # Else add to the decoded sentence
                        final_indexes.append(decoding_ids[k][count])
                    count += 1
                tmp = tokenizer.convert_ids_to_tokens(final_indexes) # Convert ids to token
                # Convert toknes to sentence
                processed_sen[k] = " ".join(tmp).replace(" ##", "").replace("[CLS]","").replace("[SEP]","").strip()
            lt[inx] = processed_sen # Store sentences for inxth head
            inx += 1
    
    return lt

In [12]:
def get_block_head(processed_sentence_list, lmbd = 0.1):
    """
    This function calculate classification scores for sentences generated by each head
    and sort them from best to worst.
    score = min(pred) + lmbd / max(pred) + lmbd, lmbd is smoothing param
    pred is list of probability score for each class, for best case pred = [0.5, 0.5] ==> score = 1
    
    it returns sorted list of (Layer, Head, Score)
    """
    scores = {}
    #scores_1 = {}
    for i in trange(len(processed_sentence_list)): # sentences by each head
        pred = np.array(run_multiple_examples(processed_sentence_list[i]))
        scores[i] = np.mean([(min(x[0], x[1])+lmbd)/(max(x[0], x[1])+lmbd) for x in pred])
        #scores_1[i] = np.mean([abs(max(x[0],x[1]) - min(x[0],x[1])) for x in pred])
    temp = sorted(scores.items(), key=lambda kv: kv[1], reverse=True)
    #temp1 = sorted(scores_1.items(), key=lambda kv: kv[1], reverse=False)
    score_lt = [(x // 12, x - (12 * (x // 12)),y) for x,y in temp]
    #score1_lt = [(x // 12, x - (12 * (x // 12)),y) for x,y in temp1]
    return score_lt  #score1_lt

In [13]:
pos_examples_file = "/zhangpai25/wyc/drg/drg_data/hlm/reference_0.txt"
neg_examples_file = "/zhangpai25/wyc/drg/drg_data/hlm/reference_1.txt"

In [14]:
'''
100 examples from each class worked good, the bottlenack is the run_multiple_examples() function,
with higher memory (either with cpu of gpu) one can reduce the processing time by incresing batch_size.
With batch_size of 32 it takes around 24 mins for 100 example on cpu.
'''
pos_data = read_file(pos_examples_file,100)
neg_data = read_file(neg_examples_file,100)
data = pos_data + neg_data

In [15]:
print(len(pos_data), len(neg_data), len(data))

100 100 200


In [18]:
att, decoding_ids = get_attention_for_batch(data)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 7/7 [00:00<00:00, 23.22it/s]


In [19]:
sen_list = process_sentences(data, att, decoding_ids)

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:02<00:00,  4.83it/s]


In [20]:
scores = get_block_head(sen_list)

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 144/144 [00:12<00:00, 11.22it/s]


In [21]:
scores

[(4, 11, 0.13576604941764295),
 (4, 4, 0.1342009114774394),
 (5, 6, 0.1322614181549304),
 (4, 7, 0.1303878578495984),
 (5, 3, 0.12994912915742785),
 (3, 7, 0.12970192821305265),
 (4, 0, 0.1293392241836173),
 (4, 3, 0.12899575082299147),
 (4, 10, 0.12871369201218674),
 (7, 3, 0.1280811496248101),
 (4, 6, 0.12799601352967502),
 (5, 4, 0.12758984178543456),
 (5, 11, 0.12757678446367696),
 (5, 7, 0.12719810859956454),
 (3, 1, 0.12716852106071155),
 (8, 9, 0.12672235722904823),
 (6, 8, 0.1264020630697278),
 (5, 9, 0.1261641761470317),
 (3, 10, 0.12585808807884907),
 (3, 8, 0.12569181593631668),
 (4, 2, 0.12562312295554184),
 (6, 6, 0.1255452621669313),
 (3, 6, 0.12546960379367877),
 (8, 7, 0.12515905359474128),
 (9, 11, 0.12290716410798502),
 (7, 7, 0.12230272382909614),
 (11, 6, 0.12230167864477255),
 (8, 2, 0.12225773257389459),
 (5, 1, 0.12204067126011722),
 (5, 5, 0.12197730297828423),
 (3, 4, 0.12172767287582419),
 (4, 5, 0.12150371487623042),
 (7, 4, 0.12145217192819689),
 (10, 8, 0.1