In [1]:
%cd ../..

/home/shapkin/effective-inference


## Import libs

In [2]:
import os

import yaml
import h5py
import numpy as np
import torch
import json
import seaborn as sns
import matplotlib.pyplot as plt

from utils.dataset_cache import cache_embeddings, get_dataset_for_regression, build_dataset_from_cached, load_cached_dataset
from utils.dataset_cache import build_dict_dataset_from_cached
from utils.prepare_dataset import load_datasets, cut_datasets
from utils.config import ConfigWrapper
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel
from typing import Tuple, List, Dict, Optional, Union
from numpy.random import shuffle
from sklearn.metrics import r2_score

from IPython.display import clear_output
from collections import defaultdict

In [3]:
from utils.attentions.bert.linear import BertWrapperLin, LinearClassifierBertAttention, LinearAttention
from utils.dataset_utils import get_dict_batch, prepare_batches
from utils.train_linear_utils import train_epoch, eval_epoch, plot_history

## Project configuration

In [4]:
config_path = 'config.yaml'

with open(config_path, "r") as f:
    config = ConfigWrapper(yaml.load(f, Loader=yaml.FullLoader))

In [5]:
tokenizer = AutoTokenizer.from_pretrained(config.model.model_name, max_length=config.general.max_len)
initial_model = AutoModel.from_pretrained(config.model.model_name).to(config.general.device)

KeyboardInterrupt: 

In [None]:
eval_datasets = load_datasets(config.data.eval_datasets, config.data.cut_size)
eval_datasets

## MRPC

In [27]:
dataset_idx, dataset_name, dataset = 0, 'mrpc', eval_datasets['mrpc']
layer2norm = {}
for ex in tqdm(dataset['train'], total=len(dataset['train']), position=True, leave=True):
    field1, field2 = config.data.eval_datasets_fields[dataset_idx]
    if field2 != '':
        encoded_inputs = tokenizer.encode(
                        ex[field1],
                        ex[field2],
                        truncation=True,
                        return_tensors='pt'
                    ).to(config.general.device)
    else:
        encoded_inputs = tokenizer.encode(
                        ex[field1],
                        truncation=True,
                        return_tensors='pt'
                    ).to(config.general.device)

    with torch.no_grad():
        outputs = initial_model(encoded_inputs, output_hidden_states=True, output_attentions=True)

    for layer in range(len(outputs.hidden_states) - 1):
        if layer not in layer2norm:
            layer2norm[layer] = []
        layer_output = outputs.hidden_states[layer][0]
        for seq_idx in range(outputs.hidden_states[layer][0].shape[0]):
            curr_emb = outputs.hidden_states[layer+1][0][seq_idx]
            prev_emb = outputs.hidden_states[layer][0][seq_idx]
            
            layer2norm[layer].append(torch.norm(curr_emb-prev_emb).item())

  0%|          | 0/3668 [00:00<?, ?it/s]

In [28]:
for layer, norms in layer2norm.items():
    print(f'{layer} mean and median:', np.mean(norms), '(', np.median(norms), ')')

0 mean and median: 9.121970471242427 ( 9.223304271697998 )
1 mean and median: 8.349919297009167 ( 8.686412811279297 )
2 mean and median: 8.354943535188614 ( 8.188782691955566 )
3 mean and median: 8.308292512355528 ( 8.24538278579712 )
4 mean and median: 8.266356978578083 ( 8.429078578948975 )
5 mean and median: 8.149031351913937 ( 8.166177749633789 )
6 mean and median: 8.319523116100328 ( 8.416105270385742 )
7 mean and median: 8.159478774294886 ( 8.258448600769043 )
8 mean and median: 7.920676039774069 ( 7.981983184814453 )
9 mean and median: 8.853833924702794 ( 8.73741102218628 )
10 mean and median: 7.909920914453066 ( 7.964186668395996 )
11 mean and median: 11.588137592496066 ( 11.739653587341309 )


## WIC

In [29]:
dataset_idx, dataset_name, dataset = 1, 'wic', eval_datasets['wic']
layer2norm = {}
for ex in tqdm(dataset['train'], total=len(dataset['train']), position=True, leave=True):
    field1, field2 = config.data.eval_datasets_fields[dataset_idx]
    if field2 != '':
        encoded_inputs = tokenizer.encode(
                        ex[field1],
                        ex[field2],
                        truncation=True,
                        return_tensors='pt'
                    ).to(config.general.device)
    else:
        encoded_inputs = tokenizer.encode(
                        ex[field1],
                        truncation=True,
                        return_tensors='pt'
                    ).to(config.general.device)

    with torch.no_grad():
        outputs = initial_model(encoded_inputs, output_hidden_states=True, output_attentions=True)

    for layer in range(len(outputs.hidden_states) - 1):
        if layer not in layer2norm:
            layer2norm[layer] = []
        layer_output = outputs.hidden_states[layer][0]
        for seq_idx in range(outputs.hidden_states[layer][0].shape[0]):
            curr_emb = outputs.hidden_states[layer+1][0][seq_idx]
            prev_emb = outputs.hidden_states[layer][0][seq_idx]
            
            layer2norm[layer].append(torch.norm(curr_emb-prev_emb).item())

  0%|          | 0/5428 [00:00<?, ?it/s]

