# Phase A 
Question in -> List of max 10 articles out; List of 10 relevant snippets out.

Components:
-> Question to PubMed? Query 
-> Relevancy ranking
-> Article snippet candidate extractor (Only from Abstract and Title?)
-> Snippet ranking



In [None]:
%pip install Biopython 

## Initial Parameters

In [1]:
# Define base URL for OpenAI API
base_url = "https://api.openai.com/v1/"

# Define your API key (replace with your own)
api_key = input()

# Define headers for authorization and organization
headers = {
    "Authorization": f"Bearer {api_key}",
    "OpenAI-Organization": None # replace with your own
}

#model_name = "gpt-4-0314"
#model_name = "gpt-3.5-turbo-0301"
model_name = "gpt-3.5-turbo"
#model_name = "gpt-4"

## Query Expansion Module

### Query Expansion V2

In [2]:
# Import requests library
# Query Expansion V2
import requests

def expand_query(question):
    
    # Define parameters for chat completion
    params_chat = {
        "model": model_name, # model name
        "messages": [
            {"role": "system", "content": "You are BioASQ-GPT, an AI expert in question answering, research, and information retrieval in the biomedical domain."},
            {"role": "user", "content": f"Expand this search query: '{question}' for PubMed by incorporating synonyms and additional terms that closely relate to \
            the main topic and help reduce ambiguity. Assume that phrases are not stemmed; therefore, generate useful variations. Return only the query that can \
            directly be used without any explanation text. Focus on maintaining the query's precision and relevance to the original question."}
        ], 
        "temperature": 0.0, # randomness of completion
        "frequency_penalty": 0.5, # discourage repetition of words or phrases
        "presence_penalty": 0.1, # discourage new topics or entities
    }

    # Make request to chat completion
    response_chat = requests.post(base_url + "chat/completions", headers=headers, json=params_chat)

    # Check status code
    if response_chat.status_code == 200:
        # Parse response as JSON
        data_chat = response_chat.json()

        # Get the generated message from the chat completion result
        generated_message = data_chat["choices"][0]["message"]["content"]

        # Return the generated message
        return generated_message

    else:
        print(response_chat.text)
        raise Exception(f"Error: {response_chat.status_code}")



## Refinment Module

In [3]:
def refine_query(question, original_query):
    params_chat = {
        "model": model_name, # model name
        "messages": [
            {"role": "system", "content": "You are BioASQ-GPT, an AI expert in question answering, \
             research, and information retrieval in the biomedical domain."},
            {"role": "user", "content": f"Given that the following search query for PubMed has returned\
            no documents, please generate a broader query that retains the original question's context and relevance.\
            Assume that phrases are not stemmed; therefore, generate useful variations. Return only the query that can\
            directly be used without any explanation text. Focus on maintaining the query's precision and relevance to\
            the original question. Original question: '{question}', Original query: '{original_query}'."}
        ],
        "temperature": 0.0, # randomness of completion
        "frequency_penalty": 0.6, # discourage repetition of words or phrases
        "presence_penalty": 0.2, # discourage new topics or entities
    }
    print(params_chat)

     # Make request to chat completion
    response_chat = requests.post(base_url + "chat/completions", headers=headers, json=params_chat)

    # Check status code
    if response_chat.status_code == 200:
        # Parse response as JSON
        data_chat = response_chat.json()

        # Get the generated message from the chat completion result
        generated_message = data_chat["choices"][0]["message"]["content"]

        # Return the generated message
        return generated_message

    else:
        print(response_chat.text)
        raise Exception(f"Error: {response_chat.status_code}")

## Reranking Module

In [6]:
import requests

def rerank_top_articles(question, articles):

    # Prepare the list of articles as a single string for the message
    articles_str = "\n".join([f"{idx + 1} - {article['title']}" for idx, article in enumerate(articles)])
    nr_of_articles = min(10, len(articles))
    

    # Define parameters for chat completion
    params_chat = {
        "model": model_name, # model name
        "messages": [
            {"role": "system", "content": "You are BioASQ-GPT, an AI expert in question answering, \
            research, and information retrieval in the biomedical domain."},
            {"role": "user", "content": f"{articles_str} \n\n Given these articles and the question: '{question}'. \
             Rerank the articles based on their relevance to the question and return the top {nr_of_articles} most relevant articles as a comma seperated list of their index ids. Don't explain your answer, return only this list, for example: '1, 2, 3, 4' "}
        ],
        "temperature": 0.0, # randomness of completion
        "frequency_penalty": 0.3, # discourage repetition of words or phrases
        "presence_penalty": 0.1, # discourage new topics or entities
    }
    # Make request to chat completion
    response_chat = requests.post(base_url + "chat/completions", headers=headers, json=params_chat)

    # Check status code
    if response_chat.status_code == 200:
        # Parse response as JSON
        data_chat = response_chat.json()

        # Get the generated message from the chat completion result
        generated_message = data_chat["choices"][0]["message"]["content"]
        print("generated rerank response: ")
        print(generated_message)

        # Extract the reranked article indices from the generated message
        reranked_indices = [int(x.strip()) - 1 for x in generated_message.split(',')]

        # Rerank the articles based on the extracted indices
        reranked_articles = [articles[idx] for idx in reranked_indices]

        # Return the top 10 reranked articles
        return reranked_articles[:10]

    else:
        print(response_chat.text)
        raise Exception(f"Error: {response_chat.status_code} {response_chat.reason} {response_chat.text}")


