In [1]:
!pip install -q transformers datasets docs-ranking-metrics xformers
!gdown 1D_Hz-4BK3tJB0zG4-SyEyF8b3S6Sqx0z

Downloading...
From: https://drive.google.com/uc?id=1D_Hz-4BK3tJB0zG4-SyEyF8b3S6Sqx0z
To: /content/data_for_rk_model.tsv
100% 3.92G/3.92G [00:43<00:00, 89.6MB/s]


In [2]:
import torch
import random
import numpy as np
import pandas as pd

from transformers import GPT2LMHeadModel, GPT2Tokenizer, pipeline
# from datasets import load_dataset
from tqdm.notebook import tqdm

from docs_ranking_metrics.ranking_metrics import RankingMetrics, Bm25, MsMarcoCE

In [11]:
SEED = 42
NUM_QUERIES = 10
RANKING_MODELS = [Bm25(), MsMarcoCE()]  # , LaBSE(), MsMarcoST()
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
RANKING_MODEL_NAME = 'MsMarcoCE'
RANKING_METRIC_NAME = 'FDARO@v2'

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)
    torch.cuda.manual_seed_all(SEED)
    print(torch.cuda.get_device_name())

Tesla T4


In [43]:
def load_tokenizer_and_model(model_name_or_path):
    return GPT2Tokenizer.from_pretrained(model_name_or_path), GPT2LMHeadModel.from_pretrained(model_name_or_path).to(DEVICE)


def preprocess_generated(text: str) -> str:
  # text = text.replace('\n', ' ').replace('\xa0', ' ')
  return ' '.join(text.split())


def generate(model, text, **kwargs):
  out = model(text, **kwargs)
  return out[0]['generated_text']


def run_experiment(df, **kwargs):
  global model
  rm = RankingMetrics(RANKING_MODELS)
  for row in tqdm(df.itertuples(index=False), total=NUM_QUERIES):
    # passage_text, is_selected, query = row.passages['passage_text'].tolist(), row.passages['is_selected'].tolist(), row.query
    passage_text, is_selected, query = row.body, row.label, row.query

    # generated = generate(model, query + ' ', **kwargs)[len(query):]
    generated = model(query + ' ', pad_token_id=50256, **kwargs)[0]['generated_text'][len(query):]

    passage_text.append(preprocess_generated(generated))
    is_selected.append(RankingMetrics.FAKE_DOC_LABEL)

    rm.update(query, passage_text, is_selected)
  rm.show_metrics()
  return rm.get()

def select_best_param(ranking_model_name, ranking_metric_name, list_of_results):
  return min(list_of_results, key=lambda metrics: metrics[1][f'{ranking_model_name}_{ranking_metric_name}'])[0]

In [5]:
# dataset = load_dataset('ms_marco', 'v1.1', split='test').to_pandas()
dataset = pd.read_csv('data_for_rk_model.tsv', delimiter='\t').groupby('query').agg({'body': list, 'label': list}).reset_index().loc[:NUM_QUERIES]
# tok, model = load_tokenizer_and_model("sberbank-ai/rugpt3medium_based_on_gpt2")
model = pipeline(model="sberbank-ai/rugpt3medium_based_on_gpt2", task='text-generation')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
# res = run_experiment(dataset, max_new_tokens=20)

In [41]:
model('сгенерируй продолжение текста')

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': 'сгенерируй продолжение текста, а то я не могу понять, что ты хочешь сказать.\n'}]

# Parameters that control the length of the output


## `max_length` (int, defaults to 20) 

In [15]:
results = []
for max_length in tqdm([20, 40, 60], desc='max_length', leave=False):
  res = run_experiment(dataset, max_length=max_length)
  results.append((max_length, res))
  print(f'max_length={max_length}')
  print(res)
  print('_'*100)

max_length:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

