In [1]:
import chromadb
from chromadb import Documents, EmbeddingFunction, Embeddings
import pandas as pd
import json
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_core.embeddings import Embeddings
import pickle
import ast
from time import time

BASE_DIR = "/home/dzigen/Desktop/PersonalAI/Personal-AI"

# remote 
#NEO4J_URL ="bolt://31.207.47.254:7687"
#NEO4J_USER = "neo4j"
#NEO4J_PWD = "password"

# local
NEO4J_URL ="bolt://31.207.47.254:7687"
NEO4J_USER = "neo4j"
NEO4J_PWD = "password"

import sys 
sys.path.insert(0, "../")

from src.neo4j_functions import Neo4jConnection

In [2]:
class MyEmbeddingFunction(EmbeddingFunction):
    def __init__(self, embedder):
        self.embedder = embedder
    def __call__(self, input: Documents) -> Embeddings:
        return self.embedder.embed_documents(input)

In [4]:
# !!! BELOW TO CHANGE !!! 
DATA_NAME = 'vectorized_triplets'
DB_VERSION = 'v4'
GRAPH_DB_NAME = 'testdb'


EMBEDDING_MODEL_PATH = f'{BASE_DIR}/models/intfloat/multilingual-e5-small'
MODEL_KWARGS = {'device': 'cuda'}
ENCODE_PROMPTS = {"query": "query: ", "passage": "passage: "}
ENCODE_KWARGS = {'normalize_embeddings': True, 'prompt': ENCODE_PROMPTS['passage']}
CHROMA_KWARGS = {"hnsw:space": "ip"}
TRIPLET_STRINGIFY_VERSION = 'v2'
# !!! ABOVE TO CHANGE !!!

SAVE_DIR = f"../data/{DATA_NAME}/{DB_VERSION}"
DENSE_DB_SAVE_PATH = f'{SAVE_DIR}/densedb'
DB_LOG_PATH = f'{SAVE_DIR}/operation_info.json' 

In [20]:
def stringify_v1(triplete):
    node1_str = f"{list(triplete['a'].labels)[0]}: {triplete['a']['name']}"
    node2_str = f"{list(triplete['b'].labels)[0]}: {triplete['b']['name']}"
    if node2_str < node1_str:
        node1_str, node2_str = node2_str, node1_str
        
    relation_str = triplete['rel'].type
    if len(triplete['rel'].keys()) > 0:
        relation_str = ', '.join([f"{k}: {v}" for k, v in triplete['rel'].items()])
    return ' | '.join([node1_str, relation_str, node2_str])

def stringify_v2(triplet) -> str:
    
    foramted_triplet = ""
    rel_type = triplet['rel'].type
    if (rel_type == 'episodic') or (rel_type == 'hyper'):
        foramted_triplet = triplet['rel']["time"] + ": " + triplet['b'].name
    elif rel_type == 'simple':
        foramted_triplet = triplet['rel']["time"] + ": " + " ".join(
                [triplet['a']['name'], triplet['rel']['name'], triplet['b']['name']])
    else:
        raise KeyError
    
    return foramted_triplet

stringify_func_map = {
    'v1': stringify_v1,
    'v2': stringify_v2
}

In [5]:
embeddings = HuggingFaceEmbeddings(
    model_name=EMBEDDING_MODEL_PATH,
    model_kwargs=MODEL_KWARGS,
    encode_kwargs=ENCODE_KWARGS 
)
ef = MyEmbeddingFunction(embeddings)

No sentence-transformers model found with name /home/dzigen/Desktop/PersonalAI/Personal-AI/models/intfloat/multilingual-e5-small. Creating a new one with MEAN pooling.


In [6]:
client = chromadb.PersistentClient(path=DENSE_DB_SAVE_PATH)
collection = client.get_or_create_collection(name=DATA_NAME,  metadata=CHROMA_KWARGS, 
                                             embedding_function=ef)

In [7]:
conn = Neo4jConnection(uri=NEO4J_URL, user=NEO4J_USER, pwd=NEO4J_PWD)

In [11]:
from tqdm import tqdm

In [21]:
raw_triplets = set(conn.execute_query("MATCH (a)-[rel]-(b) WHERE startNode(rel) = a RETURN a, rel, b", db=GRAPH_DB_NAME))

unique_relation_id = []
unique_triplets = []
for triplet in tqdm(raw_triplets):
    if triplet['rel'].element_id not in unique_relation_id:
        unique_triplets.append(triplet)
        unique_relation_id.append(triplet['rel'].element_id)

formated_triplets = list(map(lambda v: stringify_func_map[TRIPLET_STRINGIFY_VERSION](v), unique_triplets))
metadata = list(map(lambda item: {'triplet_id': item['rel'].element_id}, unique_triplets))

100%|██████████| 10355/10355 [00:00<00:00, 12957.91it/s]


KeyError: 

In [13]:
print(len(raw_triplets))
print(len(unique_triplets))

20748
10374


In [14]:
conn.close()

#### Vectorizing 

In [15]:
vectorize_t_start = time()

collection.add(
    documents=formated_triplets,
    metadatas=metadata,
    ids=list(map(lambda v: v['triplet_id'], metadata))
)

VECTORIZE_ELAPSED_TIME = round(time() - vectorize_t_start, 5)

In [16]:
collection.count()

10374

#### Saving Log

In [17]:
with open(DB_LOG_PATH, 'w') as fd:
    fd.write(json.dumps({
        "data_name": DATA_NAME, "graphdb_name": GRAPH_DB_NAME,
        "db_version": DB_VERSION, "model_name": EMBEDDING_MODEL_PATH,
        "encode_kwargs": ENCODE_KWARGS, "chroma_kwargs": CHROMA_KWARGS,
        "encode_prompts": ENCODE_PROMPTS,
        "stringify_version": TRIPLET_STRINGIFY_VERSION,
        "vectorize_elapsed_sec_time": VECTORIZE_ELAPSED_TIME}, indent=1))