In [1]:
import pandas as pd
import torch
from torch.nn.functional import softmax
from transformers import XLMRobertaTokenizer, XLMRobertaForSequenceClassification

In [2]:
def predict_class(test_file_path, language):

    # Load XLM-RoBERTa tokenizer and model
    model_path = "../fine_tuned_models/fine_tuned_xlm_roberta_model_" + language
    tokenizer = XLMRobertaTokenizer.from_pretrained('xlm-roberta-base')
    model = XLMRobertaForSequenceClassification.from_pretrained(model_path)
    
    # Load test data
    test_data = pd.read_csv(test_file_path)
    
    # Prepare inputs for the model
    inputs = tokenizer(test_data['tweet_text'].tolist(), padding=True, truncation=True, return_tensors='pt')
    
    # Make predictions
    with torch.no_grad():
        outputs = model(**inputs)
    
    # Debug: Print raw logits to examine their distribution
    print("Raw logits:", outputs.logits)
    
    # Apply softmax to get predicted probabilities
    probabilities = softmax(outputs.logits, dim=1)
    print("Softmax probabilities:", probabilities)
    
    # Get predicted class labels using argmax
    predicted_labels = ['Yes' if idx == 1 else 'No' for idx in probabilities.argmax(dim=1)]
    
    # Debug: Check a few sample predictions
    print("Sample predictions:", predicted_labels[:10])
    
    # Add predicted labels to the test data
    test_data['predicted_label'] = predicted_labels
    
    # Save results to CSV
    output_file_path = "../output_data/transformer/output_" + language + ".csv"
    test_data[['tweet_id', 'class_label']].to_csv(output_file_path, index=False)
    print(f"Results saved to {output_file_path}")

In [3]:

eng_test_file_path = "../data/processed_data/CT24_checkworthy_english/CT24_checkworthy_english_dev-test_preprocessed.csv"
predict_class(eng_test_file_path, "english")

arabic_test_file_path = "../data/processed_data/CT24_checkworthy_arabic/CT24_checkworthy_arabic_dev-test_preprocessed.csv"
predict_class(arabic_test_file_path, "arabic")

dutch_test_file_path = "../data/processed_data/CT24_checkworthy_dutch/CT24_checkworthy_dutch_dev-test_preprocessed.csv"
predict_class(dutch_test_file_path, "dutch")

Raw logits: tensor([[-0.1800,  0.1801],
        [-0.1715,  0.1947],
        [-0.1679,  0.1871],
        [-0.1727,  0.1985],
        [-0.1634,  0.1957],
        [-0.1673,  0.1950],
        [-0.1598,  0.1974],
        [-0.1649,  0.1906],
        [-0.1500,  0.1567],
        [-0.1691,  0.1964],
        [-0.1712,  0.2054],
        [-0.1652,  0.1946],
        [-0.1587,  0.1828],
        [-0.1675,  0.1951],
        [-0.1635,  0.1998],
        [-0.1636,  0.1913],
        [-0.1618,  0.2003],
        [-0.1644,  0.1850],
        [-0.1675,  0.1998],
        [-0.1711,  0.1998],
        [-0.1728,  0.1980],
        [-0.1657,  0.1983],
        [-0.1665,  0.1957],
        [-0.1592,  0.1840],
        [-0.1534,  0.1834],
        [-0.1647,  0.1956],
        [-0.1650,  0.1921],
        [-0.1654,  0.1996],
        [-0.1671,  0.1974],
        [-0.1636,  0.1973],
        [-0.1630,  0.1914],
        [-0.1700,  0.2075],
        [-0.1676,  0.1988],
        [-0.1636,  0.1869],
        [-0.1503,  0.1886],
        