Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
# Copyright (c) Cosmo Tech corporation.
# Licensed under the MIT license.
import logging
import functools

from CosmoTech_Acceleration_Library.Modelops.core.common.redis_handler import RedisHandler
from CosmoTech_Acceleration_Library.Modelops.core.io.model_metadata import ModelMetadata
from CosmoTech_Acceleration_Library.Modelops.core.utils.model_util import ModelUtil

logger = logging.getLogger(__name__)

Expand All @@ -14,91 +13,44 @@ class GraphHandler(RedisHandler):
Class that handle Graph Redis information
"""

def __init__(self, host: str, port: int, name: str, password: str = None, source_url: str = "", graph_rotation: int = 3):
def __init__(self, host: str, port: int, name: str, password: str = None):
super().__init__(host=host, port=port, name=name, password=password)
logger.debug("GraphHandler init")
self.graph = self.r.graph(name)
self.name = name
self.m_metadata = ModelMetadata(host=host, port=port, name=name, password=password)
current_metadata = self.m_metadata.get_metadata()
if not current_metadata:
if graph_rotation is None:
graph_rotation = 3
logger.debug("Create metadata key")
self.m_metadata.set_metadata(last_graph_version=0, graph_source_url=source_url,
graph_rotation=graph_rotation)


class VersionedGraphHandler(GraphHandler):
"""
Class that handle Versioned Graph Redis information
"""

def __init__(self, host: str, port: int, name: str, version: int = None, password: str = None, source_url: str = "",
graph_rotation: int = None):
super().__init__(host=host, port=port, name=name, password=password, source_url=source_url,
graph_rotation=graph_rotation)
logger.debug("VersionedGraphHandler init")
self.version = version
if version is None:
self.version = self.m_metadata.get_last_graph_version()
self.versioned_name = ModelUtil.build_graph_version_name(self.name, self.version)
self.graph = self.r.graph(self.versioned_name)

self.graph = self.r.graph(name)

class RotatedGraphHandler(VersionedGraphHandler):
"""
Class that handle Rotated Graph Redis information
"""
def do_if_graph_exist(function):
"""
Function decorator that run the function annotated if graph exists
:param function: the function annotated
"""

def __init__(self, host: str, port: int, name: str, password: str = None, version: int = None, source_url: str = "",
graph_rotation: int = None):
super().__init__(host=host, port=port, name=name, password=password, source_url=source_url, version=version,
graph_rotation=graph_rotation)
logger.debug("RotatedGraphHandler init")
self.graph_rotation = self.m_metadata.get_graph_rotation()
@functools.wraps(function)
def wrapper(self, *args, **kwargs):
if self.r.exists(self.name) != 0:
function(self, *args, **kwargs)
else:
raise Exception(f"{self.name} does not exist!")

def get_all_versions(self):
matching_graph_keys = self.r.keys(ModelUtil.build_graph_key_pattern(self.name))
versions = []
for graph_key in matching_graph_keys:
versions.append(graph_key.split(":")[-1])
return versions
return wrapper

def handle_graph_rotation(func):
def handle_graph_replace(func):
"""
Decorator to do stuff then handle graph rotation (delete the oldest graph if the amount of graph is greater than graph rotation)
"""

def handle(self, *args, **kwargs):
graph_versions = self.get_all_versions()

if len(graph_versions) > 0:
max_version = max([int(x) for x in graph_versions if x.isnumeric()])
else:
max_version = 0
# upgrade current graph to max_version+1
self.version = max_version + 1
self.version_name = ModelUtil.build_graph_version_name(self.name, self.version)
self.graph = self.r.graph(self.version_name)
logger.debug(f'Using graph updated version {self.version_name}')
self.graph = self.r.graph(f'{self.name}_tmp')
logger.debug(f'Using graph {self.name}_tmp for copy')

# do function on new graph
func(self, *args, **kwargs)

# get max version to manage case func not using (hence creating) graph
graph_versions = [int(v) for v in self.get_all_versions()]
graph_versions.sort()
graph_versions.reverse()
to_remove = graph_versions[int(self.graph_rotation):]

# remove all older versions
for v in to_remove:
oldest_graph_version_to_delete = ModelUtil.build_graph_version_name(self.name, v)
self.r.delete(oldest_graph_version_to_delete)
logger.debug(f"Graph {oldest_graph_version_to_delete} deleted")

# upgrade metadata last version to +1 after function execution
self.m_metadata.set_last_graph_version(self.version)
# action complete on graph_tmp with no error replacing graph by graph_tmp
self.r.copy(f'{self.name}_tmp', self.name)
# remove tmp graph
self.r.delete(f'{self.name}_tmp')
# set back the graph
self.graph = self.r.graph(self.name)

return handle
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,3 @@ def __init__(self, host: str, port: int, name: str, password: str = None):
self.name = name
self.password = password
self.r = redis.Redis(host=host, port=port, password=password, decode_responses=True)
self.metadata_key = name + "MetaData"
Original file line number Diff line number Diff line change
Expand Up @@ -40,29 +40,43 @@ def _to_cosmo_key(val: any) -> str:
return val

@staticmethod
def write_twin_data(export_dir: str, file_name: str, query_result: QueryResult,
delimiter: str = ',', quote_char: str = '\"') -> None:
def write_twin_data(export_dir: str,
file_name: str,
query_result: QueryResult,
delimiter: str = ',',
quote_char: str = '\"') -> None:
headers = set()
rows = []
for raw_data in query_result.result_set:
row = {}
# read all graph link properties
for i in range(len(raw_data)): # TODO for the moment its only a len 1 list with the node
row.update({CsvWriter._to_cosmo_key(k): CsvWriter._to_csv_format(v) for k, v in raw_data[i].properties.items()})
row.update({
CsvWriter._to_cosmo_key(k): CsvWriter._to_csv_format(v)
for k, v in raw_data[i].properties.items()
})
headers.update(row.keys())
rows.append(row)

output_file_name = f'{export_dir}/{file_name}.csv'
logger.debug(f"Writing CSV file {output_file_name}")
with open(output_file_name, 'w') as csvfile:
csv_writer = csv.DictWriter(csvfile, fieldnames=headers, delimiter=delimiter, quotechar=quote_char, quoting=csv.QUOTE_MINIMAL)
csv_writer = csv.DictWriter(csvfile,
fieldnames=headers,
delimiter=delimiter,
quotechar=quote_char,
quoting=csv.QUOTE_MINIMAL)
csv_writer.writeheader()
csv_writer.writerows(rows)
logger.debug(f"... CSV file {output_file_name} has been written")

@staticmethod
def write_relationship_data(export_dir: str, file_name: str, query_result: QueryResult, headers: list = [],
delimiter: str = ',', quote_char: str = '\"') -> None:
def write_relationship_data(export_dir: str,
file_name: str,
query_result: QueryResult,
headers: list = [],
delimiter: str = ',',
quote_char: str = '\"') -> None:
headers = {'source', 'target'}
rows = []
for raw_data in query_result.result_set:
Expand All @@ -71,16 +85,24 @@ def write_relationship_data(export_dir: str, file_name: str, query_result: Query
headers.update(row.keys())
rows.append(row)

output_file_name = export_dir + file_name + '.csv'
output_file_name = f'{export_dir}/{file_name}.csv'
logger.debug(f"Writing CSV file {output_file_name}")
with open(output_file_name, 'w') as csvfile:
csv_writer = csv.DictWriter(csvfile, fieldnames=headers, delimiter=delimiter, quotechar=quote_char, quoting=csv.QUOTE_MINIMAL)
csv_writer = csv.DictWriter(csvfile,
fieldnames=headers,
delimiter=delimiter,
quotechar=quote_char,
quoting=csv.QUOTE_MINIMAL)
csv_writer.writeheader()
csv_writer.writerows(rows)
logger.debug(f"... CSV file {output_file_name} has been written")

@staticmethod
def write_data(export_dir: str, file_name: str, input_rows: dict, delimiter: str = ',', quote_char: str = '\"') -> None:
def write_data(export_dir: str,
file_name: str,
input_rows: dict,
delimiter: str = ',',
quote_char: str = '\"') -> None:
output_file_name = export_dir + file_name + '.csv'
write_header = False
if not os.path.exists(output_file_name):
Expand All @@ -94,7 +116,11 @@ def write_data(export_dir: str, file_name: str, input_rows: dict, delimiter: str

logger.info(f"Writing file {output_file_name} ...")
with open(output_file_name, 'a') as csvfile:
csv_writer = csv.DictWriter(csvfile, fieldnames=headers, delimiter=delimiter, quotechar=quote_char, quoting=csv.QUOTE_MINIMAL)
csv_writer = csv.DictWriter(csvfile,
fieldnames=headers,
delimiter=delimiter,
quotechar=quote_char,
quoting=csv.QUOTE_MINIMAL)
if write_header:
csv_writer.writeheader()
csv_writer.writerows(output_rows)
Expand Down

This file was deleted.

This file was deleted.

20 changes: 10 additions & 10 deletions CosmoTech_Acceleration_Library/Modelops/core/io/model_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,30 @@
import redis
from functools import lru_cache

from CosmoTech_Acceleration_Library.Modelops.core.common.graph_handler import VersionedGraphHandler
from CosmoTech_Acceleration_Library.Modelops.core.common.graph_handler import GraphHandler
from CosmoTech_Acceleration_Library.Modelops.core.common.writer.CsvWriter import CsvWriter
from CosmoTech_Acceleration_Library.Modelops.core.decorators.model_decorators import do_if_graph_exist
from CosmoTech_Acceleration_Library.Modelops.core.io.model_reader import ModelReader

logger = logging.getLogger(__name__)


class ModelExporter(VersionedGraphHandler):
class ModelExporter(GraphHandler):
"""
Model Exporter for cached data
"""

def __init__(self, host: str, port: int, name: str, version: int, password: str = None, export_dir: str = "/"):
super().__init__(host=host, port=port, name=name, version=version, password=password)
def __init__(self, host: str, port: int, name: str, password: str = None, export_dir: str = "/"):
super().__init__(host=host, port=port, name=name, password=password)
Path(export_dir).mkdir(parents=True, exist_ok=True)
self.export_dir = export_dir
self.mr = ModelReader(host=host, port=port, name=name, password=password, version=version)

self.mr = ModelReader(host=host, port=port, name=name, password=password)
self.labels = [label[0] for label in self.graph.labels()]
self.relationships = [relation[0] for relation in self.graph.relationship_types()]
self.already_exported_nodes = {}
self.already_exported_edges = []

@do_if_graph_exist
@GraphHandler.do_if_graph_exist
def export_all_twins(self):
"""
Export all twins
Expand Down Expand Up @@ -58,7 +58,7 @@ def export_all_twins(self):
logger.debug(f"Twins exported :{twin_name}")
logger.debug("... End exporting twins")

