In [1]:
import re
import time
import logging
import pandas as pd
from tqdm import tqdm
import openai
from funcs import get_prompt_header, openai_request, safe_api_call

# Initialize OpenAI client
client = openai.OpenAI(api_key='YOUR_KEY') # Please put your openAI key here

# str_mask is a string of six 0/1 marking whether a in-context learning component is used 
# six digits correspond to Dc. 1-2, Dm. 1-4
str_mask = '111111' # full model is 111111, slim model is 001001
mask = [bool(int(x)) for x in str_mask]
prompt = get_prompt_header(mask)

# rough number of chars for truncating 
cut_length = 16000

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


df = pd.read_csv('Q&A_dataset.csv')
modules = df['Module'].unique()

# Calculate total progress
total_questions = len(df)
total_attempts = total_questions * 3

https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=gene&retmax=5&retmode=json&sort=relevance&term=LMP10
https://eutils.ncbi.nlm.nih.gov/entrez/eutils/efetch.fcgi?db=gene&retmax=5&retmode=json&id=19171,5699,8138
https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi?db=snp&retmax=10&retmode=json&id=1217074595
https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esearch.fcgi?db=omim&retmax=20&retmode=json&sort=relevance&term=Meesmann+corneal+dystrophy
https://eutils.ncbi.nlm.nih.gov/entrez/eutils/esummary.fcgi?db=omim&retmax=20&retmode=json&id=618767,601687,300778,148043,122100
https://blast.ncbi.nlm.nih.gov/blast/Blast.cgi?CMD=Put&PROGRAM=blastn&MEGABLAST=on&DATABASE=nt&FORMAT_TYPE=XML&QUERY=ATTCTGCCTTTAGTAATTTGATGACAGAGACTTCTTGGGAACCACAGCCAGGGAGCCACCCTTTACTCCACCAACAGGTGGCTTATATCCAATCTGAGAAAGAAAGAAAAAAAAAAAAGTATTTCTCT&HITLIST_SIZE=5
https://blast.ncbi.nlm.nih.gov/blast/Blast.cgi?CMD=Get&FORMAT_TYPE=Text&RID=87KS399U014


In [None]:
# Main progress bar
with tqdm(total=total_attempts, desc="Processing all questions", unit="attempts") as pbar:
    
    # Process each module separately
    for module in modules:
        # Filter dataframe for current module
        module_df = df[df['Module'] == module].copy()
        
        # Prepare results list for this module
        module_results = []
        
        # Process each question in the current module
        for idx, row in module_df.iterrows():
            question = row['Question']
            model = row['Model']
            goldstandard = row['Goldstandard']
            
            # Ask the same question 3 times
            answers = []
            
            for attempt in range(3):
                pbar.set_description(f"Module: {module} | Q{len(module_results)+1}/{len(module_df)} | Attempt {attempt+1}/3")
                
                q_prompt = prompt + f'Question: {question}\n'  
                
                # Save the prompting logs
                prompts = []
                # Record API call times
                num_calls = 0
                final_answer = None
                
                while True:
                    if len(q_prompt) > cut_length:  # Make sure 'cut_length' is defined
                        # Truncate from the start
                        q_prompt = q_prompt[len(q_prompt) - cut_length:]
                    
                    # Use the improved OpenAI request function
                    text = openai_request(q_prompt, logger, client)
                    
                    # Handle error responses
                    if text in ['timeoutError', 'rateLimitError', 'lengthError'] or text.startswith('unexpectedError'):
                        final_answer = text
                        break
                    
                    num_calls += 1
                    prompts.append([q_prompt, text])
                    
                    # Look for URLs in the response
                    url_regex = r'\[(https?://[^\[\]]+)\]'
                    matches = re.findall(url_regex, text)
                    
                    if matches:
                        url = matches[0]
                        
                        # Wait for BLAST operations
                        if 'blast' in url and 'Get' in url: 
                            time.sleep(30)
                        
                        # Use safe API call function
                        call = safe_api_call(url)
                        
                        # Handle BLAST RID extraction
                        if 'blast' in url and 'Put' in url and isinstance(call, bytes):
                            try:
                                rid = re.search('RID = (.*)\n', call.decode('utf-8')).group(1)
                                call = rid
                            except:
                                call = "BLAST_ERROR: Could not extract RID"
                        
                        # Handle bytes response
                        if isinstance(call, bytes):
                            call = call.decode('utf-8', errors='ignore')
                        
                        # Limit response length
                        if len(str(call)) > 20000:
                            call = str(call)[:20000]
                        
                        q_prompt = f'{q_prompt}{text}->[{call}]\n'
                    else:
                        final_answer = text
                        break
                    
                    # Prevent infinite loops
                    if num_calls >= 10:
                        final_answer = 'numError'
                        break
                
                answers.append(final_answer)
                pbar.update(1)
                
                # Add a delay between attempts
                time.sleep(1)
            
            # Add results for this question
            module_results.append({
                'Model': model,
                'Module': module,
                'Question': question,
                'Goldstandard': goldstandard,
                'Answer1': answers[0],
                'Answer2': answers[1],
                'Answer3': answers[2]
            })
            
            # Save progress after each question (in case of interruption)
            module_results_df = pd.DataFrame(module_results)
            filename = f"{module}_results.csv"
            module_results_df.to_csv(filename, index=False)
            logger.info(f"Saved progress: {filename}")