In [None]:
%pip install Biopython openai "elasticsearch<8" python-dotenv mistralai fireworks-ai sentence_transformers
%pip install --upgrade pandas
%pip install websocket-client wikipedia-api wikipedia
%pip install --upgrade fireworks-ai

In [1]:
from openai import OpenAI
from fireworks.client import Fireworks
import anthropic
import re
import os
import json
from elasticsearch import Elasticsearch
from dotenv import load_dotenv
import datetime
import pickle
import traceback
import pandas as pd
from concurrent.futures import ThreadPoolExecutor, as_completed
import string


client_openai = OpenAI()
client_fireworks = Fireworks()
client_anthropic = anthropic.Anthropic()

## Wiki Retrieval

In [2]:
import wikipediaapi
import re
import pickle
import os

CACHE_FILE = 'mixtral-wiki-cache.pkl'

# Load the cache from the pickle file
def load_cache():
    if os.path.exists(CACHE_FILE):
        with open(CACHE_FILE, 'rb') as f:
            return pickle.load(f)
    return {}

# Save the cache to the pickle file
def save_cache(cache):
    with open(CACHE_FILE, 'wb') as f:
        pickle.dump(cache, f)

cache = load_cache()

# Update the cache file manually with a new or updated entry
def update_cache(page_name, summary):
    cache[page_name] = summary
    save_cache(cache)

def escape_for_json(input_string):
    escaped_string = json.dumps(input_string)
    return escaped_string


def get_article_summary(page_name: str) -> str:
    key = page_name+"_summary"
    # Check if the summary is in cache first
    if key in cache:
        return cache[key]
    
    # Specify a user agent
    user_agent = "MySimpleWikiBot/1.0 (https://example.com/)"
    wiki_wiki = wikipediaapi.Wikipedia(language='en', user_agent=user_agent)
    page = wiki_wiki.page(page_name)

    if page.exists():
        summary = page.summary
        update_cache(key, summary)  # Update the cache with the new summary
        return summary
    update_cache(key, None) 
    return None

def get_article_text(page_name: str) -> str:
    key = page_name+"_text"
    # Check if the text is in cache first
    if key in cache:
        return cache[key]
    
    # Specify a user agent
    user_agent = "MySimpleWikiBot/1.0 (https://example.com/)"
    wiki_wiki = wikipediaapi.Wikipedia(language='en', user_agent=user_agent)
    page = wiki_wiki.page(page_name)

    if page.exists():
        text = page.text
        update_cache(key, text)  # Update the cache with the new summary
        return text
    update_cache(key, None) 
    return None

def concatenate_article_summaries(page_names):
    final_string = ""
    for name in page_names:
        summary = get_article_summary(name)
        if summary is not None:
            final_string += summary + "\n\n"  # Adding a newline for readability between summaries
    escape_for_json(final_string)
    return final_string

def concatenate_article_text(page_names):
    final_string = ""
    for name in page_names:
        text = get_article_text(name)
        if text is not None:
            final_string += text + "\n\n"  # Adding a newline for readability between summaries
    escape_for_json(final_string)
    return final_string

def extract_hash_wrapped_strings(text):
    pattern = r'#([^#]+)#'
    matches = re.findall(pattern, text)
    return matches

def summarize_wiki_context(question: str, wikipedia_context: str, model: str) -> str:
    prompt = f"""
    Given the biomedical question: "{question}"
    
    And the following relevant Wikipedia articles:
    {wikipedia_context}
    
    Your task is to extract and summarize the most relevant information from these articles to help answer the question. Follow these steps:
    
    1. Carefully read through the provided Wikipedia articles.
    2. Identify the key information that directly relates to the biomedical question.
    3. Extract the relevant passages, focusing on the most important details and context.
    4. Summarize the extracted information concisely, maintaining the essential meaning and context.
    5. Organize the summarized information in a clear and coherent manner.
    
    Provide your summary below, formatted as follows:
    SUMMARY: [Your concise summary of the relevant information from the Wikipedia articles]
    """
    
    messages = [
        {"role": "system", "content": "You are an AI expert in extracting and summarizing relevant information from Wikipedia articles in the biomedical domain."},
        {"role": "user", "content": prompt}
    ]
    
    print("\nWiki summary prompt:")
    print(messages)
    if "accounts" in model:
        completion = client_fireworks.chat.completions.create(
            model=model,
            messages =messages,
            max_tokens = 4096,
            prompt_truncate_len = 27000,
            temperature=0.0 # randomness of completion
        )
        answer = completion.choices[0].message.content
    elif "claude" in model:
        system_message_content = messages.pop(0)['content']
        completion = client_anthropic.messages.create(
            model=model,
            system=system_message_content,
            messages=messages,
            max_tokens=4096,
            temperature=0.0
        )
        answer = completion.content[0].text
    else:
        completion = client_openai.chat.completions.create(
            model=model,
            messages=messages,
            temperature=0.0, # randomness of completion
            seed=90128538
        )
        answer = completion.choices[0].message.content

    print("\nWiki summary answer:")
    print(answer)
    
    summary = answer
    return summary

