In [7]:
%cd ..

/home/sasha/effective-inference


## Import libs

In [8]:
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModel, GPT2LMHeadModel
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
from sklearn.metrics import classification_report
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
import torch
from tqdm.auto import tqdm
from utils.prepare_dataset import load_datasets, cut_datasets
from collections import defaultdict

## Define hyperparams

In [28]:
# Define datasets
#['mrpc', 'sst2', 'cola', 'rte', 'qnli']
glue_classification = {'mrpc': ['sentence1', 'sentence2'], 'sst2':  ['sentence'], 'qnli' : ['question', 'sentence']}
superglue_classification = {'wic': ['sentence1', 'sentence2']}
all_classification = {'glue': glue_classification, 'superglue': superglue_classification}

model_name = 'bert-base-uncased'

tokenizer = AutoTokenizer.from_pretrained(model_name, max_length=1024)
model = AutoModel.from_pretrained(model_name)
model.eval()

# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if device == 'cpu':
    device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

model.to(device)

DEBUG_FLAG = False
CUT_SIZE = None if not DEBUG_FLAG else 100

## Load datasets

In [29]:
glue_datasets = load_datasets('glue', list(glue_classification), CUT_SIZE)
superglue_datasets = load_datasets('super_glue', list(superglue_classification), CUT_SIZE)

all_datasets = {'glue': glue_datasets, 'superglue': superglue_datasets}

In [30]:
list(glue_classification), list(superglue_classification)

(['mrpc', 'sst2', 'qnli'], ['wic'])

In [34]:
tqdm_pbar = lambda x, y: tqdm(x, leave=True, position=0, total=len(x), desc=f'{y}')
def get_cls_embeddings_for_dataset(dataset_name, dataset, 
                                   feature_names, model, tokenizer, 
                                   pbar_func=tqdm_pbar, device=device, CUT_SIZE=CUT_SIZE):
    collected_embeddings = defaultdict(list)
    
    for split, data in dataset.items():
        
        pbar = pbar_func(data, f"{split} {dataset_name}") if pbar_func is not None else data
        for example in pbar:
            # Encode the input sentences
            encoded_inputs = tokenizer.encode(*list(map(lambda x: example[x] , feature_names)), 
                                              truncation=True, 
                                              return_tensors='pt').to(device)

            encoded_inputs = encoded_inputs.to(device)

            # Forward pass through the model
            with torch.no_grad():
                outputs = model(encoded_inputs)

            # Get the embedding of the [CLS] token
            cls_embedding = outputs.last_hidden_state[:, 0, :]

            # Append the [CLS] embedding to the list
            collected_embeddings[split].append(cls_embedding)
         
    return collected_embeddings

In [32]:
def train_linear(X_train, y_train):
    classifier = LogisticRegression(solver='lbfgs', max_iter=3000)
    classifier.fit(X_train, y_train)
    return classifier

def evaluate_classifier(classifier, X, y=None):
    predictions = classifier.predict(X)
    return predictions

def get_metrics_report(y_true, y_pred):
    #accuracy = accuracy_score(y_true, y_pred)
    #f1 = f1_score(y_true, y_pred,  average='weighted')
    #precision = precision_score(y_true, y_pred,  average='weighted')
    #recall = recall_score(y_true, y_pred,  average='weighted', zero_division='warn')
    #roc_auc = roc_auc_score(y_true, y_pred)
    print(classification_report(y_true, y_pred))

In [35]:
for dn, datasets in all_datasets.items():
    for dataset_name, dataset in datasets.items():
        print(f"{dn.upper()} / {dataset_name}\n")
        dataset_embeddings = get_cls_embeddings_for_dataset(
            dataset_name,
            dataset, 
            all_classification[dn][dataset_name], 
            model, 
            tokenizer)
        
        train_dataset_embeddings = torch.cat(dataset_embeddings['train'], dim=0).cpu()
        valid_dataset_embeddings = torch.cat(dataset_embeddings['validation'], dim=0).cpu()
        test_dataset_embeddings = torch.cat(dataset_embeddings['test'], dim=0).cpu()
         
        classif = train_linear(train_dataset_embeddings, [el['label'] for el in dataset['train']])
        valid_preds = evaluate_classifier(classif, valid_dataset_embeddings)
        print('Validation evaluation:\n')
        get_metrics_report([el['label'] for el in dataset['validation']], valid_preds)
        # print(train_dataset_embeddings.shape)

GLUE / mrpc



train mrpc:   0%|          | 0/3668 [00:00<?, ?it/s]

validation mrpc:   0%|          | 0/408 [00:00<?, ?it/s]

test mrpc:   0%|          | 0/1725 [00:00<?, ?it/s]

Validation evaluation:

              precision    recall  f1-score   support

           0       0.52      0.46      0.49       129
           1       0.76      0.80      0.78       279

    accuracy                           0.69       408
   macro avg       0.64      0.63      0.63       408
weighted avg       0.68      0.69      0.69       408

GLUE / sst2



train sst2:   0%|          | 0/67349 [00:00<?, ?it/s]

validation sst2:   0%|          | 0/872 [00:00<?, ?it/s]

test sst2:   0%|          | 0/1821 [00:00<?, ?it/s]

Validation evaluation:

              precision    recall  f1-score   support

           0       0.87      0.84      0.86       428
           1       0.85      0.88      0.87       444

    accuracy                           0.86       872
   macro avg       0.86      0.86      0.86       872
weighted avg       0.86      0.86      0.86       872

GLUE / qnli



train qnli:   0%|          | 0/104743 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 14.58 GiB total capacity; 13.57 GiB already allocated; 1.56 MiB free; 14.45 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF