# 1. Requirements

In [None]:
import json
import os
import pickle
import random
from collections import defaultdict
from textwrap import fill
from time import sleep
from typing import List, Dict, Any, Tuple

import nltk
import numpy as np
import openai
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from nltk.corpus import stopwords
from sentence_transformers import SentenceTransformer, util
from sklearn.metrics import confusion_matrix
from tqdm import tqdm

%matplotlib inline

In [None]:
nltk.download('stopwords')
english_stopwords = set(stopwords.words('english'))

In [None]:
# provide your OpenAI API key and organization
openai.organization = os.environ['OPENAI_ORGANIZATION']
openai.api_key = os.environ['OPENAI_KEY']

In [None]:
[model['root'] for model in openai.Model.list()['data'] if 'gpt' in model['root']]

In [None]:
model_name_user_simulator = 'gpt-3.5-turbo-0613'
model_name_evaluation = 'gpt-3.5-turbo-0613'
model_name_tests = ['gpt-3.5-turbo-0613', 'gpt-4-0613']
temperature = 1
n_generations_evaluation = 10

# 2. Data Preparation

## 2.1. Disease List

In [None]:
# The following list, sourced from https://www.nhsinform.scot/illnesses-and-conditions/a-to-z, has been 
# curated to include diseases that users are most likely to inquire about. It purposefully omits conditions 
# such as "cough" and "rare cancers" to maintain relevance and focus.

diseases_file = 'diseases.json'

with open(diseases_file) as fp:
    diseases = json.load(fp)
    diseases = [d.lower() for d in diseases]
    sorted(diseases)

print(len(diseases), 'diseases')
print(fill(', '.join(diseases)))

## 2.2. Layperson Descriptions of Symptoms

### Generation

In [None]:
generated_summaries_file = f"generated_summaries_{model_name_user_simulator}.json"
print(generated_summaries_file)

In [None]:
# The following generates a number of times, for each disease, sentence about how a layperson would describe their symptoms.
use_cache = True

if not use_cache:
    layperson_summaries = {}
    for disease in tqdm(diseases):
        prompt = f'A layperson with {disease} would describe their symptoms in 1 sentence as by saying: "'
        result = openai.ChatCompletion.create(
            model=model_name_user_simulator,
            messages=[
                {"role": "user", "content": prompt}
            ],
            temperature=temperature,
            n=n_generations_evaluation
        )
        summaries = [choice['message']['content'] for choice in result.choices]
        summaries = [s.strip('"') for s in summaries]
        layperson_summaries[disease] = summaries

        with open(generated_summaries_file, encoding='utf-8', mode='w') as fp:
            json.dump(layperson_summaries, fp, indent=4)
else:
    with open(generated_summaries_file, encoding='utf-8', mode='r') as fp:
        layperson_summaries = json.load(fp)

print('number of diseases', len(layperson_summaries))
print('number of summaries', sum([len(l) for k, l in layperson_summaries.items()]))

### Filtering

Filter out all layperson summaries where the disease (or part of it) is explicitly mentioned.

In [None]:
filtered_summaries_file = f"filtered_summaries_{model_name_user_simulator}.json"
print(filtered_summaries_file)

In [None]:
mentions = 0

filtered_summaries = {}
for disease, summaries in layperson_summaries.items():
    filtered_summaries[disease] = []
    for summary in summaries:
        tokenized_disease = disease.lower().split(' ')
        tokenized_disease = [t for t in tokenized_disease if t not in english_stopwords]
        tokenized_summary = set(summary.lower().split(' '))
        if any([t in tokenized_summary for t in tokenized_disease]):
            mentions += 1
            continue
        filtered_summaries[disease].append(summary)

    if len(filtered_summaries[disease]) == 0:
        print(f'warning: {disease} was filtered because it has no summaries left')
        del filtered_summaries[disease]

with open(filtered_summaries_file, encoding='utf-8', mode='w') as fp:
    json.dump(filtered_summaries, fp, indent=4)