def get_wiki_context(question: str, model) -> str:
    prompt = f"""
    Given the question "{question}", identify existing Wikipedia articles that offer helpful background information to answer this question. 
    Ensure that the titles listed are of real articles on Wikipedia as of your last training cut-off. Wrap the confirmed article titles in hashtags (e.g., #Article Title#). 
    Provide a step-by-step reasoning for your selections, ensuring relevance to the main components of the question.

    Step 1: Confirm the Existence of Articles
    Before listing any articles, briefly verify their existence by ensuring they are well-known topics generally covered by Wikipedia.

    Step 2: List Relevant Wikipedia Articles
    After confirming, list the articles, wrapping the titles in hashtags and explaining how each article is relevant to the question.
    """
    messages = [
        {"role": "system", "content": "You are BioASQ-GPT, an AI expert in question answering, research, and information retrieval in the biomedical domain."},
        {"role": "user", "content": prompt}
    ]
    print("\nWiki articles prompt:")
    print(messages)
    
    if "accounts" in model:
        completion = client_fireworks.chat.completions.create(
            model=model,
            messages =messages,
            max_tokens = 4096,
            prompt_truncate_len = 27000,
            temperature=0.0 # randomness of completion
        )
        answer = completion.choices[0].message.content
    elif "claude" in model:
        system_message_content = messages.pop(0)['content']
        completion = client_anthropic.messages.create(
            model=model,
            system=system_message_content,
            messages=messages,
            max_tokens=4096,
            temperature=0.0
        )
        answer = completion.content[0].text
    else:
        completion = client_openai.chat.completions.create(
            model=model,
            messages=messages,
            temperature=0.0, # randomness of completion
            seed=90128538
        )
        answer = completion.choices[0].message.content

    print("\nWiki articles answer:")
    print(answer)
    relevant_article_titles = extract_hash_wrapped_strings(answer)
    wiki_context = concatenate_article_text(relevant_article_titles)
    final_context = summarize_wiki_context(question,wiki_context, model)
    print(f"Wiki Context for question: {question}")
    print(final_context)
    return final_context

## Run

In [None]:
def generate_n_shot_examples_extraction(examples, n):
    """Takes the top n examples, flattens their messages into one list, and filters out messages with the role 'system'."""
    n_shot_examples = []
    for example in examples[:n]:
        for message in example['messages']:
            if message['role'] != 'system':  # Only add messages that don't have the 'system' role
                n_shot_examples.append(message)
    return n_shot_examples

def read_jsonl_file(file_path):
    """Reads a JSONL file and returns a list of examples."""
    examples = []
    with open(file_path, 'r', encoding='utf-8') as file:
        for line in file:
            examples.append(json.loads(line))
    return examples

yesno_examples_file = "07_QA_YesNo_11B1-3-4_62.jsonl"     
yesno_examples = read_jsonl_file(yesno_examples_file)

factoid_examples_file = "04_QA_Factoid_11B1-3-4_76.jsonl"     
factoid_examples = read_jsonl_file(factoid_examples_file)

ideal_examples_file = "05_QA_Ideal_11B1-3-4_255.jsonl"     
ideal_examples = read_jsonl_file(ideal_examples_file)

list_examples_file = "06_QA_List_11B1-3-4_54.jsonl"     
list_examples = read_jsonl_file(list_examples_file)

def remove_punctuation_and_lowercase(text):
    # Lowercase the string
    text = text.lower()
    # Remove punctuation
    text = text.translate(str.maketrans("", "", string.punctuation))
    text = text[:3]
    text = text.strip()
    return text


