diff --git a/django_typomatic/__init__.py b/django_typomatic/__init__.py index 70ddea9..cd6e8ae 100644 --- a/django_typomatic/__init__.py +++ b/django_typomatic/__init__.py @@ -5,6 +5,7 @@ from rest_framework.serializers import BaseSerializer from rest_framework.fields import empty +from django.db.models import OneToOneField from django.db.models.enums import Choices import inspect @@ -208,6 +209,18 @@ def __map_choices_to_enum_keys_by_values(enum_name, choices): return choices_enum +def __is_known_serializer_type(serializer_type, context): + for _context, serializers in __serializers.items(): + if serializer_type in serializers: + imports = __imports.get(context, {}) + type_imports = imports.get(_context, set()) + type_imports.add(serializer_type) + imports[_context] = type_imports + __imports[context] = imports + return True + return False + + def __process_field(field_name, field, context, serializer, trim_serializer_output, camelize, enum_choices, enum_values, enum_keys): ''' @@ -215,9 +228,15 @@ def __process_field(field_name, field, context, serializer, trim_serializer_outp ''' # For PrimaryKeyRelatedField, set field_type to the type of the primary key # on the related model - if isinstance(field, serializers.PrimaryKeyRelatedField): + if isinstance(field, serializers.PrimaryKeyRelatedField) and field.queryset: is_many = False - field_type = type(field.queryset.model._meta.pk) + + target_field = field.queryset.model._meta.pk + while isinstance(target_field, OneToOneField): + # Recurse into the parent model the field is inheriting from + target_field = target_field.model._meta.pk.target_field + + field_type = type(target_field) elif hasattr(field, 'child'): is_many = True field_type = type(field.child) @@ -228,7 +247,7 @@ def __process_field(field_name, field, context, serializer, trim_serializer_outp is_many = False field_type = type(field) - if field_type in __serializers[context]: + if field_type in __serializers[context] or __is_known_serializer_type(field_type, context): ts_type = __get_trimmed_name( field_type.__name__, trim_serializer_output) elif field_type in __field_mappings[context]: @@ -237,6 +256,8 @@ def __process_field(field_name, field, context, serializer, trim_serializer_outp and field_name in __mapping_overrides[context][serializer]: ts_type = __mapping_overrides[context][serializer].get( field_name, 'any') + elif field_type == serializers.PrimaryKeyRelatedField: + ts_type = "number" elif hasattr(field, 'choice_strings_to_values') and enum_choices: ts_type = f"{''.join(x.title() for x in field_name.split('_'))}ChoiceEnum" elif hasattr(field, 'choice_strings_to_values') and enum_choices and enum_values \ @@ -308,15 +329,7 @@ def __get_nested_serializer_field(context, enum_choices, enum_values, enum_keys, if is_external_serializer and return_type not in __serializers.get(context, []): # Import the serializer if it was previously generated - for _context, serializers in __serializers.items(): - if return_type in serializers: - imports = __imports.get(context, {}) - type_imports = imports.get(_context, set()) - type_imports.add(return_type) - imports[_context] = type_imports - __imports[context] = imports - break - else: + if not __is_known_serializer_type(return_type, context): # Include external Interface ts_interface(context=context)(return_type) # For duplicate interface, set not exported @@ -582,13 +595,13 @@ def generate_ts(output_path, context='default', trim_serializer_output=False, ca output_path.parent.mkdir(exist_ok=True, parents=True) with open(output_path, 'w') as output_file: - imports = __generate_imports(context, trim_serializer_output) interfaces = __generate_interfaces(context, trim_serializer_output, camelize, enum_choices, enum_values, enum_keys, annotations) enums = [] if enum_choices or enum_values or enum_keys: enums = __generate_enums(context, enum_choices, enum_values, enum_keys) enums_string = __remove_duplicate_enums(enums) + imports = __generate_imports(context, trim_serializer_output) output_file.write(imports + enums_string + ''.join(interfaces)) diff --git a/django_typomatic/management/commands/generate_ts.py b/django_typomatic/management/commands/generate_ts.py index 159ab49..af781fb 100644 --- a/django_typomatic/management/commands/generate_ts.py +++ b/django_typomatic/management/commands/generate_ts.py @@ -1,4 +1,5 @@ import inspect +from importlib import import_module from pathlib import Path from django.apps import apps @@ -96,11 +97,11 @@ def add_arguments(self, parser): @staticmethod def _get_app_serializers(app_name): serializers = [] - modules = sys.modules.get(app_name, None) + modules = import_module(app_name) possibly_modules = filter(lambda name: not name.startswith('_'), dir(modules)) for module_name in possibly_modules: - module = sys.modules.get(f'{app_name}.{module_name}', None) + module = import_module(f'{app_name}.{module_name}') possibly_serializers = filter(lambda name: not name.startswith('_'), dir(module)) for serializer_class_name in possibly_serializers: @@ -119,7 +120,7 @@ def _get_app_serializers(app_name): return serializers def _generate_ts(self, module_name, serializer_name, output, **options): - module = sys.modules.get(module_name, None) + module = import_module(module_name) if not module: self.stdout.write(f'Module {module_name} not found, skip', self.style.WARNING) diff --git a/django_typomatic/mappings.py b/django_typomatic/mappings.py index ffed54c..127349b 100644 --- a/django_typomatic/mappings.py +++ b/django_typomatic/mappings.py @@ -2,9 +2,9 @@ from rest_framework import serializers mappings = { - fields.AutoField: 'string', - fields.BigAutoField: 'string', - fields.BigIntegerField : 'string', + fields.AutoField: 'number', + fields.BigAutoField: 'number', + fields.BigIntegerField : 'number', fields.BinaryField: 'string', fields.BooleanField: 'boolean', fields.CharField: 'string',