max_length=20
{'Bm25_AverageLoc': 100.54545454545455, 'MsMarcoCE_AverageLoc': 73.0, 'Bm25_Top@1': 0.0, 'Bm25_Top@3': 0.0, 'Bm25_Top@5': 0.18181818181818182, 'MsMarcoCE_Top@1': 0.18181818181818182, 'MsMarcoCE_Top@3': 1.0, 'MsMarcoCE_Top@5': 1.4545454545454546, 'Bm25_FDARO@v1': 0.09090909090909091, 'Bm25_FDARO@v2': 0.2727272727272727, 'MsMarcoCE_FDARO@v1': 0.36363636363636365, 'MsMarcoCE_FDARO@v2': 0.6363636363636364, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.09090909090909091, 'Bm25_AverageRelLoc': 3.6654121850557644, 'MsMarcoCE_AverageRelLoc': 2.7422546960692333}
____________________________________________________________________________________________________


  0%|          | 0/10 [00:00<?, ?it/s]

max_length=40
{'Bm25_AverageLoc': 127.0, 'MsMarcoCE_AverageLoc': 91.9090909090909, 'Bm25_Top@1': 0.0, 'Bm25_Top@3': 0.0, 'Bm25_Top@5': 0.09090909090909091, 'MsMarcoCE_Top@1': 0.18181818181818182, 'MsMarcoCE_Top@3': 1.0, 'MsMarcoCE_Top@5': 1.6363636363636365, 'Bm25_FDARO@v1': 0.09090909090909091, 'Bm25_FDARO@v2': 0.2727272727272727, 'MsMarcoCE_FDARO@v1': 0.36363636363636365, 'MsMarcoCE_FDARO@v2': 0.6363636363636364, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.09090909090909091, 'Bm25_AverageRelLoc': 4.455484212733425, 'MsMarcoCE_AverageRelLoc': 3.2979416418894187}
____________________________________________________________________________________________________


  0%|          | 0/10 [00:00<?, ?it/s]

max_length=60
{'Bm25_AverageLoc': 154.8181818181818, 'MsMarcoCE_AverageLoc': 108.45454545454545, 'Bm25_Top@1': 0.0, 'Bm25_Top@3': 0.0, 'Bm25_Top@5': 0.09090909090909091, 'MsMarcoCE_Top@1': 0.18181818181818182, 'MsMarcoCE_Top@3': 1.0, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.0, 'Bm25_FDARO@v2': 0.2727272727272727, 'MsMarcoCE_FDARO@v1': 0.36363636363636365, 'MsMarcoCE_FDARO@v2': 0.6363636363636364, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.09090909090909091, 'Bm25_AverageRelLoc': 5.238860708393438, 'MsMarcoCE_AverageRelLoc': 3.7717373826363385}
____________________________________________________________________________________________________


In [16]:
MAX_LENGTH = select_best_param(RANKING_MODEL_NAME, RANKING_METRIC_NAME, results)
MAX_LENGTH

20

## `early_stopping` (bool, defaults to False)

In [17]:
results = []
for early_stopping in tqdm([True, False], desc='early_stopping', leave=False):
  res = run_experiment(dataset, max_length=MAX_LENGTH, early_stopping=True)
  results.append((early_stopping, res))
  print(f'early_stopping={early_stopping}')
  print(res)
  print('_'*100)

early_stopping:   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

early_stopping=True
{'Bm25_AverageLoc': 175.27272727272728, 'MsMarcoCE_AverageLoc': 133.0909090909091, 'Bm25_Top@1': 0.09090909090909091, 'Bm25_Top@3': 0.2727272727272727, 'Bm25_Top@5': 0.45454545454545453, 'MsMarcoCE_Top@1': 0.18181818181818182, 'MsMarcoCE_Top@3': 1.0, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.09090909090909091, 'Bm25_FDARO@v2': 0.2727272727272727, 'MsMarcoCE_FDARO@v1': 0.36363636363636365, 'MsMarcoCE_FDARO@v2': 0.6363636363636364, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.09090909090909091, 'Bm25_AverageRelLoc': 5.715287933561751, 'MsMarcoCE_AverageRelLoc': 4.455075848209962}
____________________________________________________________________________________________________


  0%|          | 0/10 [00:00<?, ?it/s]

