In [1]:
# Uncomment line below to install exlib
# !pip install exlib

In [2]:
import torch
from transformers import AutoModel, AutoTokenizer
import numpy as np
import tqdm
from tqdm import tqdm
from torch.utils.data import DataLoader
from datasets import load_dataset
import torch.nn as nn
import sentence_transformers

import sys
sys.path.insert(0, "../../src")
import exlib
from exlib.utils.emotion_helper import project_points_onto_axes, load_emotions
from exlib.datasets.emotion import load_data, load_model, EmotionDataset, EmotionClassifier, EmotionFixScore, get_emotion_scores

from exlib.features.text import *

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  from .autonotebook import tqdm as notebook_tqdm


### Load datasets and pre-trained models

In [3]:
dataset = EmotionDataset("test")
dataloader = DataLoader(dataset, batch_size=4, shuffle=False)
model = EmotionClassifier().eval().to(device)

SamLowe/roberta-base-go_emotions


### Model prediction

In [4]:
for batch in tqdm(dataloader): 
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    output = model(input_ids, attention_mask)
    utterances = [dataset.tokenizer.decode(input_id, skip_special_tokens=True) for input_id in input_ids]
    for utterance, label in zip(utterances, output.logits):
        id_str = model.model.config.id2label[label.argmax().item()]
        print("Text: {}\nEmotion: {}\n".format(utterance, id_str))
    break

  0%|                             | 0/1357 [00:00<?, ?it/s]

Text: I’m really sorry about your situation :( Although I love the names Sapphira, Cirilla, and Scarlett!
Emotion: remorse

Text: It's wonderful because it's awful. At not with.
Emotion: admiration

Text: Kings fan here, good luck to you guys! Will be an interesting game to watch! 
Emotion: optimism

Text: I didn't know that, thank you for teaching me something today!
Emotion: gratitude






In [5]:
scores = get_emotion_scores([
    "identity", "random", "word", "phrase", "sentence", "clustering", "archipelago"    
])

SamLowe/roberta-base-go_emotions


100%|██████████████████| 1357/1357 [00:44<00:00, 30.35it/s]
100%|██████████████████| 1357/1357 [02:01<00:00, 11.19it/s]
100%|██████████████████| 1357/1357 [01:24<00:00, 16.00it/s]
100%|██████████████████| 1357/1357 [01:36<00:00, 14.01it/s]
100%|██████████████████| 1357/1357 [00:50<00:00, 26.92it/s]
100%|██████████████████| 1357/1357 [03:08<00:00,  7.22it/s]
100%|██████████████████| 1357/1357 [18:39<00:00,  1.21it/s]


In [6]:
for name in scores:
    metric = torch.tensor(scores[name])
    mean_metric = metric.nanmean()
    print(f'BASELINE {name} mean score: {mean_metric}')

BASELINE identity mean score: 0.010318498686651098
BASELINE random mean score: 0.02961087006586181
BASELINE word mean score: 0.1353508580950209
BASELINE phrase mean score: 0.021202182059697837
BASELINE sentence mean score: 0.016167678546141324
BASELINE clustering mean score: 0.09731207329481384
BASELINE archipelago mean score: 0.052713106135909224
