# *Nano Graph Rag: Fake News Detection*
### *Aman Pawar*

In [2]:
# -*- coding: utf-8 -*-
"""
Fast Graph-RAG Implementation (Adapted for Nano/Speed).

This script implements a Retrieval-Augmented Generation (RAG) approach
for text classification, adapted for faster execution ("Fast-Graph-RAG" or "Nano-Graph-RAG").
Key adaptations include:
- Using DistilBERT as the classifier for faster training/inference.
- Employing FAISS IndexIVFFlat for faster approximate nearest neighbor search.
- Adding optional mixed-precision training (torch.cuda.amp).
- Configuration options tailored for speed vs. accuracy trade-offs.

Core Components:
1. Configuration: Hyperparameters, paths, model names, FAISS settings, AMP flag.
2. Data Loading: Loads and preprocesses data.
3. Custom Dataset (GraphRagNewsDataset): Integrates retrieval.
4. Retriever Setup: SentenceTransformer + FAISS (IndexIVFFlat).
5. Graph-like Retrieval Function: Finds primary/secondary neighbors using FAISS.
6. Classifier Model: DistilBERT for sequence classification.
7. Training Loop: Fine-tunes the classifier with optional AMP.
8. Evaluation: Measures performance.

*** IMPORTANT EXECUTION NOTE ***
This script uses multiprocessing for data loading. Due to how Python's
multiprocessing (especially with 'spawn' start method) works, it's crucial
to run this script as a standalone .py file from your terminal:
  python <your_script_name>.py
Running cells interactively (e.g., in some IDEs or notebooks) might lead to
AttributeErrors or NameErrors related to multiprocessing context.
******************************
"""

import pandas as pd
import numpy as np
import torch
import torch.multiprocessing as mp # Import torch multiprocessing
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
import torch.nn as nn
from transformers import AutoTokenizer, AutoModelForSequenceClassification, get_linear_schedule_with_warmup
from sentence_transformers import SentenceTransformer
import faiss # For efficient similarity search
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from tqdm.auto import tqdm
import os
import time
import gc # Garbage collector
import traceback # For detailed error printing
import logging # Using logging for better output management
from torch.cuda.amp import GradScaler, autocast # For Mixed Precision Training

# --- Set Environment Variable to Suppress Tokenizer Parallelism Warning ---
os.environ["TOKENIZERS_PARALLELISM"] = "false"

# --- Configure Logging ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - [%(processName)s] %(message)s') # Added process name

# --- Declare Global Variables for Multiprocessing ---
# These will be assigned in the __main__ block but need to be accessible
# by the retrieve_documents_graph_like function running in worker processes.
retriever_model = None
faiss_index = None
train_embeddings_cpu = None
train_texts = []


# --- 1. Configuration ---
class Config:
    """Configuration class for hyperparameters and settings."""
    # File paths (Update these paths if your files are located elsewhere)
    # IMPORTANT: Ensure these paths are correct relative to where you run the script.
    # Using relative paths assuming data is one level up
    train_file = '../Constraint_English_Train.xlsx'
    val_file = '../Constraint_English_Val.xlsx'
    test_file = '../english_test_with_labels.xlsx'

    # Model names
    retriever_model_name = 'all-MiniLM-L6-v2' # Still a good balance of speed/accuracy
    # Using DistilBERT for faster classification
    classifier_model_name = 'distilbert-base-uncased'

    # RAG parameters (Adjust k/m for speed vs. context trade-off)
    num_retrieved_docs_primary = 3 # 'k': How many direct neighbors
    num_retrieved_docs_secondary = 2 # 'm': How many neighbors of neighbors
    max_total_retrieved_docs = 5 # Total unique docs (primary + secondary)

    # FAISS Index Parameters (Using IVFFlat for speed)
    faiss_index_type = 'IVFFlat' # Options: 'IVFFlat', 'FlatL2'
    # Number of centroids for IVFFlat. Rule of thumb: ~4*sqrt(N) to 16*sqrt(N)
    # Adjust based on your training data size (N = len(train_texts))
    # Example: If N=6400, sqrt(N)=80. nlist could be 320 to 1280.
    # Start with a moderate value. Increase for potentially better accuracy but slower build/search.
    faiss_nlist = 100 # Number of Voronoi cells (centroids)
    # Number of cells to probe during search. Higher means more accuracy but slower search.
    faiss_nprobe = 10

    # Training parameters
    max_seq_length = 512 # Max length for DistilBERT input
    batch_size = 16 # Can often increase batch size with DistilBERT/AMP
    epochs = 3
    learning_rate = 3e-5 # Adjusted slightly for DistilBERT, may need tuning
    warmup_steps = 100
    gradient_accumulation_steps = 1

    # Hardware, Reproducibility, and Speed-ups
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    seed = 42
    # Enable Mixed Precision Training (requires CUDA and compatible hardware)
    # Set to False if encountering issues or not using a suitable GPU.
    use_amp = torch.cuda.is_available() # Enable AMP if CUDA is available

    # Output directory
    output_dir = "./fast_graph_rag_classifier_model"

    # --- DEBUGGING FLAG ---
    # Set to True to force num_workers=0 in DataLoaders for easier debugging
    # Set back to False once multiprocessing issues are resolved
    DEBUG_DATALOADER = False

