In [None]:
# Uncomment line below to install exlib
# !pip install exlib

In [1]:
import torch
from transformers import AutoModel, AutoTokenizer
import numpy as np
import tqdm
from tqdm import tqdm
from torch.utils.data import DataLoader
from datasets import load_dataset
import torch.nn as nn
import sentence_transformers

import sys
sys.path.append('/shared_data0/chaenyk/exlib/src')
import exlib
from exlib.utils.emotion_helper import project_points_onto_axes, load_emotions
from exlib.datasets.emotion import load_data, load_model, EmotionDataset, EmotionClassifier, Metric, get_emotion_scores

from exlib.features.text.identity import IdentityGroups
from exlib.features.text.random import RandomGroups
from exlib.features.text.word import WordGroups
from exlib.features.text.phrase import PhraseGroups
from exlib.features.text.sentence import SentenceGroups
from exlib.features.text.clustering import ClusteringGroups
from exlib.features.text.archipelago import WrappedModel, ArchipelagoGroups

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load datasets and pre-trained models

In [2]:
dataset = EmotionDataset("train")
model = EmotionClassifier()

### Model prediction

In [3]:
model.to(device)
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)

In [4]:
model.eval()
for batch in tqdm(dataloader): 
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    output = model(input_ids, attention_mask)
    utterances = [dataset.tokenizer.decode(input_id, skip_special_tokens=True) for input_id in input_ids]
    for utterance, label in zip(utterances, output.logits):
        print("Text: {}\nEmotion: {}\n".format(utterance, label.argmax()))
    break

  0% 0/10853 [00:00<?, ?it/s]

Text: My favourite food is anything I didn't have to cook myself.
Emotion: 18

Text: Now if he does off himself, everyone will think hes having a laugh screwing with people instead of actually dead
Emotion: 27

Text: WHY THE FUCK IS BAYLESS ISOING
Emotion: 14

Text: To make her feel threatened
Emotion: 18






### Baselines
- Identity
- Random
- Words
- Phrases
- Sentences
- Clustering
- Archipelago

In [None]:
scores = get_emotion_scores()

In [None]:
for name in scores:
    metric = torch.tensor(scores[name])
    mean_metric = metric.nanmean()
    print(f'BASELINE {name} mean score: {mean_metric}')