### Evaluate the model performance

In [None]:
import os
import pandas as pd
import torch
from tqdm.auto import tqdm # For progress bars
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)
from torch.utils.data import Dataset, DataLoader
import re # For improved label extraction
import numpy as np # For AE calculation

# === CONFIG (User should verify these paths and settings) ===
# Model for Reliability Scoring (Fine-tuned)
CHECKPOINTS_DIR = "/root/Fine-Tuning_Truth/granite-finetuned-articles" # Path to directory of fine-tuned checkpoints

# Model for Justification (Base Model)
BASE_MODEL_PATH = "/root/Fine-Tuning_Truth/granite-3.1-1b-a400m-base" # Path to the original base model

CSV_PATH = "val_articles_fine_tuning.csv"  # Update with your new data path
BATCH_SIZE_PREDICTION = 20
BATCH_SIZE_JUSTIFICATION = 20 # May need to be smaller for base model if it's larger or for longer prompts
MAX_LENGTH_PREDICTION = 1660
# Increase MAX_LENGTH_JUSTIFICATION for the few-shot prompt.
# Base models might handle longer contexts better if this is increased.
MAX_LENGTH_JUSTIFICATION = 768 # Max length for tokenizing justification prompts (includes statement & few-shot examples)
MAX_NEW_TOKENS_PREDICTION = 10
MAX_NEW_TOKENS_JUSTIFICATION = 300 # Desired length for the justification text

USE_GPUS = list(range(8))

# === Set visible GPUs ===
if USE_GPUS:
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, USE_GPUS))
    if torch.cuda.is_available():
        print(f"Using GPUs: {os.environ['CUDA_VISIBLE_DEVICES']}")
        DEVICE = "cuda"
    else:
        print("CUDA not available, using CPU.")
        DEVICE = "cpu"
else:
    print("No GPUs specified, using CPU.")
    DEVICE = "cpu"

# === Helper to find latest checkpoint ===
def get_latest_checkpoint(checkpoints_dir):
    if not os.path.exists(checkpoints_dir):
        print(f"Error: Checkpoints directory not found at {checkpoints_dir}")
        return None
    checkpoints = [d for d in os.listdir(checkpoints_dir) if d.startswith("checkpoint-") and os.path.isdir(os.path.join(checkpoints_dir, d))]
    if not checkpoints:
        print(f"No checkpoints found in {checkpoints_dir}")
        return None
    checkpoints.sort(key=lambda x: int(x.split("-")[1]))
    latest_checkpoint_name = checkpoints[-1]
    print(f"Found latest checkpoint for scoring: {latest_checkpoint_name}")
    return os.path.join(checkpoints_dir, latest_checkpoint_name)

# === Custom Dataset for Reliability Prediction ===
class PredictionDataset(Dataset):
    def __init__(self, statements, tokenizer, max_length):
        self.statements = statements
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.statements)

    def __getitem__(self, idx):
        statement = str(self.statements[idx])
        prompt = f"Assess the reliability of this statement (article from the Internet) on the scale from 0 to 1, where 0 is completely unreliable and 1 is completely reliable. Do not provide any explanation. Just the number:\nStatement: {statement}\nLabel:"
        encoding = self.tokenizer(
            prompt,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt"
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0)
        }

