In [None]:
import os
import pandas as pd
import torch
from tqdm.auto import tqdm # For progress bars
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
)
from torch.utils.data import Dataset, DataLoader
import re # For improved label extraction
import numpy as np # For AE calculation
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
example1_score = 0.1
example2_score = 0.35
example3_score = 1.0

example1_statement = 'Dogs can photosynthesize sunlight to produce their own food.'

example2_statement = 'It’s often said that sleeping less than four hours per night improves creativity.'

example3_statement = 'Mount Everest is the highest mountain above sea level on Earth.'

In [3]:
def get_prediction_prompt(statement: str) -> str:
    """
    Creates a concise prompt for a raw LLM to predict a reliability score
    on a scale of 0.1 to 1.0 based on the statement's content, using few-shot examples.
    """
    
    prompt = f"""Assess the reliability of the following statement and provide a numerical score between 0.1 (completely unreliable) and 1.00 (perfectly reliable). Aim for a precise score.
Here are a few examples of statements and their corresponding reliability scores:
Example 1:
Statement: {example1_statement}
Score: {example1_score:.2f}
Example 2:
Statement: {example2_statement}
Score: {example2_score:.2f}
Example 3:
Statement: {example3_statement}
Score: {example3_score:.2f}
Now, assess the reliability of the following statement according to the instructions and examples above:
Statement: {statement}"""
    return prompt

In [4]:
CSV_PATH = "val_short_granite.csv"  # Update with your new data path
BATCH_SIZE_PREDICTION = 15
MAX_LENGTH_PREDICTION = 2240
# Increase MAX_LENGTH_JUSTIFICATION for the few-shot prompt.
# Base models might handle longer contexts better if this is increased.
MAX_NEW_TOKENS_PREDICTION = 20 # 10

USE_GPUS = list(range(10))

# === Set visible GPUs ===
if USE_GPUS:
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(map(str, USE_GPUS))
    if torch.cuda.is_available():
        print(f"Using GPUs: {os.environ['CUDA_VISIBLE_DEVICES']}")
        DEVICE = "cuda"
    else:
        print("CUDA not available, using CPU.")
        DEVICE = "cpu"
else:
    print("No GPUs specified, using CPU.")
    DEVICE = "cpu"

class PredictionDataset(Dataset):
    def __init__(self, statements, tokenizer, max_length):
        self.statements = statements
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.statements)

    def __getitem__(self, idx):
        statement = str(self.statements[idx])
        # Compose the initial prompt up to the "Reliability Score:" part
        # prompt_prefix = (
        #     "Assess the reliability score of the following statement on a continuous scale from 0.1 (very unreliable) to 1 (perfectly reliable). "
        #     "You should consider factors such as factual accuracy, verifiability, logical coherence, and evidence transparency. Provide only the numerical score:\n"
        #     f"Statement: {statement}\n"
        # )

        prompt_prefix = get_prediction_prompt(statement)

        # Reserve 10 tokens for "Reliability Score:"
        max_prompt_tokens = self.max_length - 10
        # Tokenize the prefix and truncate if necessary
        prefix_ids = self.tokenizer.encode(prompt_prefix, add_special_tokens=False)
        if len(prefix_ids) > max_prompt_tokens:
            # Truncate the prefix to fit
            prefix_ids = prefix_ids[:max_prompt_tokens]
            prompt_prefix = self.tokenizer.decode(prefix_ids, skip_special_tokens=True)
        # Add the "Reliability Score:" at the end
        prompt = prompt_prefix + "Reliability Score:"
        encoding = self.tokenizer(
            prompt,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt"
        )
        return {
            "input_ids": encoding["input_ids"].squeeze(0),
            "attention_mask": encoding["attention_mask"].squeeze(0)
        }
    
def extract_label(text):
    try:
        match = re.search(r"Reliability Score:(.*?)([\d\.]+)", text, re.IGNORECASE | re.DOTALL)
        if match:
            label_str = match.group(2).strip()
            if label_str.endswith('.'): label_str = label_str[:-1]
            return float(label_str)
        else:
            parts = text.split("Reliability Score:")
            target_part = parts[-1] if len(parts) > 1 else text
            numeric_match = re.search(r"([\d\.]+)", target_part)
            if numeric_match:
                label_str = numeric_match.group(1).strip()
                if label_str.endswith('.'): label_str = label_str[:-1]
                return float(label_str)
            return None
    except ValueError:
        return None

Using GPUs: 0,1,2,3,4,5,6,7,8,9


