diff --git a/mcpgateway/main.py b/mcpgateway/main.py index 4bde2f993..43679835a 100644 --- a/mcpgateway/main.py +++ b/mcpgateway/main.py @@ -3628,7 +3628,17 @@ async def delete_gateway(gateway_id: str, db: Session = Depends(get_db), user=De logger.debug(f"User '{user}' requested deletion of gateway {gateway_id}") try: user_email = user.get("email") if isinstance(user, dict) else str(user) + current = await gateway_service.get_gateway(db, gateway_id) + has_resources = bool(current.capabilities.get("resources")) await gateway_service.delete_gateway(db, gateway_id, user_email=user_email) + + # If the gateway had resources and was successfully deleted, invalidate + # the whole resource cache. This is needed since the cache holds both + # individual resources and the full listing which will also need to be + # invalidated. + if has_resources: + await invalidate_resource_cache() + return {"status": "success", "message": f"Gateway {gateway_id} deleted"} except PermissionError as e: raise HTTPException(status_code=403, detail=str(e)) diff --git a/tests/unit/mcpgateway/test_main.py b/tests/unit/mcpgateway/test_main.py index 3a6d8a295..af1b6ac7d 100644 --- a/tests/unit/mcpgateway/test_main.py +++ b/tests/unit/mcpgateway/test_main.py @@ -147,7 +147,6 @@ def camel_to_snake_tool(d: dict) -> dict: "url": "http://example.com", "description": "A test gateway", "transport": "SSE", - "auth_type": "none", "created_at": "2023-01-01T00:00:00+00:00", "updated_at": "2023-01-01T00:00:00+00:00", "enabled": True, @@ -947,14 +946,30 @@ def test_update_gateway_endpoint(self, mock_update, test_client, auth_headers): mock_update.assert_called_once() @patch("mcpgateway.main.gateway_service.delete_gateway") - def test_delete_gateway_endpoint(self, mock_delete, test_client, auth_headers): - """Test deleting a gateway.""" + @patch("mcpgateway.main.gateway_service.get_gateway") + def test_delete_gateway_endpoint_no_resources(self, mock_get, mock_delete, test_client, auth_headers): + """Test deleting a gateway that doesn't have resources.""" mock_delete.return_value = None + mock_get.return_value.capabilities = {} response = test_client.delete("/gateways/1", headers=auth_headers) assert response.status_code == 200 assert response.json()["status"] == "success" mock_delete.assert_called_once() + @patch("mcpgateway.main.gateway_service.delete_gateway") + @patch("mcpgateway.main.gateway_service.get_gateway") + @patch("mcpgateway.main.invalidate_resource_cache") + def test_delete_gateway_endpoint_with_resources(self, mock_invalidate_cache, mock_get, mock_delete, test_client, auth_headers): + """Test deleting a gateway that does have resources.""" + mock_delete.return_value = None + mock_get.return_value = MagicMock() + mock_get.return_value.capabilities = {"resources": {"some": "thing"}} + response = test_client.delete("/gateways/1", headers=auth_headers) + assert response.status_code == 200 + assert response.json()["status"] == "success" + mock_delete.assert_called_once() + mock_invalidate_cache.assert_called_once() + @patch("mcpgateway.main.gateway_service.toggle_gateway_status") def test_toggle_gateway_status(self, mock_toggle, test_client, auth_headers): """Test toggling gateway active/inactive status.""" @@ -1006,9 +1021,11 @@ def test_update_gateway_endpoint(self, mock_update, test_client, auth_headers): mock_update.assert_called_once() @patch("mcpgateway.main.gateway_service.delete_gateway") - def test_delete_gateway_endpoint(self, mock_delete, test_client, auth_headers): + @patch("mcpgateway.main.gateway_service.get_gateway") + def test_delete_gateway_endpoint(self, mock_get, mock_delete, test_client, auth_headers): """Test deleting a gateway.""" mock_delete.return_value = None + mock_get.return_value.capabilities = {} response = test_client.delete("/gateways/1", headers=auth_headers) assert response.status_code == 200 assert response.json()["status"] == "success" @@ -1220,7 +1237,7 @@ def test_sse_endpoint(self, mock_transport_class, mock_respond, mock_add_session mock_transport.create_sse_response.return_value = MagicMock() mock_transport_class.return_value = mock_transport - response = test_client.get("/sse", headers=auth_headers) + test_client.get("/sse", headers=auth_headers) # Note: This test may need adjustment based on actual SSE implementation # The exact assertion will depend on how SSE responses are structured