## Snippet Generation Module

In [5]:
import requests
import json


def snippet_call_with_article_str(question, articles_str):
    # Define parameters for chat completion
    params_chat = {
        "model": model_name, # model name
        "messages": [
            {"role": "system", "content": "You are BioASQ-GPT, an AI expert system in question answering, \
             research, and information retrieval in the biomedical domain. You only answer in JSON format."},
            {"role": "user", "content": f"{articles_str} \n Given above list of articles and the question: '{question}'. \
             Extract the most relevant snippets for the question from title or abstract and return them in as a JSON list of JSON objects with the following structure.\
             Each snippet should become a json object containing the fields: document (the unique document url of the article from which the snippet is taken),\
             beginSection (either title or abstract), endSection (same as beginSection), \
             offsetInBeginSection (the character offset where the snippet starts), offsetInEndSection (the character offset where the snippet ends), \
             and text (the text of the snippet). Each snippet should be at most 3 sentences long, extract at most 3 snippets per article. The answer should be a valid JSON list. \
             \n\n"}
        ],
        "temperature": 0.0, # randomness of completion
        "frequency_penalty": 0.3, # discourage repetition of words or phrases
        "presence_penalty": 0.1, # discourage new topics or entities
    }

    # Make request to chat completion
    response_chat = requests.post(base_url + "chat/completions", headers=headers, json=params_chat)

    # Check status code
    if response_chat.status_code == 200:
        # Parse response as JSON
        data_chat = response_chat.json()

        # Get the generated message from the chat completion result
        generated_message = data_chat["choices"][0]["message"]["content"]
        print("gpt response:")
        print(generated_message)
        
        cleaned_message = generated_message.replace("\\xa0", " ").replace("\xa0", " ")

        # Extract the relevant snippets from the generated message
        relevant_snippets = json.loads(cleaned_message)

        return relevant_snippets

    else:
        print(response_chat.text)
        raise Exception(f"Error: {response_chat.status_code}")

def extract_relevant_snippets(question, articles):

    if len(articles) > 0:
        # Generate a string for the first 3 articles or all articles if the list contains less than 3
        articles_str_1 = "\n".join([f"{idx + 1}. {article['id']} - title:'{article['title']}' abstract:'{article['abstract']}'" for idx, article in enumerate(articles[:min(3, len(articles))])])

        # Generate a string for the next 3 articles only if the list contains more than 3 articles
        if len(articles) > 3:
            articles_str_2 = "\n".join([f"{idx + 4}. {article['id']} - title:'{article['title']}' abstract:'{article['abstract']}'" for idx, article in enumerate(articles[3:6])])
        else:
            articles_str_2 = ""

        print("First 3 articles or all if less than 3:")
        print(articles_str_1)
        snippets = snippet_call_with_article_str(question, articles_str_1)
        
        if articles_str_2:
            print("\nNext 3 articles:")
            print(articles_str_2)
            snippets.extend(snippet_call_with_article_str(question, articles_str_2))

            # Generate a string for the last articles if the list contains more than 6 articles
            if len(articles) > 6:
                articles_str_3 = "\n".join([f"{idx + 7}. {article['id']} - title:'{article['title']}' abstract:'{article['abstract']}'" for idx, article in enumerate(articles[6:])])
                print("\nLast articles:")
                print(articles_str_3)
                snippets.extend(snippet_call_with_article_str(question, articles_str_3))
        
        return snippets
    else:
        return []


## Retrieval

In [5]:
from Bio import Entrez
from urllib.error import HTTPError


Entrez.email = None #replace with your own
Entrez.api_key = None #replace with your own

def create_article_dict(pmid, title, abstract):
    if isinstance(abstract, list):
        abstract = " ".join([str(a) for a in abstract])

    return {
        "id": "http://www.ncbi.nlm.nih.gov/pubmed/" + pmid,
        "title": title,
        "abstract": abstract
    }