In [5]:
df_full = pd.read_csv(CSV_PATH)

In [6]:
def evaluate_model(CHECKPOINTS_DIR, padding_side='left'):
    df = df_full.copy()
    print('Use model: ', CHECKPOINTS_DIR)
    tokenizer_scoring = AutoTokenizer.from_pretrained(CHECKPOINTS_DIR, padding_side=padding_side)
    if tokenizer_scoring.pad_token is None:
        tokenizer_scoring.pad_token = tokenizer_scoring.eos_token
    model_scoring = AutoModelForCausalLM.from_pretrained(
        CHECKPOINTS_DIR,
        torch_dtype=torch.bfloat16 if DEVICE == "cuda" and torch.cuda.is_bf16_supported() else torch.float16,
        device_map="auto" # Let Hugging Face handle device mapping
    )
    # --- 1. Reliability Prediction (using Fine-tuned Model) ---
    statements_for_prediction = df['statement'].tolist()
    # Use tokenizer_scoring for the prediction dataset
    prediction_dataset = PredictionDataset(statements_for_prediction, tokenizer_scoring, MAX_LENGTH_PREDICTION)
    prediction_dataloader = DataLoader(prediction_dataset, batch_size=BATCH_SIZE_PREDICTION, shuffle=False)
    predicted_scores_raw = []
    model_scoring.eval()
    with torch.no_grad():
        for batch in tqdm(prediction_dataloader, desc="Predicting Reliability", unit="batch"):
            inputs = {
                "input_ids": batch["input_ids"].to(model_scoring.device), # Ensure tensors go to the correct model's device
                "attention_mask": batch["attention_mask"].to(model_scoring.device)
            }
            outputs = model_scoring.generate(
                **inputs,
                max_new_tokens=MAX_NEW_TOKENS_PREDICTION,
                do_sample=False,
                pad_token_id=tokenizer_scoring.pad_token_id
            )
            # Use tokenizer_scoring to decode
            decoded_outputs = tokenizer_scoring.batch_decode(outputs, skip_special_tokens=True)
            for text in decoded_outputs:
                predicted_scores_raw.append(text)

    # Extract the part after 'Reliability Score:' for each item in predicted_scores_raw
    # predicted_scores_raw = [
    #     text.split('Reliability Score:', 1)[1].strip() if 'Reliability Score:' in text else text
    #     for text in predicted_scores_raw
    # ]
    extracted_labels = []
    problematic_extractions_indices = []
    for i, text in enumerate(predicted_scores_raw):
        label = extract_label(text)
        if label is None:
            print(f"Warning: Could not extract a valid numeric label from raw output for statement index {i}: '{text}'")
            problematic_extractions_indices.append(i)
        extracted_labels.append(label)

    df['predicted_label_raw'] = predicted_scores_raw
    df['predicted_label'] = extracted_labels
    df['predicted_label'] = pd.to_numeric(df['predicted_label'], errors='coerce').fillna(-1.0).astype(float)
    df['predicted_label'] = df['predicted_label'].apply(lambda x: min(max(x, 0.0), 1.0) if x != -1.0 else -1.0)

    df1 = df[df['predicted_label'] != -1.0].copy()
    mae = mean_absolute_error(df1['labels'], df1['predicted_label'])
    mse = mean_squared_error(df1['labels'], df1['predicted_label'])
    rmse = float(np.sqrt(mse))
    r2 = r2_score(df1['labels'], df1['predicted_label'])
    st_dev = float(df1['predicted_label'].std())
    num_failed_predictions_ratio = (len(df) - len(df1)) / len(df)
    majority_prediction_ratio = float(df1.predicted_label.value_counts().max()) / len(df1)
    most_common = float(df1.predicted_label.value_counts().keys()[0])
    return df1, mae, rmse, r2, st_dev, num_failed_predictions_ratio, majority_prediction_ratio, most_common

### Granite

In [None]:
CHECKPOINTS_DIR = "/root/Fine-Tuning_Truth/granite-3.1-1b-a400m-base"
df_granite, mae_granite, rmse_granite, r2_granite, st_dev_granite, num_failed_predictions_ratio_granite, majority_prediction_ratio_granite, most_common_granite = evaluate_model(CHECKPOINTS_DIR)

In [8]:
num_failed_predictions_ratio_granite, majority_prediction_ratio_granite, most_common_granite, mae_granite, rmse_granite, r2_granite, st_dev_granite

