# RAG Evaluation Retrieval

In [2]:
import minsearch
import json
import pandas as pd


from tqdm.auto import tqdm
from sentence_transformers import SentenceTransformer
from elasticsearch import Elasticsearch
from sklearn.model_selection import train_test_split

from hyperopt import fmin, tpe, hp, STATUS_OK, Trials
from hyperopt.pyll import scope

## Ingestion

In [3]:
with open('notebooks/documents-with-ids.json', 'r') as file:
    documents = json.load(file)

## Index and mapping Minsearch

In [4]:
index = minsearch.Index(
    text_fields=["Question", "Answer", "Category"],
    keyword_fields=["Section", "id"]
)

In [5]:
index.fit(documents)

def min_search(query, section):
    boost = {}

    results = index.search(
        query=query,
        filter_dict={},
        boost_dict=boost,
        num_results=5
    )

    return results

## RAG flow

In [6]:
def build_prompt(query, search_results):
    prompt_template = """
You are an expert in United Kingdom Benefit Claims and Medical Negligence Claims. Answer the QUESTION based on the CONTEXT from the FAQ database. 
Use only the facts from the CONTEXT when answering the QUESTION.

QUESTION: {question}

CONTEXT: 
{context}
""".strip()

    context = ""
    
    for doc in search_results:
        context = context + f"section: {doc['section']}\nquestion: {doc['question']}\nanswer: {doc['answer']}\n\n"
    
    prompt = prompt_template.format(question=query, context=context).strip()
    return prompt

In [7]:
def llm(prompt):
    response = client.chat.completions.create(
        model='gpt-4o',
        messages=[{"role": "user", "content": prompt}]
    )
    
    return response.choices[0].message.content

In [8]:
query = "Can I get sick pay if I'm self-isolating?"

def rag(query):
    search_results = search(query)
    prompt = build_prompt(query, search_results)
    answer = llm(prompt)
    return answer

# Evaluating Retrieval
- hit_rate
- Mean Reciprocal Rank

## Minsearch without tuning

In [9]:
def hit_rate(relevance_total):
    cnt = 0

    for line in relevance_total:
        if True in line:
            cnt = cnt + 1

    return cnt / len(relevance_total)

In [10]:
def mrr(relevance_total):
    total_score = 0.0

    for line in relevance_total:
        for rank in range(len(line)):
            if line[rank] == True:
                total_score = total_score + 1 / (rank + 1)

    return total_score / len(relevance_total)

In [11]:
df_ground_truth = pd.read_csv('notebooks/ground-truth-data.csv')
ground_truth = df_ground_truth.to_dict(orient='records')

In [12]:
def evaluate(ground_truth, search_function):
    relevance_total = []

    for q in tqdm(ground_truth):
        doc_id = q['document']
        results = search_function(q)
        relevance = [d['id'] == doc_id for d in results]
        relevance_total.append(relevance)

    return {
        'hit_rate': hit_rate(relevance_total),
        'mrr': mrr(relevance_total),
    }

evaluate(ground_truth, lambda q: min_search(q['question'], q['claims_type']))

100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 2055/2055 [00:04<00:00, 427.99it/s]


{'hit_rate': 0.7965936739659367, 'mrr': 0.6443309002433094}

## Hyperparameter tuning

In [13]:
space = {
    'num_results': scope.int(hp.quniform('num_results', 1, 10, 1)),
    'boost_factor': hp.uniform('boost_factor', 0.1, 2.0),
    'section': hp.choice('section', ['nhs claim benefits', 'general claim benefits'])
}

In [14]:
def objective(params):
    num_results = params['num_results']
    boost_factor = params['boost_factor']
    section = params['section']
    
    def tuned_min_search(query, section):
        boost = {
            "Question": boost_factor,
            "Answer": boost_factor
        }

        filter_dict = {
            'section': section  # Pass the optimized section value
        }
        
        results = index.search(
            query=query,
            filter_dict=filter_dict,
            boost_dict=boost,
            num_results=num_results
        )
        return results

    metrics = evaluate(ground_truth, lambda q: tuned_min_search(q['question'], q['claims_type']))
    return {'loss': -metrics['hit_rate'], 'status': STATUS_OK}

In [15]:
trials = Trials()
best = fmin(
    fn=objective,
    space=space,
    algo=tpe.suggest,
    max_evals=50,
    trials=trials
)

  0%|                                                                                                  | 0/50 [00:00<?, ?trial/s, best loss=?]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|1         | 37/2055 [00:00<00:05, 362.43it/s]
