In [1]:
import sys

sys.path.append('space-model')

In [2]:
import random
import os

import pandas as pd
import numpy as np

import torch
import torch.nn.functional as F

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, jaccard_score
from sklearn.model_selection import train_test_split

from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors

from tqdm import tqdm

import matplotlib.pyplot as plt
import plotly.graph_objects as go

from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification
from transformers import DataCollatorWithPadding

from datasets import load_dataset, Dataset, DatasetDict

from space_model.model import *
from space_model.loss import *

from logger import get_logger
from train import training, eval_results, plot_results, eval, eval_epoch

In [3]:
SEED = 42

In [4]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


seed_everything(seed=SEED)

In [5]:
def on_gpu(f):
    def wrapper(*args):
        if torch.cuda.is_available():
            return f(*args)
        else:
            print('cuda unavailable')

    return wrapper

In [6]:
if torch.cuda.is_available():
    ! pip install pynvml
    from pynvml import *
    from numba import cuda


@on_gpu
def print_gpu_utilization(dev_id):
    try:
        nvmlInit()
        handle = nvmlDeviceGetHandleByIndex(dev_id)
        info = nvmlDeviceGetMemoryInfo(handle)
        print(f"GPU memory occupied: {info.used // 1024 ** 2} MB.")
    except Exception as e:
        print(e)


@on_gpu
def free_gpu_cache(dev_id=0):
    print("Initial GPU Usage")
    print_gpu_utilization(dev_id)

    torch.cuda.empty_cache()

    print("GPU Usage after emptying the cache")
    print_gpu_utilization(dev_id)


def print_summary(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}")
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
    print_gpu_utilization()



In [7]:
device_id = 0

