In [None]:
import sys
import json 
import joblib
import gc
from tqdm import tqdm
import os
from typing import List, Dict, Tuple
import yaml

# TO CHANGE
BASEDIR = "../../../"
sys.path.insert(0, BASEDIR)

from src.pipelines.memorize import MemPipelineConfig, MemPipeline, LLMExtractorConfig, LLMUpdatorConfig
from src.kg_model import KnowledgeGraphModel, EmbeddingsModelConfig, GraphModelConfig, EmbedderModelConfig
from src.db_drivers.graph_driver import GraphDBConnectionConfig, GraphDriverConfig
from src.db_drivers.vector_driver import VectorDBConnectionConfig, VectorDriverConfig

# gigachat key
#GIGACHAT_CREDS = 'OWUwOGUzOWEtMjJiNi00YmMxLThmMmItNzMwNjM2MTI2YmYxOjg2ODdiOTVhLTZkNDctNGFjOC1iMmViLTEyNDA5MmFiN2Q5Mw=='
# openai key
#API_KEY = "'sk-861mINAavom2SSBqgrI82D4thMOfqT37knCof2o0H0T3BlbkFJ2gdVXJuVjNesNNP2aeUwPoBpZP3a3R1gn1kqv97CsA'"

gc.collect()

### Setting knowledge-graph configuration

In [2]:
# Read YAML file
with open("params.yaml", 'r') as stream:
    HYPER_PARAMS = yaml.safe_load(stream)

In [8]:
BASE_PATH = "../../../data/knowledge_graphs/"
DATASET_PATH = BASE_PATH + f"{HYPER_PARAMS['DATASET_NAME']}/"
KG_PATH = DATASET_PATH + f"{HYPER_PARAMS['KNOWLEDGE_GRAPH_NAME']}/"

HYPER_PARAMS_PATH = KG_PATH + 'hyperparameters.json'
EXTRACTED_TRIPLETS_PATH = KG_PATH + "extracted_triplets"
GRAPH_DRIVER_CONFIG_PATH = KG_PATH + "graph_config"
EMBEDDINGS_DRIVER_CONFIG_PATH = KG_PATH + "embeddings_config"
MEM_PIPELINE_CONFIG_PATH = KG_PATH + "mem_pipeline_config"

VECTORIZED_DB_PATH = KG_PATH + "embeddings_part/"
GRAPH_DB_PATH = KG_PATH + "graph_part/"
CACHE_DB_PATH = KG_PATH + "cache_part/"

In [None]:
if not os.path.exists(BASE_PATH):
    raise ValueError(f"Директории не существует: {BASE_PATH}")
if not os.path.exists(DATASET_PATH):
    raise ValueError(f"Директории не существует: {DATASET_PATH}")
if os.path.exists(KG_PATH):
    raise ValueError(f"Директория существует: {KG_PATH}")

os.mkdir(KG_PATH)
os.mkdir(VECTORIZED_DB_PATH)
os.mkdir(GRAPH_DB_PATH)
os.mkdir(CACHE_DB_PATH)

In [9]:
print(VECTORIZED_DB_PATH)
print(GRAPH_DB_PATH)
print(CACHE_DB_PATH)

../../../data/knowledge_graphs/diaasqa/gigachat_filtered/embeddings_part/
../../../data/knowledge_graphs/diaasqa/gigachat_filtered/graph_part/
../../../data/knowledge_graphs/diaasqa/gigachat_filtered/cache_part/


In [10]:
# Setting knowledge graph

graph_config = GraphModelConfig(
    driver_config=GraphDriverConfig(
        db_vendor='neo4j',
        db_config=GraphDBConnectionConfig(
            uri="bolt://personalai_mmenschikov_neo4j:7687", params={'user': "neo4j", 'pwd': 'password'},
            need_to_clear=HYPER_PARAMS['need_to_clear'])))

embed_config = EmbeddingsModelConfig(
    nodesdb_driver_config=VectorDriverConfig(
        db_vendor='chroma',
        db_config=VectorDBConnectionConfig(
            path=VECTORIZED_DB_PATH, db_info={'db': 'personalaidb', 'table': "vectorized_nodes"}, need_to_clear=HYPER_PARAMS['need_to_clear'])),
    tripletsdb_driver_config=VectorDriverConfig(
        db_vendor='chroma',
        db_config=VectorDBConnectionConfig(
            path=VECTORIZED_DB_PATH, db_info={'db': 'personalaidb', 'table': "vectorized_triplets"}, need_to_clear=HYPER_PARAMS['need_to_clear'])),
    embedder_config=EmbedderModelConfig(model_name_or_path=HYPER_PARAMS['EMBEDDER_MODEL_PATH']))

In [None]:
kg_model = KnowledgeGraphModel(
    graph_config=graph_config,
    embeddings_config=embed_config)

In [6]:
print(kg_model.embeddings_struct.vectordbs['nodes'].count_items())
print(kg_model.embeddings_struct.vectordbs['triplets'].count_items())
print(kg_model.graph_struct.db_conn.count_items())