@do_if_graph_exist
@GraphHandler.do_if_graph_exist
def export_all_relationships(self):
"""
Export all relationships
Expand Down Expand Up @@ -88,7 +88,7 @@ def export_all_relationships(self):
logger.debug(f"Relationships exported :{relationship_name}")
logger.debug("... End exporting relationships")

@do_if_graph_exist
@GraphHandler.do_if_graph_exist
def export_all_data(self):
"""
Export all data
Expand All @@ -97,7 +97,7 @@ def export_all_data(self):
self.export_all_twins()
self.export_all_relationships()

@do_if_graph_exist
@GraphHandler.do_if_graph_exist
def export_from_queries(self, queries: list):
"""
Export data from queries
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,17 @@

from redisgraph_bulk_loader.bulk_insert import bulk_insert

from CosmoTech_Acceleration_Library.Modelops.core.common.graph_handler import RotatedGraphHandler
from CosmoTech_Acceleration_Library.Modelops.core.utils.model_util import ModelUtil
from CosmoTech_Acceleration_Library.Modelops.core.common.graph_handler import GraphHandler

logger = logging.getLogger(__name__)


class ModelImporter(RotatedGraphHandler):
class ModelImporter(GraphHandler):
"""
Model Exporter for cached data
"""

@RotatedGraphHandler.handle_graph_rotation
@GraphHandler.handle_graph_replace
def bulk_import(self, twin_file_paths: list = [], relationship_file_paths: list = [], enforce_schema: bool = False):
"""
Import all csv data
Expand Down Expand Up @@ -44,7 +43,7 @@ def bulk_import(self, twin_file_paths: list = [], relationship_file_paths: list
command_parameters.append('--relations')
command_parameters.append(relationship_file_path)

command_parameters.append(ModelUtil.build_graph_version_name(self.name, self.version))
command_parameters.append(self.name)
logger.debug(command_parameters)
# TODO: Think about use '--index Label:Property' command parameters to create indexes on default id properties
try:
Expand Down
Loading