# This is where I test different llama2 models, prompts and inputs to assess computation time/accuracy before building the full pipeline for the project.
Fine tuning: https://medium.com/@ogbanugot/notes-on-fine-tuning-llama-2-using-qlora-a-detailed-breakdown-370be42ccca1

## Questions/Thoughts
1. Combine children and neonatal or analyze separately?
2. Start with small model then use big?
3. When/if to move to Azure?
4. Prompting - how many classes to allow as potential outputs
5. Validation - InSilicoVA, openVA, etc.
6. PPI correction

In [None]:
import json
import time
import pathlib
import pandas as pd
import numpy as np
import os
import torch
from tqdm import tqdm


from datasets import load_dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    pipeline,
    logging,
)
from peft import LoraConfig, PeftModel
from trl import SFTTrainer
from IPython.display import display, HTML
from llama_cpp import Llama

import warnings
warnings.filterwarnings('ignore')
warnings.simplefilter('ignore')

## Read in data

In [None]:
df = pd.read_csv('../../data/phmrc/phmrc_adult_tokenized.csv')

In [None]:
regions = list(df['site'].unique())

In [None]:
# train
# Read in CSV files and store in dictionary
train_excluded_dict = {}
for region in regions:
    file_path = f'../../data/train_test_val/train_ex_{region.lower()}.csv'
    train_excluded_dict[region] = pd.read_csv(file_path)
    
# assign training data df names
train_ex_ap = train_excluded_dict['ap']
train_ex_dar = train_excluded_dict['dar']
train_ex_pemba = train_excluded_dict['pemba']
train_ex_mexico = train_excluded_dict['mexico']
train_ex_bohol = train_excluded_dict['bohol']
train_ex_up = train_excluded_dict['up']

In [None]:
# test / val

# Dictionary to store DataFrames
test_dict = {}
val_dict = {}

# Read in test and validation CSV files and store in dictionaries
for region in regions:
    test_file_path = f'../../data/train_test_val/test_{region}.csv'
    val_file_path = f'../../data/train_test_val/val_{region}.csv'
    
    test_dict[region] = pd.read_csv(test_file_path)
    val_dict[region] = pd.read_csv(val_file_path)

In [None]:
# assign test and val data df names
test_ap = test_dict['ap']
test_dar = test_dict['dar']
test_pemba = test_dict['pemba']
test_mexico = test_dict['mexico']
test_bohol = test_dict['bohol']
test_up = test_dict['up']

val_ap = val_dict['ap']
val_dar = val_dict['dar']
val_pemba = val_dict['pemba']
val_mexico = val_dict['mexico']
val_bohol = val_dict['bohol']
val_up = val_dict['up']

In [None]:
# list of dfs
training_dfs = [
    train_ex_ap,
    train_ex_dar, 
    train_ex_pemba, 
    train_ex_mexico, 
    train_ex_bohol, 
    train_ex_up]

# combine labeled and unlabeled testing data
test_ap = pd.concat([test_ap, val_ap])
test_dar = pd.concat([test_dar, val_dar])
test_pemba = pd.concat([test_pemba, val_pemba])
test_mexico = pd.concat([test_mexico, val_mexico])
test_bohol = pd.concat([test_bohol, val_bohol])
test_up = pd.concat([test_up, val_up])
    
testing_dfs = [
    test_ap,
    test_dar,
    test_pemba,
    test_mexico,
    test_bohol,
    test_up
]

## Load different models

In [None]:
small = Llama(
    model_path="../../models/llama-2-7b-chat.Q2_K.gguf",
    n_ctx=2048)
# medium = Llama(
#     model_path="../models/llama-2-7b-chat.Q4_K_M.gguf",
#     n_ctx=2048)
big = Llama(
    model_path="../../models/llama-2-7b-chat.Q8_0.gguf",
    n_ctx=2048)

## Query function for making prompt calls to model

In [None]:
# def query(model, question):
#     model_name = pathlib.Path(model.model_path).name
#     time_start = time.time()
#     prompt = f"Q: {question} A:"
#     output = model(prompt=prompt, max_tokens=0) # if max tokens is zero, depends on n_ctx
#     response = output["choices"][0]["text"]
#     time_elapsed = time.time() - time_start
#     display(HTML(f'<code>{model_name} response time: {time_elapsed:.02f} sec</code>'))
#     display(HTML(f'<strong>Question:</strong> {question}'))
#     display(HTML(f'<strong>Answer:</strong> {response}'))
#     print(json.dumps(output, indent=2))

In [None]:
def query_tostring(model, question):
    model_name = pathlib.Path(model.model_path).name
    time_start = time.time()
    prompt = question
    output = model(prompt=prompt, max_tokens=0) # if max tokens is zero, depends on n_ctx
    response = output["choices"][0]["text"]
    time_elapsed = time.time() - time_start
    print(time_elapsed)
    return response

## Create Prompts

In [None]:
label_to_score = {
    'aids-tb': 0,
    'communicable': 1,
    'external': 2,
    'maternal': 3, 
    'non-communicable': 4
}

score_to_label = {
    0: 'aids-tb',
    1: 'communicable',
    2: 'external',
    3: 'maternal',
    4: 'non-communicable' 
}

In [None]:
def inspect_narrative(row):
    print('Narrative: ' + df['narrative'][row])
    print('True Label: ' + df['gs_text34'][row])
    print('Broad Category: ' + df['gs_cod'][row])
    print('Embedding Representation: ' + str(label_to_score[df['gs_cod'][row]]))

In [None]:
inspect_narrative(4)

In [None]:
def create_prompt(narrative):
    '''
    takes in narrative string and returns full prompt as string
    '''
    
    result = f"""
    <narrative>
    {narrative}
    </narrative>

    <labels>
    aids-tb: Patient died resulting from HIV-AIDs or Tuberculosis.
    communicable: Patient died from a communicable disease which is defined as 
    illnesses that spread from one human to another such as pneumonia, diarrhea 
    or dysentery.
    external: Patient died from external causes including as accidents like fires,
    drowning, road traffic, falls, poisonous animals and violence like suicide, 
    homicide, or other injuries.
    maternal: Patient died from complications related to pregnancy or childbirth 
    including from severe bleeding, sepsis, pre-eclampsia and eclampsia.
    non-communicable: Patient died from a non-communicable disease which is defined
    as illnesses that cannot be transmitted from one human to another such as cirrhosis,
    epilepsy, acute myocardial infarction, copd, renal failure, cancer, diabetes,
    stroke, malaria, asthma, or other non-communicable diseases.
    </labels>

    <options>
    aids-tb, communicable, external, maternal, non-communicable
    </options>


    Which label best applies applies to the narrative (aids-tb, communicable, external, maternal, non-communicable)?
    Limit your response to one of the options exactly as it appears in the list.
    """
    return result

In [None]:
inspect_narrative(4)

In [None]:
create_prompt(df['narrative'][4])

In [None]:
query_tostring(small, create_prompt(df['narrative'][4]))

In [None]:
query_tostring(big, create_prompt(df['narrative'][4]))

In [None]:
predictions_llama = []
for text in tqdm(df['narrative'][:5]):
    predictions_llama.append(query_tostring(big, create_prompt(df['narrative'][4])))

In [None]:
predictions_llama

## fuzzy match to extract labels

In [None]:
def find_exact_match(dictionary, long_strings):
    result_list = []

    for long_string in long_strings:
        # Extract the first 30 characters from the string
        short_string = long_string[:30]

        # Check if any dictionary string exists in the input string
        matching_keys = [key for key, value in dictionary.items() if value in short_string]

        # Check if any matches were found
        if matching_keys:
            result_list.extend(matching_keys)
        else:
            # No match found, return 3
            result_list.append(3)

    return result_list


In [None]:
find_exact_match(cod_dict, 'aids')

In [None]:
pd.Series(predictions_llama_up).to_csv('text_predictions_llama2_up.csv', index=False)

In [None]:
def fuzzymatch(text):
    '''
    takes in text that needs to be matched
    returns constrained label from dict
    '''
    
    def get_first_30_characters(input_string):
        return input_string[:30]
    
    first = get_first_30_characters(text)
    
    def fuzzy_match_and_get_value(input_string, dictionary):
        # Get the best match and its score
        match, score = process.extractOne(input_string, dictionary.keys())

        # You can adjust the threshold for the fuzzy matching score
        # For example, consider matches with a score of at least 80
        if score >= 80:
            return dictionary[match]
        else:
            return None  # No satisfactory match found
    
    
    
    

In [None]:
labeled_data = {'Sepsis': 'According to respondent child had severe pain in back from last 15 days which was unbearable. Doctor told that may be child got tumor…child received treatment for few days and also got relief for few days but again child had the same condition.Then child was taken to Lucknow where after all treatment nothing could be diagnosed but the pain was increasing day by day. Child received treatment and got some relief but suddenly died.',
                'Fires': 'When my son was playing with a kite, its thread was caught up on an electric pole. He climbed the electric pole to take the kite but he got the electric shock. Then immediately we took him to the Siddipet hospital. They told us to take him to the Gandhi hospital. Then we went to the Gandhi hospital. As he was under the treatment there, he died.',
                'Road Traffic': 'My nice was studying in a hostel at Mulugu. One day she was met an accident with a car when she was crossing the road at the school. We have admitted in the Gandhi hospital, she died while the treatment was going on.'}

In [None]:
common_causes = ['pneumonia',
 'diarrhea',
 'malaria',
 'road traffic',
 'drowning',
 'cardiovascular disease',
 'fires',
 'meningitis',
 'venomous animal',
 'falls',
 'encephalitis',
 'sepsis',
 'measles',
 'aids',
 'tuberculosis']

In [None]:
prompt1 = 'Given a text narrative about a death, attribute the most likely cause of death. Respond only with the cause of death, or "other" if you are not sure. Narrative: '

In [None]:
prompt1

In [None]:
# include explicit list of output classes
prompt2 = 'Given a text narrative about a death, attribute the most likely cause of death from this list: ' + ', '.join(common_causes) + '. Respond only with the cause of death from the list, or "other" if you are not sure. Narrative: '

In [None]:
prompt2

In [None]:
# include explicit list of output classes and do not offer 'other' as option
prompt3 = 'Given a text narrative about a death, attribute the most likely cause of death from this list: ' + ', '.join(common_causes) + '. Respond only with the one word cause of death from the list. Narrative: '

In [None]:
prompt3

In [None]:
# make it explicit that the response must come from the given list. 
prompt4 = 'Given a text narrative about a death, attribute the most likely cause of death from this list: ' + ', '.join(common_causes) + '. Your response must match exactly to one of the options from this list. Narrative: '

In [None]:
prompt4

## Few shot fine tuning

## Run models and prompt1

In [None]:
query(small, prompt1 + labeled_data['Road Traffic'])

In [None]:
query(medium, prompt1 + labeled_data['Road Traffic'])

In [None]:
query(big, prompt1 + labeled_data['Road Traffic'])

In [None]:
query(small, prompt1 + labeled_data['Sepsis'])

In [None]:
query(small, prompt1 + labeled_data['Fires'])

## Run models and prompt 2

In [None]:
query(small, prompt2 + labeled_data['Road Traffic'])

In [None]:
query(medium, prompt2 + labeled_data['Road Traffic'])

In [None]:
query(big, prompt2 + labeled_data['Road Traffic'])

In [None]:
query(small, prompt2 + labeled_data['Sepsis'])

In [None]:
query(small, prompt2 + labeled_data['Fires'])

## Run models and prompt 3

In [None]:
query(small, prompt3 + labeled_data['Road Traffic'])

In [None]:
query(medium, prompt3 + labeled_data['Road Traffic'])

In [None]:
query(big, prompt3 + labeled_data['Road Traffic'])

In [None]:
query(small, prompt3 + labeled_data['Sepsis'])

In [None]:
query(small, prompt3 + labeled_data['Fires'])

In [None]:
query(big, prompt3 + labeled_data['Sepsis'])

## Run models and prompt 4

In [None]:
query(small, prompt4 + labeled_data['Road Traffic'])

In [None]:
query(medium, prompt4 + labeled_data['Road Traffic'])

In [None]:
query(big, prompt4 + labeled_data['Road Traffic'])

In [None]:
query(small, prompt4 + labeled_data['Sepsis'])

In [None]:
query(small, prompt4 + labeled_data['Fires'])

In [None]:
query(big, prompt4 + labeled_data['Sepsis'])

### Small model prompt - need to return only one or two words.

In [None]:
query(small, prompt4 + labeled_data['Sepsis'] + ' LIMIT YOUR RESPONSE TO ONE OR TWO WORDS ONLY.')

In [None]:
query(small, prompt4 + labeled_data['Fires'] + ' LIMIT YOUR RESPONSE TO ONE OR TWO WORDS ONLY.')

In [None]:
query(small, prompt4 + labeled_data['Road Traffic'] + ' LIMIT YOUR RESPONSE TO ONE OR TWO WORDS ONLY.')

### Big model prompting

In [None]:
query(big, prompt4 + labeled_data['Road Traffic'] + ' LIMIT YOUR RESPONSE TO ONE OR TWO WORDS ONLY, EXACTLY AS THEY APPEAR IN THE LIST.')

In [None]:
query(big, prompt4 + labeled_data['Sepsis'] + ' LIMIT YOUR RESPONSE TO ONE OR TWO WORDS ONLY, EXACTLY AS THEY APPEAR IN THE LIST.')

In [None]:
query(big, prompt4 + labeled_data['Fires'] + ' LIMIT YOUR RESPONSE TO ONE OR TWO WORDS ONLY, EXACTLY AS THEY APPEAR IN THE LIST.')

### Update query function to instead return only the text response instead of printint Q, A, and token info.