In [1]:
from src import llm, utils, config

from torch.utils.data import DataLoader

prompt_template = config.standardized_prompt_template 

In [2]:
prompt_template = prompt_template.replace("News Title:\n{TITLE}","")

In [3]:
prompt_template

'Given the title and content of a healthcare news article, analyze whether the claims align with plausible scenarios and whether the article maintains internal consistency. Check for misleading or unclear statements, and conclude whether the news is real or fake.\n\nNews Content:\n{NEWS}\nConclusion: '

In [4]:
import pandas as pd


def map_labels(text):
    if text == 0:
        return 'fake'
    return 'real'

data = pd.read_csv("CoAID.csv")[['text_clean',  'class']]
data.rename(columns={'class': 'claim', 'text_clean':'news'}, inplace=True)

data['claim'] =data['claim'].apply(map_labels)
data['title'] = ["  " for _ in range(data.shape[0])]
data['claim'].value_counts()

claim
real    1643
fake     177
Name: count, dtype: int64

In [7]:
max_new_tokens = 10

metadata = [
    # ["assets/sft-qwen-fakehealth", llm.load_qwen_llm, 'FakeHealth', "Qwen", 4],
    ["assets/sft-qwen-recovery", llm.load_qwen_llm, 'ReCOVery', "Qwen", 4],
    
    # ["assets/sft-llama-fakehealth", llm.load_llama_llm, 'FakeHealth', "Llama3", 4], 
    ["assets/sft-llama-recovery", llm.load_llama_llm, 'ReCOVery', "Llama3", 2], 

    # ["assets/sft-falcon-fakehealth", llm.load_falcon_llm, 'FakeHealth', "Falcon", 2],
    ["assets/sft-falcon-recovery", llm.load_falcon_llm, 'ReCOVery', "Falcon", 2],
    
    # ["assets/sft-phi-fakehealth", llm.load_phi_llm, 'FakeHealth', "Phi", 1],
    ["assets/sft-phi-recovery", llm.load_phi_llm, 'ReCOVery', "Phi", 1],
    
]
for model_path, load_callback, dataset_name, model_name, per_device_train_batch_size in metadata:
    for model_loss in ['sigmoid_bco']:
        output_dir = f"assets/rlhf-{model_name.lower()}-{dataset_name.lower()}-cpo-{model_loss.replace('_','-')}"
        tokenizer, model, llm_path = load_callback(llm_path=output_dir)
        
        def standard_prompting(test, tokenizer, model, prompt_template, batch_size=1, max_new_tokens=10):
            test_data = llm.MISSINFODataset(tokenizer=tokenizer, df=test, prompt_template=prompt_template)
            test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
            predicts, labels = llm.make_the_generation(model=model, tokenizer=tokenizer, max_length=512,
                                                    data_loader=test_dataloader, max_new_tokens=max_new_tokens)
            processed_predicts = utils.output_processor(predicts)
            clf_report = utils.evaluation_report(y_true=labels, y_pred=processed_predicts)
            return predicts, processed_predicts, clf_report

        predicts, processed_predicts, clf_report = standard_prompting(test=data, 
                                                                     tokenizer=tokenizer, 
                                                                     model=model, 
                                                                     prompt_template=prompt_template)
        report_dict = {
            "llm": llm_path,
            "clf_report": clf_report,
            "generations": predicts,
            "processed_generations": processed_predicts,
        }

        utils.write_json(data=report_dict, path=f"results/{model_name}-CoAID-rlhf-cpo-{model_loss.replace('_','-')}.json")
        
        del model
        del tokenizer

In [8]:
max_new_tokens = 10

metadata = [
    ["assets/sft-qwen-fakehealth", llm.load_qwen_llm, 'FakeHealth', "Qwen", 4],
    # ["assets/sft-qwen-recovery", llm.load_qwen_llm, 'ReCOVery', "Qwen", 4],
    
    ["assets/sft-llama-fakehealth", llm.load_llama_llm, 'FakeHealth', "Llama3", 4], 
    # ["assets/sft-llama-recovery", llm.load_llama_llm, 'ReCOVery', "Llama3", 2], 

    ["assets/sft-falcon-fakehealth", llm.load_falcon_llm, 'FakeHealth', "Falcon", 2],
    # ["assets/sft-falcon-recovery", llm.load_falcon_llm, 'ReCOVery', "Falcon", 2],
    
    ["assets/sft-phi-fakehealth", llm.load_phi_llm, 'FakeHealth', "Phi", 1],
    # ["assets/sft-phi-recovery", llm.load_phi_llm, 'ReCOVery', "Phi", 1],
    
]
for model_path, load_callback, dataset_name, model_name, per_device_train_batch_size in metadata:
    for model_loss in ['sigmoid_bco']:
        output_dir = f"assets/rlhf-{model_name.lower()}-{dataset_name.lower()}-cpo-{model_loss.replace('_','-')}"
        tokenizer, model, llm_path = load_callback(llm_path=output_dir)
        
        def standard_prompting(test, tokenizer, model, prompt_template, batch_size=1, max_new_tokens=10):
            test_data = llm.MISSINFODataset(tokenizer=tokenizer, df=test, prompt_template=prompt_template)
            test_dataloader = DataLoader(test_data, batch_size=batch_size, shuffle=False)
            predicts, labels = llm.make_the_generation(model=model, tokenizer=tokenizer, max_length=512,
                                                    data_loader=test_dataloader, max_new_tokens=max_new_tokens)
            processed_predicts = utils.output_processor(predicts)
            clf_report = utils.evaluation_report(y_true=labels, y_pred=processed_predicts)
            return predicts, processed_predicts, clf_report

        predicts, processed_predicts, clf_report = standard_prompting(test=data, 
                                                                     tokenizer=tokenizer, 
                                                                     model=model, 
                                                                     prompt_template=prompt_template)
        report_dict = {
            "llm": llm_path,
            "clf_report": clf_report,
            "generations": predicts,
            "processed_generations": processed_predicts,
        }

        utils.write_json(data=report_dict, path=f"results/{model_name}-FH-CoAID-rlhf-cpo-{model_loss.replace('_','-')}.json")
        
        del model
        del tokenizer