In [1]:
%cd ../..

/home/shapkin/effective-inference


## Import libs

In [2]:
import os

import yaml
import h5py
import numpy as np
import torch
import json
import seaborn as sns
import matplotlib.pyplot as plt

from utils.dataset_cache import cache_embeddings, get_dataset_for_regression, build_dataset_from_cached, load_cached_dataset
from utils.dataset_cache import build_dict_dataset_from_cached
from utils.prepare_dataset import load_datasets, cut_datasets
from utils.config import ConfigWrapper
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel
from typing import Tuple, List, Dict, Optional, Union
from numpy.random import shuffle
from sklearn.metrics import r2_score

from IPython.display import clear_output
from collections import defaultdict

In [3]:
from utils.attentions.bert.linear import BertWrapperLin, LinearClassifierBertAttention, LinearAttention
from utils.dataset_utils import get_dict_batch, prepare_batches
from utils.train_linear_utils import train_epoch, eval_epoch, plot_history

## Project configuration

In [4]:
config_path = 'config.yaml'

with open(config_path, "r") as f:
    config = ConfigWrapper(yaml.load(f, Loader=yaml.FullLoader))

In [5]:
tokenizer = AutoTokenizer.from_pretrained(config.model.model_name, max_length=config.general.max_len)
initial_model = AutoModel.from_pretrained(config.model.model_name).to(config.general.device)

In [6]:
eval_datasets = load_datasets(config.data.eval_datasets, config.data.cut_size)
eval_datasets

{'mrpc': DatasetDict({
     train: Dataset({
         features: ['sentence1', 'sentence2', 'label', 'idx'],
         num_rows: 3668
     })
     validation: Dataset({
         features: ['sentence1', 'sentence2', 'label', 'idx'],
         num_rows: 408
     })
     test: Dataset({
         features: ['sentence1', 'sentence2', 'label', 'idx'],
         num_rows: 1725
     })
 }),
 'wic': DatasetDict({
     train: Dataset({
         features: ['word', 'sentence1', 'sentence2', 'start1', 'start2', 'end1', 'end2', 'idx', 'label'],
         num_rows: 5428
     })
     validation: Dataset({
         features: ['word', 'sentence1', 'sentence2', 'start1', 'start2', 'end1', 'end2', 'idx', 'label'],
         num_rows: 638
     })
     test: Dataset({
         features: ['word', 'sentence1', 'sentence2', 'start1', 'start2', 'end1', 'end2', 'idx', 'label'],
         num_rows: 1400
     })
 })}

## Prepare model

In [7]:
linear_model = BertWrapperLin(initial_model, LinearClassifierBertAttention, config, layer_nums=[8, 9, 10, 11])

In [8]:
def init_linear_modules(config, linear_model):
    for layer_num, bert_att in enumerate(linear_model.bert_model.encoder.layer):
        for param_name, param in bert_att.named_modules():
            if '.' in param_name and 'linear_model' in param_name.split('.')[-1]:
                head_num = int(param_name.split('_')[-1])
                save_pattern = f"{config.data.model_save_pattern}_{layer_num}_{head_num}"
                param.load_state_dict(torch.load(f'{config.data.data_path}/linear_models/{save_pattern}/model.pth'), strict=False)

def init_linear_modules2(config, linear_model):
    for layer_num, bert_att in enumerate(linear_model.bert_model.encoder.layer):
        for param_name, param in bert_att.named_modules():
            if '.' in param_name and 'linear_model' in param_name.split('.')[-1]:
                save_pattern = f"{config.data.model_save_pattern}_{layer_num}"
                param.load_state_dict(torch.load(f'{config.data.data_path}/linear_models/{save_pattern}/model.pth'), strict=False)

if config.attention_config.split_heads or config.attention_config.model_for_each_head:
    init_linear_modules(config, linear_model)
else:
    init_linear_modules2(config, linear_model)

In [9]:
tqdm_pbar = lambda x, y: tqdm(x, leave=True, position=0, total=len(x), desc=f'{y}')
def get_cls_embeddings_for_dataset(dataset_idx, dataset_name, dataset, config, tokenizer, model,
                                   pbar_func=tqdm_pbar):
    collected_embeddings = defaultdict(list)

    for split, data in eval_datasets[dataset_name].items():
        pbar = pbar_func(list(enumerate(data)), f"{split} {dataset_name}") if pbar_func is not None else data
        for ex_idx, ex in pbar:
            field1, field2 = config.data.eval_datasets_fields[dataset_idx]
            if field2 != '':
                encoded_inputs = tokenizer.encode(
                                ex[field1],
                                ex[field2],
                                truncation=True,
                                return_tensors='pt'
                            ).to(config.general.device)
            else:
                encoded_inputs = tokenizer.encode(
                                ex[field1],
                                truncation=True,
                                return_tensors='pt'
                            ).to(config.general.device)
            
            with torch.no_grad():
                outputs = model(encoded_inputs)

            # Get the embedding of the [CLS] token
            cls_embedding = outputs.last_hidden_state[:, 0, :]

            # Append the [CLS] embedding to the list
            collected_embeddings[split].append(cls_embedding)
    return collected_embeddings

