In [1]:
from collections import defaultdict
import itertools
import re
import subprocess

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from tqdm import tqdm
from transformers import BertTokenizer, BertModel, BertForSequenceClassification

from tda4atd.stats_count import *
from tda4atd.grab_weights import grab_attention_weights, text_preprocessing

In [2]:
import warnings

warnings.filterwarnings('ignore')

In [3]:
!env | grep CUDA_VISIBLE

In [4]:
np.random.seed(42) # For reproducibility.

In [5]:
max_tokens_amount  = 512 # The number of tokens to which the tokenized text is truncated / padded.
stats_cap          = 500 # Max value that the feature can take. Is NOT applicable to Betty numbers.
    
layers_of_interest = [i for i in range(12)]  # Layers for which attention matrices and features on them are 
                                             # calculated. For calculating features on all layers, leave it be
                                             # [i for i in range(12)].
stats_name = "s_e_v_c_b0b1" # The set of topological features that will be count (see explanation below)

thresholds_array = [0.025, 0.05, 0.1, 0.25, 0.5, 0.75] # The set of thresholds
thrs = len(thresholds_array)                           # ("t" in the paper)

model_path = "bert-base-uncased"  
tokenizer_path = "bert-base-uncased"  

# You can use either standard or fine-tuned BERT. If you want to use fine-tuned BERT to your current task, save the
# model and the tokenizer with the commands tokenizer.save_pretrained(output_dir); 
# bert_classifier.save_pretrained(output_dir) into the same directory and insert the path to it here.

In [6]:
from datasets import load_dataset

imdb = load_dataset("imdb")
x = 0
sentences = []
for elem in imdb["train"]:
    sentences.append(elem["text"])
sentences = np.array(sentences)
print(sentences.shape)
# print("Average amount of words in example:", \
#       np.mean(list(map(len, map(lambda x: re.sub('\w', ' ', x).split(" "), sentences)))))
# print("Max. amount of words in example:", \
#       np.max(list(map(len, map(lambda x: re.sub('\w', ' ', x).split(" "), sentences)))))
# print("Min. amount of words in example:", \
#       np.min(list(map(len, map(lambda x: re.sub('\w', ' ', x).split(" "), sentences)))))

Found cached dataset imdb (/home/sha43/.cache/huggingface/datasets/imdb/plain_text/1.0.0/d613c88cf8fa3bab83b4ded3713f1f74830d1100e171db75bbddb80b3345c9c0)


  0%|          | 0/3 [00:00<?, ?it/s]

(25000,)


In [7]:
data = pd.DataFrame({"sentence": sentences[:2048]})
sentences = data['sentence']


In [8]:
def get_token_length(batch_texts):
    inputs = tokenizer.batch_encode_plus(batch_texts,
       return_tensors='pt',
       add_special_tokens=True,
       max_length=MAX_LEN,             # Max length to truncate/pad
       pad_to_max_length=True,         # Pad sentence to max length
       truncation=True
    )
    inputs = inputs['input_ids'].numpy()
    n_tokens = []
    print("Counting lens")
    for i in tqdm(range(inputs.shape[0])):
        ids = np.argwhere(inputs[i] == tokenizer.pad_token_id)
        if not len(ids):
            n_tokens.append(MAX_LEN)
        else:
            n_tokens.append(ids[0, 0])
    return n_tokens

In [9]:
MAX_LEN = max_tokens_amount
tokenizer = BertTokenizer.from_pretrained(tokenizer_path, do_lower_case=True)

In [10]:
data['tokenizer_length'] = get_token_length(data['sentence'].values)
ntokens_array = data['tokenizer_length'].values

Counting lens


  0%|          | 0/2048 [00:00<?, ?it/s]

In [11]:
from math import ceil

batch_size = 2 # batch size
number_of_batches = ceil(len(data['sentence']) / batch_size)
DUMP_SIZE = 64 # number of batches to be dumped
batched_sentences = np.array_split(data['sentence'].values, number_of_batches)
adj_matricies = []
adj_filenames = []
assert number_of_batches == len(batched_sentences) # sanity check


In [12]:
device='cuda'
tokenizer = BertTokenizer.from_pretrained(tokenizer_path, do_lower_case=True)
MAX_LEN = max_tokens_amount

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.bias', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

In [13]:
# What is calculated in "f(v)". You can add any other function from the array with vertex degrees.

def function_for_v(list_of_v_degrees_of_graph):
    return sum(map(lambda x: np.sqrt(x*x), list_of_v_degrees_of_graph))

def split_matricies_and_lengths(adj_matricies, ntokens_array, num_of_workers):
    splitted_adj_matricies = np.array_split(adj_matricies, num_of_workers)
    splitted_ntokens = np.array_split(ntokens_array, num_of_workers)
    print([len(elem) for elem in splitted_adj_matricies])
    print([len(elem) for elem in splitted_ntokens])
    assert all([len(m)==len(n) for m, n in zip(splitted_adj_matricies, splitted_ntokens)]), "Split is not valid!"
    return zip(splitted_adj_matricies, splitted_ntokens)

