diff --git a/mfr/core/utils.py b/mfr/core/utils.py index b248b11cf..b6d74d0b4 100644 --- a/mfr/core/utils.py +++ b/mfr/core/utils.py @@ -1,3 +1,4 @@ +import pkg_resources from stevedore import driver from mfr.core import exceptions @@ -97,6 +98,46 @@ def make_renderer(name, metadata, file_path, url, assets_url, export_url): } ) +def get_renderer_name(name): + """ Return the name of the renderer used for a certain file extension. + + :param str name: The name of the extension to get the renderer name for. (.jpg, .docx, etc) + + :rtype : `str` + """ + + # This can give back empty tuples + try: + entry_attrs = pkg_resources.iter_entry_points(group='mfr.renderers', name=name) + + # ep.attrs is a tuple of attributes. There should only ever be one or `None`. + # None case occurs when trying to render an unsupported file type + # entry_attrs is an iterable object, so we turn into a list to index it + return list(entry_attrs)[0].attrs[0] + + # This means the file type is not supported. Just return the blank string so `make_renderers` can + # log a real exception with all the variables and names it has + except IndexError: + return '' + +def get_exporter_name(name): + """ Return the name of the exporter used for a certain file extension. + + :param str name: The name of the extension to get the exporter name for. (.jpg, .docx, etc) + + :rtype : `str` + """ + + # `make_renderer` should have already caught if an extension doesn't exist. + + # should be a list of length one, since we don't have multiple entrypoints per group + entry_attrs = pkg_resources.iter_entry_points(group='mfr.exporters', name=name) + + # ep.attrs is a tuple of attributes. There should only ever be one or `None`. + # For our case however there shouldn't be `None` + return list(entry_attrs)[0].attrs[0] + + def sizeof_fmt(num, suffix='B'): for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: if abs(num) < 1000.0: diff --git a/mfr/server/handlers/export.py b/mfr/server/handlers/export.py index a9ba84f3a..25dffe629 100644 --- a/mfr/server/handlers/export.py +++ b/mfr/server/handlers/export.py @@ -29,11 +29,12 @@ async def prepare(self): " appropriate extension") # TODO: do we need to catch exceptions for decoding? self.format = format[0].decode('utf-8') + self.exporter_name = utils.get_exporter_name(self.metadata.ext) self.cache_file_id = '{}.{}'.format(self.metadata.unique_key, self.format) self.cache_file_path = await self.cache_provider.validate_path( - '/export/{}'.format(self.cache_file_id) + '/export/{}.{}'.format(self.cache_file_id, self.exporter_name) ) self.source_file_path = await self.local_cache_provider.validate_path( '/export/{}'.format(self.source_file_id) diff --git a/mfr/server/handlers/render.py b/mfr/server/handlers/render.py index caf82770f..abdef44eb 100644 --- a/mfr/server/handlers/render.py +++ b/mfr/server/handlers/render.py @@ -24,9 +24,11 @@ async def prepare(self): await super().prepare() + self.renderer_name = utils.get_renderer_name(self.metadata.ext) + self.cache_file_id = self.metadata.unique_key self.cache_file_path = await self.cache_provider.validate_path( - '/render/{}'.format(self.cache_file_id) + '/render/{}.{}'.format(self.cache_file_id, self.renderer_name) ) self.source_file_path = await self.local_cache_provider.validate_path( '/render/{}'.format(self.source_file_id) diff --git a/tests/core/test_utils.py b/tests/core/test_utils.py new file mode 100644 index 000000000..957de273f --- /dev/null +++ b/tests/core/test_utils.py @@ -0,0 +1,39 @@ +import pytest +import pkg_resources + +from mfr.core import utils as mfr_utils + + +class TestGetRendererName: + + def test_get_renderer_name_explicit_assertions(self): + assert mfr_utils.get_renderer_name('.jpg') == 'ImageRenderer' + assert mfr_utils.get_renderer_name('.txt') == 'CodePygmentsRenderer' + assert mfr_utils.get_renderer_name('.xlsx') == 'TabularRenderer' + assert mfr_utils.get_renderer_name('.odt') == 'UnoconvRenderer' + assert mfr_utils.get_renderer_name('.pdf') == 'PdfRenderer' + + def test_get_renderer_name(self): + entry_points = pkg_resources.iter_entry_points(group='mfr.renderers') + for ep in entry_points: + expected = ep.attrs[0] + assert mfr_utils.get_renderer_name(ep.name) == expected + + def test_get_renderer_name_no_entry_point(self): + assert mfr_utils.get_renderer_name('jpg') == '' + +class TestGetExporterName: + + def test_get_exporter_name_explicit_assertions(self): + assert mfr_utils.get_exporter_name('.jpg') == 'ImageExporter' + assert mfr_utils.get_exporter_name('.odt') == 'UnoconvExporter' + + def test_get_exporter_name(self): + entry_points = pkg_resources.iter_entry_points(group='mfr.exporters') + for ep in entry_points: + expected = ep.attrs[0] + assert mfr_utils.get_exporter_name(ep.name) == expected + + def test_get_exporter_name_no_entry_point(self): + with pytest.raises(IndexError): + mfr_utils.get_exporter_name('jpg')