def get_completion(messages, model):
    if "accounts" in model:
        completion = client_fireworks.chat.completions.create(
            model=model,
            messages =messages,
            max_tokens = 4096,
            prompt_truncate_len = 27000,
            temperature=0.0 # randomness of completion
        )
    elif "claude" in model:
        system_message_content = messages.pop(0)['content']
        completion = client_anthropic.messages.create(
            model=model,
            system=system_message_content,
            messages=messages,
            max_tokens=4096,
            temperature=0.0
        )
    else:
        completion = client_openai.chat.completions.create(
            model=model,
            messages=messages,
            temperature=0.0, # randomness of completion
            seed=90128538
        )
    print("\n Completion:")
    print(completion)
    print("\n")
    if hasattr(completion, 'choices'):
        completion_text =  completion.choices[0].message.content
    else:
        completion_text = completion.content[0].text
    prefix = "ASSISTANT: " # bug from fireworks fine-tuning

    if completion_text.startswith(prefix):
        completion_text = completion_text[len(prefix):]
    return completion_text

def generate_exact_answer(question, snippets, n_shots, wiki_context):
    exact_answer = []
    system_message = {"role": "system", "content": "You are BioASQ-GPT, an AI expert in question answering, research, and information retrieval in the biomedical domain."}
    messages = [system_message]

    if question["type"] == "yesno":
        few_shot_examples = generate_n_shot_examples_extraction(yesno_examples, n_shots)
        messages.extend(few_shot_examples)
        user_message = {"role": "user", "content": f"""
                {wiki_context}\n\n
                {snippets}\n\n
                '{question['body']}'. 
                You *must answer* only with lowercase 'yes' or 'no' even if you are not sure about the answer."""}
        messages.append(user_message)
        print(messages)
        answer = get_completion(messages, model_yesno)
        print("\ngpt response yesno:")
        print(answer)
        exact_answer = remove_punctuation_and_lowercase(answer) 

    elif question["type"] == "factoid":
        few_shot_examples = generate_n_shot_examples_extraction(factoid_examples, n_shots)
        messages.extend(few_shot_examples)
        user_message ={"role": "user", "content": f"""
                {wiki_context}\n\n 
                {snippets}\n\n
                 '{question['body']}'. 
                 Answer this question by returning a JSON string array called 'entities of entity names, numbers, or similar short expressions that are an answer to the question, 
                 ordered by decreasing confidence. The array should contain at max 5 elements but can contain less. If you don't know any answer return an empty array. 
                 Return only this array, it must not contain phrases and **must be valid JSON**. Example: {{"entities": ["entity1", "entity2"]}}"""}
        messages.append(user_message)
        print(messages)
        answer = get_completion(messages, model_factoid)
        print("\ngpt response factoid:")
        print(answer) 
        factoids = json.loads(answer)
        wrapped_list = [[item] for item in factoids['entities']]  
        exact_answer = wrapped_list

    elif question["type"] == "list":
        few_shot_examples = generate_n_shot_examples_extraction(list_examples, n_shots)
        messages.extend(few_shot_examples)
        user_message = {"role": "user", "content": f"""
                {wiki_context}\n\n 
                {snippets}\n\n
                 '{question['body']}'. 
                 Answer this question by only returning a JSON string array called 'entities of entity names, numbers, or similar short expressions that are an answer to the question 
                 (e.g., the most common symptoms of a disease). The returned array will have to contain no more than 100 entries of no more than 100 characters each. If you don't know any answer return an empty array. 
                 Return only this array, it must not contain phrases and **must be valid JSON**. Example: {{"entities": ["entity1", "entity2"]}}"""}
        messages.append(user_message)
        print(messages)
        answer = get_completion(messages, model_list)
        print("\ngpt response list:")
        print(answer)       
        list_answer = json.loads(answer)   
        wrapped_list = [[item] for item in list_answer['entities']]
        exact_answer = wrapped_list
    return exact_answer