In [8]:
device = torch.device(f'cuda:{device_id}' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [9]:
if torch.cuda.is_available():
    torch.cuda.set_device(device)

In [10]:
MODEL_NAME = 'distilbert-base-cased'
DATASET_NAME = 'go-emotions'

NUM_LABELS = 28
N_LATENT = 128

BATCH_SIZE = 256
MAX_SEQ_LEN = 512

In [11]:
emotions_1_df = pd.read_csv(f'data/goemotions_1.csv')
emotions_2_df = pd.read_csv(f'data/goemotions_2.csv')
emotions_3_df = pd.read_csv(f'data/goemotions_3.csv')

emotions_df = pd.concat([
    emotions_1_df,
    emotions_2_df,
    emotions_3_df
], ignore_index=True, axis=0)
emotions_df

Unnamed: 0,text,id,author,subreddit,link_id,parent_id,created_utc,rater_id,example_very_unclear,admiration,...,love,nervousness,optimism,pride,realization,relief,remorse,sadness,surprise,neutral
0,That game hurt.,eew5j0j,Brdd9,nrl,t3_ajis4z,t1_eew18eq,1.548381e+09,1,False,0,...,0,0,0,0,0,0,0,1,0,0
1,>sexuality shouldn’t be a grouping category I...,eemcysk,TheGreen888,unpopularopinion,t3_ai4q37,t3_ai4q37,1.548084e+09,37,True,0,...,0,0,0,0,0,0,0,0,0,0
2,"You do right, if you don't care then fuck 'em!",ed2mah1,Labalool,confessions,t3_abru74,t1_ed2m7g7,1.546428e+09,37,False,0,...,0,0,0,0,0,0,0,0,0,1
3,Man I love reddit.,eeibobj,MrsRobertshaw,facepalm,t3_ahulml,t3_ahulml,1.547965e+09,18,False,0,...,1,0,0,0,0,0,0,0,0,0
4,"[NAME] was nowhere near them, he was by the Fa...",eda6yn6,American_Fascist713,starwarsspeculation,t3_ackt2f,t1_eda65q2,1.546669e+09,2,False,0,...,0,0,0,0,0,0,0,0,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
211220,Everyone likes [NAME].,ee6pagw,Senshado,heroesofthestorm,t3_agjf24,t3_agjf24,1.547634e+09,16,False,0,...,1,0,0,0,0,0,0,0,0,0
211221,Well when you’ve imported about a gazillion of...,ef28nod,5inchloser,nottheonion,t3_ak26t3,t3_ak26t3,1.548553e+09,15,False,0,...,0,0,0,0,0,0,0,0,0,0
211222,That looks amazing,ee8hse1,springt1me,shittyfoodporn,t3_agrnqb,t3_agrnqb,1.547684e+09,70,False,1,...,0,0,0,0,0,0,0,0,0,0
211223,The FDA has plenty to criticize. But like here...,edrhoxh,enamedata,medicine,t3_aejqzd,t1_edrgdtx,1.547169e+09,4,False,0,...,0,0,0,0,0,0,0,0,0,0


In [12]:
labels = [
    'admiration',
    'amusement',
    'anger',
    'annoyance',
    'approval',
    'caring',
    'confusion',
    'curiosity',
    'desire',
    'disappointment',
    'disapproval',
    'disgust',
    'embarrassment',
    'excitement',
    'fear',
    'gratitude',
    'grief',
    'joy',
    'love',
    'nervousness',
    'optimism',
    'pride',
    'realization',
    'relief',
    'remorse',
    'sadness',
    'surprise',
    'neutral'
]

In [13]:
emotions_df['label'] = emotions_df[labels].apply(lambda x: x.to_list(), axis=1)
emotions_df

Unnamed: 0,text,id,author,subreddit,link_id,parent_id,created_utc,rater_id,example_very_unclear,admiration,...,nervousness,optimism,pride,realization,relief,remorse,sadness,surprise,neutral,label
0,That game hurt.,eew5j0j,Brdd9,nrl,t3_ajis4z,t1_eew18eq,1.548381e+09,1,False,0,...,0,0,0,0,0,0,1,0,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,>sexuality shouldn’t be a grouping category I...,eemcysk,TheGreen888,unpopularopinion,t3_ai4q37,t3_ai4q37,1.548084e+09,37,True,0,...,0,0,0,0,0,0,0,0,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,"You do right, if you don't care then fuck 'em!",ed2mah1,Labalool,confessions,t3_abru74,t1_ed2m7g7,1.546428e+09,37,False,0,...,0,0,0,0,0,0,0,0,1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,Man I love reddit.,eeibobj,MrsRobertshaw,facepalm,t3_ahulml,t3_ahulml,1.547965e+09,18,False,0,...,0,0,0,0,0,0,0,0,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,"[NAME] was nowhere near them, he was by the Fa...",eda6yn6,American_Fascist713,starwarsspeculation,t3_ackt2f,t1_eda65q2,1.546669e+09,2,False,0,...,0,0,0,0,0,0,0,0,1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
211220,Everyone likes [NAME].,ee6pagw,Senshado,heroesofthestorm,t3_agjf24,t3_agjf24,1.547634e+09,16,False,0,...,0,0,0,0,0,0,0,0,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
211221,Well when you’ve imported about a gazillion of...,ef28nod,5inchloser,nottheonion,t3_ak26t3,t3_ak26t3,1.548553e+09,15,False,0,...,0,0,0,0,0,0,0,0,0,"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
211222,That looks amazing,ee8hse1,springt1me,shittyfoodporn,t3_agrnqb,t3_agrnqb,1.547684e+09,70,False,1,...,0,0,0,0,0,0,0,0,0,"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
211223,The FDA has plenty to criticize. But like here...,edrhoxh,enamedata,medicine,t3_aejqzd,t1_edrgdtx,1.547169e+09,4,False,0,...,0,0,0,0,0,0,0,0,0,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [14]:
train_split, test_split = train_test_split(emotions_df, test_size=0.2, random_state=SEED)
test_split, val_split = train_test_split(test_split, test_size=0.5, random_state=SEED)

In [15]:
dataset = DatasetDict({
    'emotions_train': Dataset.from_pandas(train_split[['text', 'label']]),
    'emotions_val': Dataset.from_pandas(test_split[['text', 'label']]),
    'emotions_test': Dataset.from_pandas(val_split[['text', 'label']]),
})
dataset

DatasetDict({
    emotions_train: Dataset({
        features: ['text', 'label', '__index_level_0__'],
        num_rows: 168980
    })
    emotions_val: Dataset({
        features: ['text', 'label', '__index_level_0__'],
        num_rows: 21122
    })
    emotions_test: Dataset({
        features: ['text', 'label', '__index_level_0__'],
        num_rows: 21123
    })
})

In [16]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer

DistilBertTokenizerFast(name_or_path='distilbert-base-cased', vocab_size=28996, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [17]:
tokenized_dataset = dataset
tokenized_dataset = tokenized_dataset.map(
    lambda x: tokenizer(x['text'], truncation=True, padding='max_length', max_length=MAX_SEQ_LEN,
                        return_tensors='pt'), batched=True)
tokenized_dataset.set_format('torch', device=device)
tokenized_dataset

Map:   0%|          | 0/168980 [00:00<?, ? examples/s]

Map:   0%|          | 0/21122 [00:00<?, ? examples/s]

Map:   0%|          | 0/21123 [00:00<?, ? examples/s]

DatasetDict({
    emotions_train: Dataset({
        features: ['text', 'label', '__index_level_0__', 'input_ids', 'attention_mask'],
        num_rows: 168980
    })
    emotions_val: Dataset({
        features: ['text', 'label', '__index_level_0__', 'input_ids', 'attention_mask'],
        num_rows: 21122
    })
    emotions_test: Dataset({
        features: ['text', 'label', '__index_level_0__', 'input_ids', 'attention_mask'],
        num_rows: 21123
    })
})

In [18]:
# source:
# https://emotiontypology.com/positive_emotion/positivesurprise/

In [19]:
defintions = [
    {'emotion': 'admiration',
     'definition': 'The feeling when you look up to someone who has excellent abilities or has accomplished impressive things. You have the urge to also achieve such things and be more like this person.'},
    {'emotion': 'amusement',
     'definition': 'The feeling when you encounter something silly, ironic, witty, or absurd, which makes you laugh. You have the urge to be playful and share the joke with others.'},
    {'emotion': 'anger',
     'definition': 'The feeling when someone did something bad that harmed or offended you. You want to go against this person to stop them or prevent them from doing it again.'},
    {'emotion': 'annoyance',
     'definition': 'The feeling when something is happening that bothers you. You have the urge to say or do something to change it or make it stop.'},
    {'emotion': 'approval',
     'definition': 'The feeling when you agree with or accept something. You have the urge to support or encourage it.'},
    {'emotion': 'caring',
     'definition': 'The feeling when you are concerned about someone or something. You have the urge to help or protect them.'},
    {'emotion': 'confusion',
     'definition': 'The feeling when you get information that does not make sense to you, leaving you uncertain what to do with it.'},
    {'emotion': 'curiosity',
     'definition': 'The feeling when you want to know more about something. You have the urge to explore and learn.'},
    {'emotion': 'desire', 'definition': 'The feeling when you want something. You have the urge to get it.'},
    {'emotion': 'disappointment',
     'definition': 'The feeling when something you hoped for did not happen. You have the urge to express your sadness and frustration.'},
    {'emotion': 'disapproval',
     'definition': 'The feeling when you disagree with or dislike something. You have the urge to criticize or oppose it.'},
    {'emotion': 'disgust',
     'definition': 'The feeling when you encounter something that you don’t want to get into contact with in any way (neither see, hear, feel, smell, or taste it), because you expect it is bad for you. You want to get it away from you.'},
    {'emotion': 'embarrassment',
     'definition': 'The feeling when people suddenly focus unwanted attention on you in a situation that is not in your control. You have the urge to get away from the attention.'},
    {'emotion': 'excitement',
     'definition': 'The feeling when you expect something good or nice will happen to you. You cannot wait for it to happen.'},
    {'emotion': 'fear',
     'definition': 'The feeling when you encounter or think about a thing or person that can harm you. You have the urge to avoid or get away from the threat.'},
    {'emotion': 'gratitude',
     'definition': 'The feeling when you think that someone has gone out of their way to do something good or nice for you. You have the urge to do something back and get closer to this person.'},
    {'emotion': 'grief',
     'definition': 'The feeling when you have lost something or someone that was important to you. You have the urge to express your sadness and cry.'},
    {'emotion': 'joy',
     'definition': 'The feeling when you are happy. You have the urge to smile and be friendly to others.'},
    {'emotion': 'love',
     'definition': 'The feeling when you care deeply about someone or something. You have the urge to get closer to this person or thing.'},
    {'emotion': 'nervousness',
     'definition': 'The feeling when you have to do something, but you think that something might go wrong that prevents you from succeeding. You don’t feel in control of the situation.'},
    {'emotion': 'optimism',
     'definition': 'The feeling when you think that something good or nice will happen to you. You have the urge to be positive and look forward to it.'},
    {'emotion': 'pride',
     'definition': 'The feeling when you possess or have accomplished something that other people find praiseworthy. You feel vigorous and have the urge to show off to others.'},
    {'emotion': 'realization',
     'definition': 'The feeling when you suddenly understand something that you did not understand before. You have the urge to act on this new understanding.'},
    {'emotion': 'relief',
     'definition': 'The feeling when an unpleasant experience is finally over, or when you find out that something you had dreaded has not happened (or will not happen). You can finally take your mind off it.'},
    {'emotion': 'remorse',
     'definition': 'The feeling when you have done something wrong and you feel sorry about it. You have the urge to apologize and make amends.'},
    {'emotion': 'sadness',
     'definition': 'The feeling when you lost something that was important to you. You have the urge to withdraw and to seek comfort.'},
    {'emotion': 'surprise',
     'definition': 'The feeling when something unexpected happens. You have the urge to pay attention to it and to find out more about it.'},
    {'emotion': 'neutral', 'definition': 'The feeling when you don’t feel any particular emotion.'}
]

In [20]:
emotions_defitions_df = pd.DataFrame([{**d, 'label': labels.index(d['emotion'])} for d in defintions])
emotions_defitions_df

Unnamed: 0,emotion,definition,label
0,admiration,The feeling when you look up to someone who ha...,0
1,amusement,The feeling when you encounter something silly...,1
2,anger,The feeling when someone did something bad tha...,2
3,annoyance,The feeling when something is happening that b...,3
4,approval,The feeling when you agree with or accept some...,4
5,caring,The feeling when you are concerned about someo...,5
6,confusion,The feeling when you get information that does...,6
7,curiosity,The feeling when you want to know more about s...,7
8,desire,The feeling when you want something. You have ...,8
9,disappointment,The feeling when something you hoped for did n...,9


In [21]:
emotions_defitions_df['text'] = emotions_defitions_df['emotion'] + ' - ' + emotions_defitions_df['definition']

definitions_dataset = DatasetDict({
    'definitions': Dataset.from_pandas(emotions_defitions_df[['text', 'label']]),
})
definitions_dataset

DatasetDict({
    definitions: Dataset({
        features: ['text', 'label'],
        num_rows: 28
    })
})

In [22]:
tokenized_definitions_dataset = definitions_dataset.map(
    lambda x: tokenizer(x['text'], truncation=True, padding='max_length', max_length=MAX_SEQ_LEN,
                        return_tensors='pt'),
    batched=True)

tokenized_definitions_dataset.set_format('torch', device=device)

tokenized_definitions_dataset

Map:   0%|          | 0/28 [00:00<?, ? examples/s]

DatasetDict({
    definitions: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 28
    })
})

