# Baseline Predictions
In this file we generate the baseline predictions

In [1]:
# pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
import torch
import transformers
import pandas as pd
import json
import logging
import tqdm
import re
import torch

In [2]:
################################################################################
#######################   PATH VARIABLES        ################################
################################################################################

test_dataset_path = "Data/Processed/test.json"
model_path_llama = "Models/Meta-Llama-3.1-8B-Instruct"
model_path_qwen = "Models/Qwen2.5-7B-Instruct"
results_path = "Evaluation/Results/"

################################################################################
#######################   STATIC VARIABLES      ################################
################################################################################

# Setup logger manually
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)

# Create file handler (only if not already added)
if not logger.handlers:
    fh = logging.FileHandler('baseline_predictions.log')
    fh.setLevel(logging.INFO)
    formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    fh.setFormatter(formatter)
    logger.addHandler(fh)

# Detect device
device = torch.device(
    "mps" if torch.backends.mps.is_available()
    else "cuda" if torch.cuda.is_available()
    else "cpu"
)

# Log the device info
logger.info("--------  Start with Baseline Predictions  -------------")
logger.info(f'Device selected: {device}')

## Zero Shot prompting
In this section we genererate critical questions with different pretrained vanilla models. We use this generated questions as a baseline to compare it against our results. The following models were used to generate the baseline results:
- LLama 3.1 8B Instruct
- Qwen 2.5 7B Instruct

In [3]:
models = [
    {
        "name": "llama",
        "model_id": model_path_llama,
        "output_file": results_path + "results_zeroshot_llama_3.1-8B-instruct.json",
    },
    {
        "name": "qwen",
        "model_id": model_path_qwen,
        "output_file": results_path + "results_zeroshot_qwen2.5-7b-instruction.json",
    },
]

In [4]:
def structure_output(whole_text):
    cqs_list = whole_text.split('\n')
    final = []
    valid = []
    not_valid = []
    for cq in cqs_list:
        if re.match(r'.*\?(\")?( )?(\([a-zA-Z0-9\.\'-\,\? ]*\))?([a-zA-Z \.,\"\']*)?(\")?$', cq):
            valid.append(cq)
        else:
            not_valid.append(cq)

    still_not_valid = []
    for text in not_valid:
        new_cqs = re.split(r'\?\"', text + 'end')
        if len(new_cqs) > 1:
            for cq in new_cqs[:-1]:
                valid.append(cq + '?"')
        else:
            still_not_valid.append(text)

    for i, cq in enumerate(valid):
        occurrence = re.search(r'[A-Z]', cq)
        if occurrence:
            final.append(cq[occurrence.start():])
        else:
            continue

    output = []
    if len(final) >= 3:
        for i in [0, 1, 2]:
            output.append({'id': i, 'cq': final[i]})
        return output
    else:
        return 'Missing CQs'

In [5]:
def generate_critical_questions(pipe, model_name, intervention_text):
    prompt = f"""Suggest 3 critical questions that should be raised before accepting the arguments in this text:\n\n\"{intervention_text}\"\n\nGive one question per line. Make the questions simple, and do not give any explanation regarding why the question is relevant."""

    messages = [
        {"role": "system", "content": "You are a helpful assistant."},
        {"role": "user", "content": prompt}
    ]

    outputs = pipe(
        messages,
        max_new_tokens=512,
        do_sample=True,
        temperature=0.6,
        top_p=0.9,
    )

    if model_name == "llama":
        assistant_response = outputs[0]["generated_text"][-1]["content"]
    elif model_name == "qwen":
        assistant_response = outputs[0]["generated_text"]
        if isinstance(assistant_response, list):
            assistant_response = "\n".join([m["content"] for m in assistant_response if isinstance(m, dict) and "content" in m])
    else:
        raise ValueError(f"Unsupported model name: {model_name}")

    structured = structure_output(assistant_response)
    return structured

In [6]:
with open(test_dataset_path, 'r') as f:
    data = json.load(f)

for model in models:
    print(f"Loading model: {model['model_id']}")
    logger.info(f"Loading model: {model['model_id']}")

    tokenizer = transformers.AutoTokenizer.from_pretrained(model["model_id"])
    pipe = transformers.pipeline(
        "text-generation",
        model=model["model_id"],
        model_kwargs={"torch_dtype": torch.bfloat16},
        device=device,
        pad_token_id=tokenizer.eos_token_id,
    )

    output_data = []

    for item_id, item in data.items():
        intervention_text = item["intervention"]
        questions = generate_critical_questions(pipe, model["name"], intervention_text)

        output_entry = {
            "id": item_id,
            "input_text": intervention_text,
            "cqs": questions
        }
        logger.info(f"Generated {item_id}: {questions}")
        output_data.append(output_entry)

    with open(model["output_file"], 'w') as f:
        json.dump(output_data, f, indent=2)

    logger.info(f"Output saved to {model['output_file']}")

Loading model: Models/Meta-Llama-3.1-8B-Instruct


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Device set to use mps
Sliding Window Attention is enabled but not implemented for `sdpa`; unexpected results may be encountered.


Loading model: Models/Qwen2.5-7B-Instruct


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Device set to use mps
