In [2]:
! rm -rf space-model
! git clone https://github.com/StepanTita/space-model.git

Cloning into 'space-model'...
remote: Enumerating objects: 176, done.[K
remote: Counting objects: 100% (176/176), done.[K
remote: Compressing objects: 100% (145/145), done.[K
remote: Total 176 (delta 54), reused 147 (delta 28), pack-reused 0[K
Receiving objects: 100% (176/176), 32.11 MiB | 26.84 MiB/s, done.
Resolving deltas: 100% (54/54), done.


In [2]:
! pip install transformers plotly datasets



In [1]:
import sys

sys.path.append('space-model')

In [2]:
import math
import json
from collections import Counter
import random
import os

import pandas as pd
import numpy as np

import torch
import torch.nn.functional as F

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, jaccard_score
from sklearn.model_selection import train_test_split

from sklearn.neighbors import KNeighborsClassifier, NearestNeighbors

from tqdm import tqdm

import matplotlib.pyplot as plt
import plotly.graph_objects as go

from transformers import AutoModel, AutoTokenizer, AutoModelForSequenceClassification
from transformers import DataCollatorWithPadding

from datasets import load_dataset, Dataset, DatasetDict

from space_model.model import *
from space_model.loss import *

from logger import get_logger
from train import training, eval_results, plot_results, eval, eval_epoch

In [3]:
SEED = 42

In [4]:
def seed_everything(seed=42):
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True


seed_everything(seed=SEED)

In [5]:
def on_gpu(f):
    def wrapper(*args):
        if torch.cuda.is_available():
            return f(*args)
        else:
            print('cuda unavailable')

    return wrapper

In [6]:
if torch.cuda.is_available():
    ! pip install pynvml
    from pynvml import *
    from numba import cuda


@on_gpu
def print_gpu_utilization(dev_id):
    try:
        nvmlInit()
        handle = nvmlDeviceGetHandleByIndex(dev_id)
        info = nvmlDeviceGetMemoryInfo(handle)
        print(f"GPU memory occupied: {info.used // 1024 ** 2} MB.")
    except Exception as e:
        print(e)


@on_gpu
def free_gpu_cache(dev_id=0):
    print("Initial GPU Usage")
    print_gpu_utilization(dev_id)

    torch.cuda.empty_cache()

    print("GPU Usage after emptying the cache")
    print_gpu_utilization(dev_id)


def print_summary(result):
    print(f"Time: {result.metrics['train_runtime']:.2f}")
    print(f"Samples/second: {result.metrics['train_samples_per_second']:.2f}")
    print_gpu_utilization()



In [7]:
device_id = 0

In [8]:
device = torch.device(f'cuda:{device_id}' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda', index=0)

In [9]:
if torch.cuda.is_available():
    torch.cuda.set_device(device)

In [10]:
MODEL_NAME = 'distilbert-base-cased'

DATASET_NAME = 'go-emotions'

NUM_LABELS = 28
N_LATENT = 64

NUM_EPOCHS = 50
BATCH_SIZE = 256
MAX_SEQ_LEN = 512
LEARNING_RATE = 2e-4
MAX_GRAD_NORM = 1000

In [11]:
emotions_1_df = pd.read_csv(f'data/goemotions_1.csv')
emotions_2_df = pd.read_csv(f'data/goemotions_2.csv')
emotions_3_df = pd.read_csv(f'data/goemotions_3.csv')

emotions_df = pd.concat([
    emotions_1_df,
    emotions_2_df,
    emotions_3_df
], ignore_index=True, axis=0)
emotions_df

Unnamed: 0,text,id,author,subreddit,link_id,parent_id,created_utc,rater_id,example_very_unclear,admiration,...,love,nervousness,optimism,pride,realization,relief,remorse,sadness,surprise,neutral
0,That game hurt.,eew5j0j,Brdd9,nrl,t3_ajis4z,t1_eew18eq,1.548381e+09,1,False,0,...,0,0,0,0,0,0,0,1,0,0
1,>sexuality shouldn’t be a grouping category I...,eemcysk,TheGreen888,unpopularopinion,t3_ai4q37,t3_ai4q37,1.548084e+09,37,True,0,...,0,0,0,0,0,0,0,0,0,0
2,"You do right, if you don't care then fuck 'em!",ed2mah1,Labalool,confessions,t3_abru74,t1_ed2m7g7,1.546428e+09,37,False,0,...,0,0,0,0,0,0,0,0,0,1
3,Man I love reddit.,eeibobj,MrsRobertshaw,facepalm,t3_ahulml,t3_ahulml,1.547965e+09,18,False,0,...,1,0,0,0,0,0,0,0,0,0
4,"[NAME] was nowhere near them, he was by the Fa...",eda6yn6,American_Fascist713,starwarsspeculation,t3_ackt2f,t1_eda65q2,1.546669e+09,2,False,0,...,0,0,0,0,0,0,0,0,0,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
211220,Everyone likes [NAME].,ee6pagw,Senshado,heroesofthestorm,t3_agjf24,t3_agjf24,1.547634e+09,16,False,0,...,1,0,0,0,0,0,0,0,0,0
211221,Well when you’ve imported about a gazillion of...,ef28nod,5inchloser,nottheonion,t3_ak26t3,t3_ak26t3,1.548553e+09,15,False,0,...,0,0,0,0,0,0,0,0,0,0
211222,That looks amazing,ee8hse1,springt1me,shittyfoodporn,t3_agrnqb,t3_agrnqb,1.547684e+09,70,False,1,...,0,0,0,0,0,0,0,0,0,0
211223,The FDA has plenty to criticize. But like here...,edrhoxh,enamedata,medicine,t3_aejqzd,t1_edrgdtx,1.547169e+09,4,False,0,...,0,0,0,0,0,0,0,0,0,0


In [12]:
labels = [
    'admiration',
    'amusement',
    'anger',
    'annoyance',
    'approval',
    'caring',
    'confusion',
    'curiosity',
    'desire',
    'disappointment',
    'disapproval',
    'disgust',
    'embarrassment',
    'excitement',
    'fear',
    'gratitude',
    'grief',
    'joy',
    'love',
    'nervousness',
    'optimism',
    'pride',
    'realization',
    'relief',
    'remorse',
    'sadness',
    'surprise',
    'neutral'
]

In [13]:
emotions_df['label'] = emotions_df[labels].apply(lambda x: x.to_list(), axis=1)
emotions_df

Unnamed: 0,text,id,author,subreddit,link_id,parent_id,created_utc,rater_id,example_very_unclear,admiration,...,nervousness,optimism,pride,realization,relief,remorse,sadness,surprise,neutral,label
0,That game hurt.,eew5j0j,Brdd9,nrl,t3_ajis4z,t1_eew18eq,1.548381e+09,1,False,0,...,0,0,0,0,0,0,1,0,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,>sexuality shouldn’t be a grouping category I...,eemcysk,TheGreen888,unpopularopinion,t3_ai4q37,t3_ai4q37,1.548084e+09,37,True,0,...,0,0,0,0,0,0,0,0,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,"You do right, if you don't care then fuck 'em!",ed2mah1,Labalool,confessions,t3_abru74,t1_ed2m7g7,1.546428e+09,37,False,0,...,0,0,0,0,0,0,0,0,1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,Man I love reddit.,eeibobj,MrsRobertshaw,facepalm,t3_ahulml,t3_ahulml,1.547965e+09,18,False,0,...,0,0,0,0,0,0,0,0,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,"[NAME] was nowhere near them, he was by the Fa...",eda6yn6,American_Fascist713,starwarsspeculation,t3_ackt2f,t1_eda65q2,1.546669e+09,2,False,0,...,0,0,0,0,0,0,0,0,1,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
211220,Everyone likes [NAME].,ee6pagw,Senshado,heroesofthestorm,t3_agjf24,t3_agjf24,1.547634e+09,16,False,0,...,0,0,0,0,0,0,0,0,0,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
211221,Well when you’ve imported about a gazillion of...,ef28nod,5inchloser,nottheonion,t3_ak26t3,t3_ak26t3,1.548553e+09,15,False,0,...,0,0,0,0,0,0,0,0,0,"[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
211222,That looks amazing,ee8hse1,springt1me,shittyfoodporn,t3_agrnqb,t3_agrnqb,1.547684e+09,70,False,1,...,0,0,0,0,0,0,0,0,0,"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
211223,The FDA has plenty to criticize. But like here...,edrhoxh,enamedata,medicine,t3_aejqzd,t1_edrgdtx,1.547169e+09,4,False,0,...,0,0,0,0,0,0,0,0,0,"[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [14]:
train_split, test_split = train_test_split(emotions_df, test_size=0.2, random_state=SEED)
test_split, val_split = train_test_split(test_split, test_size=0.5, random_state=SEED)

In [15]:
dataset = DatasetDict({
    'emotions_train': Dataset.from_pandas(train_split[['text', 'label']]),
    'emotions_val': Dataset.from_pandas(test_split[['text', 'label']]),
    'emotions_test': Dataset.from_pandas(val_split[['text', 'label']]),
})
dataset

DatasetDict({
    emotions_train: Dataset({
        features: ['text', 'label', '__index_level_0__'],
        num_rows: 168980
    })
    emotions_val: Dataset({
        features: ['text', 'label', '__index_level_0__'],
        num_rows: 21122
    })
    emotions_test: Dataset({
        features: ['text', 'label', '__index_level_0__'],
        num_rows: 21123
    })
})

In [16]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer

DistilBertTokenizerFast(name_or_path='distilbert-base-cased', vocab_size=28996, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True),  added_tokens_decoder={
	0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	100: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	101: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	102: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
	103: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True),
}

In [17]:
tokenized_dataset = dataset
tokenized_dataset = tokenized_dataset.map(
    lambda x: tokenizer(x['text'], truncation=True, padding='max_length', max_length=MAX_SEQ_LEN,
                        return_tensors='pt'), batched=True)
tokenized_dataset.set_format('torch', device=device)
tokenized_dataset

Map:   0%|          | 0/168980 [00:00<?, ? examples/s]

Map:   0%|          | 0/21122 [00:00<?, ? examples/s]

Map:   0%|          | 0/21123 [00:00<?, ? examples/s]

DatasetDict({
    emotions_train: Dataset({
        features: ['text', 'label', '__index_level_0__', 'input_ids', 'attention_mask'],
        num_rows: 168980
    })
    emotions_val: Dataset({
        features: ['text', 'label', '__index_level_0__', 'input_ids', 'attention_mask'],
        num_rows: 21122
    })
    emotions_test: Dataset({
        features: ['text', 'label', '__index_level_0__', 'input_ids', 'attention_mask'],
        num_rows: 21123
    })
})

In [18]:
raw_model_base = AutoModel.from_pretrained(MODEL_NAME).to(device)
raw_model_base

DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Li

In [19]:
class BertForMultilabelOutput:
    def __init__(self, loss, logits):
        self.loss = loss
        self.logits = logits


class BertForMultilabelClassification(torch.nn.Module):
    def __init__(self, model, num_labels):
        super(BertForMultilabelClassification, self).__init__()
        self.num_labels = num_labels

        self.bert = model
        self.device = model.device

        self.dropout = torch.nn.Dropout(0.1)
        self.classifier = torch.nn.Linear(768, num_labels)

    def to(self, device):
        self.device = device
        super().to(device)
        return self

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        # pooled_output = outputs.pooler_output
        pooled_output = outputs.last_hidden_state[:, 0, :]
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss = F.binary_cross_entropy_with_logits(logits.view(-1, self.num_labels),
                                                      labels.view(-1, self.num_labels).float())

        return BertForMultilabelOutput(loss, logits)

In [20]:
raw_model = BertForMultilabelClassification(raw_model_base, NUM_LABELS).to(device)

In [21]:
raw_model.load_state_dict(torch.load(f'models/{DATASET_NAME}_{MODEL_NAME}_{NUM_EPOCHS}.bin'))

<All keys matched successfully>

In [18]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

In [23]:
for p in raw_model.bert.parameters():
    p.requires_grad = False
# for p in raw_model.bert.parameters():
#     p.requires_grad = False

In [24]:
count_parameters(raw_model)

21532

In [19]:
def get_preds_from_logits(outputs):
    return outputs.loss, (outputs.logits >= 0).long(), outputs.logits

In [20]:
config = {
    'experiment_name': 'default',
    'log_terminal': True,

    'dataset_name': DATASET_NAME,
    'model_name': MODEL_NAME,

    'num_labels': NUM_LABELS,
    'num_epochs': NUM_EPOCHS,
    'iterations': 1,

    'max_seq_len': MAX_SEQ_LEN,
    'batch_size': BATCH_SIZE,
    'lr': LEARNING_RATE,
    'fp16': False,
    'max_grad_norm': MAX_GRAD_NORM,
    'weight_decay': 0.01,
    'num_warmup_steps': 0,
    'gradient_accumulation_steps': 1,

    # funcs:
    'preds_from_logits_func': get_preds_from_logits
}

In [21]:
base_name = f'{DATASET_NAME}-{MODEL_NAME}-{NUM_EPOCHS}'

In [28]:
log = get_logger(f'logs/{config["experiment_name"]}', base_name)

In [None]:
# base_history = training(raw_model, tokenized_dataset['emotions_train'], tokenized_dataset['emotions_val'], log, config)

[90m2024-02-17 22:31:13,992 - default.terminal - DEBUG - Train steps: 33003[0m[0m
[90m2024-02-17 22:31:13,994 - default.terminal - DEBUG - Steps per epoch: 660.078125[0m[0m
[36m2024-02-17 22:31:13,995 - default.terminal - INFO - Epoch: 1[0m[0m
100%|██████████| 661/661 [13:15<00:00,  1.20s/it]
100%|██████████| 83/83 [01:33<00:00,  1.13s/it]
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
[36m2024-02-17 22:46:13,754 - default.terminal - INFO - [0m[0m
[36m2024-02-17 22:46:13,755 - default.terminal - INFO - Train loss: 0.18553323121727325 | Val loss: 0.1526179112583758[0m[0m
[36m2024-02-17 22:46:13,756 - default.terminal - INFO - Train acc: 0.015700082850041423 | Val acc: 0.016144304516617743[0m[0m
[36m2024-02-17 22:46:13,757 - default.terminal - INFO - Train f1: 0.006716014946861905 | Val f1: 1.2810002049600328e-05[0m[0m
[36m2024-02-17 22:46:13,757 - default.terminal - INFO - Train precision: 0.03822812250

In [None]:
plot_results(log, base_history, plot_name='base_history')

In [22]:
def save_model(model, config, name):
    if not os.path.exists(f'models/{config["experiment_name"]}'):
        os.makedirs(f'models/{config["experiment_name"]}', exist_ok=True)
    full_model_path = f'models/{config["experiment_name"]}/{name}.bin'
    torch.save(model.state_dict(), full_model_path)
    return full_model_path

In [33]:
# save_model(raw_model, config, base_name)

'models/go-emotions_distilbert-base-cased_50/go-emotions_distilbert-base-cased_50.bin'

In [23]:
test_dataloader = torch.utils.data.DataLoader(tokenized_dataset['emotions_test'], batch_size=config['batch_size'])

In [31]:
eval_results(log, base_name, raw_model, test_dataloader, config)

100%|██████████| 83/83 [01:33<00:00,  1.13s/it]
  _warn_prf(average, modifier, msg_start, len(result))
[36m2024-02-18 18:33:52,030 - default.terminal - INFO - Val loss: 0.1306661101708929[0m[0m
[36m2024-02-18 18:33:52,031 - default.terminal - INFO - Val acc: 0.070681247928798[0m[0m
[36m2024-02-18 18:33:52,032 - default.terminal - INFO - Val f1: 0.059824949199795176[0m[0m
[36m2024-02-18 18:33:52,033 - default.terminal - INFO - Val precision: 0.32909000303780583[0m[0m
[36m2024-02-18 18:33:52,033 - default.terminal - INFO - Val recall: 0.038825854239911625[0m[0m


In [None]:
# source:
# https://emotiontypology.com/positive_emotion/positivesurprise/

In [24]:
defintions = [
    {'emotion': 'admiration',
     'definition': 'The feeling when you look up to someone who has excellent abilities or has accomplished impressive things. You have the urge to also achieve such things and be more like this person.'},
    {'emotion': 'amusement',
     'definition': 'The feeling when you encounter something silly, ironic, witty, or absurd, which makes you laugh. You have the urge to be playful and share the joke with others.'},
    {'emotion': 'anger',
     'definition': 'The feeling when someone did something bad that harmed or offended you. You want to go against this person to stop them or prevent them from doing it again.'},
    {'emotion': 'annoyance',
     'definition': 'The feeling when something is happening that bothers you. You have the urge to say or do something to change it or make it stop.'},
    {'emotion': 'approval',
     'definition': 'The feeling when you agree with or accept something. You have the urge to support or encourage it.'},
    {'emotion': 'caring',
     'definition': 'The feeling when you are concerned about someone or something. You have the urge to help or protect them.'},
    {'emotion': 'confusion',
     'definition': 'The feeling when you get information that does not make sense to you, leaving you uncertain what to do with it.'},
    {'emotion': 'curiosity',
     'definition': 'The feeling when you want to know more about something. You have the urge to explore and learn.'},
    {'emotion': 'desire', 'definition': 'The feeling when you want something. You have the urge to get it.'},
    {'emotion': 'disappointment',
     'definition': 'The feeling when something you hoped for did not happen. You have the urge to express your sadness and frustration.'},
    {'emotion': 'disapproval',
     'definition': 'The feeling when you disagree with or dislike something. You have the urge to criticize or oppose it.'},
    {'emotion': 'disgust',
     'definition': 'The feeling when you encounter something that you don’t want to get into contact with in any way (neither see, hear, feel, smell, or taste it), because you expect it is bad for you. You want to get it away from you.'},
    {'emotion': 'embarrassment',
     'definition': 'The feeling when people suddenly focus unwanted attention on you in a situation that is not in your control. You have the urge to get away from the attention.'},
    {'emotion': 'excitement',
     'definition': 'The feeling when you expect something good or nice will happen to you. You cannot wait for it to happen.'},
    {'emotion': 'fear',
     'definition': 'The feeling when you encounter or think about a thing or person that can harm you. You have the urge to avoid or get away from the threat.'},
    {'emotion': 'gratitude',
     'definition': 'The feeling when you think that someone has gone out of their way to do something good or nice for you. You have the urge to do something back and get closer to this person.'},
    {'emotion': 'grief',
     'definition': 'The feeling when you have lost something or someone that was important to you. You have the urge to express your sadness and cry.'},
    {'emotion': 'joy',
     'definition': 'The feeling when you are happy. You have the urge to smile and be friendly to others.'},
    {'emotion': 'love',
     'definition': 'The feeling when you care deeply about someone or something. You have the urge to get closer to this person or thing.'},
    {'emotion': 'nervousness',
     'definition': 'The feeling when you have to do something, but you think that something might go wrong that prevents you from succeeding. You don’t feel in control of the situation.'},
    {'emotion': 'optimism',
     'definition': 'The feeling when you think that something good or nice will happen to you. You have the urge to be positive and look forward to it.'},
    {'emotion': 'pride',
     'definition': 'The feeling when you possess or have accomplished something that other people find praiseworthy. You feel vigorous and have the urge to show off to others.'},
    {'emotion': 'realization',
     'definition': 'The feeling when you suddenly understand something that you did not understand before. You have the urge to act on this new understanding.'},
    {'emotion': 'relief',
     'definition': 'The feeling when an unpleasant experience is finally over, or when you find out that something you had dreaded has not happened (or will not happen). You can finally take your mind off it.'},
    {'emotion': 'remorse',
     'definition': 'The feeling when you have done something wrong and you feel sorry about it. You have the urge to apologize and make amends.'},
    {'emotion': 'sadness',
     'definition': 'The feeling when you lost something that was important to you. You have the urge to withdraw and to seek comfort.'},
    {'emotion': 'surprise',
     'definition': 'The feeling when something unexpected happens. You have the urge to pay attention to it and to find out more about it.'},
    {'emotion': 'neutral', 'definition': 'The feeling when you don’t feel any particular emotion.'}
]

In [25]:
emotions_defitions_df = pd.DataFrame([{**d, 'label': labels.index(d['emotion'])} for d in defintions])
emotions_defitions_df

Unnamed: 0,emotion,definition,label
0,admiration,The feeling when you look up to someone who ha...,0
1,amusement,The feeling when you encounter something silly...,1
2,anger,The feeling when someone did something bad tha...,2
3,annoyance,The feeling when something is happening that b...,3
4,approval,The feeling when you agree with or accept some...,4
5,caring,The feeling when you are concerned about someo...,5
6,confusion,The feeling when you get information that does...,6
7,curiosity,The feeling when you want to know more about s...,7
8,desire,The feeling when you want something. You have ...,8
9,disappointment,The feeling when something you hoped for did n...,9


In [26]:
emotions_defitions_df['text'] = emotions_defitions_df['emotion'] + ' - ' + emotions_defitions_df['definition']

definitions_dataset = DatasetDict({
    'definitions': Dataset.from_pandas(emotions_defitions_df[['text', 'label']]),
})
definitions_dataset

DatasetDict({
    definitions: Dataset({
        features: ['text', 'label'],
        num_rows: 28
    })
})

In [27]:
tokenized_definitions_dataset = definitions_dataset.map(
    lambda x: tokenizer(x['text'], truncation=True, padding='max_length', max_length=MAX_SEQ_LEN,
                        return_tensors='pt'),
    batched=True)

tokenized_definitions_dataset.set_format('torch', device=device)

tokenized_definitions_dataset

Map:   0%|          | 0/28 [00:00<?, ? examples/s]

DatasetDict({
    definitions: Dataset({
        features: ['text', 'label', 'input_ids', 'attention_mask'],
        num_rows: 28
    })
})

In [28]:
@eval
def create_base_knowledge_embeddings(model, dataloader):
    knowledge_embeds, knowledge_texts, knowledge_labels = [], [], []
    with torch.no_grad():
        for step, batch in enumerate(tqdm(dataloader, total=len(dataloader))):
            ids = batch['input_ids'].to(device, dtype=torch.long)
            mask = batch['attention_mask'].to(device, dtype=torch.long)

            projected = model.bert(input_ids=ids, attention_mask=mask)  # (B, seq_len, 768)

            # knowledge_embeds += projected.pooler_output.detach().cpu().tolist()
            knowledge_embeds += projected.last_hidden_state[:, 0, :].detach().cpu().tolist()
            knowledge_texts += batch['text']
            # we need this for further evaluation
            knowledge_labels += [d.item() for d in batch['label']]
    return {'embeds': knowledge_embeds, 'texts': knowledge_texts, 'labels': knowledge_labels}

In [29]:
definitions_dataloader = torch.utils.data.DataLoader(tokenized_definitions_dataset['definitions'],
                                                     batch_size=BATCH_SIZE, shuffle=False)

In [38]:
knowledge_dict = create_base_knowledge_embeddings(raw_model, definitions_dataloader)

100%|██████████| 1/1 [00:00<00:00,  7.03it/s]


In [39]:
len(knowledge_dict['embeds'][0])

768

In [40]:
knn = NearestNeighbors(n_neighbors=1, metric='euclidean')
knn.fit(knowledge_dict['embeds'])

In [30]:
@eval
def eval_base_knowledge_embeds(model, knn, dataloader, knowledge_dict):
    explained_embeds, explained_texts, explained_labels = [], [], []
    neigh_explained_texts, neigh_explained_labels = [], []
    neigh_explained_embeds = []
    with torch.no_grad():
        for step, batch in enumerate(tqdm(dataloader, total=len(dataloader))):
            ids = batch['input_ids'].to(device, dtype=torch.long)
            mask = batch['attention_mask'].to(device, dtype=torch.long)

            projected = model.bert(input_ids=ids, attention_mask=mask)  # (B, seq_len, 768)

            # raw_embeds = projected.pooler_output.detach().cpu().tolist()  # (B, 768)
            raw_embeds = projected.last_hidden_state[:, 0, :].detach().cpu().tolist()  # (B, 768)
            neighbors_ids = knn.kneighbors(raw_embeds, return_distance=False)  # (B, k), k neighbors ids for each sample

            k_neigh_texts = [[knowledge_dict['texts'][q] for q in neigh] for neigh in neighbors_ids]  # (B, k)
            k_neigh_embeds = [[knowledge_dict['embeds'][q] for q in neigh] for neigh in neighbors_ids]  # (B, k, 768)
            k_neigh_labels = [[knowledge_dict['labels'][q] for q in neigh] for neigh in neighbors_ids]  # (B, k)

            explained_embeds += raw_embeds  # (t, 768)
            explained_texts += batch['text']  # (t)
            explained_labels += batch['label'].cpu().tolist()  # (t)
            neigh_explained_texts += k_neigh_texts  # (t, k)
            neigh_explained_embeds += k_neigh_embeds  # (t, k, 768)
            neigh_explained_labels += k_neigh_labels  # (t, k)

    return {
        'embeds': explained_embeds,
        'texts': explained_texts,
        'labels': explained_labels,
        'neigh_texts': neigh_explained_texts,
        'neigh_embeds': neigh_explained_embeds,
        'neigh_labels': neigh_explained_labels
    }

In [42]:
explained_dict = eval_base_knowledge_embeds(raw_model, knn, test_dataloader, knowledge_dict)

100%|██████████| 83/83 [01:35<00:00,  1.16s/it]


In [31]:
def get_indices(lst):
    return [i for i in range(len(lst)) if lst[i] == 1]

In [44]:
explained_dict['texts'][0], get_indices(explained_dict['labels'][0]), explained_dict['neigh_texts'][0], \
explained_dict['neigh_labels'][0]

('Nice try Lumi',
 [4],
 ['excitement - The feeling when you expect something good or nice will happen to you. You cannot wait for it to happen.'],
 [13])

In [45]:
len(explained_dict['embeds'][0])

768

In [46]:
knowledge_df = pd.DataFrame(knowledge_dict)
explained_df = pd.DataFrame(explained_dict)
knowledge_df

Unnamed: 0,embeds,texts,labels
0,"[0.34494632482528687, -0.03852308914065361, -0...",admiration - The feeling when you look up to s...,0
1,"[0.37065282464027405, -0.011883205734193325, 0...",amusement - The feeling when you encounter som...,1
2,"[0.4170636534690857, 0.003493282478302717, -0....",anger - The feeling when someone did something...,2
3,"[0.39650318026542664, 0.07422290742397308, -0....",annoyance - The feeling when something is happ...,3
4,"[0.3548925817012787, -0.027174873277544975, -0...",approval - The feeling when you agree with or ...,4
5,"[0.3770696222782135, 0.010515465401113033, -0....",caring - The feeling when you are concerned ab...,5
6,"[0.3423830270767212, -0.05301979184150696, 0.0...",confusion - The feeling when you get informati...,6
7,"[0.380428284406662, 0.04202545806765556, -0.02...",curiosity - The feeling when you want to know ...,7
8,"[0.39525580406188965, 0.09177958220243454, -0....",desire - The feeling when you want something. ...,8
9,"[0.36203494668006897, 0.09390724450349808, -0....",disappointment - The feeling when something yo...,9


In [47]:
explained_df

Unnamed: 0,embeds,texts,labels,neigh_texts,neigh_embeds,neigh_labels
0,"[0.3434976041316986, -0.03572804480791092, -0....",Nice try Lumi,"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[excitement - The feeling when you expect some...,"[[0.28399282693862915, 0.13156133890151978, -0...",[13]
1,"[0.32336699962615967, 0.19512392580509186, -0....","Thank you, you're awesome","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[joy - The feeling when you are happy. You hav...,"[[0.29260754585266113, 0.035176899284124374, -...",[17]
2,"[0.2534746527671814, 0.05391553044319153, -0.0...",lol at them both getting downvotes,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[excitement - The feeling when you expect some...,"[[0.28399282693862915, 0.13156133890151978, -0...",[13]
3,"[0.4159287214279175, 0.09216415882110596, 0.05...","> the fact that it is illegal Watch out, we've...","[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[excitement - The feeling when you expect some...,"[[0.28399282693862915, 0.13156133890151978, -0...",[13]
4,"[0.4190342128276825, 0.017816435545682907, -0....","I meant we as in the group, sorry if that was ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[neutral - The feeling when you don’t feel any...,"[[0.35577085614204407, -0.10203522443771362, 0...",[27]
...,...,...,...,...,...,...
21118,"[0.4421524107456207, 0.029359780251979828, -0....",You. I like you.,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[joy - The feeling when you are happy. You hav...,"[[0.29260754585266113, 0.035176899284124374, -...",[17]
21119,"[0.4373226761817932, 0.0724048912525177, -0.08...",Well except for keeping it around as it is onl...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[neutral - The feeling when you don’t feel any...,"[[0.35577085614204407, -0.10203522443771362, 0...",[27]
21120,"[0.390005886554718, 0.038477808237075806, -0.1...",Thanks 💙 ice skating is also one of my favouri...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[joy - The feeling when you are happy. You hav...,"[[0.29260754585266113, 0.035176899284124374, -...",[17]
21121,"[0.3210469186306, 0.12851789593696594, -0.0660...",I'm not to proud of them right now. I'm mad at...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[excitement - The feeling when you expect some...,"[[0.28399282693862915, 0.13156133890151978, -0...",[13]


In [48]:
explained_df['neigh_labels_oh'] = explained_df['neigh_labels'].apply(
    lambda x: [1 if i in x else 0 for i in range(NUM_LABELS)])
explained_df

Unnamed: 0,embeds,texts,labels,neigh_texts,neigh_embeds,neigh_labels,neigh_labels_oh
0,"[0.3434976041316986, -0.03572804480791092, -0....",Nice try Lumi,"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[excitement - The feeling when you expect some...,"[[0.28399282693862915, 0.13156133890151978, -0...",[13],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."
1,"[0.32336699962615967, 0.19512392580509186, -0....","Thank you, you're awesome","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[joy - The feeling when you are happy. You hav...,"[[0.29260754585266113, 0.035176899284124374, -...",[17],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,"[0.2534746527671814, 0.05391553044319153, -0.0...",lol at them both getting downvotes,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[excitement - The feeling when you expect some...,"[[0.28399282693862915, 0.13156133890151978, -0...",[13],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."
3,"[0.4159287214279175, 0.09216415882110596, 0.05...","> the fact that it is illegal Watch out, we've...","[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[excitement - The feeling when you expect some...,"[[0.28399282693862915, 0.13156133890151978, -0...",[13],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."
4,"[0.4190342128276825, 0.017816435545682907, -0....","I meant we as in the group, sorry if that was ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[neutral - The feeling when you don’t feel any...,"[[0.35577085614204407, -0.10203522443771362, 0...",[27],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...,...
21118,"[0.4421524107456207, 0.029359780251979828, -0....",You. I like you.,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[joy - The feeling when you are happy. You hav...,"[[0.29260754585266113, 0.035176899284124374, -...",[17],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
21119,"[0.4373226761817932, 0.0724048912525177, -0.08...",Well except for keeping it around as it is onl...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[neutral - The feeling when you don’t feel any...,"[[0.35577085614204407, -0.10203522443771362, 0...",[27],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
21120,"[0.390005886554718, 0.038477808237075806, -0.1...",Thanks 💙 ice skating is also one of my favouri...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[joy - The feeling when you are happy. You hav...,"[[0.29260754585266113, 0.035176899284124374, -...",[17],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
21121,"[0.3210469186306, 0.12851789593696594, -0.0660...",I'm not to proud of them right now. I'm mad at...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[excitement - The feeling when you expect some...,"[[0.28399282693862915, 0.13156133890151978, -0...",[13],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, ..."


In [32]:
def cosine_similarity(a, b):
    return np.dot(a, b) / ((np.dot(a, a) ** 0.5) * (np.dot(b, b) ** 0.5))


def euclid_similarity(a, b):
    return np.linalg.norm(np.array(a) - np.array(b))


def jaccard_similarity(a, b):
    return (np.array(a) & np.array(b)).sum() / (np.array(a) | np.array(b)).sum()

In [33]:
def eval_metrics(knowledge_df, explained_df):
    metrics_dict = {
        'jaccard': [],
        'mean_cosine': [],
        'min_cosine': [],
        'mean_euclid': [],
        'max_euclid': []
    }

    # explained - is reactions with their predicted neighbors
    # knowledge - is definitions

    for explained_record in explained_df.to_dict('records'):
        if len(explained_record['neigh_labels']) == 0:
            print('No neighbors for this tweet')
            continue

        cosines = []
        euclids = []
        for neigh_embed in explained_record['neigh_embeds']:
            cosines.append(cosine_similarity(explained_record['embeds'], neigh_embed))
            euclids.append(euclid_similarity(explained_record['embeds'], neigh_embed))

        metrics_dict['mean_cosine'].append(np.mean(cosines))
        metrics_dict['min_cosine'].append(np.min(cosines))
        metrics_dict['mean_euclid'].append(np.mean(euclids))
        metrics_dict['max_euclid'].append(np.max(euclids))

        metrics_dict['jaccard'].append(
            jaccard_similarity(explained_record['labels'], explained_record['neigh_labels_oh']))

    return metrics_dict

In [51]:
metrics_dict = eval_metrics(knowledge_df, explained_df)
metrics_dict

{'jaccard': [0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.3333333333333333,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.3333333333333333,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  1.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.5,
  0.0,
  0.0,
  0.0,
  0.5,
  0.0,
  0.0,
  0.0,
  0.5,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,
  0.0,


In [52]:
cum_metrics_dict = {k: np.mean(v) for k, v in metrics_dict.items()}
cum_metrics_dict

{'jaccard': 0.08174310106564048,
 'mean_cosine': 0.9782368599900808,
 'min_cosine': 0.9782368599900808,
 'mean_euclid': 3.1602739775690765,
 'max_euclid': 3.1602739775690765}

In [53]:
cum_metrics_dict['accuracy'] = f1_score(explained_df['labels'].to_list(), explained_df['neigh_labels_oh'].to_list(),
                                        average='micro')
cum_metrics_dict['f1_score'] = f1_score(explained_df['labels'].tolist(), explained_df['neigh_labels_oh'].tolist(),
                                        average='macro')
cum_metrics_dict['precision'] = precision_score(explained_df['labels'].tolist(),
                                                explained_df['neigh_labels_oh'].tolist(), average='macro')
cum_metrics_dict['recall'] = recall_score(explained_df['labels'].tolist(), explained_df['neigh_labels_oh'].tolist(),
                                          average='macro')
cum_metrics_dict

  _warn_prf(average, modifier, msg_start, len(result))


{'jaccard': 0.08174310106564048,
 'mean_cosine': 0.9782368599900808,
 'min_cosine': 0.9782368599900808,
 'mean_euclid': 3.1602739775690765,
 'max_euclid': 3.1602739775690765,
 'accuracy': 0.08708676116260393,
 'f1_score': 0.04574756077724424,
 'precision': 0.1232083127510257,
 'recall': 0.07994180806351178}

In [56]:
# TOP 1
# {'jaccard': 0.08174310106564048,
#  'mean_cosine': 0.9782368599900808,
#  'min_cosine': 0.9782368599900808,
#  'mean_euclid': 3.1602739775690765,
#  'max_euclid': 3.1602739775690765,
#  'accuracy': 0.08708676116260393,
#  'f1_score': 0.04574756077724424,
#  'precision': 0.1232083127510257,
#  'recall': 0.07994180806351178}

# TOP 3
# {'jaccard': 0.0752102199588352,
#  'mean_cosine': 0.9773891830242699,
#  'min_cosine': 0.9766105941804047,
#  'mean_euclid': 3.2252378127506516,
#  'max_euclid': 3.2796130020449348,
#  'accuracy': 0.11535979165487177,
#  'f1_score': 0.08723256943361214,
#  'precision': 0.11160074421435333,
#  'recall': 0.20535381544324613}

# TOP 5
# {'jaccard': 0.07056713925766618,
#  'mean_cosine': 0.9768433215065202,
#  'min_cosine': 0.9757742301112715,
#  'mean_euclid': 3.2648585658317386,
#  'max_euclid': 3.337138323055487,
#  'accuracy': 0.11884100152419176,
#  'f1_score': 0.10087488701994642,
#  'precision': 0.09559821659645454,
#  'recall': 0.3068529642976469}

## KNN Model

In [25]:
# concept_spaces = [ # (n=2, B=2, seq_len=3, n_latent=4)
#     torch.tensor([
#         [
#             [0.9, 0.9, 0.9, 0.9],
#             [0.9, 0.9, 0.9, 0.9],
#             [0.9, 0.9, 0.9, 0.9]
#         ],
#         [
#             [-1.0, -1.0, -1.0, -1.0],
#             [-1.0, -1.0, -1.0, -1.0],
#             [-1.0, -1.0, -1.0, -1.0]
#         ],
#     ]),
#     torch.tensor([
#         [
#             [-0.5, -0.5, -0.5, -0.5],
#             [-0.5, -0.5, -0.5, -0.5],
#             [-0.5, -0.5, -0.5, -0.5]
#         ],
#         [
#             [1.0, 1.0, 1.0, 1.0],
#             [1.0, 1.0, 1.0, 1.0],
#             [1.0, 1.0, 1.0, 1.0]
#         ],
#     ])
# ]
#
# labels = torch.tensor([
#     [0, 1],
#     [1, 0]
# ])

In [26]:
# def inter_space_loss_old(concept_spaces, labels: torch.Tensor, m1:float = 0.5, m2: float = 0.5):
#     """
#     :param concept_spaces: pytorch tensors of shape: (n_concept_spaces, B, seq_len, n_latent)
#     :param embed: LM embeddings of shape: (B, seq_len, n_embed)
#     :param labels: labels of shape (B)
#     :param m1: weight of match loss
#     :param m2: weight of miss match loss
#     :return: loss (scalar)
#     """
#
#     loss = 0.0
#
#     for k in range(len(concept_spaces)):
#         # match loss
#         loss += m1 * torch.nan_to_num((1 - concept_spaces[k][labels == k]).mean())  # (B', n_latent, seq_len) * (B', seq_len, n_embed)
#
#         # mismatch loss
#         loss += m2 * torch.nan_to_num((1 + concept_spaces[k][labels != k]).mean())
#     return loss

In [27]:
# def inter_space_loss(concept_spaces: List[torch.Tensor], labels: torch.Tensor, m1: float = 0.5, m2: float = 0.5):
#     """
#     :param concept_spaces: pytorch tensors of shape: (n_concept_spaces, B, seq_len, n_latent)
#     :param labels: labels of shape (B) or of shape (B, labels_dim) (for multi-label classification)
#     :param m1: weight of match loss
#     :param m2: weight of miss match loss
#     :return: loss (scalar)
#     """
#
#     concept_spaces = torch.stack(concept_spaces, dim=0)  # (n_concept_spaces, B, seq_len, n_latent)
#
#     loss = 0.0
#
#     if len(labels.shape) == 1:
#         for k in range(len(concept_spaces)):
#             # match loss
#             loss += m1 * torch.nan_to_num((1 - concept_spaces[k][labels == k]).mean())  # (B', n_latent, seq_len) * (B', seq_len, n_embed)
#
#             # mismatch loss
#             loss += m2 * torch.nan_to_num((1 + concept_spaces[k][labels != k]).mean())
#         return loss
#
#     if len(labels.shape) > 2:
#         raise ValueError("labels must be a tensor or an integer: of shape (B) or of shape (B, labels_dim) (for multi-label classification)")
#
#     for idx, lb in enumerate(labels):
#         loss += m1 * torch.nan_to_num((1 - concept_spaces[torch.argwhere(lb > 0), idx])).mean(2).mean(2) # (n', 1, seq_len, n_latent)
#
#         loss += m2 * torch.nan_to_num((1 + concept_spaces[torch.argwhere(lb == 0), idx])).mean(2).mean(2)
#     return loss.mean()

In [28]:
# def inter_space_loss_optimized(concept_spaces: List[torch.Tensor], labels: torch.Tensor, m1: float = 0.5, m2: float = 0.5):
#     """
#     :param concept_spaces: pytorch tensors of shape: (n_concept_spaces, B, seq_len, n_latent)
#     :param labels: labels of shape (B) or of shape (B, labels_dim) (for multi-label classification)
#     :param m1: weight of match loss
#     :param m2: weight of miss match loss
#     :return: loss (scalar)
#     """
#
#     concept_spaces = torch.stack(concept_spaces, dim=0)  # (n_concept_spaces, B, seq_len, n_latent)
#
#     if len(labels.shape) == 1:
#         # match loss
#         match_loss = (1 - concept_spaces[labels == torch.arange(len(concept_spaces))]).mean()
#         # mismatch loss
#         mismatch_loss = (1 + concept_spaces[labels != torch.arange(len(concept_spaces))]).mean()
#     else:
#         if len(labels.shape) > 2:
#             raise ValueError("labels must be a tensor or an integer: of shape (B) or of shape (B, labels_dim) (for multi-label classification)")
#
#         match_loss = (1 - concept_spaces[labels.T > 0]).mean()
#         mismatch_loss = (1 + concept_spaces[labels.T == 0]).mean()
#
#     loss = m1 * match_loss + m2 * mismatch_loss
#
#     return loss.mean()

In [34]:
base_model = AutoModel.from_pretrained(MODEL_NAME)
base_model

DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(28996, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): MultiHeadSelfAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): Li

In [35]:
# class SpaceModelForMultiLabelOutput:
#     def __init__(self, loss, logits, concept_spaces):
#         self.loss = loss
#         self.logits = logits
#         self.concept_spaces = concept_spaces
#
#
# class SpaceModelForMultiLabelClassification(torch.nn.Module):
#     def __init__(self, base_model, n_embed, n_latent, n_concept_spaces, l1, l2, ce_w, fine_tune=False):
#         super(SpaceModelForMultiLabelClassification, self).__init__()
#         self.device = base_model.device
#         self.n_concept_spaces = n_concept_spaces
#
#         self.base_model = base_model
#
#         if fine_tune:
#             for p in base_model.parameters():
#                 p.requires_grad_(False)
#
#         self.space_model = SpaceModel(n_embed, n_latent, n_concept_spaces, output_concept_spaces=True)
#
#         self.classifier = torch.nn.Linear(n_concept_spaces * n_latent, n_concept_spaces)
#
#         self.l1 = l1
#         self.l2 = l2
#         self.ce_w = ce_w
#
#     def to(self, device):
#         self.device = device
#         super().to(device)
#         return self
#
#     def forward(self, input_ids, attention_mask, labels=None):
#         embed = self.base_model(input_ids, attention_mask).last_hidden_state  # (B, max_seq_len, 768)
#
#         # SpaceModelOutput(logits=(B, n_concept_spaces * n_latent), concept_spaces=(n_concept_spaces, B, max_seq_len, n_latent))
#         projected = self.space_model(embed)
#
#         concept_hidden = projected.logits
#
#         logits = self.classifier(concept_hidden)
#
#         loss = None
#         if labels is not None:
#             loss = self.ce_w * F.binary_cross_entropy_with_logits(
#                 logits.view(-1, self.n_concept_spaces),
#                 labels.view(-1, self.n_concept_spaces).float()
#             )
#
#             loss += self.l1 * inter_space_loss_optimized(projected.concept_spaces, labels) + \
#                     self.l2 * intra_space_loss(projected.concept_spaces)
#
#         return SpaceModelForMultiLabelOutput(loss, logits, projected.concept_spaces)

In [36]:
config = {
    'experiment_name': 'default',
    'log_terminal': True,

    'dataset_name': DATASET_NAME,
    'model_name': MODEL_NAME,

    'num_labels': NUM_LABELS,
    'num_epochs': NUM_EPOCHS,
    'iterations': 1,

    'max_seq_len': MAX_SEQ_LEN,
    'batch_size': BATCH_SIZE,
    'lr': LEARNING_RATE,
    'fp16': False,
    'max_grad_norm': MAX_GRAD_NORM,
    'weight_decay': 0.01,
    'num_warmup_steps': 0,
    'gradient_accumulation_steps': 1,

    'n_latent': N_LATENT,
    'cross_entropy_weight': 1.0,
    'l1': 0.1,
    'l2': 1e-5,
    'ce_w': 1.0,

    # funcs:
    'preds_from_logits_func': get_preds_from_logits
}

In [37]:
space_model = SpaceModelForMultiLabelClassification(
    base_model,
    n_embed=768,
    n_latent=N_LATENT,
    n_concept_spaces=NUM_LABELS,
    l1=config['l1'],
    l2=config['l2'],
    ce_w=config['ce_w'],
    fine_tune=True
).to(device)

In [38]:
# space_model.load_state_dict(
#     torch.load(f'models/{config["experiment_name"]}/{DATASET_NAME}_space-{MODEL_NAME}-({N_LATENT})_{NUM_EPOCHS}.bin', map_location=device))

In [39]:
count_parameters(space_model)

1426460

In [40]:
space_name = f'{DATASET_NAME}_space-{MODEL_NAME}-({N_LATENT})_{NUM_EPOCHS}'

In [41]:
log = get_logger(f'logs/{config["experiment_name"]}', space_name)

In [42]:
space_history = training(space_model, tokenized_dataset['emotions_train'], tokenized_dataset['emotions_val'], log, config)

[90m2024-02-20 02:45:13,810 - default.terminal - DEBUG - Train steps: 33003[0m[0m
[90m2024-02-20 02:45:13,811 - default.terminal - DEBUG - Steps per epoch: 660.078125[0m[0m
[36m2024-02-20 02:45:13,812 - default.terminal - INFO - Epoch: 1[0m[0m
  0%|          | 1/661 [00:02<22:15,  2.02s/it]


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.00 GiB (GPU 0; 15.77 GiB total capacity; 9.41 GiB already allocated; 2.51 GiB free; 12.29 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
plot_results(log, space_history, plot_name='space_history')

In [None]:
# save_model(space_model, config, space_name)

In [46]:
eval_results(log, space_name, space_model, test_dataloader, config)

100%|██████████| 83/83 [01:34<00:00,  1.14s/it]
  _warn_prf(average, modifier, msg_start, len(result))
[36m2024-02-20 02:27:42,850 - default.terminal - INFO - Val loss: 0.19250612032700734[0m[0m
[36m2024-02-20 02:27:42,851 - default.terminal - INFO - Val acc: 0.09099086304028783[0m[0m
[36m2024-02-20 02:27:42,852 - default.terminal - INFO - CS Val acc: 0.27486625952752924[0m[0m
[36m2024-02-20 02:27:42,853 - default.terminal - INFO - Val f1: 0.0811773450468776[0m[0m
[36m2024-02-20 02:27:42,854 - default.terminal - INFO - CS Val f1: 0.056084605794717614[0m[0m
[36m2024-02-20 02:27:42,854 - default.terminal - INFO - Val precision: 0.3806738435431122[0m[0m
[36m2024-02-20 02:27:42,855 - default.terminal - INFO - Val recall: 0.05573656721582062[0m[0m


In [47]:
@eval
def create_space_knowledge_embeddings(model, dataloader):
    knowledge_embeds, knowledge_texts, knowledge_labels = [], [], []
    with torch.no_grad():
        for step, batch in enumerate(tqdm(dataloader, total=len(dataloader))):
            ids = batch['input_ids'].to(device, dtype=torch.long)
            mask = batch['attention_mask'].to(device, dtype=torch.long)

            embed = model.base_model(input_ids=ids, attention_mask=mask).last_hidden_state  # (B, seq_len, 768)

            projected = model.space_model(embed)  # (B, n_concept_spaces * n_latent)

            knowledge_embeds += projected.logits.detach().cpu().tolist()
            knowledge_texts += batch['text']
            # we need this for further evaluation
            knowledge_labels += [d.item() for d in batch['label']]
    return {'embeds': knowledge_embeds, 'texts': knowledge_texts, 'labels': knowledge_labels}

In [48]:
space_knowledge_dict = create_space_knowledge_embeddings(space_model, definitions_dataloader)

100%|██████████| 1/1 [00:00<00:00,  7.10it/s]


In [49]:
len(space_knowledge_dict['embeds'][0])

84

In [54]:
@eval
def eval_space_knowledge_embeds(model, knn, dataloader, knowledge_dict):
    explained_embeds, explained_texts, explained_labels = [], [], []
    neigh_explained_texts, neigh_explained_labels = [], []
    neigh_explained_embeds = []
    with torch.no_grad():
        for step, batch in enumerate(tqdm(dataloader, total=len(dataloader))):
            ids = batch['input_ids'].to(device, dtype=torch.long)
            mask = batch['attention_mask'].to(device, dtype=torch.long)

            embed = model.base_model(input_ids=ids, attention_mask=mask).last_hidden_state  # (B, seq_len, 768)

            projected = model.space_model(embed)  # (B, n_concept_spaces * n_latent)

            raw_embeds = projected.logits.detach().cpu().tolist()  # (B, 768)
            neighbors_ids = knn.kneighbors(raw_embeds, return_distance=False)  # (B, k), k neighbors ids for each sample

            k_neigh_texts = [[knowledge_dict['texts'][q] for q in neigh] for neigh in neighbors_ids]  # (B, k)
            k_neigh_embeds = [[knowledge_dict['embeds'][q] for q in neigh] for neigh in neighbors_ids]  # (B, k, 768)
            k_neigh_labels = [[knowledge_dict['labels'][q] for q in neigh] for neigh in neighbors_ids]  # (B, k)

            explained_embeds += raw_embeds  # (t, 768)
            explained_texts += batch['text']  # (t)
            explained_labels += batch['label'].cpu().tolist()  # (t)
            neigh_explained_texts += k_neigh_texts  # (t, k)
            neigh_explained_embeds += k_neigh_embeds  # (t, k, 768)
            neigh_explained_labels += k_neigh_labels  # (t, k)

    return {
        'embeds': explained_embeds,
        'texts': explained_texts,
        'labels': explained_labels,
        'neigh_texts': neigh_explained_texts,
        'neigh_embeds': neigh_explained_embeds,
        'neigh_labels': neigh_explained_labels
    }

In [69]:
space_knn = NearestNeighbors(n_neighbors=1, metric='euclidean')
space_knn.fit(space_knowledge_dict['embeds'])

In [70]:
space_explained_dict = eval_space_knowledge_embeds(space_model, space_knn, test_dataloader, space_knowledge_dict)

100%|██████████| 83/83 [01:34<00:00,  1.14s/it]


In [71]:
space_knowledge_df = pd.DataFrame(space_knowledge_dict)
space_explained_df = pd.DataFrame(space_explained_dict)
space_explained_df

Unnamed: 0,embeds,texts,labels,neigh_texts,neigh_embeds,neigh_labels
0,"[0.18983136117458344, 0.8249680995941162, 0.96...",Nice try Lumi,"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[amusement - The feeling when you encounter so...,"[[0.24588538706302643, 0.5915502905845642, 0.6...",[1]
1,"[0.948785126209259, 0.9889746904373169, 0.9892...","Thank you, you're awesome","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[admiration - The feeling when you look up to ...,"[[0.9180653691291809, 0.9442663192749023, 0.96...",[0]
2,"[-0.22700846195220947, -0.5483300685882568, -0...",lol at them both getting downvotes,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[amusement - The feeling when you encounter so...,"[[0.24588538706302643, 0.5915502905845642, 0.6...",[1]
3,"[0.10304953902959824, 0.4127613604068756, 0.41...","> the fact that it is illegal Watch out, we've...","[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[realization - The feeling when you suddenly u...,"[[-0.08924122154712677, 0.5145065784454346, 0....",[22]
4,"[-0.5465888381004333, -0.8726741075515747, -0....","I meant we as in the group, sorry if that was ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[neutral - The feeling when you don’t feel any...,"[[-0.5819293260574341, -0.6192928552627563, -0...",[27]
...,...,...,...,...,...,...
21118,"[0.635576605796814, 0.9259747862815857, 0.9566...",You. I like you.,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[approval - The feeling when you agree with or...,"[[0.5311691761016846, 0.8413762450218201, 0.48...",[4]
21119,"[0.11893484741449356, 0.42488181591033936, -0....",Well except for keeping it around as it is onl...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[neutral - The feeling when you don’t feel any...,"[[-0.5819293260574341, -0.6192928552627563, -0...",[27]
21120,"[0.7046402096748352, 0.9438740015029907, 0.941...",Thanks 💙 ice skating is also one of my favouri...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[admiration - The feeling when you look up to ...,"[[0.9180653691291809, 0.9442663192749023, 0.96...",[0]
21121,"[0.36170586943626404, 0.5032520890235901, 0.61...",I'm not to proud of them right now. I'm mad at...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[amusement - The feeling when you encounter so...,"[[0.24588538706302643, 0.5915502905845642, 0.6...",[1]


In [72]:
space_explained_df['neigh_labels_oh'] = space_explained_df['neigh_labels'].apply(
    lambda x: [1 if i in x else 0 for i in range(NUM_LABELS)])
space_explained_df

Unnamed: 0,embeds,texts,labels,neigh_texts,neigh_embeds,neigh_labels,neigh_labels_oh
0,"[0.18983136117458344, 0.8249680995941162, 0.96...",Nice try Lumi,"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[amusement - The feeling when you encounter so...,"[[0.24588538706302643, 0.5915502905845642, 0.6...",[1],"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
1,"[0.948785126209259, 0.9889746904373169, 0.9892...","Thank you, you're awesome","[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[admiration - The feeling when you look up to ...,"[[0.9180653691291809, 0.9442663192749023, 0.96...",[0],"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
2,"[-0.22700846195220947, -0.5483300685882568, -0...",lol at them both getting downvotes,"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[amusement - The feeling when you encounter so...,"[[0.24588538706302643, 0.5915502905845642, 0.6...",[1],"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
3,"[0.10304953902959824, 0.4127613604068756, 0.41...","> the fact that it is illegal Watch out, we've...","[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[realization - The feeling when you suddenly u...,"[[-0.08924122154712677, 0.5145065784454346, 0....",[22],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
4,"[-0.5465888381004333, -0.8726741075515747, -0....","I meant we as in the group, sorry if that was ...","[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[neutral - The feeling when you don’t feel any...,"[[-0.5819293260574341, -0.6192928552627563, -0...",[27],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
...,...,...,...,...,...,...,...
21118,"[0.635576605796814, 0.9259747862815857, 0.9566...",You. I like you.,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[approval - The feeling when you agree with or...,"[[0.5311691761016846, 0.8413762450218201, 0.48...",[4],"[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
21119,"[0.11893484741449356, 0.42488181591033936, -0....",Well except for keeping it around as it is onl...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[neutral - The feeling when you don’t feel any...,"[[-0.5819293260574341, -0.6192928552627563, -0...",[27],"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
21120,"[0.7046402096748352, 0.9438740015029907, 0.941...",Thanks 💙 ice skating is also one of my favouri...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[admiration - The feeling when you look up to ...,"[[0.9180653691291809, 0.9442663192749023, 0.96...",[0],"[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."
21121,"[0.36170586943626404, 0.5032520890235901, 0.61...",I'm not to proud of them right now. I'm mad at...,"[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ...",[amusement - The feeling when you encounter so...,"[[0.24588538706302643, 0.5915502905845642, 0.6...",[1],"[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ..."


In [73]:
space_metrics_dict = eval_metrics(space_knowledge_df, space_explained_df)

In [74]:
space_cum_metrics_dict = {k: np.mean(v) for k, v in space_metrics_dict.items()}
space_cum_metrics_dict

{'jaccard': 0.15188938710455538,
 'mean_cosine': 0.8230332105397384,
 'min_cosine': 0.8230332105397384,
 'mean_euclid': 3.77350620053899,
 'max_euclid': 3.77350620053899}

In [75]:
space_cum_metrics_dict['accuracy'] = f1_score(space_explained_df['labels'].to_list(), space_explained_df['neigh_labels_oh'].to_list(),
                                        average='micro')
space_cum_metrics_dict['f1_score'] = f1_score(space_explained_df['labels'].tolist(), space_explained_df['neigh_labels_oh'].tolist(),
                                        average='macro')
space_cum_metrics_dict['precision'] = precision_score(space_explained_df['labels'].tolist(),
                                                space_explained_df['neigh_labels_oh'].tolist(), average='macro')
space_cum_metrics_dict['recall'] = recall_score(space_explained_df['labels'].tolist(), space_explained_df['neigh_labels_oh'].tolist(),
                                          average='macro')
space_cum_metrics_dict

{'jaccard': 0.15188938710455538,
 'mean_cosine': 0.8230332105397384,
 'min_cosine': 0.8230332105397384,
 'mean_euclid': 3.77350620053899,
 'max_euclid': 3.77350620053899,
 'accuracy': 0.1599340120254401,
 'f1_score': 0.09668061013933564,
 'precision': 0.17857667555470375,
 'recall': 0.09767180932841606}

In [69]:
def concept_space_to_embeds(concept_spaces, targets):
    tensor_concept_spaces = torch.cat([cs.unsqueeze(0) for cs in concept_spaces], dim=0)
    concept_space_dist = tensor_concept_spaces.permute(1, 0, 2, 3)  # (B, n, seq_len, n_latent)

    (B, n, seq_len, n_latent) = concept_space_dist.shape
    return concept_space_dist[torch.arange(B), targets, :, :]