# Baseline Predictions
In this file we generate the baseline predictions

In [1]:
# pip3 install --pre torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/nightly/cpu
import torch
import transformers
import pandas as pd
import json
import logging
import tqdm
import re
import torch

In [3]:
################################################################################
#######################   PATH VARIABLES        ################################
################################################################################

sample_path = "Data/Raw/sharedTask/sample.json"
model_path_llama = "Models/Meta-Llama-3-8B-Instruct"
model_path_qwen = "Models/Meta-Qwen-3-8B-Instruct"

################################################################################
#######################   STATIC VARIABLES      ################################
################################################################################
logging.basicConfig()
logger = logging.getLogger()
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")


## Llama 3 8B Instruct

In [4]:
def structure_output(whole_text):
    cqs_list = whole_text.split('\n')
    final = []
    valid = []
    not_valid = []
    for cq in cqs_list:
        if re.match('.*\?(\")?( )?(\([a-zA-Z0-9\.\'\-,\? ]*\))?([a-zA-Z \.,\"\']*)?(\")?$', cq):
            valid.append(cq)
        else:
            not_valid.append(cq)

    still_not_valid = []
    for text in not_valid:
        new_cqs = re.split("\?\"", text+'end')
        if len(new_cqs) > 1:
            for cq in new_cqs[:-1]:
                valid.append(cq+'?\"')
        else:
            still_not_valid.append(text)

    for i, cq in enumerate(valid):
        occurrence = re.search(r'[A-Z]', cq)
        if occurrence:
            final.append(cq[occurrence.start():])
        else:
            continue

    output = []
    if len(final) >= 3:
        for i in [0, 1, 2]:
            output.append({'id':i, 'cq':final[i]})
        return output
    else:
        logger.warning('Missing CQs')
        return 'Missing CQs'

def generate_answer(pipe, terminator, input_text):
    messages = [
        {"role": "system", "content": "You are a helpful assistant to rephrase given text. But you always keep in mind that the meaning must not change."},
        {"role": "user", "content": input_text},
    ]

    prompt = pipe.tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True
    )
    outputs = pipe(
        prompt,
        max_new_tokens=256,
        eos_token_id=terminator,
        do_sample=True,
        temperature=0.6,
        top_p=0.9,
    )

    return outputs[0]["generated_text"][len(prompt):]

  if re.match('.*\?(\")?( )?(\([a-zA-Z0-9\.\'\-,\? ]*\))?([a-zA-Z \.,\"\']*)?(\")?$', cq):
  new_cqs = re.split("\?\"", text+'end')


In [6]:
prefixes = ["""Suggest 3 critical questions that should be raised before accepting the arguments in this text:

                "{intervention}"

                Give one question per line. Make the questions simple, and do not give any explanation regarding why the question is relevant."""]

model_path = model_path_llama

pipeline = transformers.pipeline(
    "text-generation",
    model=model_path,
    model_kwargs={"torch_dtype": torch.bfloat16},
    device=device,
)

terminators = [
    pipeline.tokenizer.eos_token_id,
    pipeline.tokenizer.convert_tokens_to_ids("<|eot_id|>")
]

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Device set to use mps


In [7]:
with open(sample_path) as f:
    data=json.load(f)

for key,line in tqdm.tqdm(data.items()):
    for prefix in prefixes:
        intervention = line['intervention']
        instruction = prefix.format(**{'intervention':intervention})
        cqs = generate_answer(pipeline, terminators, instruction)
        print("Input text: ", instruction)
        print("Critical Questions: ", cqs)
        print("-----------")


  0%|          | 0/5 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.
 20%|██        | 1/5 [00:08<00:32,  8.19s/it]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


Input text:  Suggest 3 critical questions that should be raised before accepting the arguments in this text:

                "CLINTON: "The central question in this election is really what kind of country we want to be and what kind of future we 'll build together
Today is my granddaughter 's second birthday
I think about this a lot
we have to build an economy that works for everyone , not just those at the top
we need new jobs , good jobs , with rising incomes
I want us to invest in you
I want us to invest in your future
jobs in infrastructure , in advanced manufacturing , innovation and technology , clean , renewable energy , and small business
most of the new jobs will come from small business
We also have to make the economy fairer
That starts with raising the national minimum wage and also guarantee , finally , equal pay for women 's work
I also want to see more companies do profit-sharing""

                Give one question per line. Make the questions simple, and do not give a

 40%|████      | 2/5 [00:14<00:21,  7.04s/it]Setting `pad_token_id` to `eos_token_id`:128009 for open-end generation.


Input text:  Suggest 3 critical questions that should be raised before accepting the arguments in this text:

                "Javier: "I have no problem requiring the airlines to give notification of any "known" delays within 30 minutes of when they become aware of it.
BUT, being a frequent traveler, I see all sorts of problems which are unavoidable and for which the airlines will be blamed by giving such notice with the intend that some flyers may be able to delay thier trip to the airport or even the departure gate.
I know that many times the airlines can't also know exactly when a weather hold or a maintenance issue will be rectified and that the flight is then ready to go.
Many time it can be surprisingly faster than expected.
The problem is that some flyers may then  wait before going to the airport,
but then find that the problem was rectified sooner than expected and the flight departed.
Of course the flyer and the flyers rights organization will then crucify the airlines for s

 40%|████      | 2/5 [00:22<00:34, 11.47s/it]
  if re.match('.*\?(\")?( )?(\([a-zA-Z0-9\.\'\-,\? ]*\))?([a-zA-Z \.,\"\']*)?(\")?$', cq):
  new_cqs = re.split("\?\"", text+'end')


KeyboardInterrupt: 

In [None]:
'''
new_entries = []

for _, row in df_filtered.iterrows():
    input_text = system_prompt + row['modified_premises']
    rephrased_premise = generate_answer(pipeline, terminators, input_text)

    new_entry = {
        "modified_premises": rephrased_premise,
        "read_premises" : row['read_premises'],
        "scheme": row['scheme'],
        "modified_cqs": row['modified_cqs']
    }

    new_entries.append(new_entry)

# Combine old and new entries
full_data = df_filtered.to_dict(orient='records') + new_entries

# Save to new file
with open(augmented_data_path, "w") as f:
    json.dump(full_data, f, indent=2)

print(f"Extended dataset saved to {augmented_data_path}")'''