Skip to content

Commit

Permalink
add new approach to resolve output LexRef; add resolve Union types by…
Browse files Browse the repository at this point in the history
… py_type field; migrate BlobRef class to pydantic
  • Loading branch information
MarshalX committed Aug 26, 2023
1 parent b7102a3 commit 5ae9acf
Show file tree
Hide file tree
Showing 61 changed files with 719 additions and 454 deletions.
1 change: 1 addition & 0 deletions atproto/cid/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

# CID = _CID

Check failure on line 5 in atproto/cid/__init__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

atproto/cid/__init__.py:2:1: I001 Import block is un-sorted or un-formatted

Check failure on line 5 in atproto/cid/__init__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ERA001)

atproto/cid/__init__.py:5:1: ERA001 Found commented-out code

Check failure on line 5 in atproto/cid/__init__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

atproto/cid/__init__.py:2:1: I001 Import block is un-sorted or un-formatted

Check failure on line 5 in atproto/cid/__init__.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ERA001)

atproto/cid/__init__.py:5:1: ERA001 Found commented-out code


class CID(BaseModel):
def encode(self, *_, **__):
...
Expand Down
33 changes: 33 additions & 0 deletions atproto/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import typing as t
from pathlib import Path

from atproto.exceptions import InvalidNsidError
from atproto.nsid import NSID

_DISCLAIMER_LINES = [
Expand Down Expand Up @@ -108,3 +109,35 @@ def capitalize_first_symbol(string: str) -> str:
return ''.join(chars)

return string


def get_def_model_name(method_name: str) -> str:
return f'{capitalize_first_symbol(method_name)}'


def get_model_path(nsid: NSID, method_name: str) -> str:
return f'models.{get_import_path(nsid)}.{get_def_model_name(method_name)}'


def _resolve_nsid_ref(nsid: NSID, ref: str, *, local: bool = False) -> t.Tuple[str, str]:
"""Returns the path to the model and model name"""
if '#' in ref:
ref_nsid_str, def_name = ref.split('#', 1)
def_name = get_def_model_name(def_name)

try:
ref_nsid = NSID.from_str(ref_nsid_str)
return get_model_path(ref_nsid, def_name), def_name
except InvalidNsidError:
if local:
return def_name, def_name
return get_model_path(nsid, def_name), def_name
else:
ref_nsid = NSID.from_str(ref)
def_name = get_def_model_name(nsid.name)

if local:
return def_name, def_name

# FIXME(MarshalX): Is it works well? ;d
return get_model_path(ref_nsid, 'Main'), def_name
9 changes: 0 additions & 9 deletions atproto/codegen/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,18 +112,9 @@ def build_record_models() -> BuiltRecordModels:
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_RECORDS)


BuiltRefsModels = t.Dict[NSID, t.Dict[str, t.Union[models.LexXrpcQuery, models.LexXrpcProcedure]]]


def build_refs_models() -> BuiltRefsModels:
_LEX_DEF_TYPES_FOR_REFS = {models.LexDefinitionType.QUERY, models.LexDefinitionType.PROCEDURE}
return _build_nsid_to_defs_map(lexicon_parse_dir(lexicon_dir.get()), _LEX_DEF_TYPES_FOR_REFS)


if __name__ == '__main__':
build_params_models()
build_data_models()
build_response_models()
build_def_models()
build_record_models()
build_refs_models()
102 changes: 22 additions & 80 deletions atproto/codegen/models/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,18 @@
INPUT_MODEL,
OUTPUT_MODEL,
PARAMS_MODEL,
_resolve_nsid_ref,
append_code,
capitalize_first_symbol,
format_code,
gen_description_by_camel_case_name,
get_def_model_name,
get_file_path_parts,
get_import_path,
join_code,
write_code,
)
from atproto.codegen import get_code_intent as _
from atproto.codegen.models import builder
from atproto.exceptions import InvalidNsidError
from atproto.lexicon import models
from atproto.nsid import NSID

