In [8]:
from utils import Config
from gpt2 import GPTCls
import torch
import tiktoken
class EmotionDetector():
    def __init__(self, config):
        self.config = config
        self.emotions = ['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']
        self.model = GPTCls(config=self.config, load_gpt2=False).to(self.config.device)
        self.model.load_state_dict(torch.load("3.pth", weights_only=True, map_location=self.config.device))
        self.tokenizer = tiktoken.get_encoding("gpt2")
    
    def inference(self, text, threshold=0.3):
        self.model.eval()
        encoded_text = torch.tensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.config.device)
        logits = self.model(encoded_text)
        probs = logits.sigmoid()
        preds = probs.clone()
        preds[preds < threshold] = 0
        preds[preds >= threshold] = 1
        preds, probs = preds.squeeze(0), probs.squeeze(0)
        all_emotions = {}
        for i, each_prob in enumerate(preds):
            if each_prob == 1:
                all_emotions[self.emotions[i]] = round(probs[i].item(), 4)
        all_emotions = sorted(all_emotions.items(), key=lambda x:x[1], reverse=True)
        print(f"all emotions: {all_emotions}")
        return all_emotions
config = Config()
ed = EmotionDetector(config=config)

number of parameters:  124.44M


In [9]:
ed.inference("I am very sad.");

all emotions: [('sadness', 0.989)]


In [10]:
ed.inference("John is more than happy about the result of the exam.");

all emotions: [('joy', 0.8425)]


In [11]:
ed.inference("Do you really think it is a good idea to go to the party");

all emotions: [('curiosity', 0.3292)]


In [12]:
ed.inference("I agree with you");

all emotions: [('approval', 0.9936)]


In [13]:
ed.inference("lol");

all emotions: [('amusement', 0.8112), ('love', 0.4416), ('joy', 0.3582), ('excitement', 0.3419), ('optimism', 0.3076)]


In [14]:
ed.inference("The Japanese team defeated the Chinese team.");

all emotions: [('admiration', 0.629), ('neutral', 0.4327)]