early_stopping=False
{'Bm25_AverageLoc': 201.0, 'MsMarcoCE_AverageLoc': 158.72727272727272, 'Bm25_Top@1': 0.09090909090909091, 'Bm25_Top@3': 0.2727272727272727, 'Bm25_Top@5': 0.45454545454545453, 'MsMarcoCE_Top@1': 0.18181818181818182, 'MsMarcoCE_Top@3': 1.0, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.09090909090909091, 'Bm25_FDARO@v2': 0.2727272727272727, 'MsMarcoCE_FDARO@v1': 0.36363636363636365, 'MsMarcoCE_FDARO@v2': 0.6363636363636364, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.09090909090909091, 'Bm25_AverageRelLoc': 6.350292560225925, 'MsMarcoCE_AverageRelLoc': 5.126456187196887}
____________________________________________________________________________________________________


In [18]:
EARLY_STOPPING = select_best_param(RANKING_MODEL_NAME, RANKING_METRIC_NAME, results)
EARLY_STOPPING

True

# Parameters that control the generation strategy used

## `do_sample` (bool, defaults to False)

In [19]:
results = []
for do_sample in tqdm([True, False], desc='do_sample', leave=False):
  res = run_experiment(dataset, max_length=MAX_LENGTH, early_stopping=EARLY_STOPPING, do_sample=do_sample)
  results.append((do_sample, res))
  print(f'do_sample={do_sample}')
  print(res)
  print('_'*100)

do_sample:   0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

do_sample=True
{'Bm25_AverageLoc': 233.1818181818182, 'MsMarcoCE_AverageLoc': 185.45454545454547, 'Bm25_Top@1': 0.09090909090909091, 'Bm25_Top@3': 0.2727272727272727, 'Bm25_Top@5': 0.45454545454545453, 'MsMarcoCE_Top@1': 0.18181818181818182, 'MsMarcoCE_Top@3': 1.0, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.09090909090909091, 'Bm25_FDARO@v2': 0.2727272727272727, 'MsMarcoCE_FDARO@v1': 0.36363636363636365, 'MsMarcoCE_FDARO@v2': 0.6363636363636364, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.09090909090909091, 'Bm25_AverageRelLoc': 7.146054232657438, 'MsMarcoCE_AverageRelLoc': 5.794863732792941}
____________________________________________________________________________________________________


  0%|          | 0/10 [00:00<?, ?it/s]

do_sample=False
{'Bm25_AverageLoc': 261.72727272727275, 'MsMarcoCE_AverageLoc': 213.0909090909091, 'Bm25_Top@1': 0.09090909090909091, 'Bm25_Top@3': 0.2727272727272727, 'Bm25_Top@5': 0.45454545454545453, 'MsMarcoCE_Top@1': 0.18181818181818182, 'MsMarcoCE_Top@3': 1.0, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.09090909090909091, 'Bm25_FDARO@v2': 0.2727272727272727, 'MsMarcoCE_FDARO@v1': 0.36363636363636365, 'MsMarcoCE_FDARO@v2': 0.6363636363636364, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.0, 'Bm25_AverageRelLoc': 7.774391789288405, 'MsMarcoCE_AverageRelLoc': 6.445265198460686}
____________________________________________________________________________________________________


In [20]:
DO_SAMPLE = select_best_param(RANKING_MODEL_NAME, RANKING_METRIC_NAME, results)
DO_SAMPLE

True

## `num_beams` (int, defaults to 1)

In [21]:
results = []
for num_beams in tqdm([1, 3, 5, 7], desc='num_beams', leave=False):
  res = run_experiment(dataset, max_length=MAX_LENGTH, early_stopping=EARLY_STOPPING, do_sample=DO_SAMPLE, num_beams=num_beams)
  results.append((num_beams, res))
  print(f'num_beams={num_beams}')
  print(res)
  print('_'*100)

num_beams:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