def train_linear(X_train, y_train):
    classifier = LogisticRegression(solver='lbfgs', max_iter=3000)
    classifier.fit(X_train, y_train)
    return classifier

def evaluate_classifier(classifier, X, y=None):
    predictions = classifier.predict(X)
    return predictions

def get_metrics_report(y_true, y_pred):
    accuracy = accuracy_score(y_true, y_pred)
    f1 = f1_score(y_true, y_pred,  average='weighted')
    print('Weighted F1', f1)
    print('Accuracy', accuracy)
    print('-------------------------------')

In [10]:
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, roc_auc_score
from sklearn.metrics import classification_report

def check_results(custom_model, initial_model, datasets, config):
    for dataset_idx, (dataset_name, dataset) in enumerate(datasets.items()):
        print(f"{dataset_name}\n")

        print('Original')

        dataset_embeddings_orig = get_cls_embeddings_for_dataset(
            dataset_idx, 
            dataset_name,
            dataset, 
            config,
            tokenizer, 
            initial_model)
        
        train_dataset_embeddings = torch.cat(dataset_embeddings_orig['train'], dim=0)
        valid_dataset_embeddings = torch.cat(dataset_embeddings_orig['validation'], dim=0)
        test_dataset_embeddings = torch.cat(dataset_embeddings_orig['test'], dim=0)
        
        classif = train_linear(train_dataset_embeddings.cpu(), [el['label'] for el in dataset['train']])
        valid_preds = evaluate_classifier(classif, valid_dataset_embeddings.cpu())
        print('Validation evaluation:\n')
        get_metrics_report([el['label'] for el in dataset['validation']], valid_preds)
        # print(train_dataset_embeddings.shape)

        
        print('\nLinear:')
        
        dataset_embeddings_custom = get_cls_embeddings_for_dataset(
            dataset_idx, 
            dataset_name,
            dataset, 
            config,
            tokenizer, 
            custom_model)
        
        train_dataset_embeddings = torch.cat(dataset_embeddings_custom['train'], dim=0)
        valid_dataset_embeddings = torch.cat(dataset_embeddings_custom['validation'], dim=0)
        test_dataset_embeddings = torch.cat(dataset_embeddings_custom['test'], dim=0)


        classif = train_linear(train_dataset_embeddings.cpu(), [el['label'] for el in dataset['train']])
        valid_preds = evaluate_classifier(classif, valid_dataset_embeddings.cpu())
        print('Validation evaluation:\n')
        get_metrics_report([el['label'] for el in dataset['validation']], valid_preds)

In [11]:
initial_model = initial_model.to(config.general.device)
linear_model = linear_model.to(config.general.device)

In [12]:
check_results(linear_model, initial_model, eval_datasets, config)

mrpc

Original


train mrpc:   0%|          | 0/3668 [00:00<?, ?it/s]

validation mrpc:   0%|          | 0/408 [00:00<?, ?it/s]

test mrpc:   0%|          | 0/1725 [00:00<?, ?it/s]

Validation evaluation:

Weighted F1 0.688181620007579
Accuracy 0.6936274509803921
-------------------------------

Linear:


train mrpc:   0%|          | 0/3668 [00:00<?, ?it/s]

  result += namespace['cur_result'].T


validation mrpc:   0%|          | 0/408 [00:00<?, ?it/s]

test mrpc:   0%|          | 0/1725 [00:00<?, ?it/s]

Validation evaluation:

Weighted F1 0.6775067493060573
Accuracy 0.6936274509803921
-------------------------------
wic

Original


train wic:   0%|          | 0/5428 [00:00<?, ?it/s]

validation wic:   0%|          | 0/638 [00:00<?, ?it/s]

test wic:   0%|          | 0/1400 [00:00<?, ?it/s]

Validation evaluation:

Weighted F1 0.6009986605735895
Accuracy 0.6018808777429467
-------------------------------

Linear:


train wic:   0%|          | 0/5428 [00:00<?, ?it/s]

validation wic:   0%|          | 0/638 [00:00<?, ?it/s]

test wic:   0%|          | 0/1400 [00:00<?, ?it/s]

Validation evaluation:

Weighted F1 0.5829245514868518
Accuracy 0.5830721003134797
-------------------------------