[A
  4%|3         | 81/2055 [00:00<00:04, 407.01it/s]
[A
  6%|6         | 127/2055 [00:00<00:04, 431.07it/s]
[A
  8%|8         | 174/2055 [00:00<00:04, 445.74it/s]
[A
 11%|#         | 220/2055 [00:00<00:04, 450.30it/s]
[A
 13%|#2        | 266/2055 [00:00<00:04, 428.26it/s]
[A
 15%|#5        | 310/2055 [00:00<00:05, 343.88it/s]
[A
 17%|#6        | 347/2055 [00:00<00:04, 342.69it/s]
[A
 19%|#8        | 383/2055 [00:01<00:04, 345.91it/s]
[A
 21%|##        | 430/2055 [00:01<00:04, 378.28it/s]
[A
 23%|##2       | 469/2055 [00:01<00:04, 356.91it/s]
[A
 25%|##4       | 509/2055 [00:01<00:04, 367.25it/s]
[A
 27%|##7       | 555/2055 [00:01<00:03, 392.68it/s]
[A
 29%|##9       | 601/2055 [00:01<00:03, 410.52it/s]
[A
 31%|###1      | 646/2055 [00:01<00:03, 421.41it/s]
[A
 34%|###3      | 692/2055 [00:01<00:03, 431.40it/s]
[A
 36%|###5      | 739/2055 [00:01<00:02, 440.47it/s]
[A
 38%|

  2%|█▍                                                                     | 1/50 [00:04<03:57,  4.85s/trial, best loss: -0.6165450121654501]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 44/2055 [00:00<00:04, 435.88it/s]
[A
  4%|4         | 91/2055 [00:00<00:04, 450.92it/s]
[A
  7%|6         | 137/2055 [00:00<00:04, 454.70it/s]
[A
  9%|8         | 183/2055 [00:00<00:04, 453.10it/s]
[A
 11%|#1        | 230/2055 [00:00<00:04, 455.82it/s]
[A
 13%|#3        | 276/2055 [00:00<00:03, 454.79it/s]
[A
 16%|#5        | 322/2055 [00:00<00:03, 450.66it/s]
[A
 18%|#7        | 368/2055 [00:00<00:03, 452.33it/s]
[A
 20%|##        | 415/2055 [00:00<00:03, 454.80it/s]
[A
 22%|##2       | 461/2055 [00:01<00:03, 455.73it/s]
[A
 25%|##4       | 507/2055 [00:01<00:03, 447.25it/s]
[A
 27%|##6       | 554/2055 [00:01<00:03, 452.15it/s]
[A
 29%|##9       | 600/2055 [00:01<00:03, 454.26it/s]
[A
 31%|###1      | 646/2055 [00:01<00:03, 376.70it/s]
[A
 33%|###3      | 686/2055 [00:01<00:03, 381.01it/s]
[A
 35%|###5      | 726/2055 [00:01<00:03, 383.91it/s]
[A
 38%|###7      | 771/2055 [00:01<00:03, 401.16it/s]
[A
 40%|

  4%|██▉                                                                     | 2/50 [00:09<03:49,  4.79s/trial, best loss: -0.683698296836983]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 466.52it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 460.52it/s]
[A
  7%|6         | 141/2055 [00:00<00:04, 460.60it/s]
[A
  9%|9         | 188/2055 [00:00<00:04, 449.95it/s]
[A
 11%|#1        | 234/2055 [00:00<00:04, 437.12it/s]
[A
 14%|#3        | 280/2055 [00:00<00:04, 443.49it/s]
[A
 16%|#5        | 325/2055 [00:00<00:04, 413.19it/s]
[A
 18%|#8        | 371/2055 [00:00<00:03, 426.11it/s]
[A
 20%|##        | 417/2055 [00:00<00:03, 435.75it/s]
[A
 23%|##2       | 463/2055 [00:01<00:03, 440.61it/s]
[A
 25%|##4       | 510/2055 [00:01<00:03, 446.62it/s]
[A
 27%|##7       | 556/2055 [00:01<00:03, 450.10it/s]
[A
 29%|##9       | 602/2055 [00:01<00:03, 369.55it/s]
[A
 31%|###1      | 642/2055 [00:01<00:03, 361.84it/s]
[A
 33%|###3      | 688/2055 [00:01<00:03, 386.44it/s]
[A
 36%|###5      | 734/2055 [00:01<00:03, 404.80it/s]
[A
 38%|###7      | 776/2055 [00:01<00:03, 381.65it/s]
[A
 40%|

  6%|████▎                                                                   | 3/50 [00:14<03:47,  4.84s/trial, best loss: -0.683698296836983]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 467.98it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 458.35it/s]
[A
  7%|6         | 141/2055 [00:00<00:04, 459.95it/s]
[A
  9%|9         | 188/2055 [00:00<00:04, 456.33it/s]
[A
 11%|#1        | 234/2055 [00:00<00:04, 441.81it/s]
[A
 14%|#3        | 279/2055 [00:00<00:04, 438.01it/s]
[A
 16%|#5        | 326/2055 [00:00<00:03, 445.64it/s]
[A
 18%|#8        | 372/2055 [00:00<00:03, 449.06it/s]
[A
 20%|##        | 419/2055 [00:00<00:03, 453.19it/s]
[A
 23%|##2       | 465/2055 [00:01<00:03, 454.24it/s]
[A
 25%|##4       | 512/2055 [00:01<00:03, 456.09it/s]
[A
 27%|##7       | 558/2055 [00:01<00:03, 456.63it/s]
[A
 29%|##9       | 605/2055 [00:01<00:03, 458.01it/s]
[A
 32%|###1      | 651/2055 [00:01<00:03, 458.03it/s]
[A
 34%|###3      | 698/2055 [00:01<00:02, 458.82it/s]
[A
 36%|###6      | 744/2055 [00:01<00:02, 450.28it/s]
[A
 38%|###8      | 790/2055 [00:01<00:02, 452.70it/s]
[A
 41%|

  8%|█████▊                                                                  | 4/50 [00:19<03:41,  4.82s/trial, best loss: -0.683698296836983]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 463.56it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 458.55it/s]
[A
  7%|6         | 140/2055 [00:00<00:04, 458.63it/s]
[A
  9%|9         | 186/2055 [00:00<00:04, 458.50it/s]
[A
 11%|#1        | 233/2055 [00:00<00:03, 461.02it/s]
[A
 14%|#3        | 280/2055 [00:00<00:03, 460.10it/s]
[A
 16%|#5        | 327/2055 [00:00<00:03, 460.37it/s]
[A
 18%|#8        | 374/2055 [00:00<00:03, 457.95it/s]
[A
 20%|##        | 420/2055 [00:00<00:03, 450.15it/s]
[A
 23%|##2       | 466/2055 [00:01<00:03, 448.02it/s]
[A
 25%|##4       | 513/2055 [00:01<00:03, 451.89it/s]
[A
 27%|##7       | 559/2055 [00:01<00:03, 453.21it/s]
[A
 29%|##9       | 605/2055 [00:01<00:03, 454.19it/s]
[A
 32%|###1      | 651/2055 [00:01<00:03, 431.59it/s]
[A
 34%|###3      | 698/2055 [00:01<00:03, 440.76it/s]
[A
 36%|###6      | 745/2055 [00:01<00:02, 446.38it/s]
[A
 38%|###8      | 790/2055 [00:01<00:02, 447.41it/s]
[A
 41%|

 10%|███████                                                                | 5/50 [00:24<03:37,  4.83s/trial, best loss: -0.8038929440389294]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 44/2055 [00:00<00:04, 432.37it/s]
[A
  4%|4         | 90/2055 [00:00<00:04, 444.26it/s]
[A
  7%|6         | 137/2055 [00:00<00:04, 452.65it/s]
[A
  9%|8         | 183/2055 [00:00<00:04, 452.99it/s]
[A
 11%|#1        | 229/2055 [00:00<00:04, 454.62it/s]
[A
 13%|#3        | 275/2055 [00:00<00:04, 427.44it/s]
[A
 16%|#5        | 321/2055 [00:00<00:03, 435.51it/s]
[A
 18%|#7        | 367/2055 [00:00<00:03, 442.65it/s]
[A
 20%|##        | 414/2055 [00:00<00:03, 448.59it/s]
[A
 22%|##2       | 460/2055 [00:01<00:03, 450.16it/s]
[A
 25%|##4       | 506/2055 [00:01<00:03, 441.05it/s]
[A
 27%|##6       | 552/2055 [00:01<00:03, 445.04it/s]
[A
 29%|##9       | 599/2055 [00:01<00:03, 449.48it/s]
[A
 31%|###1      | 645/2055 [00:01<00:03, 450.26it/s]
[A
 34%|###3      | 692/2055 [00:01<00:03, 453.24it/s]
[A
 36%|###5      | 738/2055 [00:01<00:02, 453.87it/s]
[A
 38%|###8      | 784/2055 [00:01<00:02, 455.62it/s]
[A
 40%|

 12%|████████▌                                                              | 6/50 [00:28<03:31,  4.81s/trial, best loss: -0.8038929440389294]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 466.38it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 459.10it/s]
[A
  7%|6         | 141/2055 [00:00<00:04, 460.36it/s]
[A
  9%|9         | 188/2055 [00:00<00:04, 448.14it/s]
[A
 11%|#1        | 234/2055 [00:00<00:04, 449.29it/s]
[A
 14%|#3        | 280/2055 [00:00<00:03, 450.61it/s]
[A
 16%|#5        | 326/2055 [00:00<00:03, 452.22it/s]
[A
 18%|#8        | 373/2055 [00:00<00:03, 454.44it/s]
[A
 20%|##        | 420/2055 [00:00<00:03, 456.62it/s]
[A
 23%|##2       | 467/2055 [00:01<00:03, 457.81it/s]
[A
 25%|##4       | 513/2055 [00:01<00:03, 457.65it/s]
[A
 27%|##7       | 559/2055 [00:01<00:03, 456.22it/s]
[A
 29%|##9       | 605/2055 [00:01<00:03, 456.57it/s]
[A
 32%|###1      | 651/2055 [00:01<00:03, 444.04it/s]
[A
 34%|###3      | 697/2055 [00:01<00:03, 448.48it/s]
[A
 36%|###6      | 743/2055 [00:01<00:02, 449.90it/s]
[A
 38%|###8      | 789/2055 [00:01<00:02, 452.50it/s]
[A
 41%|

 14%|█████████▉                                                             | 7/50 [00:33<03:26,  4.80s/trial, best loss: -0.8038929440389294]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 43/2055 [00:00<00:04, 417.62it/s]
[A
  4%|4         | 85/2055 [00:00<00:05, 392.98it/s]
[A
  6%|6         | 127/2055 [00:00<00:04, 402.16it/s]
[A
  8%|8         | 171/2055 [00:00<00:04, 414.80it/s]
[A
 10%|#         | 213/2055 [00:00<00:04, 405.03it/s]
[A
 12%|#2        | 254/2055 [00:00<00:04, 370.77it/s]
[A
 15%|#4        | 300/2055 [00:00<00:04, 397.09it/s]
[A
 17%|#6        | 346/2055 [00:00<00:04, 415.39it/s]
[A
 19%|#9        | 392/2055 [00:00<00:03, 427.15it/s]
[A
 21%|##1       | 439/2055 [00:01<00:03, 437.67it/s]
[A
 24%|##3       | 485/2055 [00:01<00:03, 443.96it/s]
[A
 26%|##5       | 532/2055 [00:01<00:03, 449.38it/s]
[A
 28%|##8       | 578/2055 [00:01<00:03, 426.43it/s]
[A
 30%|###       | 624/2055 [00:01<00:03, 435.25it/s]
[A
 33%|###2      | 671/2055 [00:01<00:03, 442.67it/s]
[A
 35%|###4      | 716/2055 [00:01<00:03, 439.08it/s]
[A
 37%|###7      | 762/2055 [00:01<00:02, 443.25it/s]
[A
 39%|

 16%|███████████▎                                                           | 8/50 [00:38<03:20,  4.78s/trial, best loss: -0.8038929440389294]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 44/2055 [00:00<00:04, 431.38it/s]
[A
  4%|4         | 90/2055 [00:00<00:04, 443.10it/s]
[A
  7%|6         | 137/2055 [00:00<00:04, 450.84it/s]
[A
  9%|8         | 184/2055 [00:00<00:04, 453.90it/s]
[A
 11%|#1        | 230/2055 [00:00<00:04, 444.41it/s]
[A
 13%|#3        | 277/2055 [00:00<00:03, 450.34it/s]
[A
 16%|#5        | 323/2055 [00:00<00:04, 425.52it/s]
[A
 18%|#7        | 369/2055 [00:00<00:03, 435.02it/s]
[A
 20%|##        | 413/2055 [00:00<00:04, 397.79it/s]
[A
 22%|##2       | 460/2055 [00:01<00:03, 415.85it/s]
[A
 24%|##4       | 503/2055 [00:01<00:03, 407.24it/s]
[A
 27%|##6       | 549/2055 [00:01<00:03, 421.70it/s]
[A
 29%|##8       | 592/2055 [00:01<00:03, 397.44it/s]
[A
 31%|###1      | 639/2055 [00:01<00:03, 415.37it/s]
[A
 33%|###3      | 685/2055 [00:01<00:03, 427.89it/s]
[A
 35%|###5      | 729/2055 [00:01<00:03, 397.41it/s]
[A
 38%|###7      | 773/2055 [00:01<00:03, 408.66it/s]
[A
 40%|

 18%|████████████▊                                                          | 9/50 [00:43<03:17,  4.81s/trial, best loss: -0.8038929440389294]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 466.77it/s]
[A
  5%|4         | 94/2055 [00:00<00:05, 368.24it/s]
[A
  7%|6         | 136/2055 [00:00<00:04, 387.01it/s]
[A
  9%|8         | 182/2055 [00:00<00:04, 413.11it/s]
[A
 11%|#1        | 229/2055 [00:00<00:04, 429.54it/s]
[A
 13%|#3        | 275/2055 [00:00<00:04, 439.15it/s]
[A
 16%|#5        | 322/2055 [00:00<00:03, 447.54it/s]
[A
 18%|#7        | 369/2055 [00:00<00:03, 452.11it/s]
[A
 20%|##        | 415/2055 [00:00<00:03, 453.91it/s]
[A
 22%|##2       | 461/2055 [00:01<00:03, 452.37it/s]
[A
 25%|##4       | 507/2055 [00:01<00:03, 445.82it/s]
[A
 27%|##6       | 553/2055 [00:01<00:03, 449.48it/s]
[A
 29%|##9       | 599/2055 [00:01<00:03, 452.30it/s]
[A
 31%|###1      | 645/2055 [00:01<00:03, 453.28it/s]
[A
 34%|###3      | 691/2055 [00:01<00:03, 416.53it/s]
[A
 36%|###5      | 737/2055 [00:01<00:03, 427.51it/s]
[A
 38%|###8      | 781/2055 [00:01<00:03, 414.09it/s]
[A
 40%|

 20%|██████████████                                                        | 10/50 [00:48<03:11,  4.80s/trial, best loss: -0.8038929440389294]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 463.37it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 458.47it/s]
[A
  7%|6         | 140/2055 [00:00<00:04, 455.24it/s]
[A
  9%|9         | 186/2055 [00:00<00:04, 456.95it/s]
[A
 11%|#1        | 232/2055 [00:00<00:04, 420.81it/s]
[A
 13%|#3        | 275/2055 [00:00<00:04, 392.16it/s]
[A
 16%|#5        | 321/2055 [00:00<00:04, 412.14it/s]
[A
 18%|#7        | 368/2055 [00:00<00:03, 426.58it/s]
[A
 20%|##        | 412/2055 [00:00<00:03, 424.78it/s]
[A
 22%|##2       | 459/2055 [00:01<00:03, 436.16it/s]
[A
 25%|##4       | 505/2055 [00:01<00:03, 442.96it/s]
[A
 27%|##6       | 551/2055 [00:01<00:03, 447.49it/s]
[A
 29%|##9       | 596/2055 [00:01<00:03, 424.63it/s]
[A
 31%|###1      | 639/2055 [00:01<00:03, 424.40it/s]
[A
 33%|###3      | 684/2055 [00:01<00:03, 430.02it/s]
[A
 36%|###5      | 730/2055 [00:01<00:03, 438.55it/s]
[A
 38%|###7      | 776/2055 [00:01<00:02, 444.75it/s]
[A
 40%|

 22%|███████████████▍                                                      | 11/50 [00:52<03:07,  4.82s/trial, best loss: -0.8374695863746958]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 461.82it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 459.12it/s]
[A
  7%|6         | 140/2055 [00:00<00:04, 446.41it/s]
[A
  9%|9         | 186/2055 [00:00<00:04, 450.54it/s]
[A
 11%|#1        | 233/2055 [00:00<00:04, 454.53it/s]
[A
 14%|#3        | 279/2055 [00:00<00:04, 418.30it/s]
[A
 16%|#5        | 325/2055 [00:00<00:04, 429.77it/s]
[A
 18%|#7        | 369/2055 [00:00<00:04, 382.31it/s]
[A
 20%|##        | 415/2055 [00:00<00:04, 403.43it/s]
[A
 22%|##2       | 462/2055 [00:01<00:03, 420.50it/s]
[A
 25%|##4       | 508/2055 [00:01<00:03, 431.02it/s]
[A
 27%|##6       | 554/2055 [00:01<00:03, 439.42it/s]
[A
 29%|##9       | 600/2055 [00:01<00:03, 445.43it/s]
[A
 31%|###1      | 647/2055 [00:01<00:03, 450.34it/s]
[A
 34%|###3      | 694/2055 [00:01<00:03, 453.54it/s]
[A
 36%|###6      | 740/2055 [00:01<00:02, 443.22it/s]
[A
 38%|###8      | 786/2055 [00:01<00:02, 447.58it/s]
[A
 41%|

 24%|████████████████▊                                                     | 12/50 [00:57<03:02,  4.81s/trial, best loss: -0.8374695863746958]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 468.43it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 460.87it/s]
[A
  7%|6         | 141/2055 [00:00<00:04, 462.69it/s]
[A
  9%|9         | 188/2055 [00:00<00:04, 461.82it/s]
[A
 11%|#1        | 235/2055 [00:00<00:03, 461.19it/s]
[A
 14%|#3        | 282/2055 [00:00<00:03, 460.44it/s]
[A
 16%|#6        | 329/2055 [00:00<00:03, 454.76it/s]
[A
 18%|#8        | 376/2055 [00:00<00:03, 456.57it/s]
[A
 21%|##        | 422/2055 [00:00<00:03, 436.62it/s]
[A
 23%|##2       | 467/2055 [00:01<00:03, 439.19it/s]
[A
 25%|##4       | 512/2055 [00:01<00:03, 389.19it/s]
[A
 27%|##7       | 555/2055 [00:01<00:03, 398.07it/s]
[A
 29%|##9       | 601/2055 [00:01<00:03, 414.27it/s]
[A
 31%|###1      | 647/2055 [00:01<00:03, 427.06it/s]
[A
 34%|###3      | 693/2055 [00:01<00:03, 435.31it/s]
[A
 36%|###5      | 739/2055 [00:01<00:02, 441.82it/s]
[A
 38%|###8      | 785/2055 [00:01<00:02, 447.00it/s]
[A
 40%|

 26%|██████████████████▏                                                   | 13/50 [01:02<02:58,  4.83s/trial, best loss: -0.8374695863746958]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|1         | 32/2055 [00:00<00:06, 312.45it/s]
[A
  4%|3         | 77/2055 [00:00<00:05, 390.77it/s]
[A
  6%|6         | 124/2055 [00:00<00:04, 423.41it/s]
[A
  8%|8         | 170/2055 [00:00<00:04, 437.12it/s]
[A
 11%|#         | 217/2055 [00:00<00:04, 446.86it/s]
[A
 13%|#2        | 263/2055 [00:00<00:03, 450.02it/s]
[A
 15%|#5        | 310/2055 [00:00<00:03, 453.51it/s]
[A
 17%|#7        | 356/2055 [00:00<00:03, 454.15it/s]
[A
 20%|#9        | 403/2055 [00:00<00:03, 456.60it/s]
[A
 22%|##1       | 449/2055 [00:01<00:03, 456.68it/s]
[A
 24%|##4       | 495/2055 [00:01<00:03, 437.72it/s]
[A
 26%|##6       | 539/2055 [00:01<00:03, 418.55it/s]
[A
 28%|##8       | 582/2055 [00:01<00:03, 401.01it/s]
[A
 30%|###       | 623/2055 [00:01<00:03, 361.06it/s]
[A
 33%|###2      | 670/2055 [00:01<00:03, 387.69it/s]
[A
 35%|###4      | 715/2055 [00:01<00:03, 403.46it/s]
[A
 37%|###7      | 761/2055 [00:01<00:03, 418.84it/s]
[A
 39%|

 28%|███████████████████▌                                                  | 14/50 [01:07<02:53,  4.81s/trial, best loss: -0.8686131386861314]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 468.56it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 461.91it/s]
[A
  7%|6         | 141/2055 [00:00<00:04, 414.79it/s]
[A
  9%|8         | 183/2055 [00:00<00:04, 404.29it/s]
[A
 11%|#         | 224/2055 [00:00<00:04, 400.33it/s]
[A
 13%|#2        | 267/2055 [00:00<00:04, 409.56it/s]
[A
 15%|#5        | 314/2055 [00:00<00:04, 426.10it/s]
[A
 18%|#7        | 360/2055 [00:00<00:03, 435.26it/s]
[A
 20%|#9        | 407/2055 [00:00<00:03, 443.61it/s]
[A
 22%|##2       | 453/2055 [00:01<00:03, 448.32it/s]
[A
 24%|##4       | 500/2055 [00:01<00:03, 452.41it/s]
[A
 27%|##6       | 546/2055 [00:01<00:03, 454.35it/s]
[A
 29%|##8       | 592/2055 [00:01<00:03, 453.51it/s]
[A
 31%|###1      | 638/2055 [00:01<00:03, 444.90it/s]
[A
 33%|###3      | 684/2055 [00:01<00:03, 448.48it/s]
[A
 36%|###5      | 730/2055 [00:01<00:02, 450.72it/s]
[A
 38%|###7      | 776/2055 [00:01<00:03, 396.40it/s]
[A
 40%|

 30%|█████████████████████                                                 | 15/50 [01:12<02:47,  4.78s/trial, best loss: -0.8686131386861314]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 464.62it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 433.21it/s]
[A
  7%|6         | 138/2055 [00:00<00:05, 335.81it/s]
[A
  8%|8         | 174/2055 [00:00<00:06, 303.90it/s]
[A
 11%|#         | 219/2055 [00:00<00:05, 346.26it/s]
[A
 12%|#2        | 256/2055 [00:00<00:05, 325.20it/s]
[A
 14%|#4        | 290/2055 [00:00<00:05, 320.66it/s]
[A
 16%|#6        | 336/2055 [00:00<00:04, 359.17it/s]
[A
 19%|#8        | 382/2055 [00:01<00:04, 387.71it/s]
[A
 21%|##        | 427/2055 [00:01<00:04, 405.57it/s]
[A
 23%|##2       | 472/2055 [00:01<00:03, 418.45it/s]
[A
 25%|##5       | 518/2055 [00:01<00:03, 429.96it/s]
[A
 27%|##7       | 565/2055 [00:01<00:03, 439.24it/s]
[A
 30%|##9       | 611/2055 [00:01<00:03, 444.48it/s]
[A
 32%|###1      | 657/2055 [00:01<00:03, 448.74it/s]
[A
 34%|###4      | 703/2055 [00:01<00:03, 449.35it/s]
[A
 36%|###6      | 749/2055 [00:01<00:02, 447.07it/s]
[A
 39%|

 32%|██████████████████████▍                                               | 16/50 [01:16<02:43,  4.80s/trial, best loss: -0.8686131386861314]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 462.95it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 457.70it/s]
[A
  7%|6         | 141/2055 [00:00<00:04, 461.33it/s]
[A
  9%|9         | 188/2055 [00:00<00:04, 455.99it/s]
[A
 11%|#1        | 234/2055 [00:00<00:04, 454.58it/s]
[A
 14%|#3        | 280/2055 [00:00<00:03, 454.84it/s]
[A
 16%|#5        | 326/2055 [00:00<00:03, 453.54it/s]
[A
 18%|#8        | 372/2055 [00:00<00:03, 455.31it/s]
[A
 20%|##        | 418/2055 [00:00<00:03, 417.91it/s]
[A
 22%|##2       | 461/2055 [00:01<00:04, 382.92it/s]
[A
 24%|##4       | 501/2055 [00:01<00:04, 379.88it/s]
[A
 26%|##6       | 540/2055 [00:01<00:04, 346.96it/s]
[A
 28%|##8       | 580/2055 [00:01<00:04, 360.01it/s]
[A
 31%|###       | 627/2055 [00:01<00:03, 387.54it/s]
[A
 33%|###2      | 673/2055 [00:01<00:03, 406.95it/s]
[A
 35%|###5      | 720/2055 [00:01<00:03, 422.18it/s]
[A
 37%|###7      | 766/2055 [00:01<00:02, 432.61it/s]
[A
 40%|

 34%|███████████████████████▊                                              | 17/50 [01:21<02:38,  4.81s/trial, best loss: -0.8875912408759125]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 465.81it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 444.99it/s]
[A
  7%|6         | 141/2055 [00:00<00:04, 452.09it/s]
[A
  9%|9         | 188/2055 [00:00<00:04, 455.60it/s]
[A
 11%|#1        | 234/2055 [00:00<00:04, 422.35it/s]
[A
 14%|#3        | 280/2055 [00:00<00:04, 433.05it/s]
[A
 16%|#5        | 327/2055 [00:00<00:03, 441.53it/s]
[A
 18%|#8        | 372/2055 [00:00<00:03, 427.48it/s]
[A
 20%|##        | 415/2055 [00:01<00:04, 364.29it/s]
[A
 22%|##2       | 453/2055 [00:01<00:04, 368.37it/s]
[A
 24%|##4       | 497/2055 [00:01<00:04, 387.83it/s]
[A
 26%|##6       | 537/2055 [00:01<00:03, 383.85it/s]
[A
 28%|##8       | 577/2055 [00:01<00:04, 351.21it/s]
[A
 30%|##9       | 614/2055 [00:01<00:04, 347.02it/s]
[A
 32%|###1      | 657/2055 [00:01<00:03, 367.69it/s]
[A
 34%|###3      | 696/2055 [00:01<00:03, 372.70it/s]
[A
 36%|###6      | 741/2055 [00:01<00:03, 393.38it/s]
[A
 38%|

 36%|█████████████████████████▏                                            | 18/50 [01:26<02:35,  4.87s/trial, best loss: -0.8875912408759125]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 44/2055 [00:00<00:04, 439.58it/s]
[A
  4%|4         | 89/2055 [00:00<00:04, 441.76it/s]
[A
  7%|6         | 134/2055 [00:00<00:06, 302.27it/s]
[A
  8%|8         | 169/2055 [00:00<00:07, 248.22it/s]
[A
 10%|9         | 197/2055 [00:00<00:09, 204.87it/s]
[A
 12%|#1        | 238/2055 [00:00<00:07, 250.05it/s]
[A
 14%|#3        | 285/2055 [00:00<00:05, 302.14it/s]
[A
 16%|#6        | 330/2055 [00:01<00:05, 339.05it/s]
[A
 18%|#8        | 376/2055 [00:01<00:04, 370.69it/s]
[A
 20%|##        | 419/2055 [00:01<00:04, 386.21it/s]
[A
 23%|##2       | 466/2055 [00:01<00:03, 407.70it/s]
[A
 25%|##4       | 509/2055 [00:01<00:03, 400.83it/s]
[A
 27%|##6       | 551/2055 [00:01<00:03, 397.24it/s]
[A
 29%|##8       | 592/2055 [00:01<00:03, 380.51it/s]
[A
 31%|###1      | 638/2055 [00:01<00:03, 401.97it/s]
[A
 33%|###3      | 685/2055 [00:01<00:03, 418.68it/s]
[A
 35%|###5      | 728/2055 [00:02<00:03, 393.93it/s]
[A
 38%|

 38%|██████████████████████████▌                                           | 19/50 [01:31<02:33,  4.96s/trial, best loss: -0.8875912408759125]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 46/2055 [00:00<00:04, 457.78it/s]
[A
  4%|4         | 92/2055 [00:00<00:04, 436.70it/s]
[A
  7%|6         | 138/2055 [00:00<00:04, 444.49it/s]
[A
  9%|8         | 183/2055 [00:00<00:04, 424.71it/s]
[A
 11%|#1        | 229/2055 [00:00<00:04, 434.95it/s]
[A
 13%|#3        | 273/2055 [00:00<00:04, 413.67it/s]
[A
 16%|#5        | 319/2055 [00:00<00:04, 426.75it/s]
[A
 18%|#7        | 366/2055 [00:00<00:03, 436.96it/s]
[A
 20%|#9        | 410/2055 [00:00<00:04, 381.05it/s]
[A
 22%|##1       | 450/2055 [00:01<00:04, 356.64it/s]
[A
 24%|##3       | 493/2055 [00:01<00:04, 375.76it/s]
[A
 26%|##6       | 539/2055 [00:01<00:03, 398.23it/s]
[A
 28%|##8       | 580/2055 [00:01<00:04, 368.41it/s]
[A
 30%|###       | 626/2055 [00:01<00:03, 391.28it/s]
[A
 33%|###2      | 672/2055 [00:01<00:03, 410.00it/s]
[A
 35%|###4      | 718/2055 [00:01<00:03, 423.59it/s]
[A
 37%|###7      | 764/2055 [00:01<00:02, 431.55it/s]
[A
 39%|

 40%|████████████████████████████                                          | 20/50 [01:36<02:28,  4.95s/trial, best loss: -0.8875912408759125]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 467.25it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 456.92it/s]
[A
  7%|6         | 140/2055 [00:00<00:04, 445.50it/s]
[A
  9%|9         | 186/2055 [00:00<00:04, 448.71it/s]
[A
 11%|#1        | 232/2055 [00:00<00:04, 449.26it/s]
[A
 14%|#3        | 278/2055 [00:00<00:03, 451.40it/s]
[A
 16%|#5        | 324/2055 [00:00<00:03, 453.48it/s]
[A
 18%|#8        | 370/2055 [00:00<00:03, 451.67it/s]
[A
 20%|##        | 417/2055 [00:00<00:03, 454.57it/s]
[A
 23%|##2       | 463/2055 [00:01<00:03, 455.71it/s]
[A
 25%|##4       | 509/2055 [00:01<00:03, 454.19it/s]
[A
 27%|##7       | 555/2055 [00:01<00:03, 454.52it/s]
[A
 29%|##9       | 601/2055 [00:01<00:03, 447.49it/s]
[A
 31%|###1      | 647/2055 [00:01<00:03, 450.25it/s]
[A
 34%|###3      | 693/2055 [00:01<00:03, 451.68it/s]
[A
 36%|###5      | 739/2055 [00:01<00:03, 398.95it/s]
[A
 38%|###8      | 785/2055 [00:01<00:03, 414.73it/s]
[A
 40%|

 42%|█████████████████████████████▍                                        | 21/50 [01:41<02:22,  4.92s/trial, best loss: -0.8875912408759125]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 466.87it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 458.31it/s]
[A
  7%|6         | 140/2055 [00:00<00:04, 455.65it/s]
[A
  9%|9         | 186/2055 [00:00<00:04, 455.39it/s]
[A
 11%|#1        | 233/2055 [00:00<00:03, 457.32it/s]
[A
 14%|#3        | 279/2055 [00:00<00:04, 443.66it/s]
[A
 16%|#5        | 326/2055 [00:00<00:03, 449.36it/s]
[A
 18%|#8        | 373/2055 [00:00<00:03, 452.62it/s]
[A
 20%|##        | 419/2055 [00:00<00:03, 454.54it/s]
[A
 23%|##2       | 465/2055 [00:01<00:03, 451.27it/s]
[A
 25%|##4       | 511/2055 [00:01<00:03, 427.61it/s]
[A
 27%|##7       | 557/2055 [00:01<00:03, 436.49it/s]
[A
 29%|##9       | 601/2055 [00:01<00:03, 430.68it/s]
[A
 31%|###1      | 645/2055 [00:01<00:03, 371.49it/s]
[A
 33%|###3      | 684/2055 [00:01<00:03, 352.88it/s]
[A
 35%|###5      | 729/2055 [00:01<00:03, 378.05it/s]
[A
 38%|###7      | 772/2055 [00:01<00:03, 389.70it/s]
[A
 40%|

 44%|██████████████████████████████▊                                       | 22/50 [01:46<02:18,  4.94s/trial, best loss: -0.8875912408759125]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 466.16it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 410.06it/s]
[A
  7%|6         | 140/2055 [00:00<00:04, 430.86it/s]
[A
  9%|9         | 186/2055 [00:00<00:04, 439.77it/s]
[A
 11%|#1        | 232/2055 [00:00<00:04, 446.80it/s]
[A
 13%|#3        | 277/2055 [00:00<00:03, 446.82it/s]
[A
 16%|#5        | 322/2055 [00:00<00:03, 440.75it/s]
[A
 18%|#7        | 368/2055 [00:00<00:03, 445.48it/s]
[A
 20%|##        | 413/2055 [00:00<00:03, 437.79it/s]
[A
 22%|##2       | 459/2055 [00:01<00:03, 443.33it/s]
[A
 25%|##4       | 505/2055 [00:01<00:03, 448.26it/s]
[A
 27%|##6       | 552/2055 [00:01<00:03, 451.47it/s]
[A
 29%|##9       | 598/2055 [00:01<00:03, 452.72it/s]
[A
 31%|###1      | 644/2055 [00:01<00:03, 451.07it/s]
[A
 34%|###3      | 690/2055 [00:01<00:03, 453.22it/s]
[A
 36%|###5      | 736/2055 [00:01<00:02, 454.73it/s]
[A
 38%|###8      | 782/2055 [00:01<00:02, 448.56it/s]
[A
 40%|

 46%|████████████████████████████████▏                                     | 23/50 [01:51<02:12,  4.90s/trial, best loss: -0.8875912408759125]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|1         | 40/2055 [00:00<00:05, 397.30it/s]
[A
  4%|4         | 86/2055 [00:00<00:04, 428.88it/s]
[A
  6%|6         | 133/2055 [00:00<00:04, 443.60it/s]
[A
  9%|8         | 179/2055 [00:00<00:04, 448.62it/s]
[A
 11%|#         | 226/2055 [00:00<00:04, 453.55it/s]
[A
 13%|#3        | 272/2055 [00:00<00:03, 453.58it/s]
[A
 15%|#5        | 318/2055 [00:00<00:03, 452.64it/s]
[A
 18%|#7        | 364/2055 [00:00<00:03, 450.98it/s]
[A
 20%|#9        | 410/2055 [00:00<00:03, 444.64it/s]
[A
 22%|##2       | 456/2055 [00:01<00:03, 447.27it/s]
[A
 24%|##4       | 501/2055 [00:01<00:03, 434.79it/s]
[A
 27%|##6       | 547/2055 [00:01<00:03, 439.86it/s]
[A
 29%|##8       | 594/2055 [00:01<00:03, 446.11it/s]
[A
 31%|###1      | 641/2055 [00:01<00:03, 450.70it/s]
[A
 33%|###3      | 687/2055 [00:01<00:03, 439.33it/s]
[A
 36%|###5      | 732/2055 [00:01<00:03, 431.34it/s]
[A
 38%|###7      | 778/2055 [00:01<00:02, 438.58it/s]
[A
 40%|

 48%|█████████████████████████████████▌                                    | 24/50 [01:56<02:06,  4.86s/trial, best loss: -0.8939172749391727]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|1         | 33/2055 [00:00<00:06, 324.53it/s]
[A
  4%|3         | 76/2055 [00:00<00:05, 383.30it/s]
[A
  6%|5         | 115/2055 [00:00<00:05, 368.19it/s]
[A
  8%|7         | 161/2055 [00:00<00:04, 403.08it/s]
[A
 10%|9         | 202/2055 [00:00<00:04, 378.81it/s]
[A
 12%|#2        | 248/2055 [00:00<00:04, 403.97it/s]
[A
 14%|#4        | 295/2055 [00:00<00:04, 422.37it/s]
[A
 17%|#6        | 341/2055 [00:00<00:03, 432.90it/s]
[A
 19%|#8        | 385/2055 [00:00<00:04, 401.12it/s]
[A
 21%|##        | 430/2055 [00:01<00:03, 413.04it/s]
[A
 23%|##3       | 476/2055 [00:01<00:03, 425.81it/s]
[A
 25%|##5       | 519/2055 [00:01<00:03, 426.04it/s]
[A
 27%|##7       | 565/2055 [00:01<00:03, 435.47it/s]
[A
 30%|##9       | 611/2055 [00:01<00:03, 442.27it/s]
[A
 32%|###1      | 657/2055 [00:01<00:03, 444.62it/s]
[A
 34%|###4      | 703/2055 [00:01<00:03, 448.47it/s]
[A
 36%|###6      | 749/2055 [00:01<00:02, 451.42it/s]
[A
 39%|

 50%|███████████████████████████████████                                   | 25/50 [02:01<02:00,  4.84s/trial, best loss: -0.8968369829683698]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 469.65it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 441.48it/s]
[A
  7%|6         | 139/2055 [00:00<00:04, 425.70it/s]
[A
  9%|9         | 185/2055 [00:00<00:04, 437.22it/s]
[A
 11%|#1        | 229/2055 [00:00<00:04, 432.17it/s]
[A
 13%|#3        | 275/2055 [00:00<00:04, 440.69it/s]
[A
 16%|#5        | 321/2055 [00:00<00:03, 442.22it/s]
[A
 18%|#7        | 366/2055 [00:00<00:04, 415.01it/s]
[A
 20%|##        | 411/2055 [00:00<00:03, 422.86it/s]
[A
 22%|##2       | 454/2055 [00:01<00:03, 412.30it/s]
[A
 24%|##4       | 496/2055 [00:01<00:03, 396.61it/s]
[A
 26%|##6       | 536/2055 [00:01<00:03, 385.90it/s]
[A
 28%|##8       | 582/2055 [00:01<00:03, 405.95it/s]
[A
 31%|###       | 628/2055 [00:01<00:03, 421.33it/s]
[A
 33%|###2      | 671/2055 [00:01<00:03, 423.39it/s]
[A
 35%|###4      | 717/2055 [00:01<00:03, 432.70it/s]
[A
 37%|###7      | 764/2055 [00:01<00:02, 441.28it/s]
[A
 39%|

 52%|████████████████████████████████████▍                                 | 26/50 [02:05<01:56,  4.85s/trial, best loss: -0.8968369829683698]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 468.09it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 458.12it/s]
[A
  7%|6         | 140/2055 [00:00<00:04, 457.64it/s]
[A
  9%|9         | 186/2055 [00:00<00:04, 429.11it/s]
[A
 11%|#1        | 230/2055 [00:00<00:04, 385.34it/s]
[A
 13%|#3        | 276/2055 [00:00<00:04, 407.19it/s]
[A
 15%|#5        | 318/2055 [00:00<00:04, 403.35it/s]
[A
 18%|#7        | 364/2055 [00:00<00:04, 420.09it/s]
[A
 20%|#9        | 410/2055 [00:00<00:03, 431.51it/s]
[A
 22%|##2       | 457/2055 [00:01<00:03, 440.41it/s]
[A
 24%|##4       | 503/2055 [00:01<00:03, 445.73it/s]
[A
 27%|##6       | 550/2055 [00:01<00:03, 450.68it/s]
[A
 29%|##9       | 596/2055 [00:01<00:03, 412.67it/s]
[A
 31%|###1      | 638/2055 [00:01<00:03, 399.64it/s]
[A
 33%|###3      | 679/2055 [00:01<00:03, 398.63it/s]
[A
 35%|###5      | 722/2055 [00:01<00:03, 406.80it/s]
[A
 37%|###7      | 763/2055 [00:01<00:03, 369.53it/s]
[A
 39%|

 54%|█████████████████████████████████████▊                                | 27/50 [02:10<01:50,  4.82s/trial, best loss: -0.8968369829683698]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  1%|1         | 28/2055 [00:00<00:07, 276.49it/s]
[A
  3%|3         | 67/2055 [00:00<00:05, 339.58it/s]
[A
  6%|5         | 114/2055 [00:00<00:04, 396.04it/s]
[A
  8%|7         | 161/2055 [00:00<00:04, 421.71it/s]
[A
 10%|#         | 207/2055 [00:00<00:04, 435.24it/s]
[A
 12%|#2        | 253/2055 [00:00<00:04, 442.95it/s]
[A
 15%|#4        | 300/2055 [00:00<00:03, 448.70it/s]
[A
 17%|#6        | 345/2055 [00:00<00:03, 438.43it/s]
[A
 19%|#8        | 389/2055 [00:00<00:03, 428.01it/s]
[A
 21%|##1       | 432/2055 [00:01<00:03, 427.83it/s]
[A
 23%|##3       | 479/2055 [00:01<00:03, 437.79it/s]
[A
 26%|##5       | 525/2055 [00:01<00:03, 441.94it/s]
[A
 28%|##7       | 571/2055 [00:01<00:03, 446.12it/s]
[A
 30%|###       | 617/2055 [00:01<00:03, 449.42it/s]
[A
 32%|###2      | 664/2055 [00:01<00:03, 452.74it/s]
[A
 35%|###4      | 710/2055 [00:01<00:02, 449.01it/s]
[A
 37%|###6      | 757/2055 [00:01<00:02, 452.77it/s]
[A
 39%|

 56%|███████████████████████████████████████▏                              | 28/50 [02:15<01:46,  4.83s/trial, best loss: -0.8968369829683698]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|1         | 36/2055 [00:00<00:05, 354.87it/s]
[A
  4%|3         | 79/2055 [00:00<00:04, 395.97it/s]
[A
  6%|5         | 119/2055 [00:00<00:04, 389.61it/s]
[A
  8%|7         | 158/2055 [00:00<00:05, 356.11it/s]
[A
 10%|9         | 204/2055 [00:00<00:04, 391.01it/s]
[A
 12%|#2        | 250/2055 [00:00<00:04, 413.20it/s]
[A
 14%|#4        | 296/2055 [00:00<00:04, 426.59it/s]
[A
 17%|#6        | 342/2055 [00:00<00:03, 434.03it/s]
[A
 19%|#8        | 388/2055 [00:00<00:03, 438.90it/s]
[A
 21%|##1       | 434/2055 [00:01<00:03, 444.22it/s]
[A
 23%|##3       | 479/2055 [00:01<00:03, 435.30it/s]
[A
 25%|##5       | 523/2055 [00:01<00:03, 433.76it/s]
[A
 28%|##7       | 567/2055 [00:01<00:03, 435.15it/s]
[A
 30%|##9       | 613/2055 [00:01<00:03, 442.44it/s]
[A
 32%|###2      | 659/2055 [00:01<00:03, 447.13it/s]
[A
 34%|###4      | 705/2055 [00:01<00:02, 450.95it/s]
[A
 37%|###6      | 751/2055 [00:01<00:02, 450.30it/s]
[A
 39%|

 58%|████████████████████████████████████████▌                             | 29/50 [02:20<01:41,  4.83s/trial, best loss: -0.8968369829683698]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 46/2055 [00:00<00:04, 459.87it/s]
[A
  4%|4         | 92/2055 [00:00<00:04, 456.69it/s]
[A
  7%|6         | 139/2055 [00:00<00:04, 460.91it/s]
[A
  9%|9         | 186/2055 [00:00<00:04, 461.97it/s]
[A
 11%|#1        | 233/2055 [00:00<00:04, 443.50it/s]
[A
 14%|#3        | 278/2055 [00:00<00:04, 413.67it/s]
[A
 16%|#5        | 320/2055 [00:00<00:04, 389.24it/s]
[A
 18%|#7        | 366/2055 [00:00<00:04, 408.29it/s]
[A
 20%|##        | 412/2055 [00:00<00:03, 421.14it/s]
[A
 22%|##2       | 458/2055 [00:01<00:03, 431.24it/s]
[A
 25%|##4       | 504/2055 [00:01<00:03, 436.75it/s]
[A
 27%|##6       | 550/2055 [00:01<00:03, 442.38it/s]
[A
 29%|##8       | 595/2055 [00:01<00:03, 441.42it/s]
[A
 31%|###1      | 640/2055 [00:01<00:03, 417.51it/s]
[A
 33%|###3      | 686/2055 [00:01<00:03, 428.89it/s]
[A
 36%|###5      | 733/2055 [00:01<00:03, 437.89it/s]
[A
 38%|###7      | 780/2055 [00:01<00:02, 444.88it/s]
[A
 40%|

 60%|██████████████████████████████████████████                            | 30/50 [02:25<01:36,  4.84s/trial, best loss: -0.8968369829683698]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 461.81it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 459.03it/s]
[A
  7%|6         | 140/2055 [00:00<00:04, 458.72it/s]
[A
  9%|9         | 187/2055 [00:00<00:04, 459.75it/s]
[A
 11%|#1        | 234/2055 [00:00<00:03, 460.44it/s]
[A
 14%|#3        | 281/2055 [00:00<00:03, 459.52it/s]
[A
 16%|#5        | 327/2055 [00:00<00:04, 421.29it/s]
[A
 18%|#8        | 373/2055 [00:00<00:03, 431.09it/s]
[A
 20%|##        | 417/2055 [00:00<00:04, 375.30it/s]
[A
 22%|##2       | 460/2055 [00:01<00:04, 388.84it/s]
[A
 25%|##4       | 506/2055 [00:01<00:03, 407.10it/s]
[A
 27%|##6       | 552/2055 [00:01<00:03, 420.41it/s]
[A
 29%|##9       | 599/2055 [00:01<00:03, 432.79it/s]
[A
 31%|###1      | 643/2055 [00:01<00:03, 426.87it/s]
[A
 34%|###3      | 689/2055 [00:01<00:03, 434.89it/s]
[A
 36%|###5      | 733/2055 [00:01<00:03, 431.29it/s]
[A
 38%|###7      | 779/2055 [00:01<00:02, 439.26it/s]
[A
 40%|

 62%|███████████████████████████████████████████▍                          | 31/50 [02:30<01:31,  4.82s/trial, best loss: -0.8968369829683698]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 44/2055 [00:00<00:04, 435.75it/s]
[A
  4%|4         | 89/2055 [00:00<00:04, 442.91it/s]
[A
  7%|6         | 135/2055 [00:00<00:04, 449.82it/s]
[A
  9%|8         | 181/2055 [00:00<00:04, 450.14it/s]
[A
 11%|#1        | 227/2055 [00:00<00:04, 450.99it/s]
[A
 13%|#3        | 273/2055 [00:00<00:03, 451.90it/s]
[A
 16%|#5        | 319/2055 [00:00<00:03, 454.50it/s]
[A
 18%|#7        | 365/2055 [00:00<00:03, 452.84it/s]
[A
 20%|##        | 411/2055 [00:00<00:03, 450.17it/s]
[A
 22%|##2       | 457/2055 [00:01<00:03, 444.73it/s]
[A
 24%|##4       | 503/2055 [00:01<00:03, 447.85it/s]
[A
 27%|##6       | 549/2055 [00:01<00:03, 450.59it/s]
[A
 29%|##8       | 595/2055 [00:01<00:03, 383.53it/s]
[A
 31%|###       | 637/2055 [00:01<00:03, 392.24it/s]
[A
 33%|###3      | 683/2055 [00:01<00:03, 409.71it/s]
[A
 35%|###5      | 729/2055 [00:01<00:03, 423.31it/s]
[A
 38%|###7      | 776/2055 [00:01<00:02, 433.92it/s]
[A
 40%|

 64%|████████████████████████████████████████████▊                         | 32/50 [02:34<01:26,  4.82s/trial, best loss: -0.8968369829683698]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|1         | 41/2055 [00:00<00:04, 406.37it/s]
[A
  4%|3         | 82/2055 [00:00<00:05, 386.29it/s]
[A
  6%|6         | 126/2055 [00:00<00:04, 407.54it/s]
[A
  8%|8         | 167/2055 [00:00<00:04, 397.53it/s]
[A
 10%|#         | 207/2055 [00:00<00:04, 379.73it/s]
[A
 12%|#2        | 252/2055 [00:00<00:04, 400.58it/s]
[A
 15%|#4        | 298/2055 [00:00<00:04, 418.66it/s]
[A
 17%|#6        | 344/2055 [00:00<00:03, 430.51it/s]
[A
 19%|#9        | 391/2055 [00:00<00:03, 440.11it/s]
[A
 21%|##1       | 437/2055 [00:01<00:03, 444.05it/s]
[A
 24%|##3       | 483/2055 [00:01<00:03, 447.88it/s]
[A
 26%|##5       | 528/2055 [00:01<00:03, 440.74it/s]
[A
 28%|##7       | 573/2055 [00:01<00:03, 420.65it/s]
[A
 30%|###       | 619/2055 [00:01<00:03, 430.92it/s]
[A
 32%|###2      | 665/2055 [00:01<00:03, 438.82it/s]
[A
 35%|###4      | 710/2055 [00:01<00:03, 385.47it/s]
[A
 37%|###6      | 756/2055 [00:01<00:03, 403.16it/s]
[A
 39%|

 66%|██████████████████████████████████████████████▏                       | 33/50 [02:39<01:21,  4.79s/trial, best loss: -0.8968369829683698]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 468.06it/s]
[A
  5%|4         | 94/2055 [00:00<00:05, 387.18it/s]
[A
  7%|6         | 134/2055 [00:00<00:05, 362.50it/s]
[A
  9%|8         | 180/2055 [00:00<00:04, 395.61it/s]
[A
 11%|#         | 221/2055 [00:00<00:05, 362.84it/s]
[A
 13%|#2        | 264/2055 [00:00<00:04, 381.26it/s]
[A
 15%|#4        | 303/2055 [00:00<00:04, 352.19it/s]
[A
 17%|#6        | 346/2055 [00:00<00:04, 373.38it/s]
[A
 19%|#8        | 386/2055 [00:01<00:04, 377.97it/s]
[A
 21%|##        | 429/2055 [00:01<00:04, 392.80it/s]
[A
 23%|##2       | 469/2055 [00:01<00:04, 380.85it/s]
[A
 25%|##4       | 513/2055 [00:01<00:03, 397.40it/s]
[A
 27%|##7       | 558/2055 [00:01<00:03, 410.95it/s]
[A
 29%|##9       | 602/2055 [00:01<00:03, 417.35it/s]
[A
 32%|###1      | 649/2055 [00:01<00:03, 430.07it/s]
[A
 34%|###3      | 695/2055 [00:01<00:03, 438.81it/s]
[A
 36%|###6      | 741/2055 [00:01<00:02, 443.05it/s]
[A
 38%|

 68%|███████████████████████████████████████████████▌                      | 34/50 [02:44<01:16,  4.81s/trial, best loss: -0.8968369829683698]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 464.98it/s]
[A
  5%|4         | 94/2055 [00:00<00:05, 347.19it/s]
[A
  6%|6         | 131/2055 [00:00<00:05, 334.63it/s]
[A
  9%|8         | 177/2055 [00:00<00:05, 374.90it/s]
[A
 11%|#         | 216/2055 [00:00<00:04, 374.12it/s]
[A
 13%|#2        | 262/2055 [00:00<00:04, 399.22it/s]
[A
 15%|#4        | 306/2055 [00:00<00:04, 409.78it/s]
[A
 17%|#7        | 352/2055 [00:00<00:04, 425.02it/s]
[A
 19%|#9        | 398/2055 [00:00<00:03, 434.77it/s]
[A
 22%|##1       | 444/2055 [00:01<00:03, 441.66it/s]
[A
 24%|##3       | 489/2055 [00:01<00:03, 443.63it/s]
[A
 26%|##6       | 535/2055 [00:01<00:03, 448.16it/s]
[A
 28%|##8       | 582/2055 [00:01<00:03, 451.97it/s]
[A
 31%|###       | 628/2055 [00:01<00:03, 415.91it/s]
[A
 33%|###2      | 675/2055 [00:01<00:03, 428.79it/s]
[A
 35%|###4      | 719/2055 [00:01<00:03, 411.54it/s]
[A
 37%|###7      | 761/2055 [00:01<00:03, 411.02it/s]
[A
 39%|

 70%|█████████████████████████████████████████████████                     | 35/50 [02:49<01:12,  4.82s/trial, best loss: -0.8968369829683698]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 464.32it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 460.02it/s]
[A
  7%|6         | 141/2055 [00:00<00:04, 457.10it/s]
[A
  9%|9         | 187/2055 [00:00<00:04, 458.00it/s]
[A
 11%|#1        | 233/2055 [00:00<00:04, 454.18it/s]
[A
 14%|#3        | 280/2055 [00:00<00:03, 456.64it/s]
[A
 16%|#5        | 326/2055 [00:00<00:04, 425.53it/s]
[A
 18%|#7        | 369/2055 [00:00<00:04, 420.20it/s]
[A
 20%|##        | 412/2055 [00:00<00:03, 422.92it/s]
[A
 22%|##2       | 459/2055 [00:01<00:03, 434.19it/s]
[A
 25%|##4       | 505/2055 [00:01<00:03, 440.79it/s]
[A
 27%|##6       | 550/2055 [00:01<00:03, 438.96it/s]
[A
 29%|##8       | 594/2055 [00:01<00:03, 425.32it/s]
[A
 31%|###1      | 639/2055 [00:01<00:03, 431.84it/s]
[A
 33%|###3      | 686/2055 [00:01<00:03, 440.40it/s]
[A
 36%|###5      | 732/2055 [00:01<00:02, 445.61it/s]
[A
 38%|###7      | 777/2055 [00:01<00:02, 446.55it/s]
[A
 40%|

 72%|██████████████████████████████████████████████████▍                   | 36/50 [02:54<01:07,  4.80s/trial, best loss: -0.8968369829683698]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 46/2055 [00:00<00:04, 459.43it/s]
[A
  4%|4         | 92/2055 [00:00<00:04, 453.78it/s]
[A
  7%|6         | 138/2055 [00:00<00:04, 443.42it/s]
[A
  9%|8         | 184/2055 [00:00<00:04, 449.01it/s]
[A
 11%|#1        | 230/2055 [00:00<00:04, 451.95it/s]
[A
 13%|#3        | 276/2055 [00:00<00:04, 424.56it/s]
[A
 16%|#5        | 319/2055 [00:00<00:04, 412.17it/s]
[A
 18%|#7        | 361/2055 [00:00<00:04, 367.30it/s]
[A
 19%|#9        | 399/2055 [00:01<00:05, 322.16it/s]
[A
 22%|##1       | 444/2055 [00:01<00:04, 353.68it/s]
[A
 24%|##3       | 488/2055 [00:01<00:04, 374.55it/s]
[A
 26%|##5       | 534/2055 [00:01<00:03, 396.68it/s]
[A
 28%|##8       | 579/2055 [00:01<00:03, 410.84it/s]
[A
 30%|###       | 626/2055 [00:01<00:03, 425.09it/s]
[A
 33%|###2      | 672/2055 [00:01<00:03, 435.09it/s]
[A
 35%|###4      | 718/2055 [00:01<00:03, 441.36it/s]
[A
 37%|###7      | 763/2055 [00:01<00:02, 443.43it/s]
[A
 39%|

 74%|███████████████████████████████████████████████████▊                  | 37/50 [02:58<01:02,  4.83s/trial, best loss: -0.9051094890510949]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 463.80it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 458.53it/s]
[A
  7%|6         | 140/2055 [00:00<00:04, 459.01it/s]
[A
  9%|9         | 186/2055 [00:00<00:04, 459.14it/s]
[A
 11%|#1        | 232/2055 [00:00<00:04, 447.50it/s]
[A
 14%|#3        | 278/2055 [00:00<00:03, 449.03it/s]
[A
 16%|#5        | 323/2055 [00:00<00:03, 448.27it/s]
[A
 18%|#7        | 369/2055 [00:00<00:03, 449.89it/s]
[A
 20%|##        | 415/2055 [00:00<00:03, 451.84it/s]
[A
 22%|##2       | 462/2055 [00:01<00:03, 454.19it/s]
[A
 25%|##4       | 508/2055 [00:01<00:03, 453.09it/s]
[A
 27%|##6       | 554/2055 [00:01<00:03, 417.59it/s]
[A
 29%|##9       | 597/2055 [00:01<00:03, 371.76it/s]
[A
 31%|###       | 637/2055 [00:01<00:03, 378.15it/s]
[A
 33%|###3      | 683/2055 [00:01<00:03, 399.81it/s]
[A
 35%|###5      | 729/2055 [00:01<00:03, 416.51it/s]
[A
 38%|###7      | 775/2055 [00:01<00:03, 426.39it/s]
[A
 40%|

 76%|█████████████████████████████████████████████████████▏                | 38/50 [03:03<00:58,  4.84s/trial, best loss: -0.9051094890510949]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 465.53it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 458.29it/s]
[A
  7%|6         | 141/2055 [00:00<00:04, 459.59it/s]
[A
  9%|9         | 187/2055 [00:00<00:04, 458.96it/s]
[A
 11%|#1        | 233/2055 [00:00<00:03, 456.42it/s]
[A
 14%|#3        | 279/2055 [00:00<00:03, 456.14it/s]
[A
 16%|#5        | 325/2055 [00:00<00:03, 446.72it/s]
[A
 18%|#8        | 371/2055 [00:00<00:03, 449.34it/s]
[A
 20%|##        | 417/2055 [00:00<00:03, 451.71it/s]
[A
 23%|##2       | 463/2055 [00:01<00:03, 453.68it/s]
[A
 25%|##4       | 509/2055 [00:01<00:03, 447.53it/s]
[A
 27%|##6       | 554/2055 [00:01<00:03, 431.32it/s]
[A
 29%|##9       | 598/2055 [00:01<00:03, 367.04it/s]
[A
 31%|###       | 637/2055 [00:01<00:04, 349.99it/s]
[A
 33%|###2      | 674/2055 [00:01<00:04, 330.67it/s]
[A
 35%|###4      | 716/2055 [00:01<00:03, 353.38it/s]
[A
 37%|###7      | 763/2055 [00:01<00:03, 382.81it/s]
[A
 39%|

 78%|██████████████████████████████████████████████████████▌               | 39/50 [03:08<00:53,  4.86s/trial, best loss: -0.9051094890510949]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 466.77it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 461.95it/s]
[A
  7%|6         | 141/2055 [00:00<00:04, 461.20it/s]
[A
  9%|9         | 188/2055 [00:00<00:04, 424.88it/s]
[A
 11%|#1        | 234/2055 [00:00<00:04, 435.86it/s]
[A
 14%|#3        | 281/2055 [00:00<00:03, 444.09it/s]
[A
 16%|#5        | 328/2055 [00:00<00:03, 449.77it/s]
[A
 18%|#8        | 374/2055 [00:00<00:03, 442.20it/s]
[A
 20%|##        | 420/2055 [00:00<00:03, 445.61it/s]
[A
 23%|##2       | 466/2055 [00:01<00:03, 447.97it/s]
[A
 25%|##4       | 512/2055 [00:01<00:03, 450.85it/s]
[A
 27%|##7       | 558/2055 [00:01<00:03, 450.92it/s]
[A
 29%|##9       | 604/2055 [00:01<00:03, 452.79it/s]
[A
 32%|###1      | 651/2055 [00:01<00:03, 455.19it/s]
[A
 34%|###3      | 697/2055 [00:01<00:02, 455.49it/s]
[A
 36%|###6      | 743/2055 [00:01<00:02, 455.47it/s]
[A
 38%|###8      | 789/2055 [00:01<00:02, 433.64it/s]
[A
 41%|

 80%|████████████████████████████████████████████████████████              | 40/50 [03:13<00:48,  4.83s/trial, best loss: -0.9051094890510949]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|1         | 34/2055 [00:00<00:06, 336.83it/s]
[A
  4%|3         | 80/2055 [00:00<00:04, 406.10it/s]
[A
  6%|5         | 121/2055 [00:00<00:04, 389.40it/s]
[A
  8%|8         | 166/2055 [00:00<00:04, 412.46it/s]
[A
 10%|#         | 208/2055 [00:00<00:04, 397.22it/s]
[A
 12%|#2        | 253/2055 [00:00<00:04, 414.16it/s]
[A
 15%|#4        | 299/2055 [00:00<00:04, 428.54it/s]
[A
 17%|#6        | 345/2055 [00:00<00:03, 437.85it/s]
[A
 19%|#9        | 391/2055 [00:00<00:03, 442.82it/s]
[A
 21%|##1       | 437/2055 [00:01<00:03, 446.26it/s]
[A
 23%|##3       | 482/2055 [00:01<00:03, 439.50it/s]
[A
 26%|##5       | 528/2055 [00:01<00:03, 443.62it/s]
[A
 28%|##7       | 574/2055 [00:01<00:03, 447.88it/s]
[A
 30%|###       | 621/2055 [00:01<00:03, 451.45it/s]
[A
 32%|###2      | 667/2055 [00:01<00:03, 441.16it/s]
[A
 35%|###4      | 713/2055 [00:01<00:03, 445.57it/s]
[A
 37%|###6      | 758/2055 [00:01<00:03, 424.49it/s]
[A
 39%|

 82%|█████████████████████████████████████████████████████████▍            | 41/50 [03:18<00:43,  4.84s/trial, best loss: -0.9060827250608272]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 464.97it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 461.67it/s]
[A
  7%|6         | 141/2055 [00:00<00:04, 448.86it/s]
[A
  9%|9         | 187/2055 [00:00<00:04, 451.91it/s]
[A
 11%|#1        | 233/2055 [00:00<00:04, 454.07it/s]
[A
 14%|#3        | 279/2055 [00:00<00:04, 413.52it/s]
[A
 16%|#5        | 325/2055 [00:00<00:04, 426.58it/s]
[A
 18%|#7        | 369/2055 [00:00<00:04, 414.32it/s]
[A
 20%|##        | 413/2055 [00:00<00:03, 417.71it/s]
[A
 22%|##2       | 456/2055 [00:01<00:04, 348.65it/s]
[A
 24%|##4       | 501/2055 [00:01<00:04, 374.53it/s]
[A
 26%|##6       | 544/2055 [00:01<00:03, 389.09it/s]
[A
 29%|##8       | 591/2055 [00:01<00:03, 409.49it/s]
[A
 31%|###1      | 638/2055 [00:01<00:03, 424.76it/s]
[A
 33%|###3      | 684/2055 [00:01<00:03, 434.76it/s]
[A
 36%|###5      | 730/2055 [00:01<00:03, 441.66it/s]
[A
 38%|###7      | 775/2055 [00:01<00:02, 443.26it/s]
[A
 40%|

 84%|██████████████████████████████████████████████████████████▊           | 42/50 [03:23<00:38,  4.83s/trial, best loss: -0.9060827250608272]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 462.73it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 462.33it/s]
[A
  7%|6         | 141/2055 [00:00<00:04, 421.91it/s]
[A
  9%|9         | 187/2055 [00:00<00:04, 434.39it/s]
[A
 11%|#1        | 234/2055 [00:00<00:04, 443.93it/s]
[A
 14%|#3        | 279/2055 [00:00<00:04, 439.14it/s]
[A
 16%|#5        | 325/2055 [00:00<00:03, 445.21it/s]
[A
 18%|#8        | 371/2055 [00:00<00:03, 447.58it/s]
[A
 20%|##        | 417/2055 [00:00<00:03, 450.21it/s]
[A
 23%|##2       | 463/2055 [00:01<00:03, 452.54it/s]
[A
 25%|##4       | 509/2055 [00:01<00:03, 443.51it/s]
[A
 27%|##7       | 555/2055 [00:01<00:03, 446.44it/s]
[A
 29%|##9       | 600/2055 [00:01<00:03, 411.89it/s]
[A
 31%|###1      | 646/2055 [00:01<00:03, 424.29it/s]
[A
 34%|###3      | 689/2055 [00:01<00:03, 408.48it/s]
[A
 36%|###5      | 731/2055 [00:01<00:04, 304.60it/s]
[A
 38%|###7      | 777/2055 [00:01<00:03, 339.13it/s]
[A
 40%|

 86%|████████████████████████████████████████████████████████████▏         | 43/50 [03:28<00:34,  4.92s/trial, best loss: -0.9060827250608272]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 462.77it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 453.81it/s]
[A
  7%|6         | 140/2055 [00:00<00:04, 450.79it/s]
[A
  9%|9         | 186/2055 [00:00<00:04, 437.95it/s]
[A
 11%|#1        | 230/2055 [00:00<00:04, 432.96it/s]
[A
 13%|#3        | 276/2055 [00:00<00:04, 440.92it/s]
[A
 16%|#5        | 322/2055 [00:00<00:03, 445.11it/s]
[A
 18%|#7        | 368/2055 [00:00<00:03, 448.20it/s]
[A
 20%|##        | 414/2055 [00:00<00:03, 449.67it/s]
[A
 22%|##2       | 460/2055 [00:01<00:03, 451.39it/s]
[A
 25%|##4       | 506/2055 [00:01<00:03, 451.38it/s]
[A
 27%|##6       | 552/2055 [00:01<00:03, 450.93it/s]
[A
 29%|##9       | 598/2055 [00:01<00:03, 426.07it/s]
[A
 31%|###1      | 644/2055 [00:01<00:03, 433.38it/s]
[A
 33%|###3      | 688/2055 [00:01<00:03, 432.07it/s]
[A
 36%|###5      | 732/2055 [00:01<00:03, 406.15it/s]
[A
 38%|###7      | 778/2055 [00:01<00:03, 419.69it/s]
[A
 40%|

 88%|█████████████████████████████████████████████████████████████▌        | 44/50 [03:33<00:29,  4.89s/trial, best loss: -0.9060827250608272]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|1         | 32/2055 [00:00<00:06, 319.75it/s]
[A
  4%|3         | 78/2055 [00:00<00:04, 398.65it/s]
[A
  6%|6         | 124/2055 [00:00<00:04, 423.12it/s]
[A
  8%|8         | 170/2055 [00:00<00:04, 436.09it/s]
[A
 11%|#         | 216/2055 [00:00<00:04, 442.24it/s]
[A
 13%|#2        | 261/2055 [00:00<00:04, 407.17it/s]
[A
 15%|#4        | 306/2055 [00:00<00:04, 417.62it/s]
[A
 17%|#6        | 349/2055 [00:00<00:04, 420.25it/s]
[A
 19%|#9        | 395/2055 [00:00<00:03, 430.88it/s]
[A
 21%|##1       | 441/2055 [00:01<00:03, 437.96it/s]
[A
 24%|##3       | 487/2055 [00:01<00:03, 442.64it/s]
[A
 26%|##5       | 533/2055 [00:01<00:03, 446.12it/s]
[A
 28%|##8       | 579/2055 [00:01<00:03, 448.35it/s]
[A
 30%|###       | 625/2055 [00:01<00:03, 449.22it/s]
[A
 33%|###2      | 671/2055 [00:01<00:03, 450.20it/s]
[A
 35%|###4      | 717/2055 [00:01<00:02, 447.58it/s]
[A
 37%|###7      | 763/2055 [00:01<00:02, 448.86it/s]
[A
 39%|

 90%|███████████████████████████████████████████████████████████████       | 45/50 [03:37<00:24,  4.90s/trial, best loss: -0.9060827250608272]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 462.85it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 396.72it/s]
[A
  7%|6         | 135/2055 [00:00<00:05, 373.86it/s]
[A
  9%|8         | 181/2055 [00:00<00:04, 401.54it/s]
[A
 11%|#1        | 227/2055 [00:00<00:04, 418.70it/s]
[A
 13%|#3        | 273/2055 [00:00<00:04, 429.91it/s]
[A
 16%|#5        | 319/2055 [00:00<00:03, 438.68it/s]
[A
 18%|#7        | 365/2055 [00:00<00:03, 443.63it/s]
[A
 20%|#9        | 410/2055 [00:00<00:03, 437.17it/s]
[A
 22%|##2       | 456/2055 [00:01<00:03, 442.36it/s]
[A
 24%|##4       | 502/2055 [00:01<00:03, 446.95it/s]
[A
 27%|##6       | 548/2055 [00:01<00:03, 450.21it/s]
[A
 29%|##8       | 594/2055 [00:01<00:03, 450.78it/s]
[A
 31%|###1      | 640/2055 [00:01<00:03, 451.40it/s]
[A
 33%|###3      | 686/2055 [00:01<00:03, 452.72it/s]
[A
 36%|###5      | 732/2055 [00:01<00:02, 451.57it/s]
[A
 38%|###7      | 778/2055 [00:01<00:02, 452.37it/s]
[A
 40%|

 92%|████████████████████████████████████████████████████████████████▍     | 46/50 [03:42<00:19,  4.87s/trial, best loss: -0.9060827250608272]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 462.44it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 441.10it/s]
[A
  7%|6         | 139/2055 [00:00<00:04, 444.46it/s]
[A
  9%|9         | 185/2055 [00:00<00:04, 448.24it/s]
[A
 11%|#1        | 230/2055 [00:00<00:04, 430.29it/s]
[A
 13%|#3        | 274/2055 [00:00<00:04, 377.36it/s]
[A
 16%|#5        | 320/2055 [00:00<00:04, 400.96it/s]
[A
 18%|#7        | 366/2055 [00:00<00:04, 415.84it/s]
[A
 20%|##        | 412/2055 [00:00<00:03, 427.92it/s]
[A
 22%|##2       | 458/2055 [00:01<00:03, 436.10it/s]
[A
 24%|##4       | 503/2055 [00:01<00:03, 408.81it/s]
[A
 27%|##6       | 549/2055 [00:01<00:03, 422.16it/s]
[A
 29%|##8       | 595/2055 [00:01<00:03, 432.79it/s]
[A
 31%|###1      | 642/2055 [00:01<00:03, 440.72it/s]
[A
 33%|###3      | 688/2055 [00:01<00:03, 444.23it/s]
[A
 36%|###5      | 734/2055 [00:01<00:02, 446.89it/s]
[A
 38%|###7      | 780/2055 [00:01<00:02, 448.85it/s]
[A
 40%|

 94%|█████████████████████████████████████████████████████████████████▊    | 47/50 [03:47<00:14,  4.88s/trial, best loss: -0.9060827250608272]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 465.49it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 459.30it/s]
[A
  7%|6         | 140/2055 [00:00<00:04, 456.28it/s]
[A
  9%|9         | 186/2055 [00:00<00:04, 417.30it/s]
[A
 11%|#1        | 232/2055 [00:00<00:04, 430.45it/s]
[A
 14%|#3        | 279/2055 [00:00<00:04, 440.12it/s]
[A
 16%|#5        | 325/2055 [00:00<00:03, 445.20it/s]
[A
 18%|#8        | 370/2055 [00:00<00:04, 399.89it/s]
[A
 20%|##        | 411/2055 [00:01<00:04, 367.46it/s]
[A
 22%|##2       | 453/2055 [00:01<00:04, 379.77it/s]
[A
 24%|##4       | 499/2055 [00:01<00:03, 400.40it/s]
[A
 27%|##6       | 545/2055 [00:01<00:03, 416.30it/s]
[A
 29%|##8       | 588/2055 [00:01<00:03, 420.10it/s]
[A
 31%|###       | 633/2055 [00:01<00:03, 428.13it/s]
[A
 33%|###3      | 679/2055 [00:01<00:03, 437.12it/s]
[A
 35%|###5      | 725/2055 [00:01<00:03, 441.89it/s]
[A
 38%|###7      | 771/2055 [00:01<00:02, 446.74it/s]
[A
 40%|

 96%|███████████████████████████████████████████████████████████████████▏  | 48/50 [03:52<00:09,  4.87s/trial, best loss: -0.9060827250608272]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 47/2055 [00:00<00:04, 460.81it/s]
[A
  5%|4         | 94/2055 [00:00<00:04, 450.31it/s]
[A
  7%|6         | 140/2055 [00:00<00:04, 453.54it/s]
[A
  9%|9         | 186/2055 [00:00<00:04, 453.67it/s]
[A
 11%|#1        | 232/2055 [00:00<00:04, 454.29it/s]
[A
 14%|#3        | 278/2055 [00:00<00:03, 445.47it/s]
[A
 16%|#5        | 324/2055 [00:00<00:03, 449.07it/s]
[A
 18%|#8        | 370/2055 [00:00<00:03, 452.08it/s]
[A
 20%|##        | 416/2055 [00:00<00:03, 452.62it/s]
[A
 22%|##2       | 462/2055 [00:01<00:03, 452.07it/s]
[A
 25%|##4       | 508/2055 [00:01<00:03, 452.17it/s]
[A
 27%|##6       | 554/2055 [00:01<00:03, 410.96it/s]
[A
 29%|##9       | 596/2055 [00:01<00:03, 390.24it/s]
[A
 31%|###1      | 642/2055 [00:01<00:03, 408.40it/s]
[A
 33%|###3      | 686/2055 [00:01<00:03, 413.59it/s]
[A
 35%|###5      | 728/2055 [00:01<00:03, 400.34it/s]
[A
 38%|###7      | 774/2055 [00:01<00:03, 416.28it/s]
[A
 40%|

 98%|████████████████████████████████████████████████████████████████████▌ | 49/50 [03:57<00:04,  4.88s/trial, best loss: -0.9094890510948905]

  0%|          | 0/2055 [00:00<?, ?it/s]
[A
  2%|2         | 44/2055 [00:00<00:04, 436.45it/s]
[A
  4%|4         | 88/2055 [00:00<00:05, 381.92it/s]
[A
  6%|6         | 130/2055 [00:00<00:04, 393.50it/s]
[A
  8%|8         | 170/2055 [00:00<00:05, 363.53it/s]
[A
 11%|#         | 216/2055 [00:00<00:04, 393.28it/s]
[A
 13%|#2        | 262/2055 [00:00<00:04, 412.52it/s]
[A
 15%|#4        | 305/2055 [00:00<00:04, 416.59it/s]
[A
 17%|#6        | 347/2055 [00:00<00:04, 407.39it/s]
[A
 19%|#8        | 390/2055 [00:00<00:04, 411.72it/s]
[A
 21%|##1       | 436/2055 [00:01<00:03, 425.37it/s]
[A
 23%|##3       | 479/2055 [00:01<00:03, 421.84it/s]
[A
 26%|##5       | 525/2055 [00:01<00:03, 430.73it/s]
[A
 28%|##7       | 571/2055 [00:01<00:03, 437.04it/s]
[A
 30%|##9       | 615/2055 [00:01<00:03, 409.59it/s]
[A
 32%|###1      | 657/2055 [00:01<00:03, 398.11it/s]
[A
 34%|###4      | 703/2055 [00:01<00:03, 413.23it/s]
[A
 36%|###6      | 746/2055 [00:01<00:03, 417.19it/s]
[A
 39%|

100%|██████████████████████████████████████████████████████████████████████| 50/50 [04:02<00:00,  4.85s/trial, best loss: -0.9094890510948905]


In [16]:
df_full_train, df_test = train_test_split(ground_truth, test_size=0.2, random_state=1)
df_train, df_val = train_test_split(df_full_train, test_size=0.25, random_state=1)

In [17]:
best_num_results = int(best['num_results'])
best_boost_factor = best['boost_factor']
best_section_index = best['section']


section_options = ['nhs claim benefits', 'general claim benefits']
 
best_filter_dict = {'section': section_options[best_section_index]}

In [18]:
def best_min_search(query, section):
    boost = {
        "Question": best_boost_factor,
        "Answer": best_boost_factor
    }
    
    results = index.search(
        query=query,
        filter_dict=best_filter_dict,
        boost_dict=boost,
        num_results=best_num_results
    )
    return results

# Evaluate using the best parameters
metrics = evaluate(df_val, lambda q: best_min_search(q['question'], q['claims_type']))
print(metrics)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 411/411 [00:01<00:00, 402.91it/s]

{'hit_rate': 0.9051094890510949, 'mrr': 0.712078554049357}





In [19]:
print(best_num_results)
print(best_boost_factor)
print(best_filter_dict)

10
1.9376603614188619
{'section': 'general claim benefits'}


In [20]:
def min_search(query, section):
    boost = {
        "Question": 1.76740328091659, 
        "Answer": 1.76740328091659     
    }

    results = index.search(
        query=query,
        filter_dict={'section': section},  
        boost_dict=boost,                  
        num_results=10                     
    )

    return results

In [21]:
def evaluate(ground_truth, search_function):
    relevance_total = []

    for q in tqdm(ground_truth):
        doc_id = q['document']  # Ground truth document ID
        results = search_function(q['question'], q['claims_type'])  # Pass query and section
        relevance = [d['id'] == doc_id for d in results]  # Check if doc ID matches in results
        relevance_total.append(relevance)

    return {
        'hit_rate': hit_rate(relevance_total),
        'mrr': mrr(relevance_total),
    }

# Evaluate on the test dataset
metrics = evaluate(df_test, lambda question, claims_type: min_search(question, claims_type))

print(metrics)

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 411/411 [00:00<00:00, 424.67it/s]

{'hit_rate': 0.8880778588807786, 'mrr': 0.7030423280423281}





In [None]:
Minsearch metrics without tuning {'hit_rate': 0.7965936739659367, 'mrr': 0.6443309002433094}

In [46]:
import json


from tqdm.auto import tqdm
from sentence_transformers import SentenceTransformer

In [2]:
with open('notebooks/documents-with-ids.json', 'r') as file:
    documents = json.load(file)

documents[10]

{'Category': 'Temporarily unable to work',
 'Question': "Can I get sick pay if I'm self-isolating?",
 'Answer': "Yes Statutory Sick Pay is available if you're self-isolating.",
 'Section': 'general claim benefits',
 'id': '1de35e0b-f233-554c-84ef-fc30494e0ea0'}

In [3]:
from elasticsearch import Elasticsearch

es_client = Elasticsearch('http://localhost:9200') 

index_settings = {
    "settings": {
        "number_of_shards": 1,
        "number_of_replicas": 0
    },
    "mappings": {
        "properties": {
            "Answer": {"type": "text"},
            "Category": {"type": "text"},
            "Question": {"type": "text"},
            "Section": {"type": "keyword"},
            "id": {"type": "keyword"},
        }
    }
}

index_name = "benefit-claims"

es_client.indices.delete(index=index_name, ignore_unavailable=True)
es_client.indices.create(index=index_name, body=index_settings)

ObjectApiResponse({'acknowledged': True, 'shards_acknowledged': True, 'index': 'benefit-claims'})

In [4]:
for doc in tqdm(documents):
    es_client.index(index=index_name, document=doc)

  0%|          | 0/425 [00:00<?, ?it/s]

In [5]:
def elastic_search(query, section):
    search_query = {
        "size": 5,
        "query": {
            "bool": {
                "must": {
                    "multi_match": {
                        "query": query,
                        "fields": ["Question^3", "Answer", "Category"],
                        "type": "best_fields"
                    }
                },
                "filter": {
                    "term": {
                        "Section": section
                    }
                }
            }
        }
    }

    response = es_client.search(index=index_name, body=search_query)
    
    result_docs = []
    
    for hit in response['hits']['hits']:
        result_docs.append(hit['_source'])
    
    return result_docs

In [6]:
elastic_search(
    query="Can I get sick pay if I'm self-isolating?",
    section="general claim benefits"
)

[{'Category': 'Temporarily unable to work',
  'Question': "Can I get sick pay if I'm self-isolating?",
  'Answer': "Yes Statutory Sick Pay is available if you're self-isolating.",
  'Section': 'general claim benefits',
  'id': '1de35e0b-f233-554c-84ef-fc30494e0ea0'},
 {'Category': 'Temporarily unable to work',
  'Question': 'How do I apply for sick pay?',
  'Answer': 'You need to provide a fit note from your doctor to apply for sick pay.',
  'Section': 'general claim benefits',
  'id': 'ecc42084-d88e-5bf0-8070-a18552c283bb'},
 {'Category': 'Temporarily unable to work',
  'Question': 'What is statutory sick pay?',
  'Answer': "Statutory Sick Pay is a legal requirement for employers to pay you if you're ill.",
  'Section': 'general claim benefits',
  'id': 'ee6a77ec-d4ef-50f7-9762-04cc228b3a48'},
 {'Category': 'Disabled or health condition',
  'Question': 'Can I get help with housing if I’m disabled?',
  'Answer': 'You may be eligible for a Disabled Facilities Grant to adapt your home to

In [47]:
import pandas as pd

In [7]:
df_ground_truth = pd.read_csv('notebooks/ground-truth-data.csv')

In [8]:
ground_truth = df_ground_truth.to_dict(orient='records')

In [9]:
ground_truth[10]

{'question': 'Is it possible to appeal?',
 'claims_type': 'general claim benefits',
 'document': '8d000ade-6c2b-571c-aa61-5d38eb463cf8'}

In [10]:
elastic_search(
    query="Can I get sick pay if I'm self-isolating?",
    section="general claim benefits"
)

[{'Category': 'Temporarily unable to work',
  'Question': "Can I get sick pay if I'm self-isolating?",
  'Answer': "Yes Statutory Sick Pay is available if you're self-isolating.",
  'Section': 'general claim benefits',
  'id': '1de35e0b-f233-554c-84ef-fc30494e0ea0'},
 {'Category': 'Temporarily unable to work',
  'Question': 'How do I apply for sick pay?',
  'Answer': 'You need to provide a fit note from your doctor to apply for sick pay.',
  'Section': 'general claim benefits',
  'id': 'ecc42084-d88e-5bf0-8070-a18552c283bb'},
 {'Category': 'Temporarily unable to work',
  'Question': 'What is statutory sick pay?',
  'Answer': "Statutory Sick Pay is a legal requirement for employers to pay you if you're ill.",
  'Section': 'general claim benefits',
  'id': 'ee6a77ec-d4ef-50f7-9762-04cc228b3a48'},
 {'Category': 'Disabled or health condition',
  'Question': 'Can I get help with housing if I’m disabled?',
  'Answer': 'You may be eligible for a Disabled Facilities Grant to adapt your home to

In [11]:
relevance_total = []

for q in tqdm(ground_truth):
    doc_id = q['document']
    results = elastic_search(query=q['question'], section=q['claims_type'])
    relevance = [d['id'] == doc_id for d in results]
    relevance_total.append(relevance)

  0%|          | 0/2055 [00:00<?, ?it/s]

In [12]:
def hit_rate(relevance_total):
    cnt = 0

    for line in relevance_total:
        if True in line:
            cnt = cnt + 1

    return cnt / len(relevance_total)

In [13]:
def mrr(relevance_total):
    total_score = 0.0

    for line in relevance_total:
        for rank in range(len(line)):
            if line[rank] == True:
                total_score = total_score + 1 / (rank + 1)

    return total_score / len(relevance_total)

In [14]:
hit_rate(relevance_total), mrr(relevance_total)

(0.7732360097323601, 0.6168369829683708)

In [15]:
documents[10]

{'Category': 'Temporarily unable to work',
 'Question': "Can I get sick pay if I'm self-isolating?",
 'Answer': "Yes Statutory Sick Pay is available if you're self-isolating.",
 'Section': 'general claim benefits',
 'id': '1de35e0b-f233-554c-84ef-fc30494e0ea0'}

## Minsearch

In [16]:
import minsearch

index = minsearch.Index(
    text_fields=["Question", "Answer", "Category"],
    keyword_fields=["Section", "id"]
)

index.fit(documents)

<minsearch.Index at 0x7b56b0afc620>

In [37]:
def min_search(query, section):
    boost = {'Question': 3.0, 'Category': 0.5}

    results = index.search(
        query=query,
        filter_dict={'Section': section},
        boost_dict=boost,
        num_results=5
    )

    return results

In [38]:
ground_truth[0]

{'question': 'How can I change my existing benefit details?',
 'claims_type': 'general claim benefits',
 'document': '30eada08-5708-5c5c-9df8-0f7d5d4dc131'}

In [39]:
relevance_total = []

for q in tqdm(ground_truth):
    doc_id = q['document']
    results = min_search(query=q['question'], section=q['claims_type'])
    relevance = [d['id'] == doc_id for d in results]
    relevance_total.append(relevance)

  0%|          | 0/2055 [00:00<?, ?it/s]

In [40]:
hit_rate(relevance_total), mrr(relevance_total)

(0.8306569343065694, 0.6697891321978915)

`ES- Hitrate:0.7732360097323601, MRR:0.6168369829683708)`

In [41]:
def evaluate(ground_truth, search_function):
    relevance_total = []

    for q in tqdm(ground_truth):
        doc_id = q['document']
        results = search_function(q)
        relevance = [d['id'] == doc_id for d in results]
        relevance_total.append(relevance)

    return {
        'hit_rate': hit_rate(relevance_total),
        'mrr': mrr(relevance_total),
    }

In [43]:
evaluate(ground_truth, lambda q: elastic_search(q['question'], q['claims_type']))

  0%|          | 0/2055 [00:00<?, ?it/s]

{'hit_rate': 0.7732360097323601, 'mrr': 0.616999188969993}

In [45]:
evaluate(ground_truth, lambda q: min_search(q['question'], q['claims_type']))

  0%|          | 0/2055 [00:00<?, ?it/s]

{'hit_rate': 0.8306569343065694, 'mrr': 0.6697891321978915}

## Vector Search

In [48]:
model_name = 'multi-qa-MiniLM-L6-cos-v1'
model = SentenceTransformer(model_name)



In [49]:
v = model.encode("Can I get sick pay if I'm self-isolating?")

In [50]:
len(v)

384

In [85]:
index_settings = {
    "settings": {
        "number_of_shards": 1,
        "number_of_replicas": 0
    },
    "mappings": {
        "properties": {
            "Answer": {"type": "text"},
            "category": {"type": "text"},
            "Question": {"type": "text"},
            "Section": {"type": "keyword"},
            "id": {"type": "keyword"},
            "question_vector": {
                "type": "dense_vector",
                "dims": 384,
                "index": True,
                "similarity": "cosine"
            },
            "answer_vector": {
                "type": "dense_vector",
                "dims": 384,
                "index": True,
                "similarity": "cosine"
            },
            "question_answer_vector": {
                "type": "dense_vector",
                "dims": 384,
                "index": True,
                "similarity": "cosine"
            },
        }
    }
}

index_name = "benefit-claims"

es_client.indices.delete(index=index_name, ignore_unavailable=True)
es_client.indices.create(index=index_name, body=index_settings)

ObjectApiResponse({'acknowledged': True, 'shards_acknowledged': True, 'index': 'benefit-claims'})

In [86]:
documents[0]

{'Category': 'Manage existing benefit',
 'Question': 'How do I update my benefit information?',
 'Answer': 'You can update your benefit information online through your account.',
 'Section': 'general claim benefits',
 'id': '30eada08-5708-5c5c-9df8-0f7d5d4dc131',
 'question_vector': array([-1.80060193e-02,  5.96722253e-02,  1.26943989e-02, -3.90117168e-02,
         2.84865871e-02,  9.22173411e-02,  2.27035414e-02,  2.67567299e-02,
        -3.86952274e-02, -2.83248979e-03, -2.74528917e-02,  2.94101797e-03,
        -3.46847787e-03, -1.18925475e-01,  1.39166079e-02,  3.98957506e-02,
        -1.79248862e-02,  7.15748966e-02, -4.80940826e-02, -3.22361998e-02,
        -1.01363212e-01, -2.29153130e-02, -4.96070758e-02,  4.44727466e-02,
         2.36339159e-02,  6.05530553e-02, -2.41196901e-02,  3.91690135e-02,
        -5.07083908e-03, -1.72626209e-02,  5.26820458e-02, -4.72456887e-02,
         1.85332298e-02, -2.26278715e-02,  3.78295816e-02, -8.19180743e-04,
        -8.93227831e-02,  2.79583

In [87]:
for doc in tqdm(documents):
    question = doc['Question']
    answer = doc['Answer']
    qt = question + ' ' + answer

    doc['question_vector'] = model.encode(question)
    doc['answer_vector'] = model.encode(answer)
    doc['question_answer_vector'] = model.encode(qt)

  0%|          | 0/425 [00:00<?, ?it/s]

In [88]:
for doc in tqdm(documents):
    es_client.index(index=index_name, document=doc)

  0%|          | 0/425 [00:00<?, ?it/s]

In [89]:
query = "Can I get sick pay if I'm self-isolating?"
v_q = model.encode(query)

In [90]:
def elastic_search_knn(field, vector, section):
    knn = {
        "field": field,
        "query_vector": vector,
        "k": 5,
        "num_candidates": 10000,
        "filter": {
            "term": {
                "Section": section
            }
        }
    }

    search_query = {
        "knn": knn,
        "_source": ["Answer", "Section", "Question", "Category", "id"]
    }

    es_results = es_client.search(
        index=index_name,
        body=search_query
    )
    
    result_docs = []
    
    for hit in es_results['hits']['hits']:
        result_docs.append(hit['_source'])

    return result_docs

In [97]:
elastic_search_knn('question_vector', v_q, 'general claim benefits')

[{'Answer': "Yes Statutory Sick Pay is available if you're self-isolating.",
  'Category': 'Temporarily unable to work',
  'Question': "Can I get sick pay if I'm self-isolating?",
  'id': '1de35e0b-f233-554c-84ef-fc30494e0ea0',
  'Section': 'general claim benefits'},
 {'Answer': 'You need to provide a fit note from your doctor to apply for sick pay.',
  'Category': 'Temporarily unable to work',
  'Question': 'How do I apply for sick pay?',
  'id': 'ecc42084-d88e-5bf0-8070-a18552c283bb',
  'Section': 'general claim benefits'},
 {'Answer': "Statutory Sick Pay is a legal requirement for employers to pay you if you're ill.",
  'Category': 'Temporarily unable to work',
  'Question': 'What is statutory sick pay?',
  'id': 'ee6a77ec-d4ef-50f7-9762-04cc228b3a48',
  'Section': 'general claim benefits'},
 {'Answer': 'You may be eligible for certain benefits even if you’re working depending on your income and circumstances.',
  'Category': 'Manage existing benefit',
  'Question': 'Can I receive b

In [98]:
ground_truth[10]

{'question': 'Is it possible to appeal?',
 'claims_type': 'general claim benefits',
 'document': '8d000ade-6c2b-571c-aa61-5d38eb463cf8'}

In [99]:
def question_vector_knn(q):
    question = q['question']
    section = q['claims_type']

    v_q = model.encode(question)

    return elastic_search_knn('question_vector', v_q, section)

In [100]:
def hit_rate(relevance_total):
    cnt = 0

    for line in relevance_total:
        if True in line:
            cnt = cnt + 1

    return cnt / len(relevance_total)

In [101]:
def mrr(relevance_total):
    total_score = 0.0

    for line in relevance_total:
        for rank in range(len(line)):
            if line[rank] == True:
                total_score = total_score + 1 / (rank + 1)

    return total_score / len(relevance_total)

In [102]:
def evaluate(ground_truth, search_function):
    relevance_total = []

    for q in tqdm(ground_truth):
        doc_id = q['document']
        results = search_function(q)
        relevance = [d['id'] == doc_id for d in results]
        relevance_total.append(relevance)

    return {
        'hit_rate': hit_rate(relevance_total),
        'mrr': mrr(relevance_total),
    }

In [103]:
evaluate(ground_truth, question_vector_knn)

  0%|          | 0/2055 [00:00<?, ?it/s]

{'hit_rate': 0.8637469586374696, 'mrr': 0.7441443633414435}

Compared to Minsearch: {'hit_rate': 0.8306569343065694, 'mrr': 0.6697891321978915}

Compared to Elasticsearch Text only: {'hit_rate': 0.7732360097323601, 'mrr':0.6168369829683708}

In [106]:
def answer_vector_knn(q):
    question = q['question']
    section = q['claims_type']

    v_q = model.encode(question)

    return elastic_search_knn('answer_vector', v_q, section)
evaluate(ground_truth, answer_vector_knn)

  0%|          | 0/2055 [00:00<?, ?it/s]

{'hit_rate': 0.8272506082725061, 'mrr': 0.6879399837793996}

In [107]:
def question_answer_vector_knn(q):
    question = q['question']
    section = q['claims_type']

    v_q = model.encode(question)

    return elastic_search_knn('question_answer_vector', v_q, section)

evaluate(ground_truth, question_answer_vector_knn)

  0%|          | 0/2055 [00:00<?, ?it/s]

{'hit_rate': 0.9304136253041363, 'mrr': 0.8123276561232763}

In [109]:
def elastic_search_knn_combined(vector, section):
    search_query = {
        "size": 5,
        "query": {
            "bool": {
                "must": [
                    {
                        "script_score": {
                            "query": {
                                "term": {
                                    "Section": section
                                }
                            },
                            "script": {
                                "source": """
                                    cosineSimilarity(params.query_vector, 'question_vector') + 
                                    cosineSimilarity(params.query_vector, 'answer_vector') + 
                                    cosineSimilarity(params.query_vector, 'question_answer_vector') + 
                                    1
                                """,
                                "params": {
                                    "query_vector": vector
                                }
                            }
                        }
                    }
                ],
                "filter": {
                    "term": {
                        "Section": section
                    }
                }
            }
        },
        "_source": ["Answer", "Section", "Question", "Category", "id"]
    }

    es_results = es_client.search(
        index=index_name,
        body=search_query
    )
    
    result_docs = []
    
    for hit in es_results['hits']['hits']:
        result_docs.append(hit['_source'])

    return result_docs

In [110]:
def vector_combined_knn(q):
    question = q['question']
    section = q['claims_type']

    v_q = model.encode(question)

    return elastic_search_knn_combined(v_q, section)

evaluate(ground_truth, vector_combined_knn)

  0%|          | 0/2055 [00:00<?, ?it/s]

{'hit_rate': 0.9211678832116789, 'mrr': 0.8021573398215729}

question_answer_vector : {'hit_rate': 0.9304136253041363, 'mrr': 0.8123276561232763}