In [23]:
definitions_dataloader = torch.utils.data.DataLoader(tokenized_definitions_dataset['definitions'],
                                                     batch_size=BATCH_SIZE, shuffle=False)

In [24]:
def cosine_similarity(a, b):
    return np.dot(a, b) / ((np.dot(a, a) ** 0.5) * (np.dot(b, b) ** 0.5))


def euclid_similarity(a, b):
    return np.linalg.norm(np.array(a) - np.array(b))


def jaccard_similarity(a, b):
    return (np.array(a) & np.array(b)).sum() / (np.array(a) | np.array(b)).sum()

In [25]:
def eval_metrics(knowledge_df, explained_df):
    metrics_dict = {
        'jaccard': [],
        'mean_cosine': [],
        'min_cosine': [],
        'mean_euclid': [],
        'max_euclid': []
    }

    # explained - is reactions with their predicted neighbors
    # knowledge - is definitions

    for explained_record in explained_df.to_dict('records'):
        if len(explained_record['neigh_labels']) == 0:
            print('No neighbors for this tweet')
            continue

        cosines = []
        euclids = []
        for neigh_embed in explained_record['neigh_embeds']:
            cosines.append(cosine_similarity(explained_record['embeds'], neigh_embed))
            euclids.append(euclid_similarity(explained_record['embeds'], neigh_embed))

        metrics_dict['mean_cosine'].append(np.mean(cosines))
        metrics_dict['min_cosine'].append(np.min(cosines))
        metrics_dict['mean_euclid'].append(np.mean(euclids))
        metrics_dict['max_euclid'].append(np.max(euclids))

        metrics_dict['jaccard'].append(
            jaccard_similarity(explained_record['labels'], explained_record['neigh_labels_oh']))

    return metrics_dict