print('total number of mentions', mentions)
print('number of diseases', len(filtered_summaries))
print('number of summaries', sum([len(l) for k, l in filtered_summaries.items()]))
print('number of summaries per disease', sum([len(l) for k, l in filtered_summaries.items()]) / len(filtered_summaries))
plt.title('Frequency of #summaries per disease')
plt.hist([len(summaries) for summaries in filtered_summaries.values()])
plt.show()

## 2.3 Disease Similarity Calculation

In [None]:
similarities_file = f'similarities_{model_name_user_simulator}.pkl'
print(similarities_file)

In [None]:
use_cache = True

if not use_cache:
    model_similarity = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
    similarities = {}
    embeddings = {}
    # compute the cosine similarity of each summary's embedding against themselves 
    for disease_i in tqdm(filtered_summaries):
        similarities[disease_i] = {}
        for i, layman_summary_i in enumerate(filtered_summaries[disease_i]):
            similarities[disease_i][i] = {}
            for disease_j in filtered_summaries:
                similarities[disease_i][i][disease_j] = {}
                for j, layman_summary_j in enumerate(filtered_summaries[disease_j]):
                    if layman_summary_i in embeddings:
                        embedding_i = embeddings[layman_summary_i]
                    else:
                        embedding_i = model_similarity.encode(layman_summary_i, convert_to_tensor=True)
                        embeddings[layman_summary_i] = embedding_i
                    if layman_summary_j in embeddings:
                        embedding_j = embeddings[layman_summary_j]
                    else:
                        embedding_j = model_similarity.encode(layman_summary_j, convert_to_tensor=True)
                        embeddings[layman_summary_j] = embedding_j
                    similarities[disease_i][i][disease_j][j] = \
                        util.pytorch_cos_sim(embedding_i, embedding_j).numpy()[0][0]

    # aggregate each summary using the mean twice, then sort them in order of similarity
    for disease_i in similarities:
        for layman_summary_i in similarities[disease_i]:
            for disease_j in similarities[disease_i][layman_summary_i]:
                mean = np.mean([s for _, s in similarities[disease_i][layman_summary_i][disease_j].items()])
                similarities[disease_i][layman_summary_i][disease_j] = mean

    new_similarities = {}
    for disease_i in similarities:
        new_similarities[disease_i] = {}
        for layman_summary_i in similarities[disease_i]:
            for disease_j in similarities[disease_i][layman_summary_i]:
                if layman_summary_i == 0:
                    similarities[disease_i][0][disease_j] = [similarities[disease_i][layman_summary_i][disease_j]]
                else:
                    similarities[disease_i][0][disease_j].append(similarities[disease_i][layman_summary_i][disease_j])
        for disease_j in similarities[disease_i][0]:
            new_similarities[disease_i][disease_j] = np.mean(similarities[disease_i][0][disease_j])
        new_similarities[disease_i] = sorted([(d, s) for (d, s) in new_similarities[disease_i].items()],
                                             key=lambda x: -x[1])
    similarities = new_similarities
    with open(similarities_file, 'wb') as fp:
        pickle.dump(similarities, fp)
else:
    with open(similarities_file, 'rb') as fp:
        similarities = pickle.load(fp)

In [None]:
df = pd.DataFrame(
    {outer_key: {inner_key: value for inner_key, value in similarities[outer_key]} for outer_key in similarities})
plt.figure(figsize=(10, 8))
sns.heatmap(df, cmap='YlGnBu', yticklabels=False, xticklabels=False)
plt.show()

# 3. Hypothesis Testing

## 3.1. Response Generation

Select which model you want to test in the cell below

In [None]:
model_name_test = model_name_tests[0]  #model_name_tests[1] 
print(f'{model_name_test} is being tested')

