diff --git a/cid/cursor.py b/cid/cursor.py index 41e3de9..de6fbc4 100644 --- a/cid/cursor.py +++ b/cid/cursor.py @@ -1,6 +1,11 @@ +from django.conf import settings from .locals import get_cid +def _base_comment_formatter(cid): + return 'cid: {}'.format(cid) + + class CidCursorWrapper(object): """ A cursor wrapper that attempts to add a cid comment to each query @@ -24,10 +29,13 @@ def __exit__(self, type, value, traceback): self.close() def add_comment(self, sql): + cid_sql_formatter = getattr( + settings, 'CID_SQL_COMMENT_FORMATTER', _base_comment_formatter + ) cid = get_cid() if cid: cid = cid.replace('/*', '\/\*').replace('*/', '\*\/') - return "/* cid: {} */\n{}".format(cid, sql) + return "/* {} */\n{}".format(cid_sql_formatter(cid), sql) return sql # The following methods cannot be implemented in __getattr__, because the diff --git a/docs/settings.rst b/docs/settings.rst index 5bfb0f2..230c18c 100644 --- a/docs/settings.rst +++ b/docs/settings.rst @@ -16,3 +16,7 @@ Settings ``CID_GENERATE`` Tell the cid middleware to generate a correlation id if it doesn't already exist. Default value: ``False``. + + ``CID_SQL_COMMENT_FORMATTER`` + Function taking a cid as argument and returning the str that will be + added as a SQL comment. Default returned value: ``cid: MY_CID``. diff --git a/tests/test_cursor.py b/tests/test_cursor.py index d3187bb..256bea4 100644 --- a/tests/test_cursor.py +++ b/tests/test_cursor.py @@ -1,4 +1,5 @@ from django.test import TestCase +from django.test.utils import override_settings from mock import Mock, patch from cid.cursor import CidCursorWrapper @@ -21,6 +22,16 @@ def test_adds_comment(self, get_cid): self.cursor_wrapper.add_comment("SELECT 1;") ) + @override_settings(CID_SQL_COMMENT_FORMATTER=lambda cid: 'correlation_id={}'.format(cid)) + @patch('cid.cursor.get_cid') + def test_adds_comment_setting_overriden(self, get_cid): + get_cid.return_value = 'testing-cursor' + expected = "/* correlation_id=testing-cursor */\nSELECT 1;" + self.assertEqual( + expected, + self.cursor_wrapper.add_comment("SELECT 1;") + ) + @patch('cid.cursor.get_cid') def test_no_comment_when_cid_is_none(self, get_cid): get_cid.return_value = None