# === Custom Dataset for Justification Generation (Using Few-Shot) ===
class JustificationDataset(Dataset):
    def __init__(self, statements, scores, tokenizer, max_length):
        self.statements = statements
        self.scores = scores
        self.tokenizer = tokenizer # This will be the tokenizer for the BASE model
        self.max_length = max_length

    def __len__(self):
        return len(self.statements)

    def get_justification_prompt(self, statement, score):
        statement_str = str(statement)
        max_statement_chars_in_prompt = 700
        if len(statement_str) > max_statement_chars_in_prompt:
            statement_str = statement_str[:max_statement_chars_in_prompt] + "..."

        example_statement_reliable = "The Eiffel Tower is located in Paris, France. It is a famous landmark."
        example_score_reliable = 0.95
        example_justification_reliable = """- The statement is factually accurate (Eiffel Tower is in Paris).
- It describes a well-known fact, easily verifiable.
- Writing quality is good and consistent."""

        example_statement_unreliable = "The moon is made of green cheese and visited by cows weekly."
        example_score_unreliable = 0.05
        example_justification_unreliable = """- The statement contains obvious factual inaccuracies (moon not cheese, cows don't visit).
- It presents scientifically implausible claims.
- Lacks any supporting evidence or credibility."""

        prompt_intro = f"""You are an expert analyst. Your task is to provide a concise, bullet-point justification for a given reliability score of a statement.
The reliability score is on a scale from 0 (completely unreliable) to 1 (completely reliable).
Your justification should ONLY be the bullet points explaining the score. Do NOT repeat the statement or the score in your response.

Here are some examples of how to format your justification:

Example 1:
Statement context: "{example_statement_reliable}"
Assigned reliability score: {example_score_reliable:.2f}
Correct Justification:
{example_justification_reliable}

Example 2:
Statement context: "{example_statement_unreliable}"
Assigned reliability score: {example_score_unreliable:.2f}
Correct Justification:
{example_justification_unreliable}

---
Now, provide the justification for the following:
"""
        if score == -1.0 or pd.isna(score):
            current_task_prompt = f"""Statement context: "{statement_str}"
The reliability score for this statement could not be determined.
Provide your overall impression of this statement's potential reliability using concise bullet points.
Consider factors like text consistency, apparent factual accuracy, and potential AI generation.
Justification:"""
        else:
            current_task_prompt = f"""Statement context: "{statement_str}"
Assigned reliability score: {score:.2f}
Provide a concise bullet-point justification for THIS SCORE.
Justification:"""
        return prompt_intro + "\n" + current_task_prompt

    def __getitem__(self, idx):
        statement = self.statements[idx]
        score = self.scores[idx]
        prompt_text = self.get_justification_prompt(statement, score)
        encoding = self.tokenizer(
            prompt_text,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt"
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0)
        }

# === Improved Label Extraction Function ===
def extract_label(text):
    try:
        match = re.search(r"Label:(.*?)([\d\.]+)", text, re.IGNORECASE | re.DOTALL)
        if match:
            label_str = match.group(2).strip()
            if label_str.endswith('.'): label_str = label_str[:-1]
            return float(label_str)
        else:
            parts = text.split("Label:")
            target_part = parts[-1] if len(parts) > 1 else text
            numeric_match = re.search(r"([\d\.]+)", target_part)
            if numeric_match:
                label_str = numeric_match.group(1).strip()
                if label_str.endswith('.'): label_str = label_str[:-1]
                return float(label_str)
            return None
    except ValueError:
        return None

In [None]:
print("--- Starting Evaluation Script ---")

# --- Load Fine-tuned Model for Reliability Scoring ---
latest_checkpoint_scoring = get_latest_checkpoint(CHECKPOINTS_DIR)
if not latest_checkpoint_scoring:
    raise ValueError("Exiting due to missing fine-tuned model checkpoint for scoring.")

print(f"Loading fine-tuned scoring model and its tokenizer from: {latest_checkpoint_scoring}")
try:
    tokenizer_scoring = AutoTokenizer.from_pretrained(latest_checkpoint_scoring)
    if tokenizer_scoring.pad_token is None:
        tokenizer_scoring.pad_token = tokenizer_scoring.eos_token
        print("Scoring tokenizer pad_token set to eos_token.")

    model_scoring = AutoModelForCausalLM.from_pretrained(
        latest_checkpoint_scoring,
        torch_dtype=torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float16,
        device_map="auto" # Let Hugging Face handle device mapping
    )
    print("Fine-tuned scoring model and tokenizer loaded successfully.")