## KNN Model

In [26]:
@eval
def create_space_knowledge_embeddings(model, dataloader):
    knowledge_embeds, knowledge_texts, knowledge_labels = [], [], []
    with torch.no_grad():
        for step, batch in enumerate(tqdm(dataloader, total=len(dataloader))):
            ids = batch['input_ids'].to(device, dtype=torch.long)
            mask = batch['attention_mask'].to(device, dtype=torch.long)

            embed = model.base_model(input_ids=ids, attention_mask=mask).last_hidden_state  # (B, seq_len, 768)

            projected = model.space_model(embed)  # (B, n_concept_spaces * n_latent)

            knowledge_embeds += projected.logits.detach().cpu().tolist()
            knowledge_texts += batch['text']
            # we need this for further evaluation
            knowledge_labels += [d.item() for d in batch['label']]
    return {'embeds': knowledge_embeds, 'texts': knowledge_texts, 'labels': knowledge_labels}


@eval
def create_base_knowledge_embeddings(model, dataloader):
    knowledge_embeds, knowledge_texts, knowledge_labels = [], [], []
    with torch.no_grad():
        for step, batch in enumerate(tqdm(dataloader, total=len(dataloader))):
            ids = batch['input_ids'].to(device, dtype=torch.long)
            mask = batch['attention_mask'].to(device, dtype=torch.long)

            projected = model.bert(input_ids=ids, attention_mask=mask)  # (B, seq_len, 768)

            # knowledge_embeds += projected.pooler_output.detach().cpu().tolist()
            knowledge_embeds += projected.last_hidden_state[:, 0, :].detach().cpu().tolist()
            knowledge_texts += batch['text']
            # we need this for further evaluation
            knowledge_labels += [d.item() for d in batch['label']]
    return {'embeds': knowledge_embeds, 'texts': knowledge_texts, 'labels': knowledge_labels}