def get_pubmed_article_details(article_ids):
    try:
        handle = Entrez.efetch(db="pubmed", id=",".join(article_ids), rettype="medline", retmode="xml")
    except HTTPError as e:
        print(e)
        print(e.response.text)

    records = Entrez.read(handle)
    handle.close()

    articles = []
    for record in records['PubmedArticle']:
        title = record['MedlineCitation']['Article']['ArticleTitle']
        abstract = record['MedlineCitation']['Article'].get('Abstract', {}).get('AbstractText', None)
        pmid = record['MedlineCitation']['PMID']
        articles.append(create_article_dict(pmid, title, abstract))

    for record in records['PubmedBookArticle']:
        title = record['BookDocument']['Book']['BookTitle']
        abstract = record['BookDocument'].get('Abstract', {}).get('AbstractText', None)
        pmid = record['BookDocument']['PMID']
        articles.append(create_article_dict(pmid, title, abstract))
        
    return articles


def search_pubmed_and_bioasq(query, articles_per_page=50):
    handle_esearch = Entrez.esearch(db="pubmed", term=query, retmax=articles_per_page, sort="relevance", datetype="pdat", maxdate="2022/12/09")
    data_esearch = Entrez.read(handle_esearch)
    id_list = data_esearch["IdList"]
    print(id_list)

    articles = []
    if len(id_list) > 0:
        articles = get_pubmed_article_details(id_list)
    return articles


## PhaseA Run

In [None]:
import json
import datetime
import time  # Import the time module

def append_to_logfile(logfile_name, text):
    with open(logfile_name, 'a', encoding='utf-8') as logfile:
        logfile.write(text + "\n")

# Get the current timestamp in a sortable format
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")

logfile_name = f"{timestamp}_{model_name}_PhaseA_No_Expansion_log_file.json"
debug_logfile = f"{timestamp}_{model_name}_PhaseA_No_Expansion_Debug_log_file.json"

def get_relevant_articles(question):
    # stub function for getting relevant articles
    print("Question: "+question)
    append_to_logfile(debug_logfile, f"{{\"question\":\"{question}\",")
    #query = expand_query(question);
    query = question
    print("Query: "+query)
    append_to_logfile(debug_logfile, f"\"query\":\"{query}\"")
    relevant_articles = search_pubmed_and_bioasq(query)
    #No Query expansion therefore no query refinement
    #if len(relevant_articles) == 0:
    #    print("Refining Query")
    #    query = refine_query(question, query)
    #    print("refined query:"+ query)
    #   append_to_logfile(debug_logfile, f"\"refined_query\":\"{query}\"")
    #    relevant_articles = search_pubmed_and_bioasq(query)
    append_to_logfile(debug_logfile, f"\"retrieved_articles\":{relevant_articles},")
    if len(relevant_articles) > 1:
        top_articles = rerank_top_articles(question, relevant_articles)
        append_to_logfile(debug_logfile, f"\"reranked_articles\":{top_articles}}}")
    else:
        top_articles = relevant_articles[:10]
    return top_articles

def get_relevant_snippets(articles, question):
    # stub function for getting relevant snippets
    relevant_snippets = extract_relevant_snippets(question, articles)
    print("relevant snippets: ")
    print(relevant_snippets)
    return relevant_snippets

# Load the input file in JSON format
with open('./BioASQ-task11bPhaseA-testset4.json') as input_file:
    data = json.loads(input_file.read())

# Create an empty list to store the results
results = []

append_to_logfile(debug_logfile, "[")


# Iterate over all questions
offset = 0
for idx, question in enumerate(data["questions"]):
    print(idx)
    if idx < offset:
        continue

    retry_count = 0  # Initialize a counter for retries
    while retry_count < 2:  # Set the maximum number of retries to 2
        try:
            # Call the stub function to get relevant articles
            relevant_articles = get_relevant_articles(question['body'])
            relevant_articles_ids = [article['id'] for article in relevant_articles]
            print(relevant_articles_ids)

            # Call the stub function to get relevant snippets
            #relevant_snippets = get_relevant_snippets(relevant_articles, question['body'])
            relevant_snippets = []

            # Create a dictionary to store the results for this question
            question_results = {
                "body": question["body"],
                "id": question["id"],
                "type": question["type"],
                "documents": relevant_articles_ids,
                "snippets": relevant_snippets
            }

            append_to_logfile(logfile_name, json.dumps(question_results))

            # Add the results for this question to the list of all results
            results.append(question_results)
            break  # If no exception is thrown, break the loop
        except Exception as e:
            print(f"Error processing question {idx}: {e}")
            retry_count += 1  # Increment the retry counter
            time.sleep(5)  # Sleep for 5 seconds before retrying

# Create a dictionary to store the results for all questions
output = {
    "questions": results
}

append_to_logfile(debug_logfile, "]")


# Prefix the output file name with the timestamp
output_file_name = f"./Result/{timestamp}_{model_name}_PhaseA_NoExpansion_output_file.json"

# Save the output to a file in pretty-formatted JSON format
with open(output_file_name, "w") as f:
    json.dump(output, f, indent=4)