In [13]:
import warnings
from utils import GoEmotionConfig
from gpt2 import GPTCls
import torch
import tiktoken
warnings.filterwarnings("ignore")
class EmotionDetector():
    def __init__(self, config):
        self.config = config
        self.emotions = ['admiration', 'amusement', 'anger', 'annoyance', 'approval', 'caring', 'confusion', 'curiosity', 'desire', 'disappointment', 'disapproval', 'disgust', 'embarrassment', 'excitement', 'fear', 'gratitude', 'grief', 'joy', 'love', 'nervousness', 'optimism', 'pride', 'realization', 'relief', 'remorse', 'sadness', 'surprise', 'neutral']
        self.model = GPTCls(config=self.config).to(self.config.device)
        state_dict = torch.load("goemotion_1.pth", map_location=self.config.device)
        self.model.backbone.load_state_dict(state_dict['backbone'])
        self.model.cls_head.load_state_dict(state_dict['cls'])
        self.tokenizer = tiktoken.get_encoding("gpt2")
    
    def inference(self, text, threshold=0.4):
        self.model.eval()
        encoded_text = torch.tensor(self.tokenizer.encode(text)).unsqueeze(0).to(self.config.device)
        print(f'input shape: {encoded_text.shape}')
        logits = self.model(encoded_text)
        probs = logits.sigmoid()
        major_emotion, major_prob = self.emotions[torch.argmax(probs, dim=1).item()], torch.max(probs).item()
        preds = probs.clone()
        preds[preds < threshold] = 0
        preds[preds >= threshold] = 1
        preds, probs = preds.squeeze(0), probs.squeeze(0)
        all_emotions = {}
        for i, each_prob in enumerate(preds):
            if each_prob == 1:
                all_emotions[self.emotions[i]] = round(probs[i].item(), 4)
        all_emotions = sorted(all_emotions.items(), key=lambda x:x[1], reverse=True)
        print(f"all emotions: {all_emotions}, major emotion: {major_emotion}, major prob: {major_prob:.4f}")
        return all_emotions, major_emotion, major_prob
config = GoEmotionConfig()
ed = EmotionDetector(config=config)

loading weights from pretrained gpt: gpt2-large
forcing vocab_size=50257, block_size=1024, bias=True
overriding dropout rate to 0.0
number of parameters:  774.03M


In [14]:
ed.inference("Fuck you!");

input shape: torch.Size([1, 3])
all emotions: [('anger', 0.6889)], major emotion: anger, major prob: 0.6889


In [15]:
ed.inference("John is more than happy about the result of the exam.");

input shape: torch.Size([1, 12])
all emotions: [('neutral', 0.4791), ('admiration', 0.4376)], major emotion: neutral, major prob: 0.4791


In [16]:
ed.inference("Do you really think it is a good idea to go to the party?");

input shape: torch.Size([1, 15])
all emotions: [('curiosity', 0.4611), ('neutral', 0.4444)], major emotion: curiosity, major prob: 0.4611


In [17]:
ed.inference("I agree with you.");

input shape: torch.Size([1, 5])
all emotions: [('approval', 0.698)], major emotion: approval, major prob: 0.6980


In [27]:
ed.inference("He was so happy but I am not happy at all.");

input shape: torch.Size([1, 12])
all emotions: [('joy', 0.5754), ('neutral', 0.4343), ('admiration', 0.4209)], major emotion: joy, major prob: 0.5754


In [31]:
ed.inference("Ah that's bad.");

input shape: torch.Size([1, 5])
all emotions: [('neutral', 0.4183), ('approval', 0.4139)], major emotion: neutral, major prob: 0.4183
