diff --git a/CosmoTech_Acceleration_Library/Modelops/core/common/graph_handler.py b/CosmoTech_Acceleration_Library/Modelops/core/common/graph_handler.py index 41237712..a383f596 100644 --- a/CosmoTech_Acceleration_Library/Modelops/core/common/graph_handler.py +++ b/CosmoTech_Acceleration_Library/Modelops/core/common/graph_handler.py @@ -1,10 +1,9 @@ # Copyright (c) Cosmo Tech corporation. # Licensed under the MIT license. import logging +import functools from CosmoTech_Acceleration_Library.Modelops.core.common.redis_handler import RedisHandler -from CosmoTech_Acceleration_Library.Modelops.core.io.model_metadata import ModelMetadata -from CosmoTech_Acceleration_Library.Modelops.core.utils.model_util import ModelUtil logger = logging.getLogger(__name__) @@ -14,91 +13,44 @@ class GraphHandler(RedisHandler): Class that handle Graph Redis information """ - def __init__(self, host: str, port: int, name: str, password: str = None, source_url: str = "", graph_rotation: int = 3): + def __init__(self, host: str, port: int, name: str, password: str = None): super().__init__(host=host, port=port, name=name, password=password) logger.debug("GraphHandler init") - self.graph = self.r.graph(name) self.name = name - self.m_metadata = ModelMetadata(host=host, port=port, name=name, password=password) - current_metadata = self.m_metadata.get_metadata() - if not current_metadata: - if graph_rotation is None: - graph_rotation = 3 - logger.debug("Create metadata key") - self.m_metadata.set_metadata(last_graph_version=0, graph_source_url=source_url, - graph_rotation=graph_rotation) - - -class VersionedGraphHandler(GraphHandler): - """ - Class that handle Versioned Graph Redis information - """ - - def __init__(self, host: str, port: int, name: str, version: int = None, password: str = None, source_url: str = "", - graph_rotation: int = None): - super().__init__(host=host, port=port, name=name, password=password, source_url=source_url, - graph_rotation=graph_rotation) - logger.debug("VersionedGraphHandler init") - self.version = version - if version is None: - self.version = self.m_metadata.get_last_graph_version() - self.versioned_name = ModelUtil.build_graph_version_name(self.name, self.version) - self.graph = self.r.graph(self.versioned_name) - + self.graph = self.r.graph(name) -class RotatedGraphHandler(VersionedGraphHandler): - """ - Class that handle Rotated Graph Redis information - """ + def do_if_graph_exist(function): + """ + Function decorator that run the function annotated if graph exists + :param function: the function annotated + """ - def __init__(self, host: str, port: int, name: str, password: str = None, version: int = None, source_url: str = "", - graph_rotation: int = None): - super().__init__(host=host, port=port, name=name, password=password, source_url=source_url, version=version, - graph_rotation=graph_rotation) - logger.debug("RotatedGraphHandler init") - self.graph_rotation = self.m_metadata.get_graph_rotation() + @functools.wraps(function) + def wrapper(self, *args, **kwargs): + if self.r.exists(self.name) != 0: + function(self, *args, **kwargs) + else: + raise Exception(f"{self.name} does not exist!") - def get_all_versions(self): - matching_graph_keys = self.r.keys(ModelUtil.build_graph_key_pattern(self.name)) - versions = [] - for graph_key in matching_graph_keys: - versions.append(graph_key.split(":")[-1]) - return versions + return wrapper - def handle_graph_rotation(func): + def handle_graph_replace(func): """ Decorator to do stuff then handle graph rotation (delete the oldest graph if the amount of graph is greater than graph rotation) """ def handle(self, *args, **kwargs): - graph_versions = self.get_all_versions() - - if len(graph_versions) > 0: - max_version = max([int(x) for x in graph_versions if x.isnumeric()]) - else: - max_version = 0 - # upgrade current graph to max_version+1 - self.version = max_version + 1 - self.version_name = ModelUtil.build_graph_version_name(self.name, self.version) - self.graph = self.r.graph(self.version_name) - logger.debug(f'Using graph updated version {self.version_name}') + self.graph = self.r.graph(f'{self.name}_tmp') + logger.debug(f'Using graph {self.name}_tmp for copy') # do function on new graph func(self, *args, **kwargs) - # get max version to manage case func not using (hence creating) graph - graph_versions = [int(v) for v in self.get_all_versions()] - graph_versions.sort() - graph_versions.reverse() - to_remove = graph_versions[int(self.graph_rotation):] - - # remove all older versions - for v in to_remove: - oldest_graph_version_to_delete = ModelUtil.build_graph_version_name(self.name, v) - self.r.delete(oldest_graph_version_to_delete) - logger.debug(f"Graph {oldest_graph_version_to_delete} deleted") - - # upgrade metadata last version to +1 after function execution - self.m_metadata.set_last_graph_version(self.version) + # action complete on graph_tmp with no error replacing graph by graph_tmp + self.r.copy(f'{self.name}_tmp', self.name) + # remove tmp graph + self.r.delete(f'{self.name}_tmp') + # set back the graph + self.graph = self.r.graph(self.name) return handle diff --git a/CosmoTech_Acceleration_Library/Modelops/core/common/redis_handler.py b/CosmoTech_Acceleration_Library/Modelops/core/common/redis_handler.py index 0cc648ff..ad5313ad 100644 --- a/CosmoTech_Acceleration_Library/Modelops/core/common/redis_handler.py +++ b/CosmoTech_Acceleration_Library/Modelops/core/common/redis_handler.py @@ -19,4 +19,3 @@ def __init__(self, host: str, port: int, name: str, password: str = None): self.name = name self.password = password self.r = redis.Redis(host=host, port=port, password=password, decode_responses=True) - self.metadata_key = name + "MetaData" diff --git a/CosmoTech_Acceleration_Library/Modelops/core/common/writer/CsvWriter.py b/CosmoTech_Acceleration_Library/Modelops/core/common/writer/CsvWriter.py index 9fbaa296..0b89f98d 100644 --- a/CosmoTech_Acceleration_Library/Modelops/core/common/writer/CsvWriter.py +++ b/CosmoTech_Acceleration_Library/Modelops/core/common/writer/CsvWriter.py @@ -40,29 +40,43 @@ def _to_cosmo_key(val: any) -> str: return val @staticmethod - def write_twin_data(export_dir: str, file_name: str, query_result: QueryResult, - delimiter: str = ',', quote_char: str = '\"') -> None: + def write_twin_data(export_dir: str, + file_name: str, + query_result: QueryResult, + delimiter: str = ',', + quote_char: str = '\"') -> None: headers = set() rows = [] for raw_data in query_result.result_set: row = {} # read all graph link properties for i in range(len(raw_data)): # TODO for the moment its only a len 1 list with the node - row.update({CsvWriter._to_cosmo_key(k): CsvWriter._to_csv_format(v) for k, v in raw_data[i].properties.items()}) + row.update({ + CsvWriter._to_cosmo_key(k): CsvWriter._to_csv_format(v) + for k, v in raw_data[i].properties.items() + }) headers.update(row.keys()) rows.append(row) output_file_name = f'{export_dir}/{file_name}.csv' logger.debug(f"Writing CSV file {output_file_name}") with open(output_file_name, 'w') as csvfile: - csv_writer = csv.DictWriter(csvfile, fieldnames=headers, delimiter=delimiter, quotechar=quote_char, quoting=csv.QUOTE_MINIMAL) + csv_writer = csv.DictWriter(csvfile, + fieldnames=headers, + delimiter=delimiter, + quotechar=quote_char, + quoting=csv.QUOTE_MINIMAL) csv_writer.writeheader() csv_writer.writerows(rows) logger.debug(f"... CSV file {output_file_name} has been written") @staticmethod - def write_relationship_data(export_dir: str, file_name: str, query_result: QueryResult, headers: list = [], - delimiter: str = ',', quote_char: str = '\"') -> None: + def write_relationship_data(export_dir: str, + file_name: str, + query_result: QueryResult, + headers: list = [], + delimiter: str = ',', + quote_char: str = '\"') -> None: headers = {'source', 'target'} rows = [] for raw_data in query_result.result_set: @@ -71,16 +85,24 @@ def write_relationship_data(export_dir: str, file_name: str, query_result: Query headers.update(row.keys()) rows.append(row) - output_file_name = export_dir + file_name + '.csv' + output_file_name = f'{export_dir}/{file_name}.csv' logger.debug(f"Writing CSV file {output_file_name}") with open(output_file_name, 'w') as csvfile: - csv_writer = csv.DictWriter(csvfile, fieldnames=headers, delimiter=delimiter, quotechar=quote_char, quoting=csv.QUOTE_MINIMAL) + csv_writer = csv.DictWriter(csvfile, + fieldnames=headers, + delimiter=delimiter, + quotechar=quote_char, + quoting=csv.QUOTE_MINIMAL) csv_writer.writeheader() csv_writer.writerows(rows) logger.debug(f"... CSV file {output_file_name} has been written") @staticmethod - def write_data(export_dir: str, file_name: str, input_rows: dict, delimiter: str = ',', quote_char: str = '\"') -> None: + def write_data(export_dir: str, + file_name: str, + input_rows: dict, + delimiter: str = ',', + quote_char: str = '\"') -> None: output_file_name = export_dir + file_name + '.csv' write_header = False if not os.path.exists(output_file_name): @@ -94,7 +116,11 @@ def write_data(export_dir: str, file_name: str, input_rows: dict, delimiter: str logger.info(f"Writing file {output_file_name} ...") with open(output_file_name, 'a') as csvfile: - csv_writer = csv.DictWriter(csvfile, fieldnames=headers, delimiter=delimiter, quotechar=quote_char, quoting=csv.QUOTE_MINIMAL) + csv_writer = csv.DictWriter(csvfile, + fieldnames=headers, + delimiter=delimiter, + quotechar=quote_char, + quoting=csv.QUOTE_MINIMAL) if write_header: csv_writer.writeheader() csv_writer.writerows(output_rows) diff --git a/CosmoTech_Acceleration_Library/Modelops/core/decorators/__init__.py b/CosmoTech_Acceleration_Library/Modelops/core/decorators/__init__.py deleted file mode 100644 index eacef447..00000000 --- a/CosmoTech_Acceleration_Library/Modelops/core/decorators/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -# Copyright (c) Cosmo Tech corporation. -# Licensed under the MIT license. diff --git a/CosmoTech_Acceleration_Library/Modelops/core/decorators/model_decorators.py b/CosmoTech_Acceleration_Library/Modelops/core/decorators/model_decorators.py deleted file mode 100644 index 9fc7a8a9..00000000 --- a/CosmoTech_Acceleration_Library/Modelops/core/decorators/model_decorators.py +++ /dev/null @@ -1,61 +0,0 @@ -# Copyright (c) Cosmo Tech corporation. -# Licensed under the MIT license. -import functools - -from CosmoTech_Acceleration_Library.Modelops.core.common.graph_handler import GraphHandler, VersionedGraphHandler - - -def update_last_version(function): - """ - Function decorator that update metadata last version after calling the function annotated - :param function: the function annotated - """ - - @functools.wraps(function) - def wrapper(*args, **kwargs): - self = args[0] - if isinstance(self, GraphHandler): - function(*args, **kwargs) - self.m_metadata.update_last_version() - else: - function(*args, **kwargs) - return wrapper - - -def update_last_modified_date(function): - """ - Function decorator that update metadata last modified date after calling the function annotated - :param function: the function annotated - """ - - @functools.wraps(function) - def wrapper(*args, **kwargs): - self = args[0] - if isinstance(self, GraphHandler): - function(*args, **kwargs) - self.m_metadata.update_last_modified_date() - else: - function(*args, **kwargs) - - return wrapper - - -def do_if_graph_exist(function): - """ - Function decorator that run the function annotated if versioned graph exists - :param function: the function annotated - """ - - @functools.wraps(function) - def wrapper(*args, **kwargs): - self = args[0] - version_graph_name = self.versioned_name - if isinstance(self, VersionedGraphHandler): - key_count = self.r.exists(version_graph_name) - if key_count != 0: - function(*args, **kwargs) - else: - raise Exception(f"{version_graph_name} does not exist!") - else: - function(*args, **kwargs) - return wrapper diff --git a/CosmoTech_Acceleration_Library/Modelops/core/io/model_exporter.py b/CosmoTech_Acceleration_Library/Modelops/core/io/model_exporter.py index aff4656d..fe71234a 100644 --- a/CosmoTech_Acceleration_Library/Modelops/core/io/model_exporter.py +++ b/CosmoTech_Acceleration_Library/Modelops/core/io/model_exporter.py @@ -6,30 +6,30 @@ import redis from functools import lru_cache -from CosmoTech_Acceleration_Library.Modelops.core.common.graph_handler import VersionedGraphHandler +from CosmoTech_Acceleration_Library.Modelops.core.common.graph_handler import GraphHandler from CosmoTech_Acceleration_Library.Modelops.core.common.writer.CsvWriter import CsvWriter -from CosmoTech_Acceleration_Library.Modelops.core.decorators.model_decorators import do_if_graph_exist from CosmoTech_Acceleration_Library.Modelops.core.io.model_reader import ModelReader logger = logging.getLogger(__name__) -class ModelExporter(VersionedGraphHandler): +class ModelExporter(GraphHandler): """ Model Exporter for cached data """ - def __init__(self, host: str, port: int, name: str, version: int, password: str = None, export_dir: str = "/"): - super().__init__(host=host, port=port, name=name, version=version, password=password) + def __init__(self, host: str, port: int, name: str, password: str = None, export_dir: str = "/"): + super().__init__(host=host, port=port, name=name, password=password) Path(export_dir).mkdir(parents=True, exist_ok=True) self.export_dir = export_dir - self.mr = ModelReader(host=host, port=port, name=name, password=password, version=version) + + self.mr = ModelReader(host=host, port=port, name=name, password=password) self.labels = [label[0] for label in self.graph.labels()] self.relationships = [relation[0] for relation in self.graph.relationship_types()] self.already_exported_nodes = {} self.already_exported_edges = [] - @do_if_graph_exist + @GraphHandler.do_if_graph_exist def export_all_twins(self): """ Export all twins @@ -58,7 +58,7 @@ def export_all_twins(self): logger.debug(f"Twins exported :{twin_name}") logger.debug("... End exporting twins") - @do_if_graph_exist + @GraphHandler.do_if_graph_exist def export_all_relationships(self): """ Export all relationships @@ -88,7 +88,7 @@ def export_all_relationships(self): logger.debug(f"Relationships exported :{relationship_name}") logger.debug("... End exporting relationships") - @do_if_graph_exist + @GraphHandler.do_if_graph_exist def export_all_data(self): """ Export all data @@ -97,7 +97,7 @@ def export_all_data(self): self.export_all_twins() self.export_all_relationships() - @do_if_graph_exist + @GraphHandler.do_if_graph_exist def export_from_queries(self, queries: list): """ Export data from queries diff --git a/CosmoTech_Acceleration_Library/Modelops/core/io/model_importer.py b/CosmoTech_Acceleration_Library/Modelops/core/io/model_importer.py index 45e45295..d9954d7d 100644 --- a/CosmoTech_Acceleration_Library/Modelops/core/io/model_importer.py +++ b/CosmoTech_Acceleration_Library/Modelops/core/io/model_importer.py @@ -4,18 +4,17 @@ from redisgraph_bulk_loader.bulk_insert import bulk_insert -from CosmoTech_Acceleration_Library.Modelops.core.common.graph_handler import RotatedGraphHandler -from CosmoTech_Acceleration_Library.Modelops.core.utils.model_util import ModelUtil +from CosmoTech_Acceleration_Library.Modelops.core.common.graph_handler import GraphHandler logger = logging.getLogger(__name__) -class ModelImporter(RotatedGraphHandler): +class ModelImporter(GraphHandler): """ Model Exporter for cached data """ - @RotatedGraphHandler.handle_graph_rotation + @GraphHandler.handle_graph_replace def bulk_import(self, twin_file_paths: list = [], relationship_file_paths: list = [], enforce_schema: bool = False): """ Import all csv data @@ -44,7 +43,7 @@ def bulk_import(self, twin_file_paths: list = [], relationship_file_paths: list command_parameters.append('--relations') command_parameters.append(relationship_file_path) - command_parameters.append(ModelUtil.build_graph_version_name(self.name, self.version)) + command_parameters.append(self.name) logger.debug(command_parameters) # TODO: Think about use '--index Label:Property' command parameters to create indexes on default id properties try: diff --git a/CosmoTech_Acceleration_Library/Modelops/core/io/model_metadata.py b/CosmoTech_Acceleration_Library/Modelops/core/io/model_metadata.py deleted file mode 100644 index 13974df1..00000000 --- a/CosmoTech_Acceleration_Library/Modelops/core/io/model_metadata.py +++ /dev/null @@ -1,152 +0,0 @@ -# Copyright (c) Cosmo Tech corporation. -# Licensed under the MIT license. -import logging -from datetime import datetime - -from CosmoTech_Acceleration_Library.Modelops.core.common.redis_handler import RedisHandler -from CosmoTech_Acceleration_Library.Modelops.core.utils.model_util import ModelUtil - -logger = logging.getLogger(__name__) - - -class ModelMetadata(RedisHandler): - """ - Model Metadata management class for cached data - """ - - last_modified_date_key = "lastModifiedDate" - last_version_key = "lastVersion" - source_url_key = "adtUrl" - graph_name_key = "graphName" - graph_rotation_key = "graphRotation" - - def get_metadata(self) -> dict: - """ - Get the metadata of the graph - :return: the dict containing all graph metadata - """ - return self.r.hgetall(self.metadata_key) - - def get_last_graph_version(self) -> str: - """ - Get the current last version of the graph - :return: the graph last version - """ - return self.get_metadata()[self.last_version_key] - - def get_graph_name(self) -> str: - """ - Get the graph's name - :return: the graph's name - """ - return self.name - - def get_graph_source_url(self) -> str: - """ - Get the datasource of the graph - :return: the datasource of the graph - """ - return self.get_metadata()[self.source_url_key] - - def get_graph_rotation(self) -> str: - """ - Get the graph rotation of the graph - :return: the graph rotation of the graph - """ - return self.get_metadata()[self.graph_rotation_key] - - def get_last_modified_date(self) -> datetime: - """ - Get the last modified date of the graph - :return: the last modified date of the graph - """ - metadata_last_version = self.get_metadata()[self.last_modified_date_key] - return ModelUtil.convert_str_to_datetime(metadata_last_version) - - def set_all_metadata(self, metadata: dict): - """ - Set the metadata of the graph - :param metadata the metadata to set - :raise Exception if the current version is greater than the new one - """ - current_metadata = self.get_metadata() - if self.last_version_key in current_metadata: - current_version = int(self.get_last_graph_version()) - new_version = int(metadata[self.last_version_key]) - if new_version > current_version: - logger.debug(f"Metatadata to set : {metadata}") - self.r.hmset(self.metadata_key, metadata) - else: - raise Exception(f"The current version {current_version} is equal or greater than the version to set: {new_version}") - else: - logger.debug(f"Metatadata to set : {metadata}") - self.r.hmset(self.metadata_key, metadata) - - def set_metadata(self, - last_graph_version: int, - graph_source_url: str, - graph_rotation: int) -> dict: - """ - Set the metadata of the graph - :param last_graph_version the new version - :param graph_source_url the source url - :param graph_rotation the graph rotation - :return the metadata set - :raise Exception if the current version is greater than the new one - """ - metadata = { - self.last_version_key: str(last_graph_version), - self.graph_name_key: self.name, - self.source_url_key: graph_source_url, - self.graph_rotation_key: str(graph_rotation), - self.last_modified_date_key: ModelUtil.convert_datetime_to_str(datetime.utcnow()) - } - logger.debug(f"Metatadata to set : {metadata}") - self.set_all_metadata(metadata=metadata) - - def set_last_graph_version(self, last_graph_version: int): - """ - Set the current last version of the graph - :param last_graph_version the new version - """ - self.r.hset(self.metadata_key, self.last_version_key, str(last_graph_version)) - logger.debug(f"Graph last_graph_version to set : {str(last_graph_version)}") - self.update_last_modified_date() - - def set_graph_source_url(self, graph_source_url: str): - """ - Set the datasource of the graph - :param graph_source_url the source url - """ - self.r.hset(self.metadata_key, self.source_url_key, graph_source_url) - logger.debug(f"Graph source_url to set : {str(graph_source_url)}") - self.update_last_modified_date() - - def set_graph_rotation(self, graph_rotation: int): - """ - Set the graph rotation of the graph - :param graph_rotation the graph rotation - """ - self.r.hset(self.metadata_key, self.graph_rotation_key, str(graph_rotation)) - logger.debug(f"Graph graph_rotation to set : {str(graph_rotation)}") - self.update_last_modified_date() - - def update_last_modified_date(self): - """ - Update the last modified date of the graph - """ - self.r.hset(self.metadata_key, self.last_modified_date_key, ModelUtil.convert_datetime_to_str(datetime.utcnow())) - - def update_last_version(self): - """ - Update the last version of the graph - """ - current_metadata = self.get_metadata() - if self.last_version_key in current_metadata: - current_version = int(self.get_last_graph_version()) - new_version = current_version + 1 - self.set_last_graph_version(str(new_version)) - self.update_last_modified_date() - else: - self.set_last_graph_version("0") - self.update_last_modified_date() diff --git a/CosmoTech_Acceleration_Library/Modelops/core/io/model_reader.py b/CosmoTech_Acceleration_Library/Modelops/core/io/model_reader.py index e36a037a..212b2cd9 100644 --- a/CosmoTech_Acceleration_Library/Modelops/core/io/model_reader.py +++ b/CosmoTech_Acceleration_Library/Modelops/core/io/model_reader.py @@ -2,14 +2,14 @@ # Licensed under the MIT license. import logging -from CosmoTech_Acceleration_Library.Modelops.core.common.graph_handler import VersionedGraphHandler +from CosmoTech_Acceleration_Library.Modelops.core.common.graph_handler import GraphHandler from CosmoTech_Acceleration_Library.Modelops.core.utils.model_util import ModelUtil from redis.commands.graph.query_result import QueryResult logger = logging.getLogger(__name__) -class ModelReader(VersionedGraphHandler): +class ModelReader(GraphHandler): """ Model Reader for cached data """ diff --git a/CosmoTech_Acceleration_Library/Modelops/core/io/model_writer.py b/CosmoTech_Acceleration_Library/Modelops/core/io/model_writer.py index 98a5d6ef..6cf52db0 100644 --- a/CosmoTech_Acceleration_Library/Modelops/core/io/model_writer.py +++ b/CosmoTech_Acceleration_Library/Modelops/core/io/model_writer.py @@ -2,19 +2,17 @@ # Licensed under the MIT license. import logging -from CosmoTech_Acceleration_Library.Modelops.core.common.graph_handler import VersionedGraphHandler -from CosmoTech_Acceleration_Library.Modelops.core.decorators.model_decorators import update_last_modified_date +from CosmoTech_Acceleration_Library.Modelops.core.common.graph_handler import GraphHandler from CosmoTech_Acceleration_Library.Modelops.core.utils.model_util import ModelUtil logger = logging.getLogger(__name__) -class ModelWriter(VersionedGraphHandler): +class ModelWriter(GraphHandler): """ Model Writer for cached data """ - @update_last_modified_date def create_twin(self, twin_type: str, properties: dict): """ Create a twin @@ -25,7 +23,6 @@ def create_twin(self, twin_type: str, properties: dict): logger.debug(f"Query: {create_query}") self.graph.query(create_query) - @update_last_modified_date def create_relationship(self, relationship_type: str, properties: dict): """ Create a relationship diff --git a/CosmoTech_Acceleration_Library/Modelops/core/tests/__init__.py b/CosmoTech_Acceleration_Library/Modelops/core/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/CosmoTech_Acceleration_Library/Modelops/core/tests/redis_test.py b/CosmoTech_Acceleration_Library/Modelops/core/tests/redis_test.py new file mode 100644 index 00000000..58a52c68 --- /dev/null +++ b/CosmoTech_Acceleration_Library/Modelops/core/tests/redis_test.py @@ -0,0 +1,181 @@ +import pytest +import redis +import csv +import os + +from redis.commands.graph import Node +from redis.commands.graph import Edge + +from ..io.model_reader import ModelReader +from ..io.model_writer import ModelWriter +from ..io.model_importer import ModelImporter +from ..io.model_exporter import ModelExporter + +GRAPH_NAME = 'test_graph' + + +def ping_redis(host, port): + r = redis.Redis(host=host, port=port) + return r.ping() + + +@pytest.fixture(scope='session') +def redis_service(docker_ip, docker_services): + """ensure redis is up and running""" + + host = docker_ip + port = docker_services.port_for("redis", 6379) + redis_client = redis.Redis(host=host, port=port) + + docker_services.wait_until_responsive(timeout=5, pause=0.2, check=redis_client.ping) + return {"host": host, "port": port} + + +@pytest.fixture +def redis_client(redis_service): + return redis.Redis(redis_service["host"], redis_service["port"]) + + +def test_redis(redis_client): + r = redis_client + assert r.ping() + + +@pytest.fixture +def redis_graph_setup(redis_client): + graphs = [] + + def _redis_graph_setup(name): + g = redis_client.graph(f'{name}') + graphs.append(g) + return g + + yield _redis_graph_setup + + graphs[0].delete() + + +def test_io_model_reader(redis_graph_setup, redis_service): + + g = redis_graph_setup(GRAPH_NAME) + node1 = Node(label="node", properties={"prop": "val"}) + g.add_node(node1) + node2 = Node(label="node", properties={"prop": "val"}) + g.add_node(node2) + rel1 = Edge(node1, "rel", node2, properties={"rel_prop": "val"}) + g.add_edge(rel1) + g.flush() + + mr = ModelReader(redis_service['host'], redis_service['port'], GRAPH_NAME) + assert mr.exists(g.name) + + # twin test + assert ['node'] == mr.get_twin_types() + + result = mr.get_twins_by_type('node').result_set + assert 2 == len(result) + assert node1.label == result[0][0].label + assert node1.properties == result[0][0].properties + assert node2.label == result[1][0].label + assert node2.properties == result[1][0].properties + + result = mr.get_twin_properties_by_type('node') + assert ['prop'] == result + + # rel test + assert ['rel'] == mr.get_relationship_types() + + result = mr.get_relationships_by_type('rel').result_set + assert 1 == len(result) + assert rel1.relation == result[0][2].relation + assert rel1.properties == result[0][2].properties + + result = mr.get_relationship_properties_by_type('rel') + assert ['source', 'target', 'rel_prop'] == result + + +def test_io_model_writer(redis_graph_setup, redis_service): + + mw = ModelWriter(redis_service['host'], redis_service['port'], GRAPH_NAME) + mw.create_twin('node', {'id': 'node_id1', 'prop': 'val'}) + mw.create_twin('node', {'id': 'node_id2', 'prop': 'val'}) + mw.create_relationship('rel', {'src': 'node_id1', 'dest': 'node_id2', 'prop': 'val'}) + + g = redis_graph_setup(GRAPH_NAME) + assert [['node']] == g.labels() + + result = g.query("MATCH (n:node) return n").result_set + assert 2 == len(result) + assert 'node' == result[0][0].label + assert {'id': 'node_id1', 'prop': 'val'} == result[0][0].properties + assert 'node' == result[1][0].label + assert {'id': 'node_id2', 'prop': 'val'} == result[1][0].properties + + result = g.query("MATCH ()-[r:rel]->() return r").result_set + assert 1 == len(result) + assert 'rel' == result[0][0].relation + assert {'src': 'node_id1', 'dest': 'node_id2', 'prop': 'val'} == result[0][0].properties + + +def test_io_model_importer(redis_client, redis_graph_setup, redis_service, tmpdir): + + # create csv for import + path_nodes = os.path.join(tmpdir, 'nodes.csv') + with open(path_nodes, 'w') as f: + csvw = csv.DictWriter(f, ['id', 'prop']) + csvw.writeheader() + csvw.writerow({'id': 'node_id1', 'prop': 'val'}) + csvw.writerow({'id': 'node_id2', 'prop': 'val'}) + + path_edges = os.path.join(tmpdir, 'edges.csv') + with open(path_edges, 'w') as f: + csvw = csv.DictWriter(f, ['src', 'dest', 'prop']) + csvw.writeheader() + csvw.writerow({'src': 'node_id1', 'dest': 'node_id2', 'prop': 'val'}) + + mi = ModelImporter(redis_service['host'], redis_service['port'], GRAPH_NAME) + mi.bulk_import([path_nodes], [path_edges]) + # double call to validate replacement management + mi.bulk_import([path_nodes], [path_edges]) + + g = redis_graph_setup(GRAPH_NAME) + result = g.query("MATCH (n:nodes) return n").result_set + assert 2 == len(result) + assert 'nodes' == result[0][0].label + assert {'id': 'node_id1', 'prop': 'val'} == result[0][0].properties + assert 'nodes' == result[1][0].label + assert {'id': 'node_id2', 'prop': 'val'} == result[1][0].properties + + result = g.query("MATCH ()-[r:edges]->() return r").result_set + assert 1 == len(result) + assert 'edges' == result[0][0].relation + assert {'prop': 'val'} == result[0][0].properties + + +def test_io_model_exporter(redis_graph_setup, redis_service, tmpdir): + + g = redis_graph_setup(GRAPH_NAME) + node1 = Node(label="node", properties={"id": "node1", "prop": "val"}) + g.add_node(node1) + node2 = Node(label="node", properties={"id": "node2", "prop": "val"}) + g.add_node(node2) + rel1 = Edge(node1, "rel", node2, properties={"rel_prop": "val"}) + g.add_edge(rel1) + g.flush() + + me = ModelExporter(redis_service['host'], redis_service['port'], GRAPH_NAME, export_dir=tmpdir) + me.export_all_data() + + assert ['rel.csv', 'node.csv'] == os.listdir(tmpdir) + + with open(os.path.join(tmpdir, 'node.csv')) as f: + csvr = csv.DictReader(f) + assert set(['id', 'prop']) == set(csvr.fieldnames) + rows = list(csvr) + assert 2 == len(rows) + + with open(os.path.join(tmpdir, 'rel.csv')) as f: + csvr = csv.DictReader(f) + assert set(['source', 'target', 'rel_prop']) == set(csvr.fieldnames) + rows = list(csvr) + assert 1 == len(rows) diff --git a/CosmoTech_Acceleration_Library/Modelops/core/utils/model_util.py b/CosmoTech_Acceleration_Library/Modelops/core/utils/model_util.py index 9ccd386b..1238a28a 100644 --- a/CosmoTech_Acceleration_Library/Modelops/core/utils/model_util.py +++ b/CosmoTech_Acceleration_Library/Modelops/core/utils/model_util.py @@ -69,8 +69,7 @@ def create_twin_query(twin_type: str, properties: dict) -> str: if ModelUtil.dt_id_key in properties: cypher_params = ModelUtil.dict_to_cypher_parameters(properties) return f"CREATE (:{twin_type} {cypher_params})" - raise Exception( - f"When you create a twin, you should define at least {ModelUtil.dt_id_key} properties ") + raise Exception(f"When you create a twin, you should define at least {ModelUtil.dt_id_key} properties ") @staticmethod def create_relationship_query(relationship_type: str, properties: dict) -> str: @@ -87,7 +86,8 @@ def create_relationship_query(relationship_type: str, properties: dict) -> str: f"AND m.{ModelUtil.dt_id_key} = '{properties.get(ModelUtil.dest_key)}' " \ f"CREATE (n)-[r:{relationship_type} {cypher_params}]->(m) RETURN r" raise Exception( - f"When you create a relationship, you should define at least {ModelUtil.src_key} and {ModelUtil.dest_key} properties ") + f"When you create a relationship, you should define at least {ModelUtil.src_key} and {ModelUtil.dest_key} properties " + ) @staticmethod def dict_to_json(obj: dict) -> str: @@ -148,16 +148,6 @@ def convert_str_to_datetime(date_str: str) -> datetime: date_time_obj = datetime.strptime(date_str, '%Y/%m/%d - %H:%M:%S') return date_time_obj - @staticmethod - def build_graph_version_name(graph_name: str, version: int) -> str: - """ - Build versioned graph name - :param graph_name: the graph name - :param version: the version - :return: the versioned graph name - """ - return graph_name + ":" + str(version) - @staticmethod def build_graph_key_pattern(graph_name: str) -> str: return graph_name + ":*" diff --git a/CosmoTech_Acceleration_Library/Modelops/core/utils/tests/model_util_test.py b/CosmoTech_Acceleration_Library/Modelops/core/utils/tests/model_util_test.py index 1990be4a..5f1c3445 100644 --- a/CosmoTech_Acceleration_Library/Modelops/core/utils/tests/model_util_test.py +++ b/CosmoTech_Acceleration_Library/Modelops/core/utils/tests/model_util_test.py @@ -81,8 +81,7 @@ def test_create_twin_query(self): def test_create_twin_query_Exception(self): twin_name = 'Twin_name' - self.assertRaises(Exception, - self.model_util.create_twin_query, twin_name, self.expected_simple_parameters) + self.assertRaises(Exception, self.model_util.create_twin_query, twin_name, self.expected_simple_parameters) def test_create_relationship_query(self): source_id = 'Node1' @@ -94,16 +93,8 @@ def test_create_relationship_query(self): def test_create_relationship_query_Exception(self): relation_name = 'Relation_Name' - self.assertRaises(Exception, - self.model_util.create_relationship_query, relation_name, self.expected_simple_parameters) - - def test_unjsonify_without_jsonstring(self): - new_value = self.model_util.unjsonify(self.relationship_simple_parameters) - self.assertEqual(self.relationship_simple_parameters, new_value) - - def test_unjsonify_with_jsonstring(self): - new_value = self.model_util.unjsonify(self.dict_with_simple_json_string) - self.assertEqual(self.relationship_simple_parameters, new_value) + self.assertRaises(Exception, self.model_util.create_relationship_query, relation_name, + self.expected_simple_parameters) if __name__ == '__main__': diff --git a/CosmoTech_Acceleration_Library/__init__.py b/CosmoTech_Acceleration_Library/__init__.py index 8a661a2c..a42e334e 100644 --- a/CosmoTech_Acceleration_Library/__init__.py +++ b/CosmoTech_Acceleration_Library/__init__.py @@ -1,4 +1,4 @@ # Copyright (c) Cosmo Tech corporation. # Licensed under the MIT license. -__version__ = '0.3.1' +__version__ = '0.4.0' diff --git a/tests/docker-compose.yml b/tests/docker-compose.yml new file mode 100644 index 00000000..772c71e3 --- /dev/null +++ b/tests/docker-compose.yml @@ -0,0 +1,5 @@ +services: + redis: + image: redis/redis-stack-server + ports: + - "6379:6379"