# --- 2. Data Loading Function ---
# Defined at top level (module scope)
def load_data(filepath):
    """Loads data from Excel, cleans, maps labels."""
    if not os.path.exists(filepath):
        logging.error(f"Data file not found: {filepath}")
        return None
    try:
        df = pd.read_excel(filepath)
        if 'tweet' not in df.columns or 'label' not in df.columns:
            logging.error(f"Excel file {filepath} must contain 'tweet' and 'label' columns.")
            return None

        initial_count = len(df)
        df = df.dropna(subset=['tweet'])
        df['tweet'] = df['tweet'].astype(str)
        dropped_nan_tweet = initial_count - len(df)

        df['label'] = df['label'].map({'real': 1, 'fake': 0})
        initial_count_before_label_drop = len(df)
        df = df.dropna(subset=['label']) # Drop rows where label mapping failed (NaN)
        df['label'] = df['label'].astype(int)
        dropped_bad_label = initial_count_before_label_drop - len(df)

        if df.empty:
            logging.warning(f"No valid data loaded from {filepath} after cleaning.")
            return None

        logging.info(f"Loaded {len(df)} samples from {filepath} (dropped {dropped_nan_tweet} for NaN tweet, {dropped_bad_label} for invalid label).")
        return df
    except Exception as e:
        logging.error(f"Error loading {filepath}: {e}", exc_info=True)
        return None

# --- 3. Custom PyTorch Dataset for Graph-RAG ---
# Defined at top level (module scope) - Crucial for multiprocessing
class GraphRagNewsDataset(Dataset):
    """
    PyTorch Dataset for Graph-RAG. Retrieves context using a retriever function
    and tokenizes the combined text. Includes detailed error handling.
    """
    def __init__(self, texts, labels, tokenizer, max_len, retriever_func, num_retrieved_total):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.retriever_func = retriever_func
        self.num_retrieved_total = num_retrieved_total

        # Ensure inputs are lists
        if not isinstance(self.texts, list): self.texts = list(self.texts)
        if not isinstance(self.labels, list): self.labels = list(self.labels)
        if len(self.texts) != len(self.labels):
            raise ValueError(f"Number of texts ({len(self.texts)}) and labels ({len(self.labels)}) must match.")
        logging.info(f"Dataset initialized with {len(self.texts)} samples.")

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        """
        Retrieves item, finds context, tokenizes, returns dict.
        Includes detailed logging for debugging worker errors.
        """
        original_text = None
        label = None
        try:
            # Check index bounds rigorously
            if not isinstance(idx, int) or not (0 <= idx < len(self.texts)):
                logging.error(f"Invalid index type or value: {idx} (type: {type(idx)}). Dataset size: {len(self.texts)}.")
                # Raise error if not in a daemon process (like main process or num_workers=0)
                # Check if current process is a daemon (worker process often is)
                current_process = mp.current_process()
                if not hasattr(current_process, 'daemon') or not current_process.daemon:
                    raise IndexError(f"Invalid index: {idx}")
                return None # Return None if in a worker process to avoid crashing the loader

            original_text = str(self.texts[idx]) # Ensure text is string
            label = int(self.labels[idx]) # Ensure label is int

        except Exception as e:
            logging.error(f"Failed to get text/label at index {idx}. Error: {e}", exc_info=True)
            current_process = mp.current_process()
            if not hasattr(current_process, 'daemon') or not current_process.daemon: raise e
            return None

        # --- Retrieval Step ---
        retrieved_docs = []
        try:
            # Call the retriever function (which now expects the model to be on the correct device)
            # It will access the MODULE-LEVEL global variables.
            retrieved_docs = self.retriever_func(original_text)
            if not isinstance(retrieved_docs, list):
                logging.warning(f"Retriever function did not return a list for index {idx}. Query: '{original_text[:50]}...'. Got: {type(retrieved_docs)}. Using empty list.")
                retrieved_docs = []
        except NameError as ne:
             # Specific check for NameError which indicates globals might not be set in worker
             logging.error(f"NameError during RETRIEVAL for index {idx}. This likely means worker process couldn't access global retriever components. Error: {ne}", exc_info=True)
             retrieved_docs = []
             # Raise if in main process, otherwise return None
             current_process = mp.current_process()
             if not hasattr(current_process, 'daemon') or not current_process.daemon: raise ne
             return None
        except Exception as e:
            # Log detailed error including traceback if possible
            logging.error(f"Failed during RETRIEVAL for index {idx}. Query: '{original_text[:100]}...'. Error: {e}", exc_info=True)
            retrieved_docs = []
            # Only raise if not in worker to allow dataloader to continue
            current_process = mp.current_process()
            if not hasattr(current_process, 'daemon') or not current_process.daemon: raise e
            # If in worker, returning None might be better than empty list if it causes downstream issues
            # return None # Consider returning None if retrieval failure should skip the item

        # --- Tokenization Step ---
        combined_text = original_text
        try:
            context = " ".join([str(doc) for doc in retrieved_docs if doc]) # Filter out None or empty docs
            sep_token = self.tokenizer.sep_token if self.tokenizer.sep_token else "[SEP]"

            if context:
                combined_text = f"{original_text} {sep_token} {context}"

            # Tokenize the combined text
            encoding = self.tokenizer.encode_plus(
                combined_text,
                add_special_tokens=True,
                max_length=self.max_len,
                padding='max_length',
                truncation=True,
                return_attention_mask=True,
                return_tensors='pt', # Return PyTorch tensors
            )
        except Exception as e:
            logging.error(f"Failed during TOKENIZATION for index {idx}. Combined Text (start): '{combined_text[:100]}...'. Error: {e}", exc_info=True)
            current_process = mp.current_process()
            if not hasattr(current_process, 'daemon') or not current_process.daemon: raise e
            return None

        # --- Return Result ---
        try:
            # Ensure tensors are correctly shaped (flatten removes the batch dimension added by return_tensors='pt')
            return {
                'input_ids': encoding['input_ids'].flatten(),
                'attention_mask': encoding['attention_mask'].flatten(),
                'labels': torch.tensor(label, dtype=torch.long) # Ensure label is a tensor
            }
        except Exception as e:
            logging.error(f"Failed during final DICT CREATION for index {idx}. Label: {label}. Error: {e}", exc_info=True)
            current_process = mp.current_process()
            if not hasattr(current_process, 'daemon') or not current_process.daemon: raise e
            return None


