# OpenBook DeBERTaV3-Large with an updated model

This work is based on the great [work](https://www.kaggle.com/code/nlztrk/openbook-debertav3-large-baseline-single-model) of [nlztrk](https://www.kaggle.com/nlztrk).

I trained a model offline using the dataset I shared [here](https://www.kaggle.com/datasets/mgoksu/llm-science-exam-dataset-w-context). I just added my model to the original notebook. The model is available [here](https://www.kaggle.com/datasets/mgoksu/llm-science-run-context-2).

I also addressed the problem of [CSV Not Found at submission](https://www.kaggle.com/competitions/kaggle-llm-science-exam/discussion/434228) with this notebook by clipping the context like so:

`test_df["prompt"] = test_df["context"].apply(lambda x: x[:1500]) + " #### " +  test_df["prompt"]`

You can probably get more than 1500 without getting an OOM.

In [1]:
from __future__ import annotations
import os
import gc
import pandas as pd
import numpy as np
import re
from tqdm.auto import tqdm
import blingfire as bf

from collections.abc import Iterable

import faiss
from faiss import write_index, read_index

from sentence_transformers import SentenceTransformer
from torch.cuda.amp import autocast
import torch
import ctypes
libc = ctypes.CDLL("libc.so.6")

from dataclasses import dataclass
from typing import Optional, Union

import torch
import numpy as np
import pandas as pd
from datasets import Dataset
from transformers import AutoTokenizer
from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def process_documents(documents: Iterable[str],
                      document_ids: Iterable,
                      split_sentences: bool = True,
                      filter_len: int = 3,
                      disable_progress_bar: bool = False) -> pd.DataFrame:
    """
    Main helper function to process documents from the EMR.

    :param documents: Iterable containing documents which are strings
    :param document_ids: Iterable containing document unique identifiers
    :param document_type: String denoting the document type to be processed
    :param document_sections: List of sections for a given document type to process
    :param split_sentences: Flag to determine whether to further split sections into sentences
    :param filter_len: Minimum character length of a sentence (otherwise filter out)
    :param disable_progress_bar: Flag to disable tqdm progress bar
    :return: Pandas DataFrame containing the columns `document_id`, `text`, `section`, `offset`
    """
    
    df = sectionize_documents(documents, document_ids, disable_progress_bar)

    if split_sentences:
        df = sentencize(df.text.values, 
                        df.document_id.values,
                        df.offset.values, 
                        filter_len, 
                        disable_progress_bar)
    return df


def sectionize_documents(documents: Iterable[str],
                         document_ids: Iterable,
                         disable_progress_bar: bool = False) -> pd.DataFrame:
    """
    Obtains the sections of the imaging reports and returns only the 
    selected sections (defaults to FINDINGS, IMPRESSION, and ADDENDUM).

    :param documents: Iterable containing documents which are strings
    :param document_ids: Iterable containing document unique identifiers
    :param disable_progress_bar: Flag to disable tqdm progress bar
    :return: Pandas DataFrame containing the columns `document_id`, `text`, `offset`
    """
    processed_documents = []
    for document_id, document in tqdm(zip(document_ids, documents), total=len(documents), disable=disable_progress_bar):
        row = {}
        text, start, end = (document, 0, len(document))
        row['document_id'] = document_id
        row['text'] = text
        row['offset'] = (start, end)

        processed_documents.append(row)

    _df = pd.DataFrame(processed_documents)
    if _df.shape[0] > 0:
        return _df.sort_values(['document_id', 'offset']).reset_index(drop=True)
    else:
        return _df


def sentencize(documents: Iterable[str],
               document_ids: Iterable,
               offsets: Iterable[tuple[int, int]],
               filter_len: int = 3,
               disable_progress_bar: bool = False) -> pd.DataFrame:
    """
    Split a document into sentences. Can be used with `sectionize_documents`
    to further split documents into more manageable pieces. Takes in offsets
    to ensure that after splitting, the sentences can be matched to the
    location in the original documents.

    :param documents: Iterable containing documents which are strings
    :param document_ids: Iterable containing document unique identifiers
    :param offsets: Iterable tuple of the start and end indices
    :param filter_len: Minimum character length of a sentence (otherwise filter out)
    :return: Pandas DataFrame containing the columns `document_id`, `text`, `section`, `offset`
    """

    document_sentences = []
    for document, document_id, offset in tqdm(zip(documents, document_ids, offsets), total=len(documents), disable=disable_progress_bar):
        try:
            _, sentence_offsets = bf.text_to_sentences_and_offsets(document)
            for o in sentence_offsets:
                if o[1]-o[0] > filter_len:
                    sentence = document[o[0]:o[1]]
                    abs_offsets = (o[0]+offset[0], o[1]+offset[0])
                    row = {}
                    row['document_id'] = document_id
                    row['text'] = sentence
                    row['offset'] = abs_offsets
                    document_sentences.append(row)
        except:
            continue
    return pd.DataFrame(document_sentences)

In [3]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
DEVICE = 'cuda'
MAX_LENGTH = 512
BATCH_SIZE = 16
BERT_PATH = "/root/bert_path/sentence-transformers_all-MiniLM-L6-v2"
MODEL_PATH = "./save/recall/recall_epoch100.bin"
WIKI_PATH = "./wiki_data"
wiki_files = os.listdir(WIKI_PATH)

In [4]:
import torch.nn as nn
import torch

from transformers import AutoModel, AutoTokenizer
class RecallModel(nn.Module):
    def __init__(self):
        super(RecallModel, self).__init__()
        self.bert_model = AutoModel.from_pretrained(BERT_PATH)
    
    def mask_mean(self, x, mask=None):
        if mask != None:
            mask_x = x * (mask.unsqueeze(-1))
            x_sum = torch.sum(mask_x, dim=1)
            re_x = torch.div(x_sum, torch.sum(mask, dim=1).unsqueeze(-1))
        else:
            x_sum = torch.sum(x, dim=1)
            re_x = torch.div(x_sum, x.size()[1])
        return re_x
    
    def forward(self,input_ids):
        attention_mask = input_ids > 0
        out = self.bert_model(input_ids, attention_mask=attention_mask).last_hidden_state
        x = out[:,0,:]
        return x


# Relevant Title Retrieval

In [5]:
trn = pd.read_csv("./data/train.csv")
trn['prompt_answer'] = trn.apply(lambda row : ' '.join(str(row[x]) for x in ['prompt', 'A', 'B', 'C', 'D', 'E']),axis=1)
trn.head()

Unnamed: 0,id,prompt,A,B,C,D,E,answer,prompt_answer
0,0,Which of the following statements accurately d...,MOND is a theory that reduces the observed mis...,MOND is a theory that increases the discrepanc...,MOND is a theory that explains the missing bar...,MOND is a theory that reduces the discrepancy ...,MOND is a theory that eliminates the observed ...,D,Which of the following statements accurately d...
1,1,Which of the following is an accurate definiti...,Dynamic scaling refers to the evolution of sel...,Dynamic scaling refers to the non-evolution of...,Dynamic scaling refers to the evolution of sel...,Dynamic scaling refers to the non-evolution of...,Dynamic scaling refers to the evolution of sel...,A,Which of the following is an accurate definiti...
2,2,Which of the following statements accurately d...,The triskeles symbol was reconstructed as a fe...,The triskeles symbol is a representation of th...,The triskeles symbol is a representation of a ...,The triskeles symbol represents three interloc...,The triskeles symbol is a representation of th...,A,Which of the following statements accurately d...
3,3,What is the significance of regularization in ...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,C,What is the significance of regularization in ...
4,4,Which of the following statements accurately d...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,D,Which of the following statements accurately d...


In [6]:
from functools import partial
from torch.utils.data import DataLoader
dataloader_class = partial(DataLoader, pin_memory=True, num_workers=4)
model= RecallModel()
model.load_state_dict(torch.load(MODEL_PATH, map_location='cpu'),strict=False)
model.to(DEVICE)
model = torch.nn.parallel.DataParallel(model)
model.eval()

DataParallel(
  (module): RecallModel(
    (bert_model): BertModel(
      (embeddings): BertEmbeddings(
        (word_embeddings): Embedding(30522, 384, padding_idx=0)
        (position_embeddings): Embedding(512, 384)
        (token_type_embeddings): Embedding(2, 384)
        (LayerNorm): LayerNorm((384,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): BertEncoder(
        (layer): ModuleList(
          (0-5): 6 x BertLayer(
            (attention): BertAttention(
              (self): BertSelfAttention(
                (query): Linear(in_features=384, out_features=384, bias=True)
                (key): Linear(in_features=384, out_features=384, bias=True)
                (value): Linear(in_features=384, out_features=384, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): BertSelfOutput(
                (dense): Linear(in_features=384, out_features=384, bias=True)
   

In [7]:
from tqdm.auto import tqdm
class LLMRecallDataSet(torch.utils.data.Dataset):
    def __init__(self, data, col):
        self.tokenizer = AutoTokenizer.from_pretrained(BERT_PATH, use_fast=True)
        self.data = data
        self.col = col
    def __len__(self):
        return len(self.data) 
    
    def __getitem__(self,index):
        inputs = self.data.loc[index, self.col]
        if len(inputs) > 4000:
            inputs = inputs[:4000]
        inputs = self.tokenizer.encode(inputs, add_special_tokens=False)
        if len(inputs) > 510:
            inputs = [101] + inputs[:510] + [102]
        else:
            inputs = [101] + inputs + [102]
        return inputs
    
    def collate_fn(self, batch):
        def sequence_padding(inputs, length=None, padding=0):
            """
            Numpy函数，将序列padding到同一长度
            """
            if length is None:
                length = max([len(x) for x in inputs])

            pad_width = [(0, 0) for _ in np.shape(inputs[0])]
            outputs = []
            for x in inputs:
                x = x[:length]
                pad_width[0] = (0, length - len(x))
                x = np.pad(x, pad_width, 'constant', constant_values=padding)
                outputs.append(x)

            return np.array(outputs, dtype='int64')
        batch_ids = torch.tensor(sequence_padding(batch), dtype=torch.long)
        
        return batch_ids

        
class DataLoaderX(torch.utils.data.DataLoader):
    '''
        replace DataLoader with PrefetchDataLoader
    '''
    def __iter__(self):
        return BackgroundGenerator(super().__iter__())  

    
def get_loader(prompt,col,batch_size,train_mode=True,num_workers=4):
    ds_df = LLMRecallDataSet(prompt,col)
    # loader = DataLoaderX(ds_df, batch_size=batch_size if train_mode else batch_size//2, shuffle=train_mode, num_workers=num_workers,pin_memory=True,
    #                                      collate_fn=ds_df.collate_fn, drop_last=train_mode)
    loader = dataloader_class(ds_df, batch_size=batch_size, shuffle=False,collate_fn=ds_df.collate_fn)
    loader.num = len(ds_df)
    return loader

In [8]:
sentence_index = read_index("./wiki_index/my_index.bin")
# sentence_index.num_threads = 50

In [9]:
# sentence_index = faiss.index_cpu_to_all_gpus(sentence_index)

In [10]:
from prefetch_generator import BackgroundGenerator
loader = get_loader(trn, 'prompt_answer',512, False)
prompt_embeddings = []
with torch.no_grad():
    for batch in tqdm(loader):
        batch = batch.to(DEVICE)
        with autocast():
            output = model(batch).cpu().detach().numpy()
        faiss.normalize_L2(output)
        prompt_embeddings.append(output)
prompt_embeddings = np.concatenate(prompt_embeddings, axis=0)
# model = SentenceTransformer('./sentence-transformer', device='cuda')
# model.max_seq_length = 512
# model = model.half()
# prompt_embeddings = model.encode(trn.prompt_answer,
#                                     batch_size=BATCH_SIZE,
#                                     device=DEVICE,
#                                     show_progress_bar=True,
#                                     convert_to_tensor=True,
#                                     normalize_embeddings=True)
# prompt_embeddings = prompt_embeddings.detach().cpu().numpy()

  0%|          | 0/1 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (596 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (596 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 1/1 [00:01<00:00,  1.65s/it]


In [11]:
_ = gc.collect()

In [12]:
prompt_embeddings.shape

(200, 384)

In [13]:
scores, indexs = [], []
subarrays = np.array_split(prompt_embeddings, 20)
for item in tqdm(subarrays):
    
    search_score, search_index = sentence_index.search(item, 20)
    scores.append(search_score)
    indexs.append(search_index)

100%|██████████| 20/20 [00:55<00:00,  2.76s/it]


In [14]:
search_score = np.concatenate(scores, axis=0)
search_index = np.concatenate(indexs, axis=0)

In [15]:
## Save memory - delete sentence_index since it is no longer necessary
del sentence_index
del prompt_embeddings
_ = gc.collect()
libc.malloc_trim(0)
torch.cuda.empty_cache()

In [16]:
search_index[0]

array([3573843, 5094814,  503590,  503595, 1228442, 1297399, 1429137,
       2242104, 1635304, 4656736, 1059536, 1228455, 3422180, 1814394,
       4906500,  503593, 6197925, 3117941, 2925974, 4659517])

# Getting Sentences from the Relevant Titles

In [17]:
df = pd.read_parquet("./wiki_data/my_index.parquet",
                     columns=['id', 'file'])

In [18]:
# df.drop_duplicates(subset='id',keep='first', inplace=True)
# df['id'] = df['id'].apply(lambda x : int(x))

In [19]:
## Get the article and associated file location using the index
wikipedia_file_data = []

for i, (scr, idx) in tqdm(enumerate(zip(search_score, search_index)), total=len(search_score)):
    scr_idx = idx
    _df = df.loc[scr_idx].copy()
    _df['prompt_id'] = i
    wikipedia_file_data.append(_df)
wikipedia_file_data = pd.concat(wikipedia_file_data).reset_index(drop=True)
wikipedia_file_data = wikipedia_file_data[['id', 'prompt_id', 'file']].drop_duplicates().sort_values(['file', 'id']).reset_index(drop=True)

## Save memory - delete df since it is no longer necessary
del df
_ = gc.collect()
libc.malloc_trim(0)

100%|██████████| 200/200 [00:00<00:00, 2141.74it/s]


1

In [20]:
wikipedia_file_data

Unnamed: 0,id,prompt_id,file
0,1014414,189,a.parquet
1,1068478,198,a.parquet
2,1141,10,a.parquet
3,1141,36,a.parquet
4,1141,151,a.parquet
...,...,...,...
3995,27554141,16,z.parquet
3996,40874497,73,z.parquet
3997,4859028,165,z.parquet
3998,517682,6,z.parquet


In [21]:
## Get the full text data
wiki_text_data = []

for file in tqdm(wikipedia_file_data.file.unique(), total=len(wikipedia_file_data.file.unique())):
    _id = [str(i) for i in wikipedia_file_data[wikipedia_file_data['file']==file]['id'].tolist()]
    _df = pd.read_parquet(f"{WIKI_PATH}/{file}", columns=['id', 'text'])

    _df_temp = _df[_df['id'].isin(_id)].copy()
    del _df
    _ = gc.collect()
    libc.malloc_trim(0)
    wiki_text_data.append(_df_temp)
wiki_text_data = pd.concat(wiki_text_data).drop_duplicates().reset_index(drop=True)
_ = gc.collect()

100%|██████████| 28/28 [01:41<00:00,  3.62s/it]


In [22]:
## Parse documents into sentences
processed_wiki_text_data = process_documents(wiki_text_data.text.values, wiki_text_data.id.values)

100%|██████████| 3463/3463 [00:00<00:00, 838086.36it/s]
100%|██████████| 3463/3463 [00:11<00:00, 312.80it/s]


In [23]:
processed_wiki_text_data

Unnamed: 0,document_id,text,offset
0,10004409,Shear strength is a term used in soil mechanic...,"(0, 118)"
1,10004409,The shear resistance of soil is a result of fr...,"(119, 260)"
2,10004409,"Due to interlocking, particulate material may ...","(261, 370)"
3,10004409,"If soil expands its volume, the density of par...","(371, 551)"
4,10004409,The stress-strain relationship levels off when...,"(552, 684)"
...,...,...,...
170404,9992916,"RingGo is a pay by phone parking service, base...","(0, 126)"
170405,9992916,The system is used by local authorities for on...,"(127, 210)"
170406,9992916,The smartphone application requires that users...,"(211, 362)"
170407,9992916,It has also been suggested that councils that ...,"(363, 491)"


In [24]:
model= RecallModel()
model.load_state_dict(torch.load('./save/recall_sentence/recall_epoch100.bin', map_location='cpu'),strict=False)
model.to(DEVICE)
model = torch.nn.parallel.DataParallel(model)
model.eval()
loader = get_loader(processed_wiki_text_data, 'text',32, False)
wiki_data_embeddings = []
with torch.no_grad():
    for batch in tqdm(loader):
        batch = batch.to(DEVICE)
        with autocast():
            output = model(batch).cpu().detach().numpy()
        faiss.normalize_L2(output)
        wiki_data_embeddings.append(output)
wiki_data_embeddings = np.concatenate(wiki_data_embeddings, axis=0)
# model = SentenceTransformer('./sentence-transformer', device='cuda')
# model.max_seq_length = 512
# model = model.half()
# wiki_data_embeddings = model.encode(processed_wiki_text_data.text,
#                                     batch_size=BATCH_SIZE,
#                                     device=DEVICE,
#                                     show_progress_bar=True,
#                                     convert_to_tensor=True,
#                                     normalize_embeddings=True)
# wiki_data_embeddings = wiki_data_embeddings.detach().cpu().numpy()
# wiki_data_embeddings = []
# for i in range(8):
#     wiki_data_embeddings.append(np.load(f'./tmp/{i}.npy'))
# wiki_data_embeddings = np.concatenate(wiki_data_embeddings,axis=0)

  0%|          | 0/5326 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (593 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (593 > 512). Running this sequence through the model will result in indexing errors
  0%|          | 23/5326 [00:00<02:11, 40.46it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1216 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1971 > 512). Running this sequence through the model will result in indexing errors
  1%|          | 62/5326 [00:01<00:58, 90.46it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (703 > 512). Running this sequence

In [25]:
_ = gc.collect()

In [26]:
## Combine all answers
trn['answer_all'] = trn.apply(lambda x: " ".join([str(x['A']), str(x['B']), str(x['C']), str(x['D']), str(x['E'])]), axis=1)


## Search using the prompt and answers to guide the search
trn['prompt_answer_stem'] = trn['prompt'] + " " + trn['answer_all']

In [27]:
trn.loc[0,'prompt_answer_stem']

'Which of the following statements accurately describes the impact of Modified Newtonian Dynamics (MOND) on the observed "missing baryonic mass" discrepancy in galaxy clusters? MOND is a theory that reduces the observed missing baryonic mass in galaxy clusters by postulating the existence of a new form of matter called "fuzzy dark matter." MOND is a theory that increases the discrepancy between the observed missing baryonic mass in galaxy clusters and the measured velocity dispersions from a factor of around 10 to a factor of about 20. MOND is a theory that explains the missing baryonic mass in galaxy clusters that was previously considered dark matter by demonstrating that the mass is in the form of neutrinos and axions. MOND is a theory that reduces the discrepancy between the observed missing baryonic mass in galaxy clusters and the measured velocity dispersions from a factor of around 10 to a factor of about 2. MOND is a theory that eliminates the observed missing baryonic mass in 

In [28]:
loader = get_loader(trn, 'prompt_answer_stem',32, False)
question_embeddings = []
with torch.no_grad():
    for batch in tqdm(loader):
        batch = batch.to(DEVICE)
        with autocast():
            output = model(batch).cpu().detach().numpy()
        faiss.normalize_L2(output)
        question_embeddings.append(output)
question_embeddings = np.concatenate(question_embeddings, axis=0)

# question_embeddings = model.encode(trn.prompt_answer_stem.values, batch_size=BATCH_SIZE, device=DEVICE, show_progress_bar=True, convert_to_tensor=True, normalize_embeddings=True)
# question_embeddings = question_embeddings.detach().cpu().numpy()

  0%|          | 0/7 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (596 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (596 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 7/7 [00:00<00:00,  8.25it/s]


# Extracting Matching Prompt-Sentence Pairs

In [29]:
trn

Unnamed: 0,id,prompt,A,B,C,D,E,answer,prompt_answer,answer_all,prompt_answer_stem
0,0,Which of the following statements accurately d...,MOND is a theory that reduces the observed mis...,MOND is a theory that increases the discrepanc...,MOND is a theory that explains the missing bar...,MOND is a theory that reduces the discrepancy ...,MOND is a theory that eliminates the observed ...,D,Which of the following statements accurately d...,MOND is a theory that reduces the observed mis...,Which of the following statements accurately d...
1,1,Which of the following is an accurate definiti...,Dynamic scaling refers to the evolution of sel...,Dynamic scaling refers to the non-evolution of...,Dynamic scaling refers to the evolution of sel...,Dynamic scaling refers to the non-evolution of...,Dynamic scaling refers to the evolution of sel...,A,Which of the following is an accurate definiti...,Dynamic scaling refers to the evolution of sel...,Which of the following is an accurate definiti...
2,2,Which of the following statements accurately d...,The triskeles symbol was reconstructed as a fe...,The triskeles symbol is a representation of th...,The triskeles symbol is a representation of a ...,The triskeles symbol represents three interloc...,The triskeles symbol is a representation of th...,A,Which of the following statements accurately d...,The triskeles symbol was reconstructed as a fe...,Which of the following statements accurately d...
3,3,What is the significance of regularization in ...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,Regularizing the mass-energy of an electron wi...,C,What is the significance of regularization in ...,Regularizing the mass-energy of an electron wi...,What is the significance of regularization in ...
4,4,Which of the following statements accurately d...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,The angular spacing of features in the diffrac...,D,Which of the following statements accurately d...,The angular spacing of features in the diffrac...,Which of the following statements accurately d...
...,...,...,...,...,...,...,...,...,...,...,...
195,195,What is the relation between the three moment ...,The three moment theorem expresses the relatio...,The three moment theorem is used to calculate ...,The three moment theorem describes the relatio...,The three moment theorem is used to calculate ...,The three moment theorem is used to derive the...,C,What is the relation between the three moment ...,The three moment theorem expresses the relatio...,What is the relation between the three moment ...
196,196,"What is the throttling process, and why is it ...",The throttling process is a steady flow of a f...,The throttling process is a steady adiabatic f...,The throttling process is a steady adiabatic f...,The throttling process is a steady flow of a f...,The throttling process is a steady adiabatic f...,B,"What is the throttling process, and why is it ...",The throttling process is a steady flow of a f...,"What is the throttling process, and why is it ..."
197,197,What happens to excess base metal as a solutio...,"The excess base metal will often solidify, bec...",The excess base metal will often crystallize-o...,"The excess base metal will often dissolve, bec...","The excess base metal will often liquefy, beco...","The excess base metal will often evaporate, be...",B,What happens to excess base metal as a solutio...,"The excess base metal will often solidify, bec...",What happens to excess base metal as a solutio...
198,198,"What is the relationship between mass, force, ...",Mass is a property that determines the weight ...,Mass is an inertial property that determines a...,Mass is an inertial property that determines a...,Mass is an inertial property that determines a...,Mass is a property that determines the size of...,D,"What is the relationship between mass, force, ...",Mass is a property that determines the weight ...,"What is the relationship between mass, force, ..."


In [30]:
# def save(data, path):
#     import pickle as pkl
#     with open(path, 'wb') as f:
#         pkl.dump(data, f)
# save(processed_wiki_text_data, './tmp/processed_wiki_text_data.pkl')
# save(wikipedia_file_data, './tmp/wikipedia_file_data.pkl')
# save(question_embeddings, './tmp/question_embeddings.pkl')
# save(wiki_data_embeddings,'./tmp/wiki_data_embeddings.pkl')

In [31]:
# subs = np.array_split(trn, 10)
# for idx, sub in enumerate(subs):
#     sub = sub.reset_index(drop=True)
#     sub.to_parquet(f'./tmp/{idx}.parquet')

In [32]:
# def load(path):
#     import pickle as pkl
#     with open(path, 'rb') as f:
#         res = pkl.load(f)
#     return res

In [33]:
contexts = []
NUM_SENTENCES_INCLUDE = 5
os.environ["TOKENIZERS_PARALLELISM"] = "True"
for r in tqdm(trn.itertuples(), total=len(trn)):
    prompt_id = r.Index
    prompt_indices = processed_wiki_text_data[processed_wiki_text_data['document_id'].isin(wikipedia_file_data[wikipedia_file_data['prompt_id']==prompt_id]['id'].values)].index.values

    if prompt_indices.shape[0] > 0:
        prompt_index = faiss.IndexFlatIP(wiki_data_embeddings.shape[1])
        prompt_index.add(wiki_data_embeddings[prompt_indices])
        # prompt_index = faiss.index_cpu_to_all_gpus(prompt_index)
        context = ""

        ## Get the top matches
        ss, ii = prompt_index.search(question_embeddings, NUM_SENTENCES_INCLUDE)
        for _s, _i in zip(ss[prompt_id], ii[prompt_id]):
            context += processed_wiki_text_data.loc[prompt_indices]['text'].iloc[_i] + " "

    contexts.append(context)

100%|██████████| 200/200 [00:02<00:00, 83.87it/s]


In [34]:
# def process(idx):
#     contexts = []
#     NUM_SENTENCES_INCLUDE = 22
#     os.environ["TOKENIZERS_PARALLELISM"] = "True"
#     processed_wiki_text_data = load('./tmp/processed_wiki_text_data.pkl')
#     wikipedia_file_data = load('./tmp/wikipedia_file_data.pkl')
#     question_embeddings = load('./tmp/question_embeddings.pkl')
#     wiki_data_embeddings = []
#     for i in range(8):
#         wiki_data_embeddings.append(np.load(f'./tmp/{i}.npy'))
#     wiki_data_embeddings = np.concatenate(wiki_data_embeddings,axis=0)
#     trn = pd.read_parquet(f'./tmp/{idx}.parquet')
#     if idx == 0:
#         for r in tqdm(trn.itertuples(), total=len(trn)):
#             prompt_id = r.Index
#             prompt_indices = processed_wiki_text_data[processed_wiki_text_data['document_id'].isin(wikipedia_file_data[wikipedia_file_data['prompt_id']==prompt_id]['id'].values)].index.values

#             if prompt_indices.shape[0] > 0:
#                 prompt_index = faiss.IndexFlatIP(wiki_data_embeddings.shape[1])
#                 prompt_index.add(wiki_data_embeddings[prompt_indices])
#                 # prompt_index = faiss.index_cpu_to_all_gpus(prompt_index)
#                 context = ""

#                 ## Get the top matches
#                 ss, ii = prompt_index.search(question_embeddings, NUM_SENTENCES_INCLUDE)
#                 for _s, _i in zip(ss[prompt_id], ii[prompt_id]):
#                     context += processed_wiki_text_data.loc[prompt_indices]['text'].iloc[_i] + " "

#             contexts.append(context)
#     else:
#         for r in trn.itertuples():
#             prompt_id = r.Index
#             prompt_indices = processed_wiki_text_data[processed_wiki_text_data['document_id'].isin(wikipedia_file_data[wikipedia_file_data['prompt_id']==prompt_id]['id'].values)].index.values

#             if prompt_indices.shape[0] > 0:
#                 prompt_index = faiss.IndexFlatIP(wiki_data_embeddings.shape[1])
#                 prompt_index.add(wiki_data_embeddings[prompt_indices])
#                 # prompt_index = faiss.index_cpu_to_all_gpus(prompt_index)
#                 context = ""

#                 ## Get the top matches
#                 ss, ii = prompt_index.search(question_embeddings, NUM_SENTENCES_INCLUDE)
#                 for _s, _i in zip(ss[prompt_id], ii[prompt_id]):
#                     context += processed_wiki_text_data.loc[prompt_indices]['text'].iloc[_i] + " "

#             contexts.append(context)
#     import pickle as pkl
#     with open(f'./tmp/context_{idx}.pkl', 'wb') as f:
#         pkl.dump(contexts, f)
#     return idx
# import multiprocessing
# pool = multiprocessing.Pool(processes=40)
# results = []
# for idx in range(40):
#     result = pool.apply_async(process,args=(idx, ))
#     results.append(result)
# for result in results:
#     print(result.get())

In [35]:
# pool.close()
# pool.join()

In [36]:
trn['context'] = contexts

In [36]:
trn[["prompt", "context", "A", "B", "C", "D", "E"]].to_csv("./data/train_recall2round_3_5.csv", index=False)

In [37]:
trn.loc[0,'context']

'A 2021 article postulated that approximately 50% of all baryonic matter is outside dark matter haloes, filling the space between galaxies, and that this would explain the missing baryons not accounted for in the 2017 paper. == Current state == Currently, many groups have observed the intergalactic medium and circum-galactic medium to obtain more measurements and observations of baryons to support the leading observations. In cosmology, the missing baryon problem is an observed discrepancy between the amount of baryonic matter detected from shortly after the Big Bang and from more recent epochs. At the same time, a census of baryons in the recent observable universe has found that observed baryonic matter accounts for less than half of that amount. Agreement with observed abundances requires that baryonic matter makes up between 4–5% of the universe\'s critical density. Only a small proportion of the dark matter in the universe is likely to be baryonic. ==Characteristics== As "dark mat