(0.0,
 0.7604630454140695,
 0.5,
 0.2813143365983972,
 0.32455678937290755,
 -0.08982214832630864,
 0.12386311616333191)

In [9]:
hist_filename = 'histograms/granite_raw_histogram_p3.png'
plt.figure(figsize=(10, 6))
sns.histplot(df_granite['predicted_label'], bins=12, kde=True, color='blue')
plt.title(f'Histogram of Scores for Granite-1B Model (Prompt v3)')
plt.xlabel('Reliability Score')
plt.ylabel('Frequency')
plt.grid(axis='y', alpha=0.75)
plt.savefig(hist_filename)
plt.close()

### LLaMA

In [None]:
CHECKPOINTS_DIR = "/root/Fine-Tuning_Truth/local_models/TinyLlama-1.1B-Chat-v1.0"
df_llama, mae_llama, rmse_llama, r2_llama, st_dev_llama, num_failed_predictions_ratio_llama, majority_prediction_ratio_llama, most_common_llama = evaluate_model(CHECKPOINTS_DIR)

In [11]:
num_failed_predictions_ratio_llama, majority_prediction_ratio_llama, most_common_llama, mae_llama, rmse_llama, r2_llama, st_dev_llama

(0.05253784505788068,
 0.581766917293233,
 0.1,
 0.2894708646616542,
 0.413841808295432,
 -0.7610839368108193,
 0.277327794035006)

In [12]:
hist_filename = 'histograms/llama_histogram_p3.png'
plt.figure(figsize=(10, 6))
sns.histplot(df_llama['predicted_label'], bins=12, kde=True, color='blue')
plt.title(f'Histogram of Scores for TinyLlama-1.1B Model (Prompt v3)')
plt.xlabel('Reliability Score')
plt.ylabel('Frequency')
plt.grid(axis='y', alpha=0.75)
plt.savefig(hist_filename)
plt.close()

### Olmo

In [None]:
CHECKPOINTS_DIR = "/root/Fine-Tuning_Truth/local_models/olmo-2-0425-1b-instruct"
df_olmo, mae_olmo, rmse_olmo, r2_olmo, st_dev_olmo, num_failed_predictions_ratio_olmo, majority_prediction_ratio_olmo, most_common_olmo = evaluate_model(CHECKPOINTS_DIR)

In [14]:
num_failed_predictions_ratio_olmo, majority_prediction_ratio_olmo, most_common_olmo, mae_olmo, rmse_olmo, r2_olmo, st_dev_olmo

(0.0,
 0.3686553873552983,
 1.0,
 0.4674158504007124,
 0.5781094052311733,
 -2.4577556736853428,
 0.2083923604679034)

In [15]:
hist_filename = 'histograms/olmo_histogram_p3.png'
plt.figure(figsize=(10, 6))
sns.histplot(df_olmo['predicted_label'], bins=12, kde=True, color='blue')
plt.title(f'Histogram of Scores for Olmo-1B Model (Prompt v3)')
plt.xlabel('Reliability Score')
plt.ylabel('Frequency')
plt.grid(axis='y', alpha=0.75)
plt.savefig(hist_filename)
plt.close()

### Granite Fine-Tuned

In [None]:
CHECKPOINTS_DIR = "/root/Fine-Tuning_Truth/granite-V2-articles"
df_granite_ft, mae_granite_ft, rmse_granite_ft, r2_granite_ft, st_dev_granite_ft, num_failed_predictions_ratio_granite_ft, majority_prediction_ratio_granite_ft, most_common_granite_ft = evaluate_model(CHECKPOINTS_DIR)

In [17]:
num_failed_predictions_ratio_granite_ft, majority_prediction_ratio_granite_ft, most_common_granite_ft, mae_granite_ft, rmse_granite_ft, r2_granite_ft, st_dev_granite_ft

(0.0017809439002671415,
 0.36128456735057984,
 0.15,
 0.14088581623550403,
 0.20701247554488625,
 0.5565352843057769,
 0.2194935836808187)

In [18]:
hist_filename = 'histograms/granite_fine_tuned_histogram_p3.png'
plt.figure(figsize=(10, 6))
sns.histplot(df_granite_ft['predicted_label'], bins=12, kde=True, color='green')
plt.title(f'Histogram of Scores for TrueGL Model (Prompt v3)')
plt.xlabel('Reliability Score')
plt.ylabel('Frequency')
plt.grid(axis='y', alpha=0.75)
plt.savefig(hist_filename)
plt.close()