# --- 4. Evaluation Function ---
# Defined at top level (module scope)
def evaluate_model(model, dataloader, device, loss_fn):
    """Evaluates the model on a given dataloader."""
    model.eval() # Set model to evaluation mode
    total_loss = 0
    all_preds = []
    all_labels = []

    if dataloader is None or len(dataloader) == 0:
        logging.warning("Evaluation dataloader is empty or None. Skipping evaluation.")
        return 0, 0, 0, 0, 0 # Return zero metrics

    with torch.no_grad(): # Disable gradient calculations
        for batch in tqdm(dataloader, desc="Evaluating", leave=False):
            # Rigorous batch check
            if not isinstance(batch, dict) or 'input_ids' not in batch or 'attention_mask' not in batch or 'labels' not in batch:
                logging.warning("Skipping invalid batch during evaluation (not a dict or missing keys).")
                continue
            if batch['input_ids'].numel() == 0 or batch['labels'].numel() == 0:
                logging.warning("Skipping empty batch (zero elements in input_ids or labels) during evaluation.")
                continue

            try:
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                # No autocast needed for evaluation usually, unless facing memory issues
                outputs = model(input_ids=input_ids, attention_mask=attention_mask)
                logits = outputs.logits

                # Calculate loss only if loss_fn is provided and valid
                if callable(loss_fn):
                    try:
                        loss = loss_fn(logits, labels)
                        total_loss += loss.item()
                    except Exception as loss_e:
                         logging.warning(f"Could not compute loss for evaluation batch: {loss_e}")


                preds = torch.argmax(logits, dim=1)
                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())

            except Exception as e:
                logging.error(f"Error during evaluation batch processing: {e}", exc_info=True)
                # Optionally try to log batch details (shape, etc.) if error persists
                # logging.error(f"Batch keys: {batch.keys()}")
                # logging.error(f"Input IDs shape: {batch.get('input_ids', 'N/A').shape}")
                continue # Skip the problematic batch

    if not all_labels or not all_preds:
        logging.warning("Evaluation resulted in no valid labels or predictions. Cannot compute metrics.")
        return 0, 0, 0, 0, 0

    # Use len(all_labels) for averaging loss if loss calculation was skipped for some batches
    num_valid_batches_for_loss = len(dataloader) # Approximation, could be refined if needed
    avg_loss = total_loss / num_valid_batches_for_loss if callable(loss_fn) and num_valid_batches_for_loss > 0 else 0

    try:
        accuracy = accuracy_score(all_labels, all_preds)
        # Use 'binary' average for binary classification, report micro/macro/weighted for multiclass
        precision, recall, f1, _ = precision_recall_fscore_support(
            all_labels, all_preds, average='binary', zero_division=0
        )
    except Exception as e:
        logging.error(f"Error calculating evaluation metrics: {e}", exc_info=True)
        accuracy, precision, recall, f1 = 0, 0, 0, 0

    logging.info(f"Evaluation Results: Loss={avg_loss:.4f}, Acc={accuracy:.4f}, Prec={precision:.4f}, Rec={recall:.4f}, F1={f1:.4f}")
    return avg_loss, accuracy, precision, recall, f1