In [None]:
def select_diseases(disease_similarities: Dict[str, Dict[str, float]], current_disease: str, n: int = 2, from_top_n: int = 5) -> List[str]:
    """Selects n random diseases from the top from_top_n similar diseases to the current disease."""
    selected = [d for d, score in disease_similarities[current_disease] if d != current_disease]
    selected = selected[:from_top_n]
    return random.sample(selected, k=n)


disease = 'flu'
for _ in range(5):
    assumed_diseases = select_diseases(similarities, disease)
    print(assumed_diseases)

In [None]:
def get_prompt_variations(summary: str, correct_disease: str, assumed_diseases: List[str]):
    """Returns a dictionary of prompt variations, where the key is the prompt type and the value is the prompt text.
    The prompt types are:
    - open_ended: the user (patient) asks the doctor what disease they have, by describing their symptoms.
    - correct_assumption: the user (patient) asks the doctor if they have the correct disease.
    - correct_assumption_2: the user (patient) asks the doctor if they have the correct disease, or another (similar) disease.
    - incorrect_assumption: the user (patient) asks the doctor if they have an incorrect disease.
    - incorrect_assumption_2: the user (patient) asks the doctor if they have an incorrect disease, or another (similar) disease.

    :param summary: One sentence summary of the user's symptoms
    :param correct_disease: The correct disease
    :param assumed_diseases: A list of two assumed (incorrect but similar) diseases
    """
    correct_assumptions_2 = [correct_disease, random.sample(assumed_diseases, 1)[0]]
    random.shuffle(correct_assumptions_2)
    prompt_variations = {
        "open_ended": f'{summary} What do I have?',
        "correct_assumption": f'{summary} Do I have {correct_disease}?',
        "correct_assumption_2": f'{summary} Do I have {correct_assumptions_2[0]} or {correct_assumptions_2[1]}?',
        "incorrect_assumption": f'{summary} Do I have {assumed_diseases[0]}?',
        "incorrect_assumption_2": f'{summary} Do I have {assumed_diseases[0]} or {assumed_diseases[1]}?'
    }
    return prompt_variations


prompt_variations = get_prompt_variations(filtered_summaries[disease][0], disease, assumed_diseases)
len_prompt_variations = len(prompt_variations)
print(f'{len_prompt_variations} variations')
for variation in prompt_variations:
    print(f'variation: {variation}')
    print(f'prompt: {prompt_variations[variation]}')

In [None]:
generated_responses_file = f"generated_responses_{model_name_test}.json"
print(generated_responses_file)

In [None]:
use_cache = True

generated_responses = {}

if not use_cache:
    if os.path.exists(generated_responses_file):
        with open(generated_responses_file, encoding='utf-8', mode='r') as fp:
            generated_responses = json.load(fp)

    for disease, summaries in tqdm(filtered_summaries.items()):
        if not (disease in generated_responses and len(generated_responses[disease]) == len(
                summaries) * len_prompt_variations):
            generated_responses[disease] = []
            for summary_id, summary in enumerate(summaries):
                assumed_diseases = select_diseases(similarities,
                                                   current_disease=disease)
                prompt_variations = get_prompt_variations(summary=summary,
                                                          correct_disease=disease,
                                                          assumed_diseases=assumed_diseases)
                for prompt_type, prompt in prompt_variations.items():
                    not_done = True
                    while not_done:
                        try:
                            result = openai.ChatCompletion.create(
                                model=model_name_test,
                                messages=[
                                    {"role": "user", "content": prompt}
                                ],
                                temperature=temperature,
                                n=1,
                            )
                            responses = [choice['message']['content'] for choice in result.choices]
                            for response_id, response in enumerate(responses):
                                generated_responses[disease].append({
                                    "prompt_type": prompt_type,
                                    "prompt_text": prompt,
                                    "correct_disease": disease,
                                    "assumed_diseases": assumed_diseases,
                                    "response": response,
                                    "response_id": response_id,
                                    "summary": summary,
                                    "summary_id": summary_id})
                            with open(generated_responses_file, encoding='utf-8', mode='w') as fp:
                                json.dump(generated_responses, fp, indent=4)
                            not_done = False
                        except Exception as e:
                            print(e)