num_beams=1
{'Bm25_AverageLoc': 294.3636363636364, 'MsMarcoCE_AverageLoc': 242.45454545454547, 'Bm25_Top@1': 0.09090909090909091, 'Bm25_Top@3': 0.36363636363636365, 'Bm25_Top@5': 0.45454545454545453, 'MsMarcoCE_Top@1': 0.18181818181818182, 'MsMarcoCE_Top@3': 1.0, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.09090909090909091, 'Bm25_FDARO@v2': 0.36363636363636365, 'MsMarcoCE_FDARO@v1': 0.36363636363636365, 'MsMarcoCE_FDARO@v2': 0.7272727272727273, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.0, 'Bm25_AverageRelLoc': 8.48360663628992, 'MsMarcoCE_AverageRelLoc': 7.101290741041473}
____________________________________________________________________________________________________


  0%|          | 0/10 [00:00<?, ?it/s]

num_beams=3
{'Bm25_AverageLoc': 334.0, 'MsMarcoCE_AverageLoc': 274.45454545454544, 'Bm25_Top@1': 0.0, 'Bm25_Top@3': 0.36363636363636365, 'Bm25_Top@5': 0.36363636363636365, 'MsMarcoCE_Top@1': 0.18181818181818182, 'MsMarcoCE_Top@3': 1.0, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.09090909090909091, 'Bm25_FDARO@v2': 0.45454545454545453, 'MsMarcoCE_FDARO@v1': 0.36363636363636365, 'MsMarcoCE_FDARO@v2': 0.7272727272727273, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.0, 'Bm25_AverageRelLoc': 9.364671644754045, 'MsMarcoCE_AverageRelLoc': 7.7937474867065}
____________________________________________________________________________________________________


  0%|          | 0/10 [00:00<?, ?it/s]

num_beams=5
{'Bm25_AverageLoc': 367.90909090909093, 'MsMarcoCE_AverageLoc': 308.1818181818182, 'Bm25_Top@1': 0.09090909090909091, 'Bm25_Top@3': 0.36363636363636365, 'Bm25_Top@5': 0.36363636363636365, 'MsMarcoCE_Top@1': 0.18181818181818182, 'MsMarcoCE_Top@3': 1.0, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.09090909090909091, 'Bm25_FDARO@v2': 0.45454545454545453, 'MsMarcoCE_FDARO@v1': 0.36363636363636365, 'MsMarcoCE_FDARO@v2': 0.7272727272727273, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.0, 'Bm25_AverageRelLoc': 10.037666744073555, 'MsMarcoCE_AverageRelLoc': 8.509942017241583}
____________________________________________________________________________________________________


  0%|          | 0/10 [00:00<?, ?it/s]

num_beams=7
{'Bm25_AverageLoc': 402.54545454545456, 'MsMarcoCE_AverageLoc': 340.27272727272725, 'Bm25_Top@1': 0.18181818181818182, 'Bm25_Top@3': 0.45454545454545453, 'Bm25_Top@5': 0.45454545454545453, 'MsMarcoCE_Top@1': 0.2727272727272727, 'MsMarcoCE_Top@3': 1.0, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.18181818181818182, 'Bm25_FDARO@v2': 0.45454545454545453, 'MsMarcoCE_FDARO@v1': 0.45454545454545453, 'MsMarcoCE_FDARO@v2': 0.7272727272727273, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.0, 'Bm25_AverageRelLoc': 10.681467530234581, 'MsMarcoCE_AverageRelLoc': 9.135087116778069}
____________________________________________________________________________________________________


In [27]:
NUM_BEAMS = select_best_param(RANKING_MODEL_NAME, RANKING_METRIC_NAME, results)
NUM_BEAMS

3

## Parameters for manipulation of the model output logits

## `temperature` (float, defaults to 1.0)

In [28]:
results = []
for temperature in tqdm([1., .95, .9], desc='temperature', leave=False):
  res = run_experiment(dataset, max_length=MAX_LENGTH, early_stopping=EARLY_STOPPING, do_sample=DO_SAMPLE, num_beams=NUM_BEAMS,
                       temperature=temperature)
  results.append((temperature, res))
  print(f'temperature={temperature}')
  print(res)
  print('_'*100)

temperature:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