Expand All @@ -34,14 +34,6 @@ class ModelType(Enum):
RECORD = 'Record'


def get_def_model_name(method_name: str) -> str:
return f'{capitalize_first_symbol(method_name)}'


def get_model_path(nsid: NSID, method_name: str) -> str:
return f'models.{get_import_path(nsid)}.{get_def_model_name(method_name)}'


def save_code(nsid: NSID, code: str) -> None:
path_to_file = _MODELS_OUTPUT_DIR.joinpath(*get_file_path_parts(nsid))
write_code(_MODELS_OUTPUT_DIR.joinpath(path_to_file), code)
Expand All @@ -58,6 +50,8 @@ def _get_model_imports() -> str:
'import typing as t',
'',
'import typing_extensions as te',
'from pydantic import Field',
'',
'if t.TYPE_CHECKING:',
f'{_(1)}from atproto.xrpc_client import models',
f'{_(1)}from atproto.xrpc_client.models.blob_ref import BlobRef',
Expand Down Expand Up @@ -107,9 +101,10 @@ def _get_model_class_def(name: str, model_type: ModelType) -> str:
}


def _get_optional_typehint(type_hint, *, optional: bool) -> str:
def _get_optional_typehint(type_hint, *, optional: bool, with_value: bool = True) -> str:
value = ' = None' if with_value else ''
if optional:
return f't.Optional[{type_hint}] = None'
return f't.Optional[{type_hint}]{value}'
return type_hint


Expand All @@ -118,30 +113,6 @@ def _get_ref_typehint(nsid: NSID, field_type_def, *, optional: bool) -> str:
return _get_optional_typehint(f"'{model_path}'", optional=optional)


def _resolve_nsid_ref(nsid: NSID, ref: str, *, local: bool = False) -> t.Tuple[str, str]:
"""Returns path to the model and model name"""
if '#' in ref:
ref_nsid_str, def_name = ref.split('#', 1)
def_name = get_def_model_name(def_name)

try:
ref_nsid = NSID.from_str(ref_nsid_str)
return get_model_path(ref_nsid, def_name), def_name
except InvalidNsidError:
if local:
return def_name, def_name
return get_model_path(nsid, def_name), def_name
else:
ref_nsid = NSID.from_str(ref)
def_name = get_def_model_name(nsid.name)

if local:
return def_name, def_name

# FIXME(MarshalX): Is it works well? ;d
return get_model_path(ref_nsid, 'Main'), def_name


def _get_ref_union_typehint(nsid: NSID, field_type_def, *, optional: bool) -> str:
def_names = []
for ref in field_type_def.refs:
Expand All @@ -154,18 +125,21 @@ def _get_ref_union_typehint(nsid: NSID, field_type_def, *, optional: bool) -> st
# ref: https://github.com/bluesky-social/atproto/blob/b01e47b61730d05a780f7a42667b91ccaa192e8e/packages/lex-cli/src/codegen/lex-gen.ts#L325
# grep by "{$type: string; [k: string]: unknown}" string
# TODO(MarshalX): use 'base.UnknownDict' and convert to DotDict
def_names.append('t.Dict[str, t.Any]')
# def_names.append('t.Dict[str, t.Any]') # FIXME(MarshalX): support pydantic

Check failure on line 128 in atproto/codegen/models/generator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ERA001)

atproto/codegen/models/generator.py:128:5: ERA001 Found commented-out code

Check failure on line 128 in atproto/codegen/models/generator.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ERA001)

atproto/codegen/models/generator.py:128:5: ERA001 Found commented-out code

def_names = ', '.join([f"'{name}'" for name in def_names])
return _get_optional_typehint(f't.Union[{def_names}]', optional=optional)
def_field_meta = 'Field(default=None, discriminator="py_type")' if optional else 'Field(discriminator="py_type")'

annotated_union = f'te.Annotated[t.Union[{def_names}], {def_field_meta}]'
return _get_optional_typehint(annotated_union, optional=optional)