In [30]:
for layer, norms in layer2norm.items():
    print(f'{layer} mean and median:', np.mean(norms), '(', np.median(norms), ')')

0 mean and median: 8.827095547197262 ( 8.848946571350098 )
1 mean and median: 7.738413257428731 ( 7.8639092445373535 )
2 mean and median: 8.21751114782759 ( 7.867585182189941 )
3 mean and median: 7.696125160249789 ( 7.674619197845459 )
4 mean and median: 7.172252207408471 ( 7.515017509460449 )
5 mean and median: 7.48639344526253 ( 7.649768829345703 )
6 mean and median: 7.841299281363784 ( 8.230294227600098 )
7 mean and median: 7.827368092935812 ( 8.151829719543457 )
8 mean and median: 7.453597766049063 ( 7.750354290008545 )
9 mean and median: 8.123992402132705 ( 8.030667304992676 )
10 mean and median: 7.14757240141516 ( 7.269410133361816 )
11 mean and median: 10.794832546460642 ( 10.961954116821289 )


## IMDB

In [32]:
eval_datasets = load_datasets(config.data.train_datasets, config.data.cut_size)

In [33]:
dataset_idx, dataset_name, dataset = 0, 'imdb', eval_datasets['imdb']
layer2norm = {}
for ex in tqdm(dataset['train'], total=len(dataset['train']), position=True, leave=True):
    field1, field2 = config.data.train_datasets_fields[dataset_idx]
    if field2 != '':
        encoded_inputs = tokenizer.encode(
                        ex[field1],
                        ex[field2],
                        truncation=True,
                        return_tensors='pt'
                    ).to(config.general.device)
    else:
        encoded_inputs = tokenizer.encode(
                        ex[field1],
                        truncation=True,
                        return_tensors='pt'
                    ).to(config.general.device)

    with torch.no_grad():
        outputs = initial_model(encoded_inputs, output_hidden_states=True, output_attentions=True)

    for layer in range(len(outputs.hidden_states) - 1):
        if layer not in layer2norm:
            layer2norm[layer] = []
        layer_output = outputs.hidden_states[layer][0]
        for seq_idx in range(outputs.hidden_states[layer][0].shape[0]):
            curr_emb = outputs.hidden_states[layer+1][0][seq_idx]
            prev_emb = outputs.hidden_states[layer][0][seq_idx]
            
            layer2norm[layer].append(torch.norm(curr_emb-prev_emb).item())

  0%|          | 0/25000 [00:00<?, ?it/s]

In [34]:
for layer, norms in layer2norm.items():
    print(f'{layer} mean and median:', np.mean(norms), '(', np.median(norms), ')')

0 mean and median: 8.561740641097886 ( 8.603800296783447 )
1 mean and median: 8.607935406028636 ( 8.691346168518066 )
2 mean and median: 8.554685527824564 ( 8.430185317993164 )
3 mean and median: 8.627527011758934 ( 8.421025276184082 )
4 mean and median: 8.417464881664007 ( 8.290224552154541 )
5 mean and median: 8.634323292065046 ( 8.527114868164062 )
6 mean and median: 8.781677480413382 ( 8.685733795166016 )
7 mean and median: 8.8651576286237 ( 8.83482313156128 )
8 mean and median: 8.730750874322444 ( 8.740266799926758 )
9 mean and median: 9.197371562321866 ( 9.180856704711914 )
10 mean and median: 8.137803444909565 ( 8.145401954650879 )
11 mean and median: 11.946552854364535 ( 12.052977561950684 )


In [None]:
field1, field2 = config.data.eval_datasets_fields[dataset_idx]
if field2 != '':
    encoded_inputs = tokenizer.encode(
                    ex[field1],
                    ex[field2],
                    truncation=True,
                    return_tensors='pt'
                ).to(config.general.device)
else:
    encoded_inputs = tokenizer.encode(
                    ex[field1],
                    truncation=True,
                    return_tensors='pt'
                ).to(config.general.device)

