## Import libs

In [1]:
%cd ..

import os
os.environ['CUDA_VISIBLE_DEVICES']='5'

/home/shapkin/Projects/effective-inference


In [2]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel, GPT2LMHeadModel
from utils.attention_patterns.bert_modules import BertWrapper, WindowBert2Attention
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
from sklearn.metrics import classification_report
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
import torch
#from progressbar import progressbar
from tqdm.auto import tqdm
from utils.prepare_dataset import load_datasets, cut_datasets
from collections import defaultdict

## Define hyperparams

In [3]:
# Define datasets
#['mrpc', 'sst2', 'cola', 'rte', 'qnli']
glue_classification = {'mrpc': ['sentence1', 'sentence2'], 'sst2':  ['sentence'], 'qnli' : ['question', 'sentence']}
superglue_classification = {'wic': ['sentence1', 'sentence2']}
all_classification = {'glue': glue_classification, 'superglue': superglue_classification}

model_name = 'bert-base-uncased'

tokenizer = AutoTokenizer.from_pretrained(model_name, max_length=1024)
model = AutoModel.from_pretrained(model_name)
model.eval()

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == 'cpu':
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

    
model = BertWrapper(model, WindowBert2Attention).to(device)  
DEBUG_FLAG = False
CUT_SIZE = None if not DEBUG_FLAG else 100

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


## Load datasets

In [4]:
glue_datasets = load_datasets('glue', list(glue_classification), CUT_SIZE)
superglue_datasets = load_datasets('super_glue', list(superglue_classification), CUT_SIZE)

all_datasets = {'glue': glue_datasets, 'superglue': superglue_datasets}

Found cached dataset glue (/home/shapkin/.cache/huggingface/datasets/glue/mrpc/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset glue (/home/shapkin/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset glue (/home/shapkin/.cache/huggingface/datasets/glue/qnli/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad)


  0%|          | 0/3 [00:00<?, ?it/s]

Found cached dataset super_glue (/home/shapkin/.cache/huggingface/datasets/super_glue/wic/1.0.3/bb9675f958ebfee0d5d6dc5476fafe38c79123727a7258d515c450873dbdbbed)


  0%|          | 0/3 [00:00<?, ?it/s]

In [5]:
list(glue_classification), list(superglue_classification)

(['mrpc', 'sst2', 'qnli'], ['wic'])

In [6]:
tqdm_pbar = lambda x, y: tqdm(x, leave=True, position=0, total=len(x), desc=f'{y}')
def get_cls_embeddings_for_dataset(dataset_name, dataset, 
                                   feature_names, model, tokenizer, 
                                   pbar_func=tqdm_pbar, device=device, CUT_SIZE=CUT_SIZE):
    collected_embeddings = defaultdict(list)
    
    for split, data in dataset.items():
        
        pbar = pbar_func(data, f"{split} {dataset_name}") if pbar_func is not None else data
        for example in pbar:
            # Encode the input sentences
            encoded_inputs = tokenizer.encode(*list(map(lambda x: example[x] , feature_names)), 
                                              truncation=True, 
                                              return_tensors='pt')

            encoded_inputs = encoded_inputs.to(device)

            # Forward pass through the model
            with torch.no_grad():
                outputs = model(encoded_inputs)

            # Get the embedding of the [CLS] token
            cls_embedding = outputs.last_hidden_state[:, 0, :]

            # Append the [CLS] embedding to the list
            collected_embeddings[split].append(cls_embedding)
         
    return collected_embeddings

In [7]:
def train_linear(X_train, y_train):
    classifier = LogisticRegression(solver='lbfgs', max_iter=3000)
    classifier.fit(X_train, y_train)
    return classifier

def evaluate_classifier(classifier, X, y=None):
    predictions = classifier.predict(X)
    return predictions

def get_metrics_report(y_true, y_pred):
    #accuracy = accuracy_score(y_true, y_pred)
    #f1 = f1_score(y_true, y_pred,  average='weighted')
    #precision = precision_score(y_true, y_pred,  average='weighted')
    #recall = recall_score(y_true, y_pred,  average='weighted', zero_division='warn')
    #roc_auc = roc_auc_score(y_true, y_pred)
    print(classification_report(y_true, y_pred))

In [8]:
for dn, datasets in all_datasets.items():
    for dataset_name, dataset in datasets.items():
        print(f"{dn.upper()} / {dataset_name}\n")
        dataset_embeddings = get_cls_embeddings_for_dataset(
            dataset_name,
            dataset, 
            all_classification[dn][dataset_name], 
            model, 
            tokenizer)
        
        train_dataset_embeddings = torch.cat(dataset_embeddings['train'], dim=0).cpu()
        valid_dataset_embeddings = torch.cat(dataset_embeddings['validation'], dim=0).cpu()
        test_dataset_embeddings = torch.cat(dataset_embeddings['test'], dim=0).cpu()
        
        classif = train_linear(train_dataset_embeddings, [el['label'] for el in dataset['train']])
        valid_preds = evaluate_classifier(classif, valid_dataset_embeddings)
        print('Validation evaluation:\n')
        get_metrics_report([el['label'] for el in dataset['validation']], valid_preds)
        # print(train_dataset_embeddings.shape)

GLUE / mrpc



train mrpc:   0%|          | 0/3668 [00:00<?, ?it/s]

validation mrpc:   0%|          | 0/408 [00:00<?, ?it/s]

test mrpc:   0%|          | 0/1725 [00:00<?, ?it/s]

Validation evaluation:

              precision    recall  f1-score   support

           0       0.41      0.32      0.36       129
           1       0.72      0.79      0.75       279

    accuracy                           0.64       408
   macro avg       0.56      0.55      0.56       408
weighted avg       0.62      0.64      0.63       408

GLUE / sst2



train sst2:   0%|          | 0/67349 [00:00<?, ?it/s]

validation sst2:   0%|          | 0/872 [00:00<?, ?it/s]

test sst2:   0%|          | 0/1821 [00:00<?, ?it/s]

Validation evaluation:

              precision    recall  f1-score   support

           0       0.75      0.72      0.73       428
           1       0.74      0.77      0.75       444

    accuracy                           0.74       872
   macro avg       0.74      0.74      0.74       872
weighted avg       0.74      0.74      0.74       872

GLUE / qnli



train qnli:   0%|          | 0/104743 [00:00<?, ?it/s]