## 1. 문서 vector DB화

In [3]:
!pip install -U FlagEmbedding
!conda install -c pytorch faiss-gpu
!pip install peft

Collecting FlagEmbedding
  Using cached FlagEmbedding-1.2.11-py3-none-any.whl
Installing collected packages: FlagEmbedding
  Attempting uninstall: FlagEmbedding
    Found existing installation: FlagEmbedding 1.2.10
    Uninstalling FlagEmbedding-1.2.10:
      Successfully uninstalled FlagEmbedding-1.2.10
Successfully installed FlagEmbedding-1.2.11
^C
Collecting peft
  Downloading peft-0.13.0-py3-none-any.whl.metadata (13 kB)
Downloading peft-0.13.0-py3-none-any.whl (322 kB)
   ---------------------------------------- 322.5/322.5 kB 4.0 MB/s eta 0:00:00
Installing collected packages: peft
Successfully installed peft-0.13.0


In [1]:
import os
import faiss
import datasets
import numpy as np
from tqdm import tqdm
from FlagEmbedding import FlagModel
import json

  from .autonotebook import tqdm as notebook_tqdm



Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin c:\Users\tyflow\Anaconda3\envs\lawsuitLLM_testServer\lib\site-packages\bitsandbytes\libbitsandbytes_cuda122.dll
CUDA SETUP: CUDA runtime path found: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\bin\cudart64_12.dll
CUDA SETUP: Highest compute capability among GPUs detected: 8.6
CUDA SETUP: Detected CUDA version 122
CUDA SETUP: Loading binary c:\Users\tyflow\Anaconda3\envs\lawsuitLLM_testServer\lib\site-packages\bitsandbytes\libbitsandbytes_cuda122.dll...


  warn(msg)
  warn(msg)


documnets를 vectorDB에 저장하는 함수

