In [3]:
# Perform my imports
import torch
from transformers import AutoTokenizer, BitsAndBytesConfig, AutoModelForCausalLM
import heapq
import json
import time
import random
import os
import re
import math
import pandas as pd
import sys
import numpy as np
from typing import Optional, List
from transformers import AutoTokenizer, LlamaForCausalLM, BitsAndBytesConfig, TrainingArguments
import json, csv
from sklearn.model_selection import train_test_split
from trl import SFTTrainer
from datasets import Dataset
from peft import LoraConfig

# Add the certs information
sys.path.append("..")
from CeRTS_beam_multi import *
from CeRTS_utils import *
sys.path.remove("..")

# Add some fine-tuning specific functions
from functions.generation_functions import *
from functions.dataset_creation_functions import *

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialization

In [4]:
# Specify models that we will test for (SFT models)
model_eval_list = ['models/llama-3.2-3b-instruct-SFT-val_test_10_15-merged']

# Specify base information needed for some model functions
variables = ['asthma', 'smoking', 'pneu', 'common_cold', 'pain', 'fever', 'antibiotics']
json_template = {
        "asthma": "yes|no",
        "smoking": "yes|no",
        "pneu": "yes|no",            
        "common_cold": "yes|no",
        "pain": "yes|no",
        "fever": "high|low|no",
        "antibiotics": "yes|no"}
features = json_template.keys()

# Get Test Val Data

In [5]:
# Load in model
df_val = pd.read_csv("clinical_data/validation.csv")

# Split into training and testing for temporary stuff 
train_df, test_df = train_test_split(df_val, test_size=0.2, random_state=42)
train_df.reset_index(drop=True, inplace = True)
test_df.reset_index(drop=True, inplace = True)

test_df = test_df[:20]

# Perform Actual Test Evaluation with CeRTS

In [7]:
# Iterate Through each model
for model_path in model_eval_list:
    
    # Load tokenizer and model 
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    
    # Specify Output location and see if it already exists 
    OUT_PATH = f'results_data/CeRTS_SimSUM_20_test_val_{model_path.split("/")[1]}.csv'
    
    # check if dataframe exists and create if not
    if os.path.exists(OUT_PATH):
        out_df = pd.read_csv(OUT_PATH)
    else:
        out_columns = ["compact_note", "asthma", "smoking", "pneu", "common_cold", "pain", "fever", "antibiotics", "LLM_asthma", "asthma_conf", "LLM_smoking", "smoking_conf", "LLM_pneu", "pneu_conf", "LLM_common_cold", "common_cold_conf", "LLM_pain", "pain_conf", "LLM_fever", "fever_conf", "LLM_antibiotics", "antibiotics_conf"]
        out_df = pd.DataFrame(columns = out_columns)
        out_df.to_csv(OUT_PATH)
    
    
    # Time Model 
    t1 = time.time()
    x = 0
    print('using file', OUT_PATH)
    
    # Open and save rows
    with open(OUT_PATH, "a", newline="", encoding="utf-8") as f:
        writer = csv.writer(f)
        for _, row in test_df.iterrows():
            #print(row)
            #print(row['text'])
            #print('advanced')
            #print(row['advanced_text'])

            x+=1
            if x % 3 == 0:
                print(x, round((time.time()-t1)/60, 2), 'mins')

            
            compact_note,asthma,smoking,pneu,common_cold,pain,fever,antibiotics = row[['advanced_text','asthma','smoking','pneu','common_cold','pain','fever','antibiotics']]

            messages = gen_prompt(compact_note, features, json_template)
            #print(messages)
            answer_distributions = CeRTS_output_dist(messages, features, model, tokenizer, device, beam_width=5, max_steps=100)
            d_feature_conf = {}
            row_extract_conf = []
            for feature, answer_dist in zip(features, answer_distributions):
                # print(feature, answer_dist)
                response = answer_dist[0][0]
                # print(answer_dist[0])
                confidence = top_2_delta([answer_score_tuple[1] for answer_score_tuple in answer_dist])
                print('Feature:', feature, 'Response:', response, 'Confidence:', confidence)
                d_feature_conf[feature] = {'top_answer':response, 'confidence':confidence, 'dist':answer_dist}
                row_extract_conf.append(response)
                row_extract_conf.append(confidence)
                # print("-------------------")
            print("----------------------")
            combined_row = [x] + [compact_note,asthma,smoking,pneu,common_cold,pain,fever,antibiotics] + row_extract_conf

            writer.writerow(combined_row)

Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

using file results_data/CeRTS_SimSUM_20_test_val_llama-3.2-3b-instruct-SFT-val_test_10_15-merged.csv
Feature: asthma Response: no Confidence: 0.9499186162960178
Feature: smoking Response: no Confidence: 0.17943028487493867
Feature: pneu Response: no Confidence: 0.820056607160318
Feature: common_cold Response: no Confidence: 0.35100935368582575
Feature: pain Response: no Confidence: 0.8190578968786132
Feature: fever Response: no Confidence: 0.693214919588693
Feature: antibiotics Response: no Confidence: 0.22623851967778685
----------------------
Feature: asthma Response: no Confidence: 0.9516556080930427
Feature: smoking Response: no Confidence: 0.7675254541782468
Feature: pneu Response: no Confidence: 0.7934590399754273
Feature: common_cold Response: no Confidence: 0.8096149388717437
Feature: pain Response: no Confidence: 0.6057759298623747
Feature: fever Response: no Confidence: 0.8305814542210684
Feature: antibiotics Response: no Confidence: 0.7034691055931392
----------------------