In [27]:
@eval
def eval_space_knowledge_embeds(model, knn, dataloader, knowledge_dict):
    explained_embeds, explained_texts, explained_labels = [], [], []
    neigh_explained_texts, neigh_explained_labels = [], []
    neigh_explained_embeds = []
    with torch.no_grad():
        for step, batch in enumerate(tqdm(dataloader, total=len(dataloader))):
            ids = batch['input_ids'].to(device, dtype=torch.long)
            mask = batch['attention_mask'].to(device, dtype=torch.long)

            embed = model.base_model(input_ids=ids, attention_mask=mask).last_hidden_state  # (B, seq_len, 768)

            projected = model.space_model(embed)  # (B, n_concept_spaces * n_latent)

            raw_embeds = projected.logits.detach().cpu().tolist()  # (B, 768)
            neighbors_ids = knn.kneighbors(raw_embeds, return_distance=False)  # (B, k), k neighbors ids for each sample

            k_neigh_texts = [[knowledge_dict['texts'][q] for q in neigh] for neigh in neighbors_ids]  # (B, k)
            k_neigh_embeds = [[knowledge_dict['embeds'][q] for q in neigh] for neigh in neighbors_ids]  # (B, k, 768)
            k_neigh_labels = [[knowledge_dict['labels'][q] for q in neigh] for neigh in neighbors_ids]  # (B, k)

            explained_embeds += raw_embeds  # (t, 768)
            explained_texts += batch['text']  # (t)
            explained_labels += batch['label'].cpu().tolist()  # (t)
            neigh_explained_texts += k_neigh_texts  # (t, k)
            neigh_explained_embeds += k_neigh_embeds  # (t, k, 768)
            neigh_explained_labels += k_neigh_labels  # (t, k)

    return {
        'embeds': explained_embeds,
        'texts': explained_texts,
        'labels': explained_labels,
        'neigh_texts': neigh_explained_texts,
        'neigh_embeds': neigh_explained_embeds,
        'neigh_labels': neigh_explained_labels
    }