In [14]:
import os
from multiprocessing import Pool
from tqdm import tqdm

num_of_workers = 32
pool = Pool(num_of_workers)

In [15]:
models = [f"./finetuned-scrumbled-wikitext2-3e-4/checkpoint-{i}" for i in range(80, 40000, 80)]
models = ["bert-base-uncased"] + [f"./finetuned-scrumbled-wikitext2-3e-4/checkpoint-{i}" for i in range(80, 40000, 80)]

for model_name in models:
    model = BertForSequenceClassification.from_pretrained(model_name, output_attentions=True)
    model = model.to(device)
    
    stats_tuple_lists_array = []
    adj_matricies = []
    for i in tqdm(range(number_of_batches), desc="Weights calc"): 
        attention_w = grab_attention_weights(model, tokenizer, batched_sentences[i], max_tokens_amount, device)
        # sample X layer X head X n_token X n_token
        adj_matricies.append(attention_w)
        if (i+1) % DUMP_SIZE == 0: # dumping
            adj_matricies = np.concatenate(adj_matricies, axis=1)
            adj_matricies = np.swapaxes(adj_matricies, axis1=0, axis2=1) # sample X layer X head X n_token X n_token
            batch_n = (i - 1) // DUMP_SIZE
            ntokens = ntokens_array[batch_n*batch_size*DUMP_SIZE : (batch_n+1)*batch_size*DUMP_SIZE]
            splitted = split_matricies_and_lengths(adj_matricies, ntokens, num_of_workers)
            args = [(m, thresholds_array, ntokens, stats_name.split("_"), stats_cap) for m, ntokens in splitted]
            stats_tuple_lists_array_part = pool.starmap(
                count_top_stats, args
            )
            stats_tuple_lists_array.append(np.concatenate([_ for _ in stats_tuple_lists_array_part], axis=3))
            adj_matricies = []
    print(f"Ignoring {len(adj_matricies)} sentences")
    stats_tuple_lists_array = np.concatenate(stats_tuple_lists_array, axis=3)
    with open(f"attention_features/{model_name.replace('./', '').replace('/', '_')}", "wb") as fout:
        np.save(fout, stats_tuple_lists_array)

Weights calc:   5%|▌         | 7/128 [00:17<04:36,  2.29s/it]

(128,)
(128, 12, 12, 512, 512)
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]


Weights calc:  12%|█▏        | 15/128 [05:08<16:43,  8.88s/it]  

(128,)
(128, 12, 12, 512, 512)
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]


Weights calc:  18%|█▊        | 23/128 [09:44<15:49,  9.05s/it]  

(128,)
(128, 12, 12, 512, 512)
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]


Weights calc:  24%|██▍       | 31/128 [14:22<14:44,  9.12s/it]  

(128,)
(128, 12, 12, 512, 512)
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]


Weights calc:  30%|███       | 39/128 [19:42<15:04, 10.16s/it]  

(128,)
(128, 12, 12, 512, 512)
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]


Weights calc:  37%|███▋      | 47/128 [24:44<13:12,  9.79s/it]  

(128,)
(128, 12, 12, 512, 512)
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]


Weights calc:  43%|████▎     | 55/128 [30:20<12:38, 10.38s/it]   

(128,)
(128, 12, 12, 512, 512)
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]


Weights calc:  49%|████▉     | 63/128 [35:22<10:36,  9.80s/it]  

(128,)
(128, 12, 12, 512, 512)
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]


Weights calc:  55%|█████▌    | 71/128 [39:53<08:22,  8.82s/it]  

(128,)
(128, 12, 12, 512, 512)
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]


Weights calc:  62%|██████▏   | 79/128 [44:43<07:33,  9.26s/it]  

(128,)
(128, 12, 12, 512, 512)
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]


Weights calc:  68%|██████▊   | 87/128 [49:51<06:44,  9.86s/it]  

(128,)
(128, 12, 12, 512, 512)
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]


Weights calc:  74%|███████▍  | 95/128 [54:49<05:19,  9.67s/it]  

(128,)
(128, 12, 12, 512, 512)
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]


Weights calc:  80%|████████  | 103/128 [1:00:11<04:11, 10.07s/it]

(128,)
(128, 12, 12, 512, 512)
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]


Weights calc:  87%|████████▋ | 111/128 [1:04:49<02:32,  8.99s/it]

(128,)
(128, 12, 12, 512, 512)
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]


Weights calc:  93%|█████████▎| 119/128 [1:09:44<01:24,  9.41s/it]

(128,)
(128, 12, 12, 512, 512)
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]


Weights calc:  99%|█████████▉| 127/128 [1:15:07<00:10, 10.25s/it]

(128,)
(128, 12, 12, 512, 512)
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]
[4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4]


Weights calc: 100%|██████████| 128/128 [1:20:18<00:00, 37.64s/it] 

Ignoring 0 sentences