def generate_ideal_answer(question, snippets, n_shots, wiki_context):
    system_message = {"role": "system", "content": "You are BioASQ-GPT, an AI expert in question answering, research, and information retrieval in the biomedical domain."}
    messages = [system_message]
    few_shot_examples = generate_n_shot_examples_extraction(ideal_examples, n_shots)
    messages.extend(few_shot_examples)
    user_message = {"role": "user", "content": f"""
            {wiki_context}\n\n 
            {snippets}\n\n
             '{question['body']}'.
             You are a biomedical expert, write a concise and clear answer to the above question.
             It is very important that the answer is correct.
             The maximum allowed length of the answer is 200 words, but try to keep it short and concise."""}
    messages.append(user_message)
    print(messages)
    answer = get_completion(messages, model_ideal)
    print("\ngpt response ideal:")
    print(answer)
    return answer   

##TODO load the corresponding input files for all models!
# Load the input file in JSON format
with open('./UR-IW-5.json', encoding='utf-8') as input_file:
    data = json.loads(input_file.read())

"""
## UR-IW-1 -> Claude Opus + 10 shot 
model_name = "claude-3-opus-20240229"
model_yesno =  "claude-3-opus-20240229"
model_ideal =  "claude-3-opus-20240229"
model_list =  "claude-3-opus-20240229"
model_factoid =  "claude-3-opus-20240229"
n_shots = 10
use_wiki = True


## UR-IW-2 -> Opus 10 shot
model_name = "claude-3-opus-20240229"
model_yesno =  "claude-3-opus-20240229"
model_ideal =  "claude-3-opus-20240229"
model_list =  "claude-3-opus-20240229"
model_factoid =  "claude-3-opus-20240229"
n_shots = 10
use_wiki = False

## UR-IW-3 -> Mixtral 7B 10-Shot + wiki
model_name = "accounts/fireworks/models/mixtral-8x7b-instruct"
model_yesno = "accounts/fireworks/models/mixtral-8x7b-instruct"
model_list = "accounts/fireworks/models/mixtral-8x7b-instruct"
model_ideal = "accounts/fireworks/models/mixtral-8x7b-instruct"
model_factoid = "accounts/fireworks/models/mixtral-8x7b-instruct"
n_shots = 10
use_wiki = True
"""
## UR-IW-4 -> Mixtral 7B 10-Shot
model_name = "accounts/fireworks/models/mixtral-8x7b-instruct"
model_yesno = "accounts/fireworks/models/mixtral-8x7b-instruct"
model_list = "accounts/fireworks/models/mixtral-8x7b-instruct"
model_ideal = "accounts/fireworks/models/mixtral-8x7b-instruct"
model_factoid = "accounts/fireworks/models/mixtral-8x7b-instruct"
n_shots = 10
use_wiki = False

"""
## UR-IW-5 -> Mixtral 22B 10-Shot + wiki
model_name = "accounts/fireworks/models/mixtral-8x22b-instruct"
model_yesno = "accounts/fireworks/models/mixtral-8x22b-instruct"
model_ideal = "accounts/fireworks/models/mixtral-8x22b-instruct"
model_list = "accounts/fireworks/models/mixtral-8x22b-instruct"
model_factoid = "accounts/fireworks/models/mixtral-8x22b-instruct"
n_shots = 10
use_wiki = True
"""



# Get the current timestamp in a sortable format
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

if '/' in model_name or ':' in model_name:
    pickl_name = model_name.replace('/', '-').replace(':', '-')
else:
    pickl_name = model_name
pickl_file = f'{pickl_name}-{n_shots}-shot.pkl'



def save_state(data, file_path=pickl_file):
    """Save the current state to a pickle file."""
    with open(file_path, 'wb') as f:
        pickle.dump(data, f)

def load_state(file_path=pickl_file):
    """Load the state from a pickle file if it exists, otherwise return None."""
    try:
        if os.path.exists(file_path):
            with open(file_path, 'rb') as f:
                return pickle.load(f)
    except EOFError:  # Handles empty pickle file scenario
        return None
    return None

# Define columns
columns = ['id', 'body', 'type', 'documents', 'snippets', "ideal_answer", "exact_answer"]

# Initialize empty DataFrame
questions_df = pd.DataFrame(columns=columns)

saved_df = load_state(pickl_file)

if saved_df is not None and not saved_df.empty:
    processed_ids = set(saved_df['question_id'])  
    questions_df = saved_df
else:
    processed_ids = set()

questions_to_process = [q for q in data["questions"] if q["id"] not in processed_ids]
#questions_to_process = questions_to_process[:2]