except Exception as e:
    print(f"Error loading fine-tuned scoring model or tokenizer: {e}")
    raise

# --- Load Base Model for Justification ---
if not os.path.exists(BASE_MODEL_PATH):
    raise ValueError(f"Base model for justification not found at {BASE_MODEL_PATH}")

print(f"Loading base model for justification and its tokenizer from: {BASE_MODEL_PATH}")
try:
    # It's good practice to load the specific tokenizer for the base model,
    # even if it's expected to be the same as the fine-tuned one.
    tokenizer_justification = AutoTokenizer.from_pretrained(BASE_MODEL_PATH)
    if tokenizer_justification.pad_token is None:
        tokenizer_justification.pad_token = tokenizer_justification.eos_token
        print("Justification tokenizer pad_token set to eos_token.")

    model_justification = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL_PATH,
        torch_dtype=torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float16,
        device_map="auto" # Let Hugging Face handle device mapping
    )
    print("Base justification model and tokenizer loaded successfully.")
except Exception as e:
    print(f"Error loading base justification model or tokenizer: {e}")
    raise

# === Load data ===
try:
    df_full = pd.read_csv(CSV_PATH)
except FileNotFoundError:
    print(f"Error: Data CSV not found at {CSV_PATH}")
    raise
if 'statement' not in df_full.columns or 'labels' not in df_full.columns:
    print(f"Error: CSV must contain 'statement' and 'labels' columns. Found: {df_full.columns}")
    raise

df = df_full.sample(n=100, random_state=42).reset_index(drop=True)
# df = df_full.copy() # For full evaluation
print(f"Loaded data with {len(df)} rows.")
df

In [None]:
# --- 1. Reliability Prediction (using Fine-tuned Model) ---
print("\n--- Generating Reliability Predictions ---")
statements_for_prediction = df['statement'].tolist()
# Use tokenizer_scoring for the prediction dataset
prediction_dataset = PredictionDataset(statements_for_prediction, tokenizer_scoring, MAX_LENGTH_PREDICTION)
prediction_dataloader = DataLoader(prediction_dataset, batch_size=BATCH_SIZE_PREDICTION, shuffle=False)

predicted_scores_raw = []
model_scoring.eval()
with torch.no_grad():
    for batch in tqdm(prediction_dataloader, desc="Predicting Reliability", unit="batch"):
        inputs = {
            "input_ids": batch["input_ids"].to(model_scoring.device), # Ensure tensors go to the correct model's device
            "attention_mask": batch["attention_mask"].to(model_scoring.device)
        }
        outputs = model_scoring.generate(
            **inputs,
            max_new_tokens=MAX_NEW_TOKENS_PREDICTION,
            do_sample=False,
            pad_token_id=tokenizer_scoring.pad_token_id
        )
        # Use tokenizer_scoring to decode
        decoded_outputs = tokenizer_scoring.batch_decode(outputs, skip_special_tokens=True)
        for text in decoded_outputs:
            predicted_scores_raw.append(text)

extracted_labels = []
problematic_extractions_indices = []
for i, text in enumerate(predicted_scores_raw):
    label = extract_label(text)
    if label is None:
        print(f"Warning: Could not extract a valid numeric label from raw output for statement index {i}: '{text}'")
        problematic_extractions_indices.append(i)
    extracted_labels.append(label)

#df['predicted_label_raw'] = predicted_scores_raw
df['predicted_label'] = extracted_labels
df['predicted_label'] = pd.to_numeric(df['predicted_label'], errors='coerce').fillna(-1.0).astype(float)
df['predicted_label'] = df['predicted_label'].apply(lambda x: min(max(x, 0.0), 1.0) if x != -1.0 else -1.0)

