In [1]:
import pandas as pd
from tqdm.auto import tqdm
import os
import sys


sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..')))
# Set the current working directory to the project root
ROOT_DIR = os.path.abspath(os.path.join(os.getcwd(), '..'))
os.chdir(ROOT_DIR)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from src.inference.narrative_predictor import NarrativePredictor

MODEL_PATH = 'models/phase0_xlmr_best_model.bin'
TOKENIZER_NAME = 'xlm-roberta-base'
TEST_ARTICLES_PATH = 'devset/EN/subtask-2-documents/'
OUTPUT_FILE = 'devset/en_predictions.txt'
OPTIMAL_THRESHOLD = 0.86

In [3]:
def load_articles(folder_path):
    """Loads all .txt files from a folder."""
    articles = []
    for filename in os.listdir(folder_path):
        if filename.endswith(".txt"):
            with open(os.path.join(folder_path, filename), 'r', encoding='utf-8') as f:
                articles.append({'article_id': filename, 'text': f.read()})
    return pd.DataFrame(articles)

In [4]:
from src.data_management.label_parser import get_label_mappings

label_to_id, id_to_label, narrative_to_subnarrative_ids = get_label_mappings()
sub_to_narr_id_map = {}

# Create a mapping from sub-narrative IDs to their parent narrative IDs
for narr_id, sub_ids_list in narrative_to_subnarrative_ids.items():
    for sub_id in sub_ids_list:
        sub_to_narr_id_map[sub_id] = narr_id


parent_child_pairs = list(sub_to_narr_id_map.items())

In [5]:
label_maps = {
        "id2label": id_to_label,
        "label2id": label_to_id,
        "parent_child_pairs": parent_child_pairs
    }

# --- 2. Initialize the Predictor ---
# This loads the model and tokenizer only once.
predictor = NarrativePredictor(MODEL_PATH, TOKENIZER_NAME, label_maps)

# --- 3. Set the Optimal Threshold ---
predictor.set_threshold(OPTIMAL_THRESHOLD)

initializing the Narrative Predictor...
Using device: cuda


Some weights of XLMRobertaForSequenceClassification were not initialized from the model checkpoint at xlm-roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Predictor initialized and ready.
Threshold set to: 0.86


In [6]:
print(f"Loading articles from {TEST_ARTICLES_PATH}...")
df_test = load_articles(TEST_ARTICLES_PATH)
texts_to_predict = df_test['text'].tolist()

Loading articles from devset/EN/subtask-2-documents/...


In [7]:
print("--- Verifying id2label map integrity ---")
num_labels_in_map = len(predictor.id2label)
num_labels_in_model = predictor.model.config.num_labels

print(f"Number of labels in the provided id2label map: {num_labels_in_map}")
print(f"Number of labels the model was configured with: {num_labels_in_model}")

if num_labels_in_map != num_labels_in_model:
    print("!! FATAL ERROR: Mismatch between map size and model's number of labels.")


print(predictor.label2id)

--- Verifying id2label map integrity ---
Number of labels in the provided id2label map: 117
Number of labels the model was configured with: 117
{'CC: Amplifying Climate Fears': 0, 'CC: Climate change is beneficial': 1, 'CC: Controversy about green technologies': 2, 'CC: Criticism of climate movement': 3, 'CC: Criticism of climate policies': 4, 'CC: Criticism of institutions and authorities': 5, 'CC: Downplaying climate change': 6, 'CC: Green policies are geopolitical instruments': 7, 'CC: Hidden plots by secret schemes of powerful groups': 8, 'CC: Questioning the measurements and science': 9, 'Other': 10, 'URW: Amplifying war-related fears': 11, 'URW: Blaming the war on others rather than the invader': 12, 'URW: Discrediting Ukraine': 13, 'URW: Discrediting the West, Diplomacy': 14, 'URW: Distrust towards Media': 15, 'URW: Hidden plots by secret schemes of powerful groups': 16, 'URW: Negative Consequences for the West': 17, 'URW: Overpraising the West': 18, 'URW: Praise of Russia': 19,

In [9]:
predictor.predict(texts_to_predict[0])

{'narratives': ['CC: Amplifying Climate Fears',
  'CC: Criticism of climate movement',
  'CC: Criticism of climate policies',
  'CC: Criticism of institutions and authorities',
  'CC: Hidden plots by secret schemes of powerful groups'],
 'subnarratives': ['CC: Amplifying Climate Fears: Other',
  'CC: Criticism of climate movement: Other',
  'CC: Criticism of climate policies: Other',
  'CC: Criticism of institutions and authorities: Criticism of national governments',
  'CC: Criticism of institutions and authorities: Criticism of political organizations and figures',
  'CC: Hidden plots by secret schemes of powerful groups: Other']}