In [1]:
from classes import RobertaForSequenceClassification2
from transformers import RobertaTokenizer, Trainer
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
def load_model(model_path,tokenizer_path):
    model = RobertaForSequenceClassification2.from_pretrained(model_path,task_labels_map={"author": 4, "sentiment": 2})
    tokenizer = RobertaTokenizer.from_pretrained(tokenizer_path)
    trainer = Trainer(model = model)
    return (model, trainer, tokenizer)
    
def pipeline(text, task):
    inputs = tokenizer(text,padding="longest",return_tensors="pt")
    inputs["input_ids"] =  inputs["input_ids"].to(device)
    inputs["attention_mask"] =  inputs["attention_mask"].to(device)
    logits = model(**inputs, task_name=task)["logits"]
    probabilities = torch.softmax(logits, dim=1).detach().cpu()
    result = torch.argmax(
        probabilities,axis=1
    )
    probability = torch.max(
        probabilities,axis=1
    )[0]
    
    return result.item(), probability.item()



In [3]:
model, trainer, tokenizer = load_model("models/multihead_classification/model","models/multihead_classification/tokenizer")

In [5]:
pipeline("Total dissapointment. I do not understand why would anyone watch this","sentiment")

(0, 0.9991714954376221)

In [7]:
import PySimpleGUI as sg

def gui():
    sg.theme('DARKBLUE4')
    layout = [[sg.Input(key='-IN-')],
              [sg.Button('Sentiment', bind_return_key=True)],
              [sg.Button('Author', bind_return_key=True)],
             [ sg.Text(size=(50,30), key='-OUTPUT-')]]

    window = sg.Window('Multihead transformer', layout)
    
    while True:
        event, values = window.read()

        if event in  (None, 'Exit'):
            break

        if event == 'Sentiment':
            label, score  = pipeline(values['-IN-'],"sentiment")
            resultstring = "Sentiment: " + str(label) + ", score: " + str(score)
            window['-OUTPUT-'].update(resultstring)
            
        if event == 'Author':
            label, score  = pipeline(values['-IN-'],"author")
            resultstring = "Author: " + str(label) + ", score: " + str(score)
            window['-OUTPUT-'].update(resultstring)
            
    window.close()

In [8]:
gui()