def _get_model_field_typehint(nsid: NSID, field_name: str, field_type_def, *, optional: bool) -> str:
field_type = type(field_type_def)

if field_type == models.LexUnknown:
# unknown type is a generic response with records or any not described type in the lexicon. for example didDoc
return _get_optional_typehint("'base.UnknownDict'", optional=optional)
# unknown type is a generic response with records or any not described type in the lexicon. for example, didDoc
return _get_optional_typehint("'unknown_type.UnknownRecordTypePydantic'", optional=optional)

type_hint = _LEXICON_TYPE_TO_PRIMITIVE_TYPEHINT.get(field_type)
if type_hint:
Expand Down Expand Up @@ -258,21 +232,6 @@ def _get_model(nsid: NSID, lex_object: t.Union[models.LexObject, models.LexXrpcP
return join_code(fields)


def _get_model_ref(nsid: NSID, ref: models.LexRef) -> str:
# FIXME(MarshalX): "local=True" Is it works well? ;d
ref_class, _ = _resolve_nsid_ref(nsid, ref.ref, local=True)

# "Ref" suffix required to fix name collisions from different namespaces
lines = [
f'#: {OUTPUT_MODEL} reference to :obj:`{ref_class}` model.',
f'{OUTPUT_MODEL}Ref = "{ref_class}"', # FIXME(MarshalX): pydantic support
'',
'',
]

return join_code(lines)


def _get_model_raw_data(name: str) -> str:
lines = [f'#: {name} raw data type.', f'{name}: te.TypeAlias = bytes\n\n']
return join_code(lines)
Expand Down Expand Up @@ -327,8 +286,7 @@ def _generate_def_model(nsid: NSID, def_name: str, def_model: models.LexObject,
if def_name == 'main':
def_type = str(nsid)

lines.append(f"{_(1)}_type: str = '{def_type}'")

lines.append(f"{_(1)}py_type: te.Literal['{def_type}'] = Field(default='{def_type}', alias='$type')")
lines.append('')

return join_code(lines)
Expand Down Expand Up @@ -431,11 +389,13 @@ def _generate_record_type_database(lex_db: builder.BuiltRecordModels) -> None:
unknown_record_type_hint_lines = [
'import typing as t',
'import typing_extensions as te',
'from pydantic import Field',
'if t.TYPE_CHECKING:',
f'{_(4)}from atproto.xrpc_client import models',
'',
'UnknownRecordType: te.TypeAlias = t.Union[',
]
unknown_record_type_pydantic_lines = ['UnknownRecordTypePydantic = te.Annotated[t.Union[']

for nsid, defs in lex_db.items():
_save_code_import_if_not_exist(nsid)
Expand All @@ -451,34 +411,19 @@ def _generate_record_type_database(lex_db: builder.BuiltRecordModels) -> None:

type_conversion_lines.append(f"'{record_type}': {path_to_class},")
unknown_record_type_hint_lines.append(f"{_(4)}'{path_to_class}',")
unknown_record_type_pydantic_lines.append(f"{_(4)}'{path_to_class}',")

type_conversion_lines.append('}')

unknown_record_type_hint_lines.append(']')
unknown_record_type_pydantic_lines.append('], Field(discriminator="py_type")]')

unknown_record_type_hint_lines.extend(unknown_record_type_pydantic_lines)

write_code(_MODELS_OUTPUT_DIR.joinpath('type_conversion.py'), join_code(type_conversion_lines))
write_code(_MODELS_OUTPUT_DIR.joinpath('unknown_type.py'), join_code(unknown_record_type_hint_lines))


def _generate_ref_models(lex_db: builder.BuiltRefsModels) -> None:
for nsid, defs in lex_db.items():
definition = defs['main']
if (
hasattr(definition, 'input')
and definition.input
and definition.input.schema
and isinstance(definition.input.schema, models.LexRef)
):
save_code_part(nsid, _get_model_ref(nsid, definition.input.schema))

if (
hasattr(definition, 'output')
and definition.output
and definition.output.schema
and isinstance(definition.output.schema, models.LexRef)
):
save_code_part(nsid, _get_model_ref(nsid, definition.output.schema))


def _generate_init_files(root_package_path: Path) -> None:
# One of the ways that I tried. Doesn't work well due to circular imports
for root, dirs, files in os.walk(root_package_path):
Expand Down Expand Up @@ -606,9 +551,6 @@ def generate_models(lexicon_dir: t.Optional[Path] = None, output_dir: t.Optional
_generate_record_models(builder.build_record_models())
_generate_record_type_database(builder.build_record_models())

# refs should be generated at the end!
_generate_ref_models(builder.build_refs_models())

_generate_empty_init_files(_MODELS_OUTPUT_DIR)
_generate_import_aliases(_MODELS_OUTPUT_DIR)

Expand Down
10 changes: 4 additions & 6 deletions atproto/codegen/namespaces/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
INPUT_MODEL,
OUTPUT_MODEL,
PARAMS_MODEL,
_resolve_nsid_ref,
convert_camel_case_to_snake_case,
format_code,
gen_description_by_camel_case_name,
Expand Down Expand Up @@ -248,19 +249,16 @@ def is_optional_arg(lex_obj) -> bool:


def _get_namespace_method_return_type(method_info: MethodInfo) -> t.Tuple[str, bool]:
model_name_suffix = ''
if method_info.definition.output and isinstance(method_info.definition.output.schema, LexRef):
# fix collisions with type aliases
# example of collisions: com.atproto.admin.getRepo, com.atproto.sync.getRepo
# could be solved by separating models into different folders using segments of NSID
model_name_suffix = 'Ref'
ref_class, _ = _resolve_nsid_ref(method_info.nsid, method_info.definition.output.schema.ref)
return ref_class, True

is_model = False
return_type = 'bool' # return success of response
if method_info.definition.output:
# example of methods without response: app.bsky.graph.muteActor, app.bsky.graph.muteActor
is_model = True
return_type = f'models.{get_import_path(method_info.nsid)}.{OUTPUT_MODEL}{model_name_suffix}'
return_type = f'models.{get_import_path(method_info.nsid)}.{OUTPUT_MODEL}'

return return_type, is_model

Expand Down
4 changes: 2 additions & 2 deletions atproto/xrpc_client/client/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ async def _invoke(self, invoke_type: 'InvokeType', **kwargs) -> 'Response':

async def _get_and_set_session(self, login: str, password: str) -> models.ComAtprotoServerCreateSession.Response:
session = await self.com.atproto.server.create_session(
models.ComAtprotoServerCreateSession.Data(login, password)
models.ComAtprotoServerCreateSession.Data(identifier=login, password=password)
)
self._set_session(session)

Expand Down Expand Up @@ -72,7 +72,7 @@ async def login(self, login: str, password: str) -> models.AppBskyActorDefs.Prof
"""

session = await self._get_and_set_session(login, password)
self.me = await self.bsky.actor.get_profile(models.AppBskyActorGetProfile.Params(session.handle))
self.me = await self.bsky.actor.get_profile(models.AppBskyActorGetProfile.Params(actor=session.handle))

return self.me

Expand Down
4 changes: 3 additions & 1 deletion atproto/xrpc_client/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def _invoke(self, invoke_type: 'InvokeType', **kwargs) -> 'Response':
return super()._invoke(invoke_type, **kwargs)

def _get_and_set_session(self, login: str, password: str) -> models.ComAtprotoServerCreateSession.Response:
session = self.com.atproto.server.create_session(models.ComAtprotoServerCreateSession.Data(identifier=login, password=password))
session = self.com.atproto.server.create_session(
models.ComAtprotoServerCreateSession.Data(identifier=login, password=password)
)
self._set_session(session)

return session
Expand Down
Loading

0 comments on commit 5ae9acf

Please sign in to comment.