In [4]:
def jsonl_copus():
    corpus_list = []
    with open(f'documents.jsonl', 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            corpus_list.append({
                'id': data['docid'],
                'content': data['content']
            })

    corpus = datasets.Dataset.from_list(corpus_list)
    return corpus

def generate_index(model: FlagModel, corpus: datasets.Dataset, max_passage_length: int=512, batch_size: int=256):
    corpus_embeddings = model.encode_corpus(corpus["content"], batch_size=batch_size, max_length=max_passage_length)
    dim = corpus_embeddings.shape[-1]

    faiss_index = faiss.index_factory(dim, "Flat", faiss.METRIC_INNER_PRODUCT)
    corpus_embeddings = corpus_embeddings.astype(np.float32)
    faiss_index.train(corpus_embeddings)
    faiss_index.add(corpus_embeddings)
    return faiss_index, list(corpus["id"])

def save_result(index: faiss.Index, docid: list, index_save_dir: str):
    docid_save_path = os.path.join(index_save_dir, 'docid')
    index_save_path = os.path.join(index_save_dir, 'index')
    with open(docid_save_path, 'w', encoding='utf-8') as f:
        for _id in docid:
            f.write(str(_id) + '\n')
    faiss.write_index(index, index_save_path)

실행 코드

In [6]:
embedding_model = FlagModel(
        'BAAI/bge-m3', 
        pooling_method='cls',
        normalize_embeddings=True,
        use_fp16=True
    )

if not os.path.exists('vectorDB'):
    os.makedirs('vectorDB')

corpus = jsonl_copus()

index, docid = generate_index(
            model=embedding_model,
            corpus=corpus,
            max_passage_length=8192,
            batch_size=4
        )

save_result(index, docid, 'vectorDB')

Inference Embeddings: 100%|██████████| 1068/1068 [00:46<00:00, 22.97it/s]


## 2. 문서 검색

In [7]:
!conda install -c conda-forge nmslib -y
!conda install -c conda-forge openjdk -y
!conda install -c conda-forge openjdk=11 -y
!pip install pyserini==0.22.1

Collecting package metadata (current_repodata.json): done
Solving environment: done


  current version: 23.9.0
  latest version: 24.7.1

Please update conda by running

    $ conda update -n base -c defaults conda

Or to minimize the number of packages updated during conda update use

     conda install conda=24.7.1



## Package Plan ##

  environment location: /root/.conda

  added / updated specs:
    - nmslib


The following packages will be downloaded:

    package                    |            build
    ---------------------------|-----------------
    libblas-3.9.0              |       8_openblas          11 KB  conda-forge
    libcblas-3.9.0             |       8_openblas          11 KB  conda-forge
    libgfortran-ng-7.5.0       |      h14aa051_20          23 KB  conda-forge
    libgfortran4-7.5.0         |      h14aa051_20         1.2 MB  conda-forge
    liblapack-3.9.0            |       8_openblas          11 KB  conda-forge
    libopenblas-0.3.12         |pthreads_hb3c2

In [1]:
import os
import pandas as pd
from tqdm import tqdm
from FlagEmbedding import BGEM3FlagModel
import torch
from pyserini.search.faiss import FaissSearcher, AutoQueryEncoder
import json

  from .autonotebook import tqdm as notebook_tqdm



Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin c:\Users\tyflow\Anaconda3\envs\lawsuitLLM_testServer\lib\site-packages\bitsandbytes\libbitsandbytes_cuda122.dll
CUDA SETUP: CUDA runtime path found: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.2\bin\cudart64_12.dll
CUDA SETUP: Highest compute capability among GPUs detected: 8.6
CUDA SETUP: Detected CUDA version 122
CUDA SETUP: Loading binary c:\Users\tyflow\Anaconda3\envs\lawsuitLLM_testServer\lib\site-packages\bitsandbytes\libbitsandbytes_cuda122.dll...


  warn(msg)
  warn(msg)


FaissSearcher를 생성하여 효율적인 유사도 기반 검색을 가능하게 합니다. 

In [2]:
device = torch.device("cuda")

query_encoder = AutoQueryEncoder(
        encoder_dir='BAAI/bge-m3',
        device=device,
        pooling='cls',
        l2_norm=True
    )

searcher = FaissSearcher(
        index_dir='vectorDB',
        query_encoder=query_encoder
    )


  return self.fget.__get__(instance, owner)()


moe 미적용


dense 검색

In [3]:
evals = []
with open(f'eval.jsonl', 'r', encoding='utf-8') as f:
    for line in f:
        evals.append(json.loads(line))

In [7]:
evals[0]

{'eval_id': 78,
 'msg': [{'role': 'user', 'content': '나무의 분류에 대해 조사해 보기 위한 방법은?'}],
 'retrievers': [DenseSearchResult(docid='c63b9e3a-716f-423a-9c9b-0bcaa1b9f35d', score=0.67541975),
  DenseSearchResult(docid='9712bdf6-9419-4953-a8f1-8a4015dee986', score=0.56418586),
  DenseSearchResult(docid='29f939e1-a784-40fc-a31b-139fdaceec66', score=0.5622638),
  DenseSearchResult(docid='b730a81a-3903-42ca-9633-88b0ebb9eb42', score=0.55245775),
  DenseSearchResult(docid='6788c97f-3460-4b93-953a-ea6cbed0c2d2', score=0.552361),
  DenseSearchResult(docid='ed2aff04-ed0b-452f-9ea0-7b6b935b39c1', score=0.55055916),
  DenseSearchResult(docid='e227a022-da3b-4810-9882-a2b27c76cc79', score=0.5383136),
  DenseSearchResult(docid='d42ced41-7d0c-4346-bc0a-11454f5b6121', score=0.53474164),
  DenseSearchResult(docid='bbd9e1c7-59a9-44ae-ad75-54eb8f150a25', score=0.5261199),
  DenseSearchResult(docid='a2147bab-f37b-4afe-b2a4-5a6fc01f1024', score=0.525736)]}

In [10]:
for row in tqdm(evals, total=len(evals), desc="Searching Docs"):
    question = ''
    for msg in row['msg']:
        if msg['role'] == 'user':
            question += msg['content']
    row['question'] = question
    row['retrievers'] = searcher.search(
        query=question,
        k=10,
    )

Searching Docs: 100%|██████████| 220/220 [00:06<00:00, 32.94it/s]


In [8]:
reranker = BGEM3FlagModel(
        model_name_or_path='BAAI/bge-m3',
        pooling_method='cls',
        normalize_embeddings=True,
        device=device
    )


Fetching 30 files: 100%|██████████| 30/30 [00:00<?, ?it/s]


corpus 생성

In [12]:
corpus_dict = {}
with open(f'documents.jsonl', 'r', encoding='utf-8') as f:
        for line in f:
            data = json.loads(line)
            corpus_dict[data['docid']] = data['content']

In [14]:
qid_list = []
sentence_pairs = []
for row in tqdm(evals, total=len(evals), desc="Making sentence pairs"):
    qid_list.append(row['eval_id'])
    query = row['question']
    for retriever in row['retrievers']:
        passage = corpus_dict[retriever.docid]
        sentence_pairs.append((query, passage))

Making sentence pairs: 100%|██████████| 220/220 [00:00<00:00, 73262.95it/s]


Hybrid 검색

In [15]:
dense_weight, sparse_weight, colbert_weight = 0.15, 0.35, 0.5

scores_dict = reranker.compute_score(
        sentence_pairs, 
        batch_size=4, 
        max_query_length=512, 
        max_passage_length=512, 
        weights_for_different_modes=[dense_weight, sparse_weight, colbert_weight]
    )

Compute Scores: 100%|██████████| 550/550 [00:44<00:00, 12.49it/s]


In [17]:
for key in scores_dict.keys():
    scores = scores_dict[key]
    i = 0
    for row in tqdm(evals, total=len(evals), desc=f"{key}"):
        docids = row['retrievers']
        docids_scores = []
        for j in range(len(docids)):
            docids_scores.append((docids[j].docid, scores[i + j], corpus_dict[docids[j].docid]))
        i += len(docids)

        docids_scores.sort(key=lambda x: x[1], reverse=True)
        row[key] = docids_scores
                
    


colbert: 100%|██████████| 220/220 [00:00<00:00, 73268.77it/s]
sparse: 100%|██████████| 220/220 [00:00<00:00, 73268.77it/s]
dense: 100%|██████████| 220/220 [00:00<00:00, 73274.59it/s]
sparse+dense: 100%|██████████| 220/220 [00:00<00:00, 54958.12it/s]
colbert+sparse+dense: 100%|██████████| 220/220 [00:00<00:00, 73262.95it/s]


검색결과 저장

In [31]:
for eval_item in evals:
    del eval_item['retrievers']

In [32]:
with open('search_result.json', 'w', encoding='utf-8') as f_out:
      json.dump(evals, f_out, ensure_ascii=False, indent=4)

제출

In [39]:
submission = []
for eval_item in evals:
    topk = []
    if eval_item['colbert+sparse+dense'][0][1] > 0.32:
        for csd in eval_item['colbert+sparse+dense']:
            if len(topk) < 3:
                topk.append(csd[0])

    submission.append({'eval_id': eval_item['eval_id'], 'standalone_query': eval_item['question'], 'topk': topk})

with open('submission.csv', 'w', encoding='utf-8') as f_out:
    for sub in submission:
        json_str = json.dumps(sub, ensure_ascii=False)
        f_out.write(json_str + '\n')
