Skip to content

Commit

Permalink
Feat: import external serializer when they were previously defined
Browse files Browse the repository at this point in the history
  • Loading branch information
biolds committed Feb 5, 2024
1 parent ab3656e commit d2a7ac2
Showing 1 changed file with 34 additions and 6 deletions.
40 changes: 34 additions & 6 deletions django_typomatic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
__field_mappings = dict()
# Custom field_name to TS Type overrides
__mapping_overrides = dict()
# TS Type imports import
__imports = dict()


def ts_field(ts_type: str, context='default'):
Expand Down Expand Up @@ -302,11 +304,20 @@ def __get_nested_serializer_field(context, enum_choices, enum_values, enum_keys,
is_serializer_type = True

if is_external_serializer and return_type not in __serializers.get(context, []):
# TODO import external interface, not duplicate
# Include external Interface
ts_interface(context=context)(return_type)
# For duplicate interface, set not exported
setattr(return_type, '__exported__', False)
# Import the serializer if it was previously generated
for _context, serializers in __serializers.items():
if return_type in serializers:
imports = __imports.get(context, {})
type_imports = imports.get(_context, set())
type_imports.add(return_type)
imports[_context] = type_imports
__imports[context] = imports
break
else:
# Include external Interface
ts_interface(context=context)(return_type)
# For duplicate interface, set not exported
setattr(return_type, '__exported__', False)

if is_serializer_type:
ts_type = __get_trimmed_name(return_type.__name__, trim_serializer_output)
Expand Down Expand Up @@ -423,6 +434,22 @@ def __get_ts_interface(serializer, context, trim_serializer_output, camelize, en
return f'{"export " if exported else ""}interface {name} {{\n{collapsed_fields}\n}}\n\n'


def __generate_imports(context, trim_serializer_output):
imports_str = ''
if context in __imports:
for package, serializers in __imports[context].items():
names = []
for serializer in serializers:
name = __get_trimmed_name(
serializer.__name__, trim_serializer_output)
names.append(name)

imports_str += "import type { %s } from '../%s';\n" % (
', '.join(names), package)
imports_str += '\n'
return imports_str


def __generate_interfaces(context, trim_serializer_output, camelize, enum_choices, enum_values,
enum_keys, annotations):
if context not in __serializers:
Expand Down Expand Up @@ -552,13 +579,14 @@ def generate_ts(output_path, context='default', trim_serializer_output=False, ca
output_path.parent.mkdir(exist_ok=True, parents=True)

with open(output_path, 'w') as output_file:
imports = __generate_imports(context, trim_serializer_output)
interfaces = __generate_interfaces(context, trim_serializer_output, camelize, enum_choices,
enum_values, enum_keys, annotations)
enums = []
if enum_choices or enum_values or enum_keys:
enums = __generate_enums(context, enum_choices, enum_values, enum_keys)
enums_string = __remove_duplicate_enums(enums)
output_file.write(enums_string + ''.join(interfaces))
output_file.write(imports + enums_string + ''.join(interfaces))


def get_ts(context='default', trim_serializer_output=False, camelize=False, enum_choices=False,
Expand Down

0 comments on commit d2a7ac2

Please sign in to comment.