In [None]:
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer

In [None]:
def load_model(model_name="xlm-roberta-large", num_labels=5):
    """Loads the XLM-RoBERTa model and tokenizer for sentiment analysis."""
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
    return tokenizer, model

In [None]:
def get_sentiment_label(sentiment):
    """Maps numerical sentiment prediction to a textual label."""
    labels = {
        0: "Very Negative",
        1: "Negative",
        2: "Neutral",
        3: "Positive",
        4: "Very Positive"
    }
    return labels.get(sentiment, "Unknown")

In [None]:
def predict_sentiment(text, tokenizer, model):
    """Predicts sentiment for a given text using the XLM-R model."""
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
    with torch.no_grad():
        outputs = model(**inputs)
    scores = torch.nn.functional.softmax(outputs.logits, dim=-1)
    sentiment = torch.argmax(scores, dim=-1).item()
    sentiment_label = get_sentiment_label(sentiment)
    return sentiment, sentiment_label, scores.tolist()

In [None]:
if __name__ == "__main__":
    tokenizer, model = load_model()
    text = "I love working on AI projects!"
    sentiment, sentiment_label, scores = predict_sentiment(text, tokenizer, model)
    print(f"Predicted Sentiment: {sentiment} ({sentiment_label})")
    print(f"Confidence Scores: {scores}")

Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at xlm-roberta-large and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Predicted Sentiment: 4 (Very Positive)
Confidence Scores: [[0.13152866065502167, 0.23357628285884857, 0.180240198969841, 0.18806177377700806, 0.2665930688381195]]