# --- 5. Graph-like Retrieval Function Definition ---
# Defined at top level (module scope) - Crucial for multiprocessing
def retrieve_documents_graph_like(query_text):
    """
    Retrieves documents using similarity-based graph neighborhood approach.
    Uses MODULE-LEVEL global variables: retriever_model, faiss_index,
    train_embeddings_cpu, train_texts. Assumes these have been initialized
    in the main process before workers are started.
    """
    # Access module-level globals directly. Check if they are initialized.
    if retriever_model is None or faiss_index is None or train_embeddings_cpu is None or not train_texts:
        # This check might occur if a worker starts before initialization is fully complete,
        # or if the globals weren't properly shared/initialized in the worker's context.
        logging.warning("Retriever components (module globals) not initialized in this process. Cannot retrieve.")
        return []

    k_primary = Config.num_retrieved_docs_primary
    m_secondary = Config.num_retrieved_docs_secondary
    max_total = Config.max_total_retrieved_docs

    try:
        # 1. Embed query
        # Model should already be on Config.device
        query_embedding = retriever_model.encode([query_text], convert_to_tensor=True, device=Config.device)
        query_embedding_np = query_embedding.cpu().numpy().astype(np.float32)
        del query_embedding

        # 2. Find primary neighbors
        search_k_primary = k_primary + m_secondary
        distances_p, indices_p = faiss_index.search(query_embedding_np, search_k_primary)
        primary_indices_all = set(idx for idx in indices_p[0] if idx != -1)
        primary_indices = set(list(primary_indices_all)[:k_primary])

        if not primary_indices:
            return []

        # 3. Find secondary neighbors
        secondary_indices = set()
        if m_secondary > 0 and primary_indices:
            primary_indices_list = list(primary_indices)
            try:
                valid_primary_indices = [idx for idx in primary_indices_list if 0 <= idx < len(train_embeddings_cpu)]
                if len(valid_primary_indices) != len(primary_indices_list):
                     logging.warning(f"Some primary indices were out of bounds: {primary_indices_list}. Using only valid ones: {valid_primary_indices}")

                if not valid_primary_indices:
                    primary_neighbor_embeddings = np.array([])
                else:
                    primary_neighbor_embeddings = train_embeddings_cpu[valid_primary_indices]

            except IndexError as ie:
                logging.error(f"IndexError retrieving primary embeddings. Indices: {primary_indices_list}, Max index: {len(train_embeddings_cpu)-1}. Error: {ie}")
                primary_neighbor_embeddings = np.array([]) # Fallback

            if primary_neighbor_embeddings.shape[0] > 0:
                distances_s, indices_s = faiss_index.search(primary_neighbor_embeddings, m_secondary + 1)
                for i, primary_idx in enumerate(valid_primary_indices):
                    for neighbor_idx in indices_s[i]:
                        if neighbor_idx != -1 and neighbor_idx != primary_idx and neighbor_idx not in primary_indices:
                            secondary_indices.add(neighbor_idx)

        # 5. Combine primary and secondary, limit total
        final_indices_list = list(primary_indices)
        remaining_needed = max_total - len(final_indices_list)
        if remaining_needed > 0:
             additional_secondary = [idx for idx in secondary_indices if idx not in final_indices_list]
             final_indices_list.extend(additional_secondary[:remaining_needed])
        final_indices = final_indices_list[:max_total]

        # 6. Retrieve texts
        retrieved_texts = []
        for i in final_indices:
            if 0 <= i < len(train_texts):
                retrieved_texts.append(train_texts[i])
            else:
                logging.warning(f"Retrieved invalid index {i} during text lookup (max: {len(train_texts)-1}). Skipping.")

        return retrieved_texts

    except Exception as e:
        logging.error(f"ERROR in retrieve_documents_graph_like for query '{query_text[:50]}...': {e}", exc_info=True)
        gc.collect()
        if Config.device == torch.device("cuda"): torch.cuda.empty_cache()
        return []


