In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%cd /content/drive/MyDrive/rag_project

!pip install transformers
!pip install faiss-cpu
!pip install faiss-gpu
!pip install datasets

In [None]:
import datasets
import numpy as np
import faiss
import torch
from transformers import DPRContextEncoder, DPRContextEncoderTokenizer

tsv_file = "psgs_w100.tsv"
dataset_dict = datasets.load_dataset('csv', data_files=tsv_file, delimiter='\t')

# 데이터셋 변환
dataset = dataset_dict['train']

# 사전 훈련된 DPR 모델 및 토크나이저 로드
ctx_tokenizer = DPRContextEncoderTokenizer.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')
ctx_encoder = DPRContextEncoder.from_pretrained('facebook/dpr-ctx_encoder-single-nq-base')

# 장치 설정 (GPU가 사용 가능한 경우 GPU 사용)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
ctx_encoder.to(device)

# 텍스트 임베딩 함수
def compute_embeddings(batch, ctx_tokenizer, ctx_encoder):
    inputs = ctx_tokenizer(batch['text'], truncation=True, padding=True, return_tensors="pt")
    inputs = {key: val.to(device) for key, val in inputs.items()}
    with torch.no_grad():
        embeddings = ctx_encoder(**inputs).pooler_output
    return {'embeddings': embeddings.cpu().numpy().tolist()}

embeddings_dataset = dataset.map(lambda batch: compute_embeddings(batch, ctx_tokenizer, ctx_encoder), batched=True, batch_size=128)


In [None]:
indexes = embeddings_dataset.list_indexes()
for index in indexes:
    embeddings_dataset.drop_index(index)

embeddings_dataset.save_to_disk('/content/drive/MyDrive/rag_project/wiki_dataset_without_indexes')

embeddings_dataset.add_faiss_index(column='embeddings')

faiss.write_index(embeddings_dataset.get_index('embeddings').faiss_index, '/content/drive/MyDrive/rag_project/embeddings.faiss')

print("임베딩 및 FAISS 인덱스가 저장되었습니다.")

loaded_embeddings_dataset = datasets.Dataset.load_from_disk('/content/drive/MyDrive/rag_project/embedded_dataset')
loaded_faiss_index = faiss.read_index('/content/drive/MyDrive/rag_project/embeddings.faiss')

print("데이터셋의 임베딩 수:", len(loaded_embeddings_dataset))
print("Faiss 인덱스의 임베딩 수:", loaded_faiss_index.ntotal)