In [30]:
tqdm_pbar = lambda x, y: tqdm(x, leave=True, position=0, total=len(x), desc=f'{y}')
def get_cls_embeddings_for_dataset(dataset_idx, dataset_name, dataset, config, tokenizer, model,
                                   pbar_func=tqdm_pbar):
    collected_embeddings = defaultdict(list)

    for split, data in eval_datasets[dataset_name].items():
        pbar = pbar_func(list(enumerate(data)), f"{split} {dataset_name}") if pbar_func is not None else data
        for ex_idx, ex in pbar:
            field1, field2 = config.data.eval_datasets_fields[dataset_idx]
            if field2 != '':
                encoded_inputs = tokenizer.encode(
                                ex[field1],
                                ex[field2],
                                truncation=True,
                                return_tensors='pt'
                            ).to(config.general.device)
            else:
                encoded_inputs = tokenizer.encode(
                                ex[field1],
                                truncation=True,
                                return_tensors='pt'
                            ).to(config.general.device)
            
            with torch.no_grad():
                outputs = model(encoded_inputs)

            # Get the embedding of the [CLS] token
            cls_embedding = outputs.last_hidden_state[:, 0, :]

            # Append the [CLS] embedding to the list
            collected_embeddings[split].append(cls_embedding)
    return collected_embeddings

def train_linear(X_train, y_train):
    classifier = LogisticRegression(solver='lbfgs', max_iter=3000)
    classifier.fit(X_train, y_train)
    return classifier

def evaluate_classifier(classifier, X, y=None):
    predictions = classifier.predict(X)
    return predictions

def get_metrics_report(y_true, y_pred):
    accuracy = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred,  average='weighted')
    print('Weighted F1', f1)
    print('Accuracy', accuracy)
    print('-------------------------------')

In [31]:
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
from sklearn.metrics import classification_report

def check_results(custom_model, initial_model, datasets, config):
    for dataset_idx, (dataset_name, dataset) in enumerate(datasets.items()):
        print(f"{dataset_name}\n")

        print('Original')

        dataset_embeddings_orig = get_cls_embeddings_for_dataset(
            dataset_idx, 
            dataset_name,
            dataset, 
            config,
            tokenizer, 
            initial_model)
        
        train_dataset_embeddings = torch.cat(dataset_embeddings_orig['train'], dim=0)
        valid_dataset_embeddings = torch.cat(dataset_embeddings_orig['validation'], dim=0)
        test_dataset_embeddings = torch.cat(dataset_embeddings_orig['test'], dim=0)
        
        classif = train_linear(train_dataset_embeddings.cpu(), [el['label'] for el in dataset['train']])
        valid_preds = evaluate_classifier(classif, valid_dataset_embeddings.cpu())
        print('Validation evaluation:\n')
        get_metrics_report([el['label'] for el in dataset['validation']], valid_preds)
        # print(train_dataset_embeddings.shape)

        
        print('\nLinear:')
        
        dataset_embeddings_custom = get_cls_embeddings_for_dataset(
            dataset_idx, 
            dataset_name,
            dataset, 
            config,
            tokenizer, 
            custom_model)
        
        train_dataset_embeddings = torch.cat(dataset_embeddings_custom['train'], dim=0)
        valid_dataset_embeddings = torch.cat(dataset_embeddings_custom['validation'], dim=0)
        test_dataset_embeddings = torch.cat(dataset_embeddings_custom['test'], dim=0)


        classif = train_linear(train_dataset_embeddings.cpu(), [el['label'] for el in dataset['train']])
        valid_preds = evaluate_classifier(classif, valid_dataset_embeddings.cpu())
        print('Validation evaluation:\n')
        get_metrics_report([el['label'] for el in dataset['validation']], valid_preds)

In [32]:
initial_model = initial_model.to(config.general.device)
linear_model = linear_model.to(config.general.device)

In [33]:
check_results(linear_model, initial_model, eval_datasets, config)

mrpc

Original


train mrpc:   0%|          | 0/3668 [00:00<?, ?it/s]

validation mrpc:   0%|          | 0/408 [00:00<?, ?it/s]

test mrpc:   0%|          | 0/1725 [00:00<?, ?it/s]

Validation evaluation:

Weighted F1 0.688181620007579
Accuracy 0.6936274509803921
-------------------------------

Linear:


train mrpc:   0%|          | 0/3668 [00:00<?, ?it/s]

validation mrpc:   0%|          | 0/408 [00:00<?, ?it/s]

test mrpc:   0%|          | 0/1725 [00:00<?, ?it/s]

Validation evaluation:

Weighted F1 0.6226437943598493
Accuracy 0.6372549019607843
-------------------------------
wic

Original


train wic:   0%|          | 0/5428 [00:00<?, ?it/s]

validation wic:   0%|          | 0/638 [00:00<?, ?it/s]

test wic:   0%|          | 0/1400 [00:00<?, ?it/s]

Validation evaluation:

Weighted F1 0.6009986605735895
Accuracy 0.6018808777429467
-------------------------------

Linear:


train wic:   0%|          | 0/5428 [00:00<?, ?it/s]

validation wic:   0%|          | 0/638 [00:00<?, ?it/s]

test wic:   0%|          | 0/1400 [00:00<?, ?it/s]

Validation evaluation:

Weighted F1 0.48395963171933964
Accuracy 0.4843260188087774
-------------------------------