temperature=1.0
{'Bm25_AverageLoc': 553.1818181818181, 'MsMarcoCE_AverageLoc': 482.27272727272725, 'Bm25_Top@1': 0.2727272727272727, 'Bm25_Top@3': 0.6363636363636364, 'Bm25_Top@5': 0.8181818181818182, 'MsMarcoCE_Top@1': 0.36363636363636365, 'MsMarcoCE_Top@3': 1.0, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.2727272727272727, 'Bm25_FDARO@v2': 0.45454545454545453, 'MsMarcoCE_FDARO@v1': 0.5454545454545454, 'MsMarcoCE_FDARO@v2': 0.8181818181818182, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.0, 'Bm25_AverageRelLoc': 13.242305421223596, 'MsMarcoCE_AverageRelLoc': 11.650491329120335}
____________________________________________________________________________________________________


  0%|          | 0/10 [00:00<?, ?it/s]

temperature=0.95
{'Bm25_AverageLoc': 595.3636363636364, 'MsMarcoCE_AverageLoc': 523.1818181818181, 'Bm25_Top@1': 0.2727272727272727, 'Bm25_Top@3': 0.5454545454545454, 'Bm25_Top@5': 0.7272727272727273, 'MsMarcoCE_Top@1': 0.36363636363636365, 'MsMarcoCE_Top@3': 1.0909090909090908, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.2727272727272727, 'Bm25_FDARO@v2': 0.45454545454545453, 'MsMarcoCE_FDARO@v1': 0.5454545454545454, 'MsMarcoCE_FDARO@v2': 0.8181818181818182, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.0, 'Bm25_AverageRelLoc': 13.916278590581664, 'MsMarcoCE_AverageRelLoc': 12.3286702490998}
____________________________________________________________________________________________________


  0%|          | 0/10 [00:00<?, ?it/s]

temperature=0.9
{'Bm25_AverageLoc': 635.9090909090909, 'MsMarcoCE_AverageLoc': 562.0, 'Bm25_Top@1': 0.2727272727272727, 'Bm25_Top@3': 0.6363636363636364, 'Bm25_Top@5': 0.9090909090909091, 'MsMarcoCE_Top@1': 0.36363636363636365, 'MsMarcoCE_Top@3': 1.0909090909090908, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.2727272727272727, 'Bm25_FDARO@v2': 0.5454545454545454, 'MsMarcoCE_FDARO@v1': 0.5454545454545454, 'MsMarcoCE_FDARO@v2': 0.8181818181818182, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.0, 'Bm25_AverageRelLoc': 14.527960184766169, 'MsMarcoCE_AverageRelLoc': 12.941765699267812}
____________________________________________________________________________________________________


In [29]:
TEMPERATURE = select_best_param(RANKING_MODEL_NAME, RANKING_METRIC_NAME, results)
TEMPERATURE

1.0

## `top_k` (int, defaults to 50)

In [30]:
results = []
for top_k in tqdm([5, 10, 50, 100], desc='top_k', leave=False):
  res = run_experiment(dataset, max_length=MAX_LENGTH, early_stopping=EARLY_STOPPING, do_sample=DO_SAMPLE, num_beams=NUM_BEAMS, temperature=TEMPERATURE,
                       top_k=top_k)
  results.append((top_k, res))
  print(f'top_k={top_k}')
  print(res)
  print('_'*100)

top_k:   0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

top_k=5
{'Bm25_AverageLoc': 680.7272727272727, 'MsMarcoCE_AverageLoc': 602.0909090909091, 'Bm25_Top@1': 0.2727272727272727, 'Bm25_Top@3': 0.6363636363636364, 'Bm25_Top@5': 0.7272727272727273, 'MsMarcoCE_Top@1': 0.36363636363636365, 'MsMarcoCE_Top@3': 1.1818181818181819, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.2727272727272727, 'Bm25_FDARO@v2': 0.6363636363636364, 'MsMarcoCE_FDARO@v1': 0.5454545454545454, 'MsMarcoCE_FDARO@v2': 0.8181818181818182, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.0, 'Bm25_AverageRelLoc': 15.20998184298911, 'MsMarcoCE_AverageRelLoc': 13.550280687086705}
____________________________________________________________________________________________________


  0%|          | 0/10 [00:00<?, ?it/s]