# --- Main Execution Block ---
# Use this guard to ensure the following code only runs when the script is executed directly
if __name__ == "__main__":

    # --- Multiprocessing Setup ---
    try:
        current_start_method = mp.get_start_method(allow_none=True)
        if current_start_method is None or current_start_method != 'spawn':
            if hasattr(mp, 'set_start_method'):
                mp.set_start_method('spawn', force=True)
                logging.info(f"Multiprocessing start method set to 'spawn' (was {current_start_method}).")
            else:
                logging.warning("mp.set_start_method not available. Using default method.")
        else:
             logging.info("Multiprocessing start method already 'spawn'.")
    except (RuntimeError, ValueError) as e:
        logging.warning(f"Could not force multiprocessing start method to 'spawn': {e}. Using default: {mp.get_start_method(allow_none=True)}")

    # --- Seed and Device Setup ---
    np.random.seed(Config.seed)
    torch.manual_seed(Config.seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(Config.seed)
        logging.info(f"CUDA available. Using device: {Config.device}")
        logging.info(f"Mixed Precision Training (AMP) enabled: {Config.use_amp}")
    else:
        logging.info(f"CUDA not available. Using device: {Config.device}")
        Config.use_amp = False # Ensure AMP is disabled if no CUDA
        logging.info(f"Mixed Precision Training (AMP) disabled (CUDA not available).")


    # --- Retriever Setup ---
    # Assign values to the MODULE-LEVEL global variables
    logging.info("--- Setting up Retriever ---")
    train_df = None
    # *** REMOVED global declaration from here ***
    try:
        train_df = load_data(Config.train_file)
        if train_df is None or train_df.empty:
            raise ValueError("Failed to load training data for retriever. Cannot proceed.")

        # Assign directly to module-level variables declared at the top
        train_texts = train_df['tweet'].tolist()
        logging.info(f"Loading retriever model: {Config.retriever_model_name}...")

        retriever_model = SentenceTransformer(Config.retriever_model_name)
        retriever_model.to(Config.device)
        logging.info(f"Retriever model '{Config.retriever_model_name}' loaded and moved to {Config.device}.")

        # --- Embedding ---
        logging.info(f"Embedding {len(train_texts)} training documents for FAISS index...")
        batch_size_embed = 128
        train_embeddings = retriever_model.encode(
            train_texts,
            batch_size=batch_size_embed,
            convert_to_tensor=True,
            show_progress_bar=True,
            device=Config.device
        )
        logging.info(f"Embeddings generated with shape: {train_embeddings.shape}")
        train_embeddings_cpu = train_embeddings.cpu().numpy().astype(np.float32) # Assign to module-level global

        del train_embeddings; gc.collect()
        if Config.device == torch.device("cuda"): torch.cuda.empty_cache()

        # --- FAISS Indexing ---
        embedding_dim = train_embeddings_cpu.shape[1]
        logging.info(f"Building FAISS index ({Config.faiss_index_type}) with dimension={embedding_dim}...")

        if Config.faiss_index_type == 'FlatL2':
            faiss_index = faiss.IndexFlatL2(embedding_dim) # Assign to module-level global
            faiss_index.add(train_embeddings_cpu)
        elif Config.faiss_index_type == 'IVFFlat':
            quantizer = faiss.IndexFlatL2(embedding_dim)
            metric = faiss.METRIC_L2
            actual_nlist = min(Config.faiss_nlist, len(train_texts))
            if actual_nlist != Config.faiss_nlist:
                 logging.warning(f"Adjusted faiss_nlist from {Config.faiss_nlist} to {actual_nlist}")
            if actual_nlist == 0:
                 raise ValueError("Cannot build IVFFlat index with 0 training vectors or nlist=0.")

            faiss_index = faiss.IndexIVFFlat(quantizer, embedding_dim, actual_nlist, metric) # Assign to module-level global

            logging.info(f"Training FAISS {Config.faiss_index_type} index with nlist={actual_nlist}...")
            if train_embeddings_cpu.shape[0] < actual_nlist:
                 logging.warning(f"Number of training vectors ({train_embeddings_cpu.shape[0]}) is less than nlist ({actual_nlist}).")
            start_train_time = time.time()
            faiss_index.train(train_embeddings_cpu)
            logging.info(f"FAISS index training finished in {time.time() - start_train_time:.2f}s.")

            faiss_index.add(train_embeddings_cpu)
            faiss_index.nprobe = min(Config.faiss_nprobe, actual_nlist)
            logging.info(f"Set FAISS nprobe to {faiss_index.nprobe}")
        else:
            raise ValueError(f"Unsupported faiss_index_type: {Config.faiss_index_type}")

        logging.info(f"FAISS index built successfully. Type: {Config.faiss_index_type}, Total vectors: {faiss_index.ntotal}")

    except Exception as e:
        logging.error(f"FATAL: Error during retriever setup: {e}", exc_info=True); exit(1)


    # --- Load Data & Create DataLoaders ---
    # This happens *after* the module-level global retriever components are initialized
    logging.info("--- Loading Data and Creating DataLoaders ---")
    val_df, test_df = None, None
    train_dataset, val_dataset, test_dataset = None, None, None
    train_dataloader, val_dataloader, test_dataloader = None, None, None
    classifier_tokenizer = None

    try:
        val_df = load_data(Config.val_file)
        test_df = load_data(Config.test_file)
        if val_df is None or val_df.empty: logging.warning("Validation data loading failed or resulted in empty dataframe.")
        if test_df is None or test_df.empty: logging.warning("Test data loading failed or resulted in empty dataframe.")

        logging.info(f"Loading classifier tokenizer: {Config.classifier_model_name}...")
        classifier_tokenizer = AutoTokenizer.from_pretrained(Config.classifier_model_name)
        logging.info("Classifier tokenizer loaded.")

        # --- Create Datasets ---
        logging.info("Creating datasets...")
        # Check if module-level train_texts is populated before creating dataset
        if not train_texts:
             logging.error("Cannot create train_dataset as global train_texts is empty (retriever setup likely failed).")
             raise ValueError("Training text data is required.")

        # Use train_df for labels if available, otherwise could adapt if labels are separate
        if train_df is not None and not train_df.empty:
             train_labels = train_df['label'].tolist()
             if len(train_texts) != len(train_labels):
                  # This case should ideally not happen if train_texts came from train_df
                  logging.error(f"Mismatch between global train_texts ({len(train_texts)}) and train_df labels ({len(train_labels)}).")
                  raise ValueError("Text and label count mismatch for training data.")

             train_dataset = GraphRagNewsDataset(
                texts=train_texts, # Use module-level global texts
                labels=train_labels,
                tokenizer=classifier_tokenizer, max_len=Config.max_seq_length,
                retriever_func=retrieve_documents_graph_like, # Pass the retrieval function
                num_retrieved_total=Config.max_total_retrieved_docs
             )
        else:
             logging.error("Cannot create train_dataset as train_df (needed for labels) is invalid or empty.")
             raise ValueError("Training data (train_df with labels) is required.")


        if val_df is not None and not val_df.empty:
            val_dataset = GraphRagNewsDataset(
                texts=val_df['tweet'].tolist(), labels=val_df['label'].tolist(),
                tokenizer=classifier_tokenizer, max_len=Config.max_seq_length,
                retriever_func=retrieve_documents_graph_like,
                num_retrieved_total=Config.max_total_retrieved_docs
            )

        if test_df is not None and not test_df.empty:
            test_dataset = GraphRagNewsDataset(
                texts=test_df['tweet'].tolist(), labels=test_df['label'].tolist(),
                tokenizer=classifier_tokenizer, max_len=Config.max_seq_length,
                retriever_func=retrieve_documents_graph_like,
                num_retrieved_total=Config.max_total_retrieved_docs
            )
        logging.info("Datasets created.")

        # --- Determine Number of Workers ---
        num_workers = 0
        if Config.DEBUG_DATALOADER:
            logging.warning("DEBUG MODE ACTIVE: Setting num_workers = 0 for DataLoader.")
        else:
            if Config.device == torch.device("cuda"):
                current_start_method = mp.get_start_method(allow_none=True)
                if current_start_method == 'spawn':
                    cpu_count = os.cpu_count()
                    num_workers = min(4, cpu_count // 2 if cpu_count else 1) if cpu_count else 0
                    if num_workers > 0:
                        logging.info(f"Using {num_workers} dataloader workers (CUDA + spawn method detected).")
                    else:
                        logging.info("Calculated num_workers is 0. Using main process for data loading.")
                else:
                    logging.warning(f"CUDA available but multiprocessing start method is '{current_start_method}' (expected 'spawn'). Using 0 workers for safety.")
            else:
                logging.info("Not using CUDA. Setting num_workers = 0.")

        pin_memory = (num_workers > 0 and Config.device == torch.device("cuda"))
        logging.info(f"DataLoader pin_memory set to: {pin_memory}")

        # --- Create DataLoaders ---
        if train_dataset:
            train_dataloader = DataLoader(
                train_dataset, batch_size=Config.batch_size, shuffle=True,
                num_workers=num_workers, pin_memory=pin_memory, drop_last=True,
                # persistent_workers=(num_workers > 0) # Consider persistent workers if stable
            )
        if val_dataset:
            val_num_workers = min(num_workers, 2)
            val_pin_memory = (val_num_workers > 0 and Config.device == torch.device("cuda"))
            val_dataloader = DataLoader(
                val_dataset, batch_size=Config.batch_size, shuffle=False,
                num_workers=val_num_workers, pin_memory=val_pin_memory
            )
        if test_dataset:
            test_num_workers = min(num_workers, 2)
            test_pin_memory = (test_num_workers > 0 and Config.device == torch.device("cuda"))
            test_dataloader = DataLoader(
                test_dataset, batch_size=Config.batch_size, shuffle=False,
                num_workers=test_num_workers, pin_memory=test_pin_memory
            )
        logging.info("DataLoaders created.")

        if train_dataloader: logging.info(f"  Train DataLoader: {len(train_dataloader)} batches.")
        else: logging.error("Train dataloader creation failed. Training cannot proceed."); exit(1)
        if val_dataloader: logging.info(f"  Validation DataLoader: {len(val_dataloader)} batches.")
        else: logging.warning("Validation dataloader not created.")
        if test_dataloader: logging.info(f"  Test DataLoader: {len(test_dataloader)} batches.")
        else: logging.warning("Test dataloader not created.")

    except Exception as e:
        logging.error(f"FATAL: Error during DataLoader creation phase: {e}", exc_info=True); exit(1)

    # --- Model Definition ---
    logging.info("--- Initializing Classifier Model ---")
    classifier_model_instance = None # Use a different name to avoid confusion with global
    try:
        logging.info(f"Loading classifier model: {Config.classifier_model_name}...")
        classifier_model_instance = AutoModelForSequenceClassification.from_pretrained(
            Config.classifier_model_name,
            num_labels=2 # Binary classification (real/fake)
        )
        classifier_model_instance.to(Config.device)
        logging.info("Classifier model loaded and moved to device.")
    except Exception as e:
        logging.error(f"FATAL: Failed to initialize classifier model: {e}", exc_info=True); exit(1)

    # --- Training Setup ---
    logging.info("--- Setting up Training Components ---")
    optimizer = None; scheduler = None; loss_fn = None; scaler = None
    if train_dataloader and len(train_dataloader) > 0:
        try:
            # Pass the specific model instance to the optimizer
            optimizer = AdamW(classifier_model_instance.parameters(), lr=Config.learning_rate, eps=1e-8)
            num_update_steps_per_epoch = len(train_dataloader) // Config.gradient_accumulation_steps
            total_steps = num_update_steps_per_epoch * Config.epochs
            scheduler = get_linear_schedule_with_warmup(
                optimizer, num_warmup_steps=Config.warmup_steps, num_training_steps=total_steps
            )
            loss_fn = nn.CrossEntropyLoss()
            scaler = GradScaler(enabled=Config.use_amp)

            logging.info(f"Optimizer, Scheduler, Loss Function, and GradScaler (AMP enabled: {Config.use_amp}) initialized.")
            logging.info(f"Total training steps (considering grad accum): {total_steps}")
            logging.info(f"Warmup steps: {Config.warmup_steps}")

        except Exception as e:
            logging.error(f"FATAL: Error during training setup: {e}", exc_info=True); exit(1)
    else:
        logging.error("FATAL: Training dataloader is invalid or empty. Cannot setup training components."); exit(1)

    # --- Training Loop ---
    logging.info("--- Starting Training ---")
    logging.info("Tip: For further speed optimization, consider profiling the code using tools like cProfile or torch.profiler.")

    best_val_f1 = -1.0
    global_step = 0
    training_start_time = time.time()

    for epoch in range(Config.epochs):
        logging.info(f"===== Epoch {epoch + 1}/{Config.epochs} =====")
        epoch_start_time = time.time()
        # Use the specific model instance for training
        classifier_model_instance.train()
        total_train_loss = 0
        processed_batches = 0
        optimizer.zero_grad()

        progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1} Training", leave=False)
        for step, batch in enumerate(progress_bar):
            if not isinstance(batch, dict) or 'input_ids' not in batch or 'attention_mask' not in batch or 'labels' not in batch:
                logging.warning(f"Skipping invalid batch (not dict or missing keys) at step {step} in epoch {epoch + 1}.")
                continue
            if batch['input_ids'].numel() == 0 or batch['labels'].numel() == 0:
                 logging.warning(f"Skipping empty batch (zero elements) at step {step} in epoch {epoch + 1}.")
                 continue

            try:
                input_ids = batch['input_ids'].to(Config.device)
                attention_mask = batch['attention_mask'].to(Config.device)
                labels = batch['labels'].to(Config.device)

                with autocast(enabled=Config.use_amp):
                    # Use the specific model instance
                    outputs = classifier_model_instance(input_ids=input_ids,
                                                        attention_mask=attention_mask,
                                                        labels=labels)
                    loss = outputs.loss

                if Config.gradient_accumulation_steps > 1:
                    loss = loss / Config.gradient_accumulation_steps

                scaler.scale(loss).backward()

                total_train_loss += loss.item() * Config.gradient_accumulation_steps
                processed_batches += 1

                if (step + 1) % Config.gradient_accumulation_steps == 0 or (step + 1) == len(train_dataloader):
                    scaler.unscale_(optimizer)
                    # Clip gradients for the specific model instance
                    torch.nn.utils.clip_grad_norm_(classifier_model_instance.parameters(), 1.0)
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()
                    if scheduler: scheduler.step()
                    global_step += 1

                    current_lr = scheduler.get_last_lr()[0] if scheduler else Config.learning_rate
                    progress_bar.set_postfix({
                        'loss': f"{loss.item() * Config.gradient_accumulation_steps:.4f}",
                        'lr': f"{current_lr:.2e}",
                        'scale': f"{scaler.get_scale():.1f}",
                        'step': global_step
                    })

            except Exception as e:
                logging.error(f"Error during training step {step} in epoch {epoch + 1}: {e}", exc_info=True)
                logging.warning("Attempting to skip problematic batch and continue training...")
                optimizer.zero_grad()
                del batch, input_ids, attention_mask, labels
                if 'outputs' in locals(): del outputs
                if 'loss' in locals(): del loss
                gc.collect()
                if Config.device == torch.device("cuda"): torch.cuda.empty_cache()
                continue

        progress_bar.close()
        epoch_duration = time.time() - epoch_start_time
        avg_train_loss = total_train_loss / processed_batches if processed_batches > 0 else 0
        current_lr = scheduler.get_last_lr()[0] if scheduler else Config.learning_rate
        logging.info(f"Epoch {epoch + 1} completed in {epoch_duration:.2f}s.")
        logging.info(f"  Average Training Loss: {avg_train_loss:.4f}")
        logging.info(f"  Current Learning Rate: {current_lr:.2e}")
        logging.info(f"  AMP Scale Factor: {scaler.get_scale():.1f}")

        # --- Validation Step ---
        if val_dataloader and len(val_dataloader) > 0:
            logging.info("--- Evaluating on Validation Set ---")
            # Pass the specific model instance for evaluation
            val_loss, val_acc, val_prec, val_rec, val_f1 = evaluate_model(
                classifier_model_instance, val_dataloader, Config.device, loss_fn
            )

            if val_f1 > best_val_f1:
                logging.info(f"Validation F1 improved ({best_val_f1:.4f} --> {val_f1:.4f}). Saving model...")
                best_val_f1 = val_f1
                os.makedirs(Config.output_dir, exist_ok=True)
                # Save the specific model instance
                classifier_model_instance.save_pretrained(Config.output_dir)
                if classifier_tokenizer: classifier_tokenizer.save_pretrained(Config.output_dir)
                logging.info(f"Model saved to {Config.output_dir}")
            else:
                logging.info(f"Validation F1 did not improve ({val_f1:.4f}). Best F1: {best_val_f1:.4f}")
        else:
            logging.warning("Skipping validation: Validation dataloader not available or empty.")

        gc.collect()
        if Config.device == torch.device("cuda"): torch.cuda.empty_cache()

    # --- End of Training Loop ---
    training_duration = time.time() - training_start_time
    logging.info(f"--- Training Finished --- (Total Duration: {training_duration:.2f} seconds)")
    logging.info(f"Best Validation F1 achieved: {best_val_f1:.4f}")

    # --- Final Evaluation on Test Set ---
    if test_dataloader and len(test_dataloader) > 0:
        logging.info("--- Evaluating on Test Set using the Best Model ---")
        try:
            best_model_path = Config.output_dir
            model_file_path_bin = os.path.join(best_model_path, "pytorch_model.bin")
            model_file_path_safe = os.path.join(best_model_path, "model.safetensors")
            config_file_path = os.path.join(best_model_path, "config.json")
            tokenizer_config_path = os.path.join(best_model_path, "tokenizer_config.json")

            essential_files_exist = (
                os.path.exists(model_file_path_bin) or os.path.exists(model_file_path_safe)
            ) and os.path.exists(config_file_path) and os.path.exists(tokenizer_config_path)

            if essential_files_exist:
                logging.info(f"Loading best model from {best_model_path} for final test evaluation...")
                best_model_instance = AutoModelForSequenceClassification.from_pretrained(best_model_path)
                best_tokenizer = AutoTokenizer.from_pretrained(best_model_path)
                best_model_instance.to(Config.device)
                logging.info("Best model and tokenizer loaded successfully.")

                logging.info("Re-creating test dataset and dataloader with loaded best tokenizer...")
                if test_df is not None and not test_df.empty:
                    test_dataset_final = GraphRagNewsDataset(
                        texts=test_df['tweet'].tolist(), labels=test_df['label'].tolist(),
                        tokenizer=best_tokenizer,
                        max_len=Config.max_seq_length,
                        retriever_func=retrieve_documents_graph_like,
                        num_retrieved_total=Config.max_total_retrieved_docs
                    )
                    num_workers_test = 0
                    pin_memory_test = False
                    test_dataloader_final = DataLoader(
                        test_dataset_final, batch_size=Config.batch_size, shuffle=False,
                        num_workers=num_workers_test, pin_memory=pin_memory_test
                    )
                    logging.info(f"Final test dataloader created with {len(test_dataloader_final)} batches.")

                    logging.info("Evaluating best model on the final test set...")
                    final_loss_fn = loss_fn if 'loss_fn' in locals() and loss_fn is not None else None
                    if final_loss_fn is None: logging.warning("Loss function not available for final test evaluation.")
                    # Evaluate the loaded best model instance
                    evaluate_model(best_model_instance, test_dataloader_final, Config.device, final_loss_fn)
                else:
                    logging.error("Cannot perform final evaluation: Test data (test_df) is unavailable or empty.")

            # Fallback: Use the model instance from the end of training if no best model was saved
            elif classifier_model_instance:
                logging.warning("No saved best model found or essential files missing.")
                logging.warning("Evaluating using the model's final state after training...")
                final_loss_fn = loss_fn if 'loss_fn' in locals() and loss_fn is not None else None
                if final_loss_fn is None: logging.warning("Loss function not available for final test evaluation.")
                # Evaluate the final state model instance
                evaluate_model(classifier_model_instance, test_dataloader, Config.device, final_loss_fn)
            else:
                logging.error("Skipping final evaluation: No model available.")

        except Exception as e:
            logging.error(f"Error during final test evaluation phase: {e}", exc_info=True)
    else:
        logging.warning("Skipping final evaluation: Test dataloader not available or empty.")

    logging.info("--- Script Finished ---")


2025-04-11 00:45:44,418 - INFO - [MainProcess] Multiprocessing start method set to 'spawn' (was None).
2025-04-11 00:45:44,421 - INFO - [MainProcess] CUDA available. Using device: cuda
2025-04-11 00:45:44,421 - INFO - [MainProcess] Mixed Precision Training (AMP) enabled: True
2025-04-11 00:45:44,422 - INFO - [MainProcess] --- Setting up Retriever ---
2025-04-11 00:45:44,883 - INFO - [MainProcess] Loaded 6420 samples from ../Constraint_English_Train.xlsx (dropped 0 for NaN tweet, 0 for invalid label).
2025-04-11 00:45:44,884 - INFO - [MainProcess] Loading retriever model: all-MiniLM-L6-v2...
2025-04-11 00:45:44,885 - INFO - [MainProcess] Use pytorch device_name: cuda:0
2025-04-11 00:45:45,067 - INFO - [MainProcess] Load pretrained SentenceTransformer: all-MiniLM-L6-v2
2025-04-11 00:45:49,542 - INFO - [MainProcess] Retriever model 'all-MiniLM-L6-v2' loaded and moved to cuda.
2025-04-11 00:45:49,544 - INFO - [MainProcess] Embedding 6420 training documents for FAISS index...


Batches:   0%|          | 0/51 [00:00<?, ?it/s]

2025-04-11 00:45:51,717 - INFO - [MainProcess] Embeddings generated with shape: torch.Size([6420, 384])
2025-04-11 00:45:51,924 - INFO - [MainProcess] Building FAISS index (IVFFlat) with dimension=384...
2025-04-11 00:45:51,925 - INFO - [MainProcess] Training FAISS IVFFlat index with nlist=100...
2025-04-11 00:45:52,200 - INFO - [MainProcess] FAISS index training finished in 0.27s.
2025-04-11 00:45:52,258 - INFO - [MainProcess] Set FAISS nprobe to 10
2025-04-11 00:45:52,268 - INFO - [MainProcess] FAISS index built successfully. Type: IVFFlat, Total vectors: 6420
2025-04-11 00:45:52,269 - INFO - [MainProcess] --- Loading Data and Creating DataLoaders ---
2025-04-11 00:45:52,406 - INFO - [MainProcess] Loaded 2140 samples from ../Constraint_English_Val.xlsx (dropped 0 for NaN tweet, 0 for invalid label).
2025-04-11 00:45:52,542 - INFO - [MainProcess] Loaded 2140 samples from ../english_test_with_labels.xlsx (dropped 0 for NaN tweet, 0 for invalid label).
2025-04-11 00:45:52,542 - INFO - [

KeyboardInterrupt: 

In [1]:
import warnings
warnings.filterwarnings("ignore")