num_failed_predictions = (df['predicted_label'] == -1.0).sum()
print(f"Number of statements where label extraction failed (marked as -1.0): {num_failed_predictions} out of {len(df)}")
if problematic_extractions_indices:
    print(f"Indices of statements with problematic extractions: {problematic_extractions_indices[:20]} (showing first 20 if many)")

df1 = df[df['predicted_label'] != -1.0].copy()
if not df1.empty:
    df1['AE'] = np.abs(df1['predicted_label'] - df1['labels'])
    df1['AE'] = df1['AE'].round(3)
    print("\nAbsolute Error (AE) calculated for valid predictions:")
    print(df1[['statement', 'labels', 'predicted_label', 'AE']].head())
    print("\nValue counts for AE (on valid predictions):")
    print(df1['AE'].value_counts().sort_index())
    mae = df1['AE'].mean()
    print(f"\nMean Absolute Error (MAE) on valid predictions: {mae:.4f}")
else:
    print("No valid predictions to calculate AE.")
df1

In [None]:
# --- 2. Justification Generation (using Base Model) ---
print("\n--- Generating Justifications ---")

statements_for_justification = df['statement'].tolist()
scores_for_justification = df['predicted_label'].tolist()

# Use tokenizer_justification for the justification dataset
justification_dataset = JustificationDataset(statements_for_justification, scores_for_justification, tokenizer_justification, MAX_LENGTH_JUSTIFICATION)
justification_dataloader = DataLoader(justification_dataset, batch_size=BATCH_SIZE_JUSTIFICATION, shuffle=False)

generated_justifications_clean = []
model_justification.eval()
with torch.no_grad():
    for batch in tqdm(justification_dataloader, desc="Generating Justifications", unit="batch"):
        input_ids_batch = batch["input_ids"].to(model_justification.device) # Ensure tensors go to correct model's device
        attention_mask_batch = batch["attention_mask"].to(model_justification.device)

        outputs = model_justification.generate(
            input_ids=input_ids_batch,
            attention_mask=attention_mask_batch,
            max_new_tokens=MAX_NEW_TOKENS_JUSTIFICATION,
            do_sample=True,
            temperature=0.6, # Base models might benefit from slightly higher temp if too bland, or lower if too random
            top_p=0.9,
            repetition_penalty=1.2, # Still useful for base models
            pad_token_id=tokenizer_justification.pad_token_id
        )
        
        for i in range(outputs.shape[0]):
            prompt_tokens_count = input_ids_batch.shape[1]
            justification_tokens = outputs[i][prompt_tokens_count:]
            # Use tokenizer_justification to decode
            justification_text = tokenizer_justification.decode(justification_tokens, skip_special_tokens=True).strip()
            
            # Basic cleanup if model still prepends "Justification:"
            if justification_text.lower().startswith("justification:"):
                justification_text = justification_text[len("justification:"):].strip()
            # Additional cleanup: remove any repeated prompt fragments if the model is still confused by "Statement context:"
            # This is a bit of a heuristic.
            if "Statement context:".lower() in justification_text.lower() :
                 # Try to find the text *after* the last occurrence of "Statement context:" or the actual justification markers
                parts_after_context = re.split(r'Correct Justification:|Justification:', justification_text, flags=re.IGNORECASE)
                if len(parts_after_context) >1:
                    justification_text = parts_after_context[-1].strip()


            generated_justifications_clean.append(justification_text)

df['predicted_justification'] = generated_justifications_clean
print("\n--- Sample of Statements with Predictions and Cleaned Justifications ---")
pd.set_option('display.max_colwidth', 200)
print(df[['statement', 'labels', 'predicted_label', 'predicted_justification']].head())

if len(df['predicted_justification']) > 0 and df['predicted_justification'][0]:
    print("\nSample justification (first item):")
    print(df['predicted_justification'][0])
    print(f"Length of first justification: {len(tokenizer_justification.encode(df['predicted_justification'][0]))} tokens (approx)")
else:
    print("\nFirst justification is empty or not available.")