top_k=10
{'Bm25_AverageLoc': 720.0, 'MsMarcoCE_AverageLoc': 645.0909090909091, 'Bm25_Top@1': 0.2727272727272727, 'Bm25_Top@3': 0.7272727272727273, 'Bm25_Top@5': 1.0, 'MsMarcoCE_Top@1': 0.36363636363636365, 'MsMarcoCE_Top@3': 1.1818181818181819, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.2727272727272727, 'Bm25_FDARO@v2': 0.6363636363636364, 'MsMarcoCE_FDARO@v1': 0.5454545454545454, 'MsMarcoCE_FDARO@v2': 0.8181818181818182, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.0, 'Bm25_AverageRelLoc': 15.739342058026592, 'MsMarcoCE_AverageRelLoc': 14.19660036755531}
____________________________________________________________________________________________________


  0%|          | 0/10 [00:00<?, ?it/s]

top_k=50
{'Bm25_AverageLoc': 768.2727272727273, 'MsMarcoCE_AverageLoc': 690.6363636363636, 'Bm25_Top@1': 0.2727272727272727, 'Bm25_Top@3': 0.6363636363636364, 'Bm25_Top@5': 0.9090909090909091, 'MsMarcoCE_Top@1': 0.36363636363636365, 'MsMarcoCE_Top@3': 1.1818181818181819, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.2727272727272727, 'Bm25_FDARO@v2': 0.6363636363636364, 'MsMarcoCE_FDARO@v1': 0.5454545454545454, 'MsMarcoCE_FDARO@v2': 0.8181818181818182, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.0, 'Bm25_AverageRelLoc': 16.439894925366115, 'MsMarcoCE_AverageRelLoc': 14.865312102809208}
____________________________________________________________________________________________________


  0%|          | 0/10 [00:00<?, ?it/s]

top_k=100
{'Bm25_AverageLoc': 816.7272727272727, 'MsMarcoCE_AverageLoc': 732.4545454545455, 'Bm25_Top@1': 0.18181818181818182, 'Bm25_Top@3': 0.5454545454545454, 'Bm25_Top@5': 0.8181818181818182, 'MsMarcoCE_Top@1': 0.36363636363636365, 'MsMarcoCE_Top@3': 1.1818181818181819, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.2727272727272727, 'Bm25_FDARO@v2': 0.6363636363636364, 'MsMarcoCE_FDARO@v1': 0.5454545454545454, 'MsMarcoCE_FDARO@v2': 0.8181818181818182, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.0, 'Bm25_AverageRelLoc': 17.109833229390237, 'MsMarcoCE_AverageRelLoc': 15.43094366280264}
____________________________________________________________________________________________________


In [31]:
TOP_K = select_best_param(RANKING_MODEL_NAME, RANKING_METRIC_NAME, results)
TOP_K

5

## `top_p` (float, defaults to 1.0)

In [32]:
results = []
for top_p in tqdm([1., .95, .9], desc='top_p', leave=False):
  res = run_experiment(dataset, max_length=MAX_LENGTH, early_stopping=EARLY_STOPPING, do_sample=DO_SAMPLE, num_beams=NUM_BEAMS, temperature=TEMPERATURE, top_k=TOP_K,
                       top_p=top_p)
  results.append((top_p, res))
  print(f'top_p={top_p}')
  print(res)
  print('_'*100)

top_p:   0%|          | 0/3 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

top_p=1.0
{'Bm25_AverageLoc': 847.0, 'MsMarcoCE_AverageLoc': 775.3636363636364, 'Bm25_Top@1': 0.2727272727272727, 'Bm25_Top@3': 0.7272727272727273, 'Bm25_Top@5': 1.0909090909090908, 'MsMarcoCE_Top@1': 0.36363636363636365, 'MsMarcoCE_Top@3': 1.1818181818181819, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.2727272727272727, 'Bm25_FDARO@v2': 0.6363636363636364, 'MsMarcoCE_FDARO@v1': 0.5454545454545454, 'MsMarcoCE_FDARO@v2': 0.8181818181818182, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.0, 'Bm25_AverageRelLoc': 17.36743016451619, 'MsMarcoCE_AverageRelLoc': 15.996789016315345}
____________________________________________________________________________________________________


  0%|          | 0/10 [00:00<?, ?it/s]