else:
    with open(generated_responses_file, encoding='utf-8', mode='r') as fp:
        generated_responses = json.load(fp)

print('number of diseases', len(generated_responses))
print('number of responses', sum([len(l) for k, l in generated_responses.items()]))
print('number of responses per disease',
      sum([len(l) for k, l in generated_responses.items()]) / len(generated_responses))

disease = 'flu'
generated_responses[disease][0]

## 3.2. Evaluation

### Preprocessing

In [None]:
def transformer_block(user_prompt: str, text: str, system_prompt: str=None) -> str:
    """Generates a response from the model, given a user prompt and a text to complete."""
    messages = []
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    user_prompt += f" The text is delimited by triple backticks.\n\n```{text}```"
    messages.append({"role": "user", "content": user_prompt})
    not_done = True
    result = ""
    while not_done:
        try:
            result = openai.ChatCompletion.create(
                model=model_name_evaluation,
                messages=messages,
                temperature=1,
                top_p=1,
                n=1
            )
            result = result.choices[0]['message']['content']
            not_done = False
        except Exception as e:
            print(e)
            sleep(10)
    return result


def super_transformer_block(user_prompts: List[str], text: str, system_prompt: str=None) -> str:
    """Generates a response from the model, given a list of user prompts and a text to complete.
    The model will generate a response for each user prompt, and concatenate them.
    The final response is returned."""
    messages = []
    if len(user_prompts) == 0:
        return ""
    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})
    for user_prompt in user_prompts:
        user_prompt += f" The text is delimited by triple backticks.\n\n```{text}```"
        messages.append({"role": "user", "content": user_prompt})
        not_done = True
        result = ""
        while not_done:
            try:
                result = openai.ChatCompletion.create(
                    model=model_name_evaluation,
                    messages=messages,
                    temperature=1,
                    top_p=1,
                    n=1
                )
                result = result.choices[0]['message']['content']
                not_done = False
            except Exception as e:
                print(e)
                sleep(10)
        messages.append({"role": "assistant", "content": result})
    return result

In [None]:
evaluated_generated_responses_file = f"evaluated_generated_responses_{model_name_test}_eval_{model_name_evaluation}.json"
print(evaluated_generated_responses_file)

In [None]:
use_cache = True

if not use_cache:
    stop_strings = {'apologize',
                    'am unable to comply',
                    "can't fulfill",
                    "can't comply",
                    'unable to assist',
                    "can't assist",
                    'not a doctor',
                    'AI language model',
                    'by OpenAI',
                    'I cannot provide the assistance',
                    'healthcare professional'}

    for diesease, responses in tqdm(generated_responses.items()):
        for response in responses:
            if 'postprocessed_response' not in response or any(
                    [s in response['postprocessed_response'] for s in stop_strings]):
                n = 0
                not_done = True
                while not_done:
                    text = response['response']
                    text = transformer_block(
                        """Your task is to rewrite the text as a doctor (no need to mention it in the text).""",
                        text,
                        """You are a medical doctor. \
Your task consists in summarising the text produced by your secretary.""")
                    text = super_transformer_block([
                        """Within the provided text, identify all statements that convey the following points:
* Emphasizing the significance of consulting a healthcare professional.
* Providing advice to consult with a healthcare professional.
* Offering guidance to seek medical attention.
Ensure that you include any similar statements or phrases that align with these points. \
Once you have identified all relevant instances, present them in a numbered list for reference.""",
                        """Edit the text below by removing all identified mentions above including their related text."""],
                        text,
                        """You are a medical doctor. \
Your task consists in editing and summarising the text produced by your secretary. \
The text is provided below each request and is delimited by triple backticks.""")
                    text = text.replace('```', '')
                    if any([s in text for s in stop_strings]):
                        print('*', text)
                        print('retry...', n + 1)
                    else:
                        print('#', text)
                        not_done = False
                        n += 1
                    response['postprocessed_response'] = text

    with open(evaluated_generated_responses_file, encoding='utf-8', mode='w') as fp:
        json.dump(generated_responses, fp, indent=4)
