In [None]:
# inference_roberta_models.ipynb

import pandas as pd
import numpy as np
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from tqdm import tqdm

tqdm.pandas()

In [None]:
# Load cleaned DataFrame clean_df (created in data_cleaning_for_RoBERTa.ipub)

clean_df = pd.read_csv("../data/processed/cleaned_for_roberta.csv")

# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("CUDA available:", torch.cuda.is_available())

# ======================= HATE SPEECH MODEL =======================

hate_model_name = "cardiffnlp/twitter-roberta-base-hate-multiclass-latest"
hate_tokenizer = AutoTokenizer.from_pretrained(hate_model_name)
hate_model = AutoModelForSequenceClassification.from_pretrained(hate_model_name).to(device)
hate_model.eval()

hate_label_map = {
    0: "sexism",
    1: "racism",
    2: "disability",
    3: "sexual_orientation",
    4: "religion",
    5: "other",
    6: "not_hate"
}

def get_hate_labels(text):
    if not isinstance(text, str) or text.strip() == "":
        return pd.Series([None, None])
    inputs = hate_tokenizer(text, return_tensors="pt", truncation=True, padding=True).to(device)
    with torch.no_grad():
        outputs = hate_model(**inputs)
        probs = F.softmax(outputs.logits, dim=1).squeeze().tolist()
        predicted_label = hate_label_map[int(torch.argmax(outputs.logits, dim=1))]
    return pd.Series([predicted_label, probs])

clean_df[["predicted_label_hate", "label_probabilities_hate"]] = clean_df['cleaned_tweet'].progress_apply(
    lambda x: pd.Series(get_hate_labels(x))
)

# ======================= SENTIMENT MODEL =======================

sentiment_model_name = "cardiffnlp/twitter-roberta-base-sentiment-latest"
sentiment_tokenizer = AutoTokenizer.from_pretrained(sentiment_model_name)
sentiment_model = AutoModelForSequenceClassification.from_pretrained(sentiment_model_name).to(device)
sentiment_model.eval()

sentiment_labels = ['negative', 'neutral', 'positive']

def get_sentiment(text):
    if not isinstance(text, str) or text.strip() == "":
        return pd.Series([None, None, None])
    inputs = sentiment_tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=512
    ).to(device)
    with torch.no_grad():
        outputs = sentiment_model(**inputs)
        probs = F.softmax(outputs.logits, dim=1).squeeze().tolist()
        predicted_label = sentiment_labels[int(torch.argmax(outputs.logits))]
    return pd.Series([predicted_label, probs, len(inputs['input_ids'][0]) > 512])

clean_df[["sentiment_label", "sentiment_probs", "was_truncated"]] = clean_df['cleaned_tweet'].progress_apply(
    lambda x: pd.Series(get_sentiment(x))
)

# ======================= OFFENSIVE MODEL =======================

offensive_model_name = "cardiffnlp/twitter-roberta-base-offensive"
offensive_tokenizer = AutoTokenizer.from_pretrained(offensive_model_name)
offensive_model = AutoModelForSequenceClassification.from_pretrained(offensive_model_name).to(device)
offensive_model.eval()

offensive_labels = ["not_offensive", "offensive"]

def get_offensive_prediction(text):
    if not isinstance(text, str) or text.strip() == "":
        return pd.Series([None, None])
    inputs = offensive_tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=512
    ).to(device)
    with torch.no_grad():
        outputs = offensive_model(**inputs)
        probs = F.softmax(outputs.logits, dim=1).squeeze().tolist()
        predicted_label = offensive_labels[int(torch.argmax(outputs.logits))]
    return pd.Series([predicted_label, probs])

clean_df[["offensive_label", "offensive_probs"]] = clean_df['cleaned_tweet'].progress_apply(
    lambda x: pd.Series(get_offensive_prediction(x))
)

# ======================= SAVE OUTPUT =======================

clean_df.to_csv("data/processed_with_classifications.csv", index=False)
print("Classification complete. Results saved to 'data/processed_with_classifications.csv'")