top_p=0.95
{'Bm25_AverageLoc': 909.7272727272727, 'MsMarcoCE_AverageLoc': 819.0909090909091, 'Bm25_Top@1': 0.18181818181818182, 'Bm25_Top@3': 0.5454545454545454, 'Bm25_Top@5': 0.8181818181818182, 'MsMarcoCE_Top@1': 0.36363636363636365, 'MsMarcoCE_Top@3': 1.1818181818181819, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.2727272727272727, 'Bm25_FDARO@v2': 0.7272727272727273, 'MsMarcoCE_FDARO@v1': 0.5454545454545454, 'MsMarcoCE_FDARO@v2': 0.8181818181818182, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.0, 'Bm25_AverageRelLoc': 18.289910710607675, 'MsMarcoCE_AverageRelLoc': 16.555582219990665}
____________________________________________________________________________________________________


  0%|          | 0/10 [00:00<?, ?it/s]

top_p=0.9
{'Bm25_AverageLoc': 945.8181818181819, 'MsMarcoCE_AverageLoc': 865.0, 'Bm25_Top@1': 0.2727272727272727, 'Bm25_Top@3': 0.7272727272727273, 'Bm25_Top@5': 1.0, 'MsMarcoCE_Top@1': 0.36363636363636365, 'MsMarcoCE_Top@3': 1.1818181818181819, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.2727272727272727, 'Bm25_FDARO@v2': 0.5454545454545454, 'MsMarcoCE_FDARO@v1': 0.5454545454545454, 'MsMarcoCE_FDARO@v2': 0.8181818181818182, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.0, 'Bm25_AverageRelLoc': 18.630024986344495, 'MsMarcoCE_AverageRelLoc': 17.134636697898856}
____________________________________________________________________________________________________


In [33]:
TOP_P = select_best_param(RANKING_MODEL_NAME, RANKING_METRIC_NAME, results)
TOP_P

1.0

## `no_repeat_ngram_size` (int, defaults to 0)

In [34]:
results = []
for no_repeat_ngram_size in tqdm(range(4), leave=False):
  res = run_experiment(dataset, max_length=MAX_LENGTH, early_stopping=EARLY_STOPPING, do_sample=DO_SAMPLE, num_beams=NUM_BEAMS, temperature=TEMPERATURE, top_k=TOP_K, top_p=TOP_P,
                       no_repeat_ngram_size=no_repeat_ngram_size)
  results.append((no_repeat_ngram_size, res))
  print(f'no_repeat_ngram_size={no_repeat_ngram_size}')
  print(res)
  print('_'*100)

  0%|          | 0/4 [00:00<?, ?it/s]

  0%|          | 0/10 [00:00<?, ?it/s]

no_repeat_ngram_size=0
{'Bm25_AverageLoc': 992.9090909090909, 'MsMarcoCE_AverageLoc': 911.1818181818181, 'Bm25_Top@1': 0.2727272727272727, 'Bm25_Top@3': 0.7272727272727273, 'Bm25_Top@5': 1.0909090909090908, 'MsMarcoCE_Top@1': 0.36363636363636365, 'MsMarcoCE_Top@3': 1.1818181818181819, 'MsMarcoCE_Top@5': 2.0, 'Bm25_FDARO@v1': 0.2727272727272727, 'Bm25_FDARO@v2': 0.6363636363636364, 'MsMarcoCE_FDARO@v1': 0.5454545454545454, 'MsMarcoCE_FDARO@v2': 0.8181818181818182, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.0, 'Bm25_AverageRelLoc': 19.180547692582966, 'MsMarcoCE_AverageRelLoc': 17.698009149869645}
____________________________________________________________________________________________________


  0%|          | 0/10 [00:00<?, ?it/s]