@eval
def eval_base_knowledge_embeds(model, knn, dataloader, knowledge_dict):
    explained_embeds, explained_texts, explained_labels = [], [], []
    neigh_explained_texts, neigh_explained_labels = [], []
    neigh_explained_embeds = []
    with torch.no_grad():
        for step, batch in enumerate(tqdm(dataloader, total=len(dataloader))):
            ids = batch['input_ids'].to(device, dtype=torch.long)
            mask = batch['attention_mask'].to(device, dtype=torch.long)

            projected = model.bert(input_ids=ids, attention_mask=mask)  # (B, seq_len, 768)

            # raw_embeds = projected.pooler_output.detach().cpu().tolist()  # (B, 768)
            raw_embeds = projected.last_hidden_state[:, 0, :].detach().cpu().tolist()  # (B, 768)
            neighbors_ids = knn.kneighbors(raw_embeds, return_distance=False)  # (B, k), k neighbors ids for each sample

            k_neigh_texts = [[knowledge_dict['texts'][q] for q in neigh] for neigh in neighbors_ids]  # (B, k)
            k_neigh_embeds = [[knowledge_dict['embeds'][q] for q in neigh] for neigh in neighbors_ids]  # (B, k, 768)
            k_neigh_labels = [[knowledge_dict['labels'][q] for q in neigh] for neigh in neighbors_ids]  # (B, k)

            explained_embeds += raw_embeds  # (t, 768)
            explained_texts += batch['text']  # (t)
            explained_labels += batch['label'].cpu().tolist()  # (t)
            neigh_explained_texts += k_neigh_texts  # (t, k)
            neigh_explained_embeds += k_neigh_embeds  # (t, k, 768)
            neigh_explained_labels += k_neigh_labels  # (t, k)

    return {
        'embeds': explained_embeds,
        'texts': explained_texts,
        'labels': explained_labels,
        'neigh_texts': neigh_explained_texts,
        'neigh_embeds': neigh_explained_embeds,
        'neigh_labels': neigh_explained_labels
    }

In [28]:
class BertForMultilabelOutput:
    def __init__(self, loss, logits):
        self.loss = loss
        self.logits = logits


class BertForMultilabelClassification(torch.nn.Module):
    def __init__(self, model, num_labels):
        super(BertForMultilabelClassification, self).__init__()
        self.num_labels = num_labels

        self.bert = model
        self.device = model.device

        self.dropout = torch.nn.Dropout(0.1)
        self.classifier = torch.nn.Linear(768, num_labels)

    def to(self, device):
        self.device = device
        super().to(device)
        return self

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        # pooled_output = outputs.pooler_output
        pooled_output = outputs.last_hidden_state[:, 0, :]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss = F.binary_cross_entropy_with_logits(logits.view(-1, self.num_labels),
                                                      labels.view(-1, self.num_labels).float())

        return BertForMultilabelOutput(loss, logits)