def process_question(question):
    question_type = question["type"]
    print(f"{question['body']}\n")

    # Get the relevant articles and snippets
    relevant_snippets = question["snippets"]

    # Generate the exact answer and ideal answer
    try:
        if use_wiki:
            wiki_context = get_wiki_context(question['body'], model_name)
        else:
            wiki_context = ""
        exact_answer = generate_exact_answer(question, relevant_snippets, n_shots, wiki_context)
        ideal_answer = generate_ideal_answer(question, relevant_snippets, n_shots, wiki_context)
    except Exception as e:
        print(f"Error processing question {question["id"]}: {e}")
        traceback.print_exc()
        exact_answer = []
        ideal_answer = []


    # Create a dictionary to store the results for this question
    question_results = {
        "id": question["id"],
        "type": question_type,
        "body": question["body"],
        "documents": question["documents"],
        "snippets": question["snippets"],
        "ideal_answer": ideal_answer,
        "exact_answer": exact_answer,
    }
    return question_results

# Use ThreadPoolExecutor to process questions in parallel
with ThreadPoolExecutor(max_workers=4) as executor:
    # Dictionary to keep track of question futures
    future_to_question = {executor.submit(process_question, q): q for q in questions_to_process}
    
    for future in as_completed(future_to_question):
        question = future_to_question[future]
        try:
            result = future.result()
            if result:
                # Append result to the DataFrame
                result_df = pd.DataFrame([result])
                questions_df = pd.concat([questions_df, result_df], ignore_index=True)
                save_state(questions_df, pickl_file)
        except Exception as e:
            print(f"Error processing question {question['id']}: {e}")
            traceback.print_exc()


# Prefix the output file name with the timestamp
if '/' in model_name:
    model_name_pretty = model_name.split("/")[-1]
else:
    model_name_pretty = model_name
output_file_name = f"./Results/{timestamp}_{model_name_pretty}_11B3-{n_shots}-QA-UR-IW-4.csv"

# Ensure the directory exists before saving
os.makedirs(os.path.dirname(output_file_name), exist_ok=True)

questions_df.to_csv(output_file_name, index=False)

# After processing all questions and saving the final output:
try:
    # Check if the pickle file exists before attempting to delete it
    if os.path.exists(pickl_file):
        os.remove(pickl_file)
        print("Intermediate state pickle file deleted successfully.")
except Exception as e:
    print(f"Error deleting pickle file: {e}")
    traceback.print_exc()

## Run File Generation

In [None]:
import pandas as pd
import json

def csv_to_json(csv_filepath, json_filepath):
    # Step 1: Read the CSV file into a pandas DataFrame
    df = pd.read_csv(csv_filepath)
    
    # Transform the DataFrame into a list of dictionaries, one per question
    questions_list = df.to_dict(orient='records')
    
    # Initialize the structure of the JSON file
    json_structure = {"questions": []}
    
    # Step 2: Transform the DataFrame into the desired JSON structure
    for item in questions_list:
        question_dict = {
            "documents": eval(item["documents"])[:10],
            "snippets": eval(item["snippets"])[:10],
            "body": item["body"],
            "type": item["type"],
            "id": item["id"],
            "ideal_answer": item["ideal_answer"],
        }
        if item["type"] == "yesno":
            yesno_answer = item["exact_answer"]
            if yesno_answer not in ['yes', 'no']:
                print(yesno_answer)
                yesno_answer = 'no'
            question_dict["exact_answer"] = yesno_answer
        if item["type"] == "factoid":
            question_dict["exact_answer"] = eval(item["exact_answer"])[:5]
        if item["type"] == "list":
            question_dict["exact_answer"] = eval(item["exact_answer"])[:100]

        json_structure["questions"].append(question_dict)
    
    # Step 3: Write the JSON structure to a file
    with open(json_filepath, 'w', encoding='utf-8') as json_file:
        json.dump(json_structure, json_file, ensure_ascii=False, indent=4)

# Example usage
csv_filepath = './Results/2024-04-25_09-43-27_mixtral-8x7b-instruct_11B3-10-QA-UR-IW-4.csv'  # Update this path to your actual CSV file path
json_filepath = './Results/UR-IW-4.json'  # Update this path to where you want to save the JSON file
csv_to_json(csv_filepath, json_filepath)