no_repeat_ngram_size=1
{'Bm25_AverageLoc': 1045.090909090909, 'MsMarcoCE_AverageLoc': 958.8181818181819, 'Bm25_Top@1': 0.2727272727272727, 'Bm25_Top@3': 0.7272727272727273, 'Bm25_Top@5': 1.0909090909090908, 'MsMarcoCE_Top@1': 0.36363636363636365, 'MsMarcoCE_Top@3': 1.2727272727272727, 'MsMarcoCE_Top@5': 2.090909090909091, 'Bm25_FDARO@v1': 0.2727272727272727, 'Bm25_FDARO@v2': 0.6363636363636364, 'MsMarcoCE_FDARO@v1': 0.5454545454545454, 'MsMarcoCE_FDARO@v2': 0.8181818181818182, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.0, 'Bm25_AverageRelLoc': 19.80887118464417, 'MsMarcoCE_AverageRelLoc': 18.2712580716067}
____________________________________________________________________________________________________


  0%|          | 0/10 [00:00<?, ?it/s]

no_repeat_ngram_size=2
{'Bm25_AverageLoc': 1096.090909090909, 'MsMarcoCE_AverageLoc': 1006.7272727272727, 'Bm25_Top@1': 0.2727272727272727, 'Bm25_Top@3': 0.7272727272727273, 'Bm25_Top@5': 1.0909090909090908, 'MsMarcoCE_Top@1': 0.36363636363636365, 'MsMarcoCE_Top@3': 1.2727272727272727, 'MsMarcoCE_Top@5': 2.090909090909091, 'Bm25_FDARO@v1': 0.2727272727272727, 'Bm25_FDARO@v2': 0.6363636363636364, 'MsMarcoCE_FDARO@v1': 0.5454545454545454, 'MsMarcoCE_FDARO@v2': 0.8181818181818182, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.0, 'Bm25_AverageRelLoc': 20.389469319143583, 'MsMarcoCE_AverageRelLoc': 18.82476432815309}
____________________________________________________________________________________________________


  0%|          | 0/10 [00:00<?, ?it/s]

no_repeat_ngram_size=3
{'Bm25_AverageLoc': 1148.3636363636363, 'MsMarcoCE_AverageLoc': 1055.8181818181818, 'Bm25_Top@1': 0.2727272727272727, 'Bm25_Top@3': 0.7272727272727273, 'Bm25_Top@5': 1.0909090909090908, 'MsMarcoCE_Top@1': 0.36363636363636365, 'MsMarcoCE_Top@3': 1.2727272727272727, 'MsMarcoCE_Top@5': 2.090909090909091, 'Bm25_FDARO@v1': 0.2727272727272727, 'Bm25_FDARO@v2': 0.6363636363636364, 'MsMarcoCE_FDARO@v1': 0.5454545454545454, 'MsMarcoCE_FDARO@v2': 0.8181818181818182, 'Bm25_UpQuartile': 0.0, 'MsMarcoCE_UpQuartile': 0.0, 'Bm25_AverageRelLoc': 20.971917823009825, 'MsMarcoCE_AverageRelLoc': 19.379597668049296}
____________________________________________________________________________________________________


In [35]:
NO_REPEAT_NGRAM_SIZE = select_best_param(RANKING_MODEL_NAME, RANKING_METRIC_NAME, results)
NO_REPEAT_NGRAM_SIZE

0

In [36]:
print(
f'''max_length={MAX_LENGTH}, early_stopping={EARLY_STOPPING}, do_sample={DO_SAMPLE}, num_beams={NUM_BEAMS}, temperature={TEMPERATURE},
  top_k={TOP_K}, top_p={TOP_P}, no_repeat_ngram_size={NO_REPEAT_NGRAM_SIZE}
''')

max_length=20, early_stopping=True, do_sample=True, num_beams=3, temperature=1.0,
  top_k=5, top_p=1.0, no_repeat_ngram_size=0