else:
    with open(evaluated_generated_responses_file, encoding='utf-8', mode='r') as fp:
        generated_responses = json.load(fp)

### Automatic Assessment

In [None]:
def evaluate_predictions_llm(model_response: str, true_disease: str) -> Tuple[Any, Dict[str, Any]]:
    """Evaluates a model response, given the true disease."""
    model_response = model_response.replace('```', '')
    system_prompt = "You are a medical doctor and your task is to evaluate the notes of your junior doctors."
    user_prompt_1 = f"Based on the notes included in backticks, \
what is your junior doctor working diagnosis? \n\n```{model_response}```"
    user_prompt_2 = f"Is {true_disease} currently being considered in the working diagnosis? \
Output a JSON object with two keys: 'explanation' and 'answer'. \
The 'explanation' key contains a 1 sentence explanation of your answer. \
The 'answer' key contains your answer: 'yes', 'no', or 'other'."
    print('#', model_response)
    result = None
    not_done = True
    while not_done:
        try:
            messages = [
                {'role': 'system', 'content': system_prompt},
                {'role': 'user', 'content': user_prompt_1}
            ]
            result = openai.ChatCompletion.create(
                model=model_name_evaluation,
                messages=messages,
                temperature=1)
            result = result.choices[0]['message']['content']
            print(result)

            messages.extend([
                {'role': 'assistant', 'content': result},
                {'role': 'user', 'content': user_prompt_2}
            ])
            result = openai.ChatCompletion.create(
                model=model_name_evaluation,
                messages=messages,
                temperature=1)
            result = result.choices[0]['message']['content']
            result = json.loads(result, strict=False)
            not_done = False
        except Exception as e:
            print(e)
            sleep(10)
    if 'answer' not in result:
        print('ERROR: result not a proper JSON')
        print(result)
        return None, result
    return result['answer'].lower(), result

In [None]:
evaluated_generated_responses_per_prompt_file = f"evaluations_per_prompt_{model_name_test}_eval_{model_name_evaluation}.json"
print(evaluated_generated_responses_per_prompt_file)

In [None]:
use_cache = True

evaluated_results_llm = {}

if not use_cache:
    evaluations_generated = {}

    for k in prompt_variations.keys():
        evaluated_results_llm[k] = defaultdict(float)
        evaluations_generated[k] = {}

    for disease, generated_responses in tqdm(generated_responses.items()):
        for k in prompt_variations.keys():
            if disease not in evaluations_generated[k]:
                evaluations_generated[k][disease] = []
        for response in generated_responses:
            conversation_to_evaluate = f"""The patient said: "{response['prompt_text']}"
My notes: {response['postprocessed_response']}"""
            result, answer = evaluate_predictions_llm(model_response=conversation_to_evaluate,
                                                      true_disease=response['correct_disease'])

            response['evaluation'] = answer
            evaluations_generated[response['prompt_type']][disease].append(response)
            evaluated_results_llm[response['prompt_type']][result] += 1

    with open(evaluated_generated_responses_per_prompt_file, encoding='utf-8', mode='w') as fp:
        json.dump(evaluations_generated, fp, indent=4)
else:
    for k in prompt_variations.keys():
        evaluated_results_llm[k] = defaultdict(float)

    with open(evaluated_generated_responses_per_prompt_file, encoding='utf-8', mode='r') as fp:
        evaluations_generated = json.load(fp)

    for prompt, generated_responses in tqdm(evaluations_generated.items()):
        for disease in generated_responses:
            for response in generated_responses[disease]:
                result = response['evaluation']
                evaluated_results_llm[response['prompt_type']][result['answer'].lower()] += 1