In [29]:
def eval_explain_models(log, models_names, def_dataloader, test_dataloader, config):
    explained_dicts = {}
    for model_params in models_names:
        model_name = f'{config["dataset_name"]}-{model_params["name"].replace("/", "_")}-{config["num_epochs"]}'
        if model_params['space']:
            model_name = f'{config["dataset_name"]}_space-{model_params["name"].replace("/", "_")}-({model_params["n_latent"]})_{config["num_epochs"]}'

        log.debug(f'Loading model {model_name}', terminal=True)
        if model_params['space']:
            base_model = AutoModel.from_pretrained(model_params["name"]).to(device)
            model = SpaceModelForMultiLabelClassification(
                base_model,
                n_embed=768,
                n_latent=model_params['n_latent'],
                n_concept_spaces=config['num_labels'],
                l1=config['l1'],
                l2=config['l2'],
                ce_w=config['ce_w'],
                fine_tune=True
            ).to(device)
        else:
            raw_model_base = AutoModel.from_pretrained(model_params['name']).to(device)
            model = BertForMultilabelClassification(raw_model_base, config['num_labels']).to(device)

        model.load_state_dict(torch.load(f'models/{config["experiment_name"]}/{model_name}.bin'))

        knowledge_dict = model_params['create_knowledge_embeddings'](model, def_dataloader)

        explained_dicts[model_name] = {}

        for k in [1, 3, 5]:
            log.debug(f'Calculating Top-{k} for {model_name}', terminal=True)

            knn = NearestNeighbors(n_neighbors=k, metric=config['dist'])
            knn.fit(knowledge_dict['embeds'])

            explained_dict = model_params['eval_knowledge_embeds'](model, knn, test_dataloader, knowledge_dict)
            knowledge_df = pd.DataFrame(knowledge_dict)
            explained_df = pd.DataFrame(explained_dict)

            explained_df['neigh_labels_oh'] = explained_df['neigh_labels'].apply(
                lambda x: [1 if i in x else 0 for i in range(NUM_LABELS)]
            )

            metrics_dict = eval_metrics(knowledge_df, explained_df)

            cum_metrics_dict = {k: np.mean(v) for k, v in metrics_dict.items()}
            cum_metrics_dict
            cum_metrics_dict['accuracy'] = f1_score(explained_df['labels'].tolist(),
                                                    explained_df['neigh_labels_oh'].tolist(),
                                                    average='micro')
            cum_metrics_dict['f1_score'] = f1_score(explained_df['labels'].tolist(),
                                                    explained_df['neigh_labels_oh'].tolist(),
                                                    average='macro')
            cum_metrics_dict['precision'] = precision_score(explained_df['labels'].tolist(),
                                                            explained_df['neigh_labels_oh'].tolist(), average='macro')
            cum_metrics_dict['recall'] = recall_score(explained_df['labels'].tolist(),
                                                      explained_df['neigh_labels_oh'].tolist(),
                                                      average='macro')

            explained_dicts[model_name][f'Top-{k}'] = {
                'knowledge': knowledge_df,
                'explained': explained_df,
                'metrics': cum_metrics_dict
            }

            log.info(f'{model_name} - Top-{k} - {cum_metrics_dict}', terminal=True)
            for k, v in cum_metrics_dict.items():
                log.info(f'{k} - {v}', terminal=True)

    return explained_dicts

In [30]:
config = {
    'experiment_name': 'default',
    'log_terminal': True,

    'dataset_name': DATASET_NAME,
    'model_name': MODEL_NAME,

    'num_labels': NUM_LABELS,
    'iterations': 1,
    'num_epochs': 50,

    'max_seq_len': MAX_SEQ_LEN,
    'batch_size': BATCH_SIZE,
    'fp16': False,
    'weight_decay': 0.01,
    'num_warmup_steps': 0,
    'gradient_accumulation_steps': 1,

    'cross_entropy_weight': 1.0,
    'l1': 0.1,
    'l2': 1e-5,
    'ce_w': 1.0,

    'dist': 'cosine',
}

In [31]:
log = get_logger(f'logs/{config["experiment_name"]}', 'eval-explain')

In [32]:
test_dataloader = torch.utils.data.DataLoader(tokenized_dataset['emotions_test'], batch_size=config['batch_size'])

