diff --git a/tests/test_post.py b/tests/test_post.py index e3848cf4..af9e5832 100644 --- a/tests/test_post.py +++ b/tests/test_post.py @@ -18,48 +18,12 @@ def patch_mock_api_helper(mocker) -> None: ) -class MockFieldData: - def __init__(self, solver_data, field_info): - self._session_data = solver_data - self._request_to_serve = {"surf": [], "scalar": [], "vector": []} - self._field_info = field_info +class MockFieldTransaction: + def __init__(self, session_data, field_request): + self.service = session_data + self.fields_request = field_request - def get_surface_data( - self, - surface_name: str, - data_type: Union[SurfaceDataType, int], - overset_mesh: Optional[bool] = False, - ) -> Dict: - surfaces_info = self._field_info().get_surfaces_info() - surface_ids = surfaces_info[surface_name]["surface_id"] - self._request_to_serve["surf"].append( - ( - surface_ids, - overset_mesh, - data_type == SurfaceDataType.Vertices, - data_type == SurfaceDataType.FacesConnectivity, - data_type == SurfaceDataType.FacesCentroid, - data_type == SurfaceDataType.FacesNormal, - ) - ) - enum_to_field_name = { - SurfaceDataType.FacesConnectivity: "faces", - SurfaceDataType.Vertices: "vertices", - SurfaceDataType.FacesCentroid: "centroid", - SurfaceDataType.FacesNormal: "face-normal", - } - - tag_id = 0 - if overset_mesh: - tag_id = self._payloadTags[FieldDataProtoModule.PayloadTag.OVERSET_MESH] - return { - surface_id: self._session_data["fields"][tag_id][surface_id][ - enum_to_field_name[data_type] - ] - for surface_id in surface_ids - } - - def add_get_surfaces_request( + def add_surfaces_request( self, surface_ids: List[int], overset_mesh: bool = False, @@ -68,7 +32,7 @@ def add_get_surfaces_request( provide_faces_centroid=False, provide_faces_normal=False, ) -> None: - self._request_to_serve["surf"].append( + self.fields_request["surf"].append( ( surface_ids, overset_mesh, @@ -79,27 +43,27 @@ def add_get_surfaces_request( ) ) - def add_get_scalar_fields_request( + def add_scalar_fields_request( self, surface_ids: List[int], field_name: str, node_value: Optional[bool] = True, boundary_value: Optional[bool] = False, ) -> None: - self._request_to_serve["scalar"].append( + self.fields_request["scalar"].append( (surface_ids, field_name, node_value, boundary_value) ) - def add_get_vector_fields_request( + def add_vector_fields_request( self, surface_ids: List[int], - vector_field: Optional[str] = "velocity", + field_name: str, ) -> None: - self._request_to_serve["vector"].append((surface_ids, vector_field)) + self.fields_request["vector"].append((surface_ids, field_name)) def get_fields(self) -> Dict[int, Dict]: fields = {} - for request_type, requests in self._request_to_serve.items(): + for request_type, requests in self.fields_request.items(): for request in requests: if request_type == "surf": tag_id = 0 @@ -118,12 +82,55 @@ def get_fields(self) -> Dict[int, Dict]: surface_requests = field_requests.get(surf_id) if not surface_requests: surface_requests = field_requests[surf_id] = {} - surface_requests.update( - self._session_data["fields"][tag_id][surf_id] - ) + surface_requests.update(self.service["fields"][tag_id][surf_id]) return fields +class MockFieldData: + def __init__(self, solver_data, field_info): + self._session_data = solver_data + self._request_to_serve = {"surf": [], "scalar": [], "vector": []} + self._field_info = field_info + + def new_transaction(self): + return MockFieldTransaction(self._session_data, self._request_to_serve) + + def get_surface_data( + self, + surface_name: str, + data_type: Union[SurfaceDataType, int], + overset_mesh: Optional[bool] = False, + ) -> Dict: + surfaces_info = self._field_info().get_surfaces_info() + surface_ids = surfaces_info[surface_name]["surface_id"] + self._request_to_serve["surf"].append( + ( + surface_ids, + overset_mesh, + data_type == SurfaceDataType.Vertices, + data_type == SurfaceDataType.FacesConnectivity, + data_type == SurfaceDataType.FacesCentroid, + data_type == SurfaceDataType.FacesNormal, + ) + ) + enum_to_field_name = { + SurfaceDataType.FacesConnectivity: "faces", + SurfaceDataType.Vertices: "vertices", + SurfaceDataType.FacesCentroid: "centroid", + SurfaceDataType.FacesNormal: "face-normal", + } + + tag_id = 0 + if overset_mesh: + tag_id = self._payloadTags[FieldDataProtoModule.PayloadTag.OVERSET_MESH] + return { + surface_id: self._session_data["fields"][tag_id][surface_id][ + enum_to_field_name[data_type] + ] + for surface_id in surface_ids + } + + class MockFieldInfo: def __init__(self, solver_data): self._session_data = solver_data @@ -186,16 +193,18 @@ def test_field_api(): # Get vertices vertices_data = field_data.get_surface_data("wall", SurfaceDataType.Vertices) + transaction = field_data.new_transaction() + # Get multiple fields - field_data.add_get_surfaces_request( + transaction.add_surfaces_request( surfaces_id[:1], provide_vertices=True, provide_faces_centroid=True, provide_faces=False, ) - field_data.add_get_scalar_fields_request(surfaces_id[:1], "temperature", True) - field_data.add_get_scalar_fields_request(surfaces_id[:1], "temperature", False) - fields = field_data.get_fields() + transaction.add_scalar_fields_request(surfaces_id[:1], "temperature", True) + transaction.add_scalar_fields_request(surfaces_id[:1], "temperature", False) + fields = transaction.get_fields() surface_tag = 0 vertices = fields[surface_tag][surfaces_id[0]]["vertices"]