0
0
{'triplets': 182838, 'nodes': 49597}


In [7]:
# Setting Memorization Pipeline

mem_config = MemPipelineConfig(
    extractor_config=LLMExtractorConfig(),
    updator_config=LLMUpdatorConfig(
        delete_obsolete_info=HYPER_PARAMS['DELETE_OBSOLETE_INFO']))

mem_pipeline = MemPipeline(kg_model, mem_config)

In [8]:
with open(HYPER_PARAMS_PATH, 'w', encoding='utf-8') as fd:
    fd.write(json.dumps(HYPER_PARAMS, ensure_ascii=False, indent=1))

joblib.dump(graph_config, GRAPH_DRIVER_CONFIG_PATH)
joblib.dump(embed_config, EMBEDDINGS_DRIVER_CONFIG_PATH)
joblib.dump(mem_config, MEM_PIPELINE_CONFIG_PATH)

['../../../data/knowledge_graphs/diaasqa/gigachat_full/mem_pipeline_config']

### Loading dataset

In [9]:
def custom_diaasqa_load(dataset_path: str) -> List[Tuple[str, Dict[str, str]]]:
    with open(dataset_path, 'r', encoding='utf-8') as fd:
        data = json.loads(fd.read())

    data_pairs = []
    for item in data['data']:
        data_pairs.append((item['text_dialog'], {'time': item['time'].split(',')[0]}))

    return data_pairs

In [10]:
CUSTOM_LOAD_FUNCS = {
    'diaasqa': custom_diaasqa_load
}

In [11]:
dataset = CUSTOM_LOAD_FUNCS[HYPER_PARAMS['DATASET_NAME']](HYPER_PARAMS['DATASET_PATH'])

In [12]:
print(len(dataset))

3483


### Creating knowledge graph

In [None]:
saved_triplets = []

In [None]:
# 2106 | gigachat_full | diaasqa |

In [14]:
for item in tqdm(dataset[2106:]):
    text, properties = item[0], item[1]
    extracted_triplets, _ = mem_pipeline.remember(text, properties)
    saved_triplets.append(extracted_triplets)

  7%|▋         | 102/1377 [31:00<7:17:29, 20.59s/it]AUTHENTICATION ERROR
 15%|█▍        | 204/1377 [1:02:19<6:35:04, 20.21s/it]AUTHENTICATION ERROR
 22%|██▏       | 307/1377 [1:33:30<5:31:51, 18.61s/it]AUTHENTICATION ERROR
 29%|██▉       | 404/1377 [2:04:37<5:24:09, 19.99s/it]AUTHENTICATION ERROR
 36%|███▌      | 494/1377 [2:35:59<5:08:53, 20.99s/it]AUTHENTICATION ERROR
 43%|████▎     | 590/1377 [3:07:13<4:54:21, 22.44s/it]AUTHENTICATION ERROR
 49%|████▉     | 674/1377 [3:38:19<4:38:00, 23.73s/it] AUTHENTICATION ERROR
 55%|█████▌    | 758/1377 [4:09:40<3:53:33, 22.64s/it]AUTHENTICATION ERROR
 61%|██████▏   | 846/1377 [4:40:57<2:59:33, 20.29s/it]AUTHENTICATION ERROR
 68%|██████▊   | 934/1377 [5:12:04<2:27:03, 19.92s/it]AUTHENTICATION ERROR
 74%|███████▍  | 1017/1377 [5:43:09<2:02:05, 20.35s/it]AUTHENTICATION ERROR
 80%|████████  | 1103/1377 [6:14:19<1:43:56, 22.76s/it]AUTHENTICATION ERROR
 86%|████████▋ | 1188/1377 [6:45:32<1:07:38, 21.48s/it]AUTHENTICATION ERROR
 92%|█████████▏| 1265/1

In [None]:
print(kg_model.embeddings_struct.vectordbs['nodes'].count_items())
print(kg_model.embeddings_struct.vectordbs['triplets'].count_items())
print(kg_model.graph_struct.db_conn.count_items())

### Saving log inforamtion

In [16]:
joblib.dump(saved_triplets, EXTRACTED_TRIPLETS_PATH)

['../../data/knowledge_graphs/diaasqa/gigachat_full/mem_pipeline_config']

### Build graph on extracted triplets

In [8]:
from functools import reduce

In [9]:
saved_triplets = joblib.load(EXTRACTED_TRIPLETS_PATH)

In [10]:
flattened_triplets = reduce(lambda acc, v: acc + v, saved_triplets, [])

In [None]:
output = mem_pipeline.updator.kg_model.graph_struct.create_triplets(flattened_triplets, status_bar=True)

In [None]:
output = mem_pipeline.updator.kg_model.embeddings_struct.create_triplets(flattened_triplets, status_bar=True)

In [13]:
print(kg_model.embeddings_struct.vectordbs['nodes'].count_items())
print(kg_model.embeddings_struct.vectordbs['triplets'].count_items())
print(kg_model.graph_struct.db_conn.count_items())

49597
44328
{'triplets': 182838, 'nodes': 49597}