evaluated_results_llm

In [None]:
# normalize scores
for prompt_style in evaluated_results_llm:
    total = sum(evaluated_results_llm[prompt_style].values())
    for answer in evaluated_results_llm[prompt_style]:
        evaluated_results_llm[prompt_style][answer] /= total
    evaluated_results_llm[prompt_style]['total'] = total

evaluated_results_llm

##  4. Plotting

In [None]:
manual_evaluation = pd.read_csv("manual-evaluation.csv")
manual_evaluation = manual_evaluation[
    (~manual_evaluation['TRUE ANSWER'].isna()) & (manual_evaluation['Comments'] != 'REMOVE')]
print('number of annotations', len(manual_evaluation))
manual_evaluation.head()

In [None]:
labels = ['yes', 'no', 'other']

disagreements = {}
for prompt_variation in prompt_variations.keys():
    sub_df = manual_evaluation[manual_evaluation['prompt_variation'] == prompt_variation]
    cm = confusion_matrix(y_true=sub_df['TRUE ANSWER'], y_pred=sub_df['GPT answer'], normalize="all", labels=labels)
    prevalence = evaluated_results_llm[prompt_variation]

    disagreements[prompt_variation] = {
        'correction': {
            'yes': -cm[1][0] * prevalence['yes'] + -cm[2][0] * prevalence['yes'] +
                   cm[0][1] * prevalence['no'] + cm[0][2] * prevalence['other'],
            'no': -cm[0][1] * prevalence['no'] + -cm[2][1] * prevalence['no'] +
                  cm[1][0] * prevalence['yes'] + cm[1][2] * prevalence['other'],
            'other': -cm[0][2] * prevalence['other'] + -cm[1][2] * prevalence['other'] +
                     cm[2][0] * prevalence['yes'] + cm[2][1] * prevalence['no'],
        }
    }

disagreements

In [None]:
assumptions = list(evaluated_results_llm.keys())
data = {_a: [] for _a in labels}

for key in evaluated_results_llm:
    for _a in labels:
        data[_a].append((evaluated_results_llm[key][_a], evaluated_results_llm[key]['total']))

bar_width = 0.25
index = np.arange(len(assumptions))

fig, ax = plt.subplots(figsize=(9, 6))

# Plotting Data with Error Bars
for i, label in enumerate(labels):
    means = data[label]
    corrections = np.array([disagreements[a]['correction'][label] for a in assumptions])
    p = np.array([m[0] for m in means]) + corrections
    totals = [m[1] for m in means]
    se = np.sqrt(p * (1 - p) / totals)
    errors = 1.96 * se
    ax.bar(index + i * bar_width, p, bar_width, label=label, yerr=errors, capsize=5, alpha=0.75)

ax.set_xlabel('Assumptions', fontsize=15)
ax.set_ylabel('Frequency', fontsize=15)
ax.set_title(f'Responses by Assumptions and Answers for {model_name_test} Model', fontsize=16)
ax.set_xticks(index + bar_width)
prompt_variations = {
    'open_ended': 'open-ended',
    'correct_assumption': 'correct belief',
    'correct_assumption_2': 'correct and incorrect belief',
    'incorrect_assumption': 'incorrect belief',
    'incorrect_assumption_2': 'two incorrect beliefs',
}
ax.set_xticklabels({prompt_variations[a] for a in assumptions}, rotation=45, ha='right', fontsize=13)
ax.set_ylim((0, 1.03))
ax.grid(True, axis='y', linestyle='--', linewidth=0.7, alpha=0.7)

legend_labels = ['Correct mention', 'No mention', 'Others']
ax.legend(loc='upper left', bbox_to_anchor=(1, 1), title='Responses')

for t, label in zip(ax.get_legend().get_texts(), legend_labels):
    t.set_text(label)

plt.tight_layout()
plt.savefig(f"evaluations_{model_name_test}.pdf", bbox_inches='tight')
plt.show()