In [None]:
import pandas as pd
from tqdm.std import tqdm
from openai import OpenAI
import json
import time  # To add delay between retries if needed

In [None]:
API_KEY = 'your open AI api key'

In [None]:
Client = OpenAI(api_key=API_KEY)

In [None]:
system_prompt = """ ## Instruction: You are an excellent Artificial Intelligence tasked with determining the relevance label between a given query and a passage.

## The relevance label should be selected based on the following criteria:
    5: 100% Relevant;
    4: 75% Relevant;
    3: 50% Relevant;
    2: 25% Relevant;
    1: 0% Relevant. 

## The input data consists of three components:
    query:
        The query to evaluate.
    passage:
        The candidate passage.
    binary relevance score:
        If 1.0, this means the query and passage are completely relevant. In this case, the relevance label should always be 5.
        If 0.0, you need to analyze the semantic relationship and contextual information between the query and the passage to determine the most appropriate relevance label (5, 4, 3, 2, or 1).

## Requirement:
    When predicting relevance labels, in addition to considering the semantic relevance between the query and the passage, you should also balance the distribution of your predicted labels. This means ensuring that each relevance label (5, 4, 3, 2, 1) is predicted as evenly as possible.

## Example 1:
    Input:
        query: "How to use Python for data analysis?"
        passage: "Python is a commonly used programming language for data analysis, data processing, and machine learning."
        binary relevance score: 1.0
    Output:
        {"Relevance Label": 5}

## Example 2:
    Input:
        query: "How to optimize website performance?"
        passage: "Website performance can be optimized by reducing the number of HTTP requests, and optimizing CSS and JavaScript code."
        binary relevance score: 0.0
    Output:
        {"Relevance Label": 5}
## Example 3:
    Input:
        query: "What are the benefits of regular exercise?"  
        passage: "Regular exercise can improve cardiovascular health, enhance mood, and boost energy levels. However, specific routines vary based on individual goals."  
        binary relevance score: 0.0
    Output:
        {"Relevance Label": 4}

## Input Format:
    query: <Insert query here>
    passage: <Insert passage here>
    binary relevance score: <1.0 or 0.0>
"""

In [None]:
## Output Format:
#    {'Relevance Label': <predicted relevance label (5, 4, 3, 2, or 1)>}

In [None]:
def get_response(user_prompt):
    
    completion = Client.chat.completions.create(
        model="gpt-4o-2024-08-06",
        # model="gpt-4o-mini-2024-07-18",
        response_format={ 
            "type": "json_schema",
            "json_schema": {
                "name":"Relevance",
                "strict":True,
                "schema": {
                    "type": "object",
                    "properties": {
                        "Relevance Label": {
                            "type": "string"
                        }
                    },
                    "required": ["Relevance Label"],
                    "additionalProperties": False
                }
            }
        },
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ]
    )
    return completion.choices[0].message.content

In [None]:
def get_user_prompt(query, passage, rel_score):
    return f"query: {query} \n" \
           f"passage: {passage} \n" \
           f"binary relevance score: {rel_score}"

## load data

In [None]:
data = pd.read_csv("document/filtered_passage_reranking_data.tsv",sep="\t")

In [None]:
data

In [None]:
# data[data.qid == 20432].tail(50)

In [None]:
batch_size = 100
Index = 6501  # Initial Index=1
# Start index = (Index-1)*batch_size

for i in range(650000, len(data), batch_size):
    relevance_label = []
    new_data = data[i:i+batch_size]
    for row in tqdm(new_data.itertuples()):
        query = row.query
        passage = row.passage
        rel_score = row.relevance_score

        user_prompt = get_user_prompt(query, passage, rel_score)

        while True:  # Retry loop for error handling
            try:
                response = get_response(user_prompt)
                label = json.loads(response).get("Relevance Label")
                relevance_label.append(int(label))
                break  # Exit loop if successful
            except json.JSONDecodeError:
                print("JSON decoding error. Re-generating response...")
            except Exception as e:
                print(f"Unexpected error: {e}. Re-generating response...")

            # Optionally add a delay to avoid overwhelming external systems
            time.sleep(0.5)  
            user_prompt = get_user_prompt(query, passage, rel_score)

    new_data["relevance_score"] = relevance_label
    new_data.to_csv(f"document/temp/temp_passage_ranking_{Index}.tsv", sep="\t", index=False)
    Index += 1