In [1]:
from os import path
import sys

import datasets
import torch
from torch.optim import SGD

from core.language.preprocessing import Codec, Tokenizer
from core.language.model import Model, train, eval
from core.language import utils

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
dataset = datasets.load_dataset("emotion")
train_dataset, validation_dataset = dataset['train'], dataset['validation']

In [3]:
tokenizer = Tokenizer(lazy=True)
codec = Codec(tokenizer.tokenize_all(train_dataset["text"], flatten=True))

In [None]:
emotion_by_label = {
    0: 'sadness',
    1: 'joy',
    2: 'love',
    3: 'anger',
    4: 'fear',
    5: 'surprise'
}

In [None]:
tokenizer = Tokenizer(dataset['train'])
n_classes = len(emotion_by_label)
device = utils.get_available_device()
model = Model(tokenizer.vocab_size, n_classes).to(device)
optimizer = SGD(model.parameters(), lr=0.01)

print(f"Model initalized, starting training on '{device}'...\n")
epoch = 0
batch_size = 128
stopping_criterion = 1e-3
min_val_loss = float('inf')
iterations_without_improvement = 0
while iterations_without_improvement < 3:
    with utils.Timer() as epoch_timer:
        epoch += 1
        total_train_loss, avg_train_loss = train(model, train_dataset, tokenizer, optimizer, device, batch_size)
        total_val_loss, avg_val_loss = eval(model, validation_dataset, tokenizer, device, batch_size)
        if total_val_loss < min_val_loss - stopping_criterion:
            min_val_loss = total_val_loss
            iterations_without_improvement = 0
        else:
            iterations_without_improvement += 1
    print("Epoch #{:0>3} [{:.2f}s] :: Train loss: '{:.4f}' Validation loss: '{:.4f}'".format(
        epoch, epoch_timer.interval, avg_train_loss, avg_val_loss))

In [None]:
def predict_emotion(document: str) -> str:    
    tokens = tokenizer.tokenize(document)
    token_tensor = torch.zeros((1, tokenizer.vocab_size))
    token_tensor[:, tokens] = 1
    return emotion_by_label[model(token_tensor).argmax(dim=1).item()]

In [None]:
print(predict_emotion("That's too much"))
print(predict_emotion("I love you"))
print(predict_emotion("I hate you"))
print(predict_emotion("I'm sad"))
print(predict_emotion("I'm happy"))
print(predict_emotion("I'm scared"))