In [33]:
eval_explain_models(log, [
    {'space': True, 'name': 'distilbert-base-cased', 'n_latent': 64,
     'create_knowledge_embeddings': create_space_knowledge_embeddings,
     'eval_knowledge_embeds': eval_space_knowledge_embeds},
    {'space': True, 'name': 'distilbert-base-cased', 'n_latent': 128,
     'create_knowledge_embeddings': create_space_knowledge_embeddings,
     'eval_knowledge_embeds': eval_space_knowledge_embeds},
    {'space': False, 'name': 'FacebookAI/roberta-base',
     'create_knowledge_embeddings': create_base_knowledge_embeddings,
     'eval_knowledge_embeds': eval_base_knowledge_embeds},
    {'space': True, 'name': 'FacebookAI/roberta-base', 'n_latent': 3,
     'create_knowledge_embeddings': create_space_knowledge_embeddings,
     'eval_knowledge_embeds': eval_space_knowledge_embeds},
    {'space': True, 'name': 'FacebookAI/roberta-base', 'n_latent': 64,
     'create_knowledge_embeddings': create_space_knowledge_embeddings,
     'eval_knowledge_embeds': eval_space_knowledge_embeds},
    {'space': True, 'name': 'FacebookAI/roberta-base', 'n_latent': 128,
     'create_knowledge_embeddings': create_space_knowledge_embeddings,
     'eval_knowledge_embeds': eval_space_knowledge_embeds},
], definitions_dataloader, test_dataloader, config)

[90m2024-03-05 16:24:24,346 - default.terminal - DEBUG - Loading model go-emotions_space-distilbert-base-cased-(64)_50[0m[0m
100%|██████████| 1/1 [00:01<00:00,  1.18s/it]
[90m2024-03-05 16:24:31,570 - default.terminal - DEBUG - Calculating Top-1 for go-emotions_space-distilbert-base-cased-(64)_50[0m[0m
100%|██████████| 83/83 [01:19<00:00,  1.05it/s]
[36m2024-03-05 16:26:09,900 - default.terminal - INFO - go-emotions_space-distilbert-base-cased-(64)_50 - Top-1 - {'jaccard': 0.04692210927830867, 'mean_cosine': 0.7972908612630358, 'min_cosine': 0.7972908612630358, 'mean_euclid': 10.064698518810802, 'max_euclid': 10.064698518810802, 'accuracy': 0.055699060105493935, 'f1_score': 0.036377239060797206, 'precision': 0.13520769920135128, 'recall': 0.05165395492749796}[0m[0m
[36m2024-03-05 16:26:09,902 - default.terminal - INFO - jaccard - 0.04692210927830867[0m[0m
[36m2024-03-05 16:26:09,902 - default.terminal - INFO - mean_cosine - 0.7972908612630358[0m[0m
[36m2024-03-05 16:26:

{'go-emotions_space-distilbert-base-cased-(64)_50': {'Top-1': {'knowledge':                                                embeds  \
   0   [0.9300746917724609, 0.6807864904403687, 0.685...   
   1   [0.4744398593902588, 0.11751697957515717, 0.21...   
   2   [-0.07783995568752289, -0.28507333993911743, -...   
   3   [-0.15948240458965302, -0.020682241767644882, ...   
   4   [0.26443642377853394, 0.08171886950731277, 0.1...   
   5   [-0.1002228707075119, 0.11318329721689224, 0.1...   
   6   [-0.2496984899044037, -0.2144278734922409, -0....   
   7   [-0.019394587725400925, 0.2393021285533905, 0....   
   8   [0.10110965371131897, 0.19305704534053802, 0.2...   
   9   [-0.10128438472747803, 0.03504534065723419, 0....   
   10  [0.06136074662208557, -0.25591760873794556, -0...   
   11  [0.1158335879445076, -0.335039883852005, -0.27...   
   12  [0.119273841381073, -0.21713218092918396, -0.1...   
   13  [0.17276746034622192, 0.21402201056480408, 0.2...   
   14  [-0.2301114201545715