In [2]:
import os
import numpy as np
import torch
from safetensors.torch import load_file as load_safetensors
from transformers import AutoTokenizer, RobertaModel
from openvino.runtime import Core

In [3]:
# ─── 1) Your best Safetensors checkpoint ──────────────────────────────────
BEST_CKPT = "goemotions_multilabel_model/checkpoint-10854"

In [4]:
# ─── 2) Reconstruct tokenizer + model ─────────────────────────────────────
tokenizer = AutoTokenizer.from_pretrained("roberta-base")

class RobertaForMultiLabel(torch.nn.Module):
    def __init__(self, num_labels=28):
        super().__init__()
        self.roberta    = RobertaModel.from_pretrained("roberta-base")
        self.dropout    = torch.nn.Dropout(0.1)
        self.classifier = torch.nn.Linear(self.roberta.config.hidden_size, num_labels)

    def forward(self, input_ids=None, attention_mask=None, labels=None):
        out    = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        logits = self.classifier(self.dropout(out.pooler_output))
        loss   = None
        if labels is not None:
            loss_fct = torch.nn.BCEWithLogitsLoss()
            loss     = loss_fct(logits, labels.float())
        return {"loss": loss, "logits": logits}

model = RobertaForMultiLabel(num_labels=28)

# ─── 2b) Load Safetensors weights ──────────────────────────────────────────
safetensors_path = os.path.join(BEST_CKPT, "model.safetensors")
state_dict        = load_safetensors(safetensors_path, device="cpu")
model.load_state_dict(state_dict)
model.eval()
print("Model loaded from safetensors checkpoint")

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded from safetensors checkpoint


In [5]:
# ─── 3) Export to ONNX ─────────────────────────────────────────────────────
dummy = tokenizer(
    ["I am so happy!", "This is bad..."],
    padding="max_length", truncation=True, max_length=128, return_tensors="pt"
)
torch.onnx.export(
    model,
    (dummy["input_ids"], dummy["attention_mask"]),
    "goemotions_multilabel.onnx",
    input_names=["input_ids", "attention_mask"],
    output_names=["logits"],
    dynamic_axes={
        "input_ids":      {0: "batch", 1: "seq"},
        "attention_mask": {0: "batch", 1: "seq"},
        "logits":         {0: "batch"},
    },
    opset_version=14,
)
print("ONNX export complete")

ONNX export complete


In [7]:
# ─── 4) Compile with OpenVINO ───────────────────────────────────────────────
core     = Core()
ov_model = core.read_model(model="goemotions_multilabel.onnx")
compiled = core.compile_model(model=ov_model, device_name="CPU")

In [8]:
# ─── 5) Hard-coded GoEmotions “simplified” label names ─────────────────────
# Copied in the exact order used by the dataset’s ClassLabel
emotion_labels = [
    "admiration", "amusement", "anger", "annoyance", "approval", "caring",
    "confusion", "curiosity", "desire", "disappointment", "disapproval",
    "disgust", "embarrassment", "excitement", "fear", "gratitude", "grief",
    "joy", "love", "nervousness", "optimism", "pride", "realization", "relief",
    "remorse", "sadness", "surprise", "neutral"
]
print(f"Loaded {len(emotion_labels)} emotion labels.")

Loaded 28 emotion labels.


In [10]:
# ─── 6) Inference + thresholding + mapping back to names ──────────────────
texts = ["I am so happy!", "This is bad...", "It's not that great but not that bad either"]
toks  = tokenizer(texts, padding="max_length", truncation=True, max_length=128, return_tensors="np")
outs  = compiled([toks["input_ids"], toks["attention_mask"]])
logits = outs[compiled.output(0)]
probs  = 1 / (1 + np.exp(-logits))

# choose your operating threshold
THR   = 0.3
preds = (probs > THR).astype(int)

print("\nMulti-hot vectors:\n", preds)
print("\nFired emotions per input:")
for i, single in enumerate(preds):
    fired = [emotion_labels[j] for j, f in enumerate(single) if f]
    print(f"» Input #{i} {texts[i]!r} fired:", fired)


Multi-hot vectors:
 [[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 1 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 0 0 1 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1]]

Fired emotions per input:
» Input #0 'I am so happy!' fired: ['joy']
» Input #1 'This is bad...' fired: ['disappointment', 'disapproval', 'disgust']
» Input #2 "It's not that great but not that bad either" fired: ['disappointment', 'disapproval', 'neutral']
