In [1]:
import numpy as np
import os
import pandas as pd
import urllib.request
import time
from sentence_transformers import SentenceTransformer

데이터 로드 https://wikidocs.net/162007

In [2]:
urllib.request.urlretrieve("https://raw.githubusercontent.com/ukairia777/tensorflow-nlp-tutorial/main/19.%20Topic%20Modeling%20(LDA%2C%20BERT-Based)/dataset/abcnews-date-text.csv", filename="abcnews-date-text.csv")

df = pd.read_csv("abcnews-date-text.csv")
data = df.headline_text.to_list()

In [3]:
# 상위 5개의 샘플 출력
data[:5]

['aba decides against community broadcasting licence',
 'act fire witnesses must be aware of defamation',
 'a g calls for infrastructure protection summit',
 'air nz staff in aust strike for pay rise',
 'air nz strike to affect australian travellers']

In [4]:
print('총 샘플의 개수 :', len(data))

총 샘플의 개수 : 1082168


SBERT 임베딩

In [5]:
model = SentenceTransformer('distilbert-base-nli-mean-tokens')
encoded_data = model.encode(data)
print('임베딩 된 벡터 수 :', len(encoded_data))

임베딩 된 벡터 수 : 1082168


Cosin Similarity 계산

In [6]:
encoded_data.shape

(1082168, 768)

In [7]:
from numpy import dot
from numpy.linalg import norm
def cos_sim(A, B):
  return dot(A, B)/(norm(A)*norm(B))

In [8]:
def return_answer(question):
    t = time.time()
    embedding = model.encode(question)
    sim_scores = list(enumerate(cos_sim(encoded_data, embedding)))
    # 유사도에 따라 기사들을 정렬한다.
    sim_scores = sorted(sim_scores, key=lambda x: x[1], reverse=True)

    # 가장 유사한 5개의 기사를 받아온다.
    sim_scores = sim_scores[1:6]

    # 가장 유사한 5개의 기사의 인덱스를 얻는다.
    resut_indices = [idx[0] for idx in sim_scores]

    print('total time: {}'.format(time.time() - t))
    # 가장 유사한 5개의 기사의 제목을 리턴한다.
    return [data[_id] for _id in resut_indices]

In [9]:
# Underwater Forest Discovered
query = str(input())
results = return_answer(query)

print('results :')
for result in results:
   print('\t', result)

total time: 2.617715835571289
results :
	 thriving underwater antarctic garden discovered
	 baton goes underwater in wa
	 underwater footage shows inside doomed costa
	 underwater uluru found off wa coast
	 amateur diver shares hobarts hidden underwater world


FAISS 예제

In [10]:
import faiss

index = faiss.IndexFlatL2(encoded_data.shape[1]) # 초기화 : 벡터의 크기를 지정
index.add(encoded_data) # 임베딩을 추가

In [11]:
# index 학습 여부 확인. FlatL2 의 경우 기본적으로 trained 상태.
index.is_trained

True

In [12]:
def search(query):
   t = time.time()
   query_vector = model.encode([query])
   # 몇개나 찾을 것인가
   k = 5 
   top_k = index.search(query_vector, k)
   print('total time: {}'.format(time.time() - t))
   return [data[_id] for _id in top_k[1].tolist()[0]]

In [13]:
query = str(input())
print('input : ' + query)

results = search(query)

print('results :')
for result in results:
   print('\t', result)

input : Underwater Forest Discovered
total time: 0.6258132457733154
results :
	 baton goes underwater in wa
	 underwater footage shows inside doomed costa
	 underwater loop
	 thriving underwater antarctic garden discovered
	 shire faces underwater observatory bill blowout


IVF 예제

In [14]:
d = encoded_data.shape[1]
nlist = 50 # 셀 수 정의
# m = 8 # m은 d 의 배수여야함, 벡터 클러스터의 중심
# bits = 8

quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFFlat(quantizer, d, nlist)

In [15]:
# index 의 학습여부 확인.
index.is_trained

False

In [16]:
# index 학습이 되어있지 않으므로 학습 먼저 진행.
index.train(encoded_data)
index.add(encoded_data)
index.ntotal

1082168

In [17]:
def search_IVFFlat(query):
   t = time.time()
   query_vector = model.encode([query])
   k = 5
   top_k = index.search(query_vector, k)
   print('total time: {}'.format(time.time() - t))
   return [data[_id] for _id in top_k[1].tolist()[0]]

In [18]:
query = str(input())
print('input : ' + query)
results = search_IVFFlat(query)
print('results :')
for result in results:
   print('\t', result)

input : Underwater Forest Discovered
total time: 0.03360128402709961
results :
	 baton goes underwater in wa
	 underwater footage shows inside doomed costa
	 underwater loop
	 thriving underwater antarctic garden discovered
	 shire faces underwater observatory bill blowout


In [19]:
# 기본값은 1
print(index.nprobe)
# 주변 셀을 몇개나 검색할 것인가 지정 가능
index.nprobe = 10 

1


In [20]:
query = str(input())

print('input : ' + query)
results = search_IVFFlat(query)

print('results :')
for result in results:
   print('\t', result)

input : Underwater Forest Discovered
total time: 0.15331459045410156
results :
	 baton goes underwater in wa
	 underwater footage shows inside doomed costa
	 underwater loop
	 thriving underwater antarctic garden discovered
	 shire faces underwater observatory bill blowout


IndexIVFPQ 예제

In [21]:
m = 8 # m은 d 의 약수여야함, 벡터 클러스터의 중심
num_bits = 8  # 각 서브 벡터를 양자화하기 위한 비트 수

In [22]:
quantizer = faiss.IndexFlatL2(d)
index = faiss.IndexIVFPQ(quantizer, d, nlist, m, num_bits)

In [23]:
index.train(encoded_data)
index.add(encoded_data)

In [24]:
def search_IVFPQ(query):
   t = time.time()
   query_vector = model.encode([query])
   k = 5
   top_k = index.search(query_vector, k)
   print('total time: {}'.format(time.time() - t))
   return [data[_id] for _id in top_k[1].tolist()[0]]

In [31]:
query = str(input())

print('input : ' + query)
results = search_IVFPQ(query)

print('results :')
for result in results:
   print('\t', result)

input : Underwater Forest Discovered
total time: 0.03170180320739746
results :
	 croc trapped in cairns swimming enclosure
	 thriving underwater antarctic garden discovered
	 man drowns at greenbushes swimming pool
	 berry springs swim hole croc captured
	 croc caught near litchfield swimming hole


In [33]:
index.nprobe = 10

query = str(input())

print('input : ' + query)
results = search_IVFPQ(query)

print('results :')
for result in results:
   print('\t', result)

input : Underwater Forest Discovered
total time: 0.03354001045227051
results :
	 croc trapped in cairns swimming enclosure
	 thriving underwater antarctic garden discovered
	 man drowns at greenbushes swimming pool
	 berry springs swim hole croc captured
	 croc caught near litchfield swimming hole


In [32]:
#index 저장하기
faiss.write_index(index,"sts.index")

#인덱스 불러오기
index = faiss.read_index("./sts.index")