# SuRe: Summarizing Retrievals using Answer Candidates for Open-domain QA of LLMs (ICLR 2024)

- One need to (1) prepare own dataset or (2) download datasets from [Google drive](https://drive.google.com/drive/folders/1trPfSK37CJIFRY3ef-YNb6VKGdRKZm74?usp=sharing) 

In [1]:
import pprint
import json
import copy
import numpy as np
from tqdm import tqdm
import time
from datetime import timedelta, datetime
pp = pprint.PrettyPrinter(indent=4)

## Loading dataset
- Available: ['nq-test', 'wq-test', 'hotpotqa', '2wikimultihopqa']

In [2]:
data_type = '2wikimultihopqa'

In [None]:
dataset = json.load(open(f'./datasets/BM25/{data_type}-bm25.json'))

FileNotFoundError: [Errno 2] No such file or directory: './datasets/2wikimultihopqa-bm25.json'

## Setup OpenAI
- Caution. One needs to insert the proper API Key

In [None]:
import openai
from dotenv import load_dotenv
import os
load_dotenv()
openai.api_key = os.getenv("OPENAI_API_KEY")

#model = "gpt-3.5-turbo"
model = "gpt-4.1-mini"

In [None]:
from functions import api_query

Validity check for API Call

In [None]:
query = "Hi, how are you?"
api_query(model, query, 0, 1)

## Baseline Prompting

In [None]:
from functions import use_api_base

In [None]:
base_retrieval0 = use_api_base(model, dataset, iters=1, n_articles=0)

In [None]:
base_retrieval10 = use_api_base(model, dataset, iters=1, n_articles=10)

# SuRe

## Step #1. Candidate Generation

In [None]:
from functions import post_process_candidate, separation, use_api_candidate

In [None]:
sure_candidate = use_api_candidate(model, dataset, iters=1, n_articles=10)

## Divide Candidates

In [None]:
sure_candidate_post = post_process_candidate(sure_candidate)

In [None]:
sure_candidate1, sure_candidate2 = separation(sure_candidate_post)

## Step 2. Conditional Summarization for Each Candidate

In [None]:
from functions import use_api_summary

In [None]:
summary_candidate1 = use_api_summary(model, dataset, sure_candidate_post, pred=0, n_articles=10)

In [None]:
summary_candidate2 = use_api_summary(model, dataset, sure_candidate_post, pred=1, n_articles=10)

## Step #3. Selection via Self-verification and Ranking 

In [None]:
from functions import use_api_verif, use_api_rank

In [None]:
correctness_summary1 = use_api_verif(model, dataset, sure_candidate_post, summary_candidate1, pred_idx=0)

In [None]:
correctness_summary2 = use_api_verif(model, dataset, sure_candidate_post, summary_candidate2, pred_idx=1)

## Pair-wise Ranking for ArgMax

In [None]:
ranking_summary12 = use_api_rank(model, dataset, sure_candidate_post, summary_candidate1, summary_candidate2)

### Get Final Preds

In [None]:
from functions import get_final_pred_sure

In [None]:
sure_fin_preds, _, _ = get_final_pred_sure(sure_candidate1, sure_candidate2, 
                                          summary_candidate1, summary_candidate2,
                                          correctness_summary1, correctness_summary2, 
                                          ranking_summary12)

# Evaluation

In [None]:
from data_utils import get_em_f1

In [None]:
em_base_ret0, f1_base_ret0 = get_em_f1(dataset, base_retrieval0)
print("EM: {}, F1: {}".format(em_base_ret0.mean(), f1_base_ret0.mean()))

In [None]:
em_base_ret10, f1_base_ret10 = get_em_f1(dataset, base_retrieval10)
print("EM: {}, F1: {}".format(em_base_ret10.mean(), f1_base_ret10.mean()))

In [None]:
em_sure_10, f1_sure_10 = get_em_f1(dataset, sure_fin_preds)
print("EM: {}, F1: {}".format(em_sure_10.mean(), f1_sure_10.mean()))