In [None]:
import yaml
from pprint import pprint
import pandas as pd
from sklearn.model_selection import train_test_split
from med_assist.config import set_project_wd

set_project_wd()

In [None]:
# prepare data

df_raw = pd.read_csv("resources/med_assist_training_dataset_0301-1858.csv")

inputs = df_raw['input_prompt']
outputs = df_raw['output_output']

contexts = inputs.str.extract(r"((?<=Context: \[\[).*(?=\]\]))")[0]
questions = inputs.str.extract(r"((?<=Question: ).*(?=\n))")[0]

In [None]:
# disclaimers to be removed or replaced

DISCLAIMERS = {
    r"I hope this information (helps|is helpful)\b.{0,120}$": "",
    r"Please note that .* based .* context\b.{0,200}$": "",
    r"(Additionally|Therefore|However|It is\b.*\b(recommended|important|best|essential)|As with any|While|If you\b.*\bconcerns|Remember)\b.*\bconsult\b.*\b(healthcare|medical) professional\b.{0,200}$": "",
    r"(?<=\b[Bb]ased )on the information provided( in the [a-z]+\b)?,?": "on the study,",
    r"(?<=\b[Bb]ased )on the( provided)? context( provided)?,?": "on the study,",
    r"\n.*(harmful|unethical|illegal)\b.*\bcontent\b.{0,200}": "",
    r"\n.*[Bb]ased\b.*context\b.{0,200}": "",
    r"\n.*constitute medical advice\b.{0,200}": "",
    r"\n.*no prior knowledge\b.{0,200}": "",
    r"\n.*no additional information\b.{0,200}": "",
    r"\n.*(socially unbiased|positive in nature)\b.{0,200}": "",
    r"\n.*\brelevant\b.*\bquestion\b.{0,200}": "",
    r"\n.*\bincludes.*\bdetails\b.*\bcontext\b.{0,200}": "",
    r"\bcontext\b": "study",
    r"\bno information available\b.{0,3}$": "",
    r"\bNote\b.{0,3}$": "",
    r"\b[Ss]ource: [Cc]ontext\b": ""
}

i=-1

distinct_disclaimers = outputs[outputs.str.contains(pat=list(DISCLAIMERS.keys())[i], regex=True)].str.extract(f"({list(DISCLAIMERS.keys())[i]})").drop_duplicates()

for i in range(len(distinct_disclaimers[0])):
    pprint(distinct_disclaimers[0].iloc[i])

In [None]:
for disclaimer, fillin in DISCLAIMERS.items():
    outputs = outputs.str.replace(disclaimer, fillin, regex=True)

In [None]:
# output ids to be replaced (as a whole) with a "no information available" answer (list prepared after checking suspicious phrases)

NO_INFO_ANSWER = "No information available on this topic. Please try to rephrase the question."

with open("resources/no_info_outputs.yaml", "r") as file:
    no_info_outputs = yaml.safe_load(file).get('no_info_outputs')

outputs.loc[no_info_outputs] = NO_INFO_ANSWER

In [None]:
# suspiocious answers to check

TO_CHECK_PHRASES = [
    r"^.{0,100}\bno\b.{0,20}\binformation\b.{0,20}\bavailable\b",
    r"\bno\b.*\binformation\b.*\bavailable\b",
    r"source: context"
    ]

i=0

to_check_idx = outputs. \
    str.replace(r"\n", " ", regex=True). \
    str.contains(TO_CHECK_PHRASES[i], regex=True, case=False). \
    loc[lambda s: s == True]. \
    loc[lambda s: ~s.index.isin(no_info_outputs)]. \
    index

for idx in to_check_idx:
    pprint(idx)
    pprint(questions[idx])
    pprint(outputs[idx])

In [None]:
# standardise list indicators

has_bullet = \
    (outputs.str.count(r"[0-9]\. ") < 3) & \
    (outputs.str.count(r"[\*] ") >= 1)

has_bullet_updated = has_bullet
n_has_bullet = sum(has_bullet)
num_replace = 0

while n_has_bullet:
    num_replace += 1 
    str_replace = f"{num_replace}. "

    outputs.loc[has_bullet_updated] = \
        outputs.loc[has_bullet_updated]. \
            str.replace(r"\* ", str_replace, n=1, regex=True)

    has_bullet_updated = has_bullet & (outputs.str.count(r"[\*] ") >= 1)
    n_has_bullet = sum(has_bullet_updated)

print(f"Done in {num_replace} iterations")

In [None]:
# cleanup

outputs = outputs.str.replace(r"[\n ]*$", "", regex=True)
outputs = outputs.str.replace(r"[^A-Za-z]+$", ".", regex=True)

In [None]:
# adjusting prompts

inputs = inputs.str.replace(
    pat  = "Always return a concise list of facts regarding the question based on the provided context",
    repl = "Always return a concise numbered list of facts regarding the question based on the provided context"
)

In [None]:
# browsing outputs

i_iter = iter(range(len(outputs)))

In [None]:
i = next(i_iter)

pprint(i)
pprint(questions[i])
pprint(outputs[i])

In [None]:
# Flags for stratified sampling

has_no_info = outputs.index. \
    to_series(). \
    apply(lambda idx: idx in no_info_outputs). \
    rename("has_no_info")

is_changed = (outputs != df_raw['output_output']).rename("is_changed")

stratas = pd.concat([has_no_info, is_changed], axis=1)

In [None]:
# Split to train and test dataset

train_size = 2000

df_output_sft = pd.DataFrame(
    zip(inputs, outputs), 
    columns=["prompt", "output"])

df_train_sft, df_test_sft = train_test_split(
    df_output_sft, 
    train_size=train_size, 
    test_size=len(questions)-train_size, 
    random_state=420, 
    shuffle=True, 
    stratify=stratas
    )

In [None]:
# write human reviewed outputs to file for SFT

df_train_sft.to_csv("resources/training_dataset_sft.csv", index=False)
df_test_sft.to_csv("resources/validation_dataset_sft.csv", index=False)

In [None]:
# reshape for DPO

df_output_dpo = pd.DataFrame(
    zip(inputs, outputs, df_raw['output_output']), 
    columns=["prompt", "chosen", "rejected"]
    )[is_changed]

In [None]:
df_output_dpo.to_csv("resources/training_dataset_dpo.csv")