From 64bbc08334a35f37ce95d57e4906b682a9fd706c Mon Sep 17 00:00:00 2001 From: BlueGlassBlock Date: Wed, 23 Mar 2022 23:04:54 +0800 Subject: [PATCH] :art: *mostly* type safe See #55 Co-authored-by: Elaina --- poetry.lock | 144 +++++++----- pyproject.toml | 8 + src/graia/ariadne/__init__.py | 8 +- src/graia/ariadne/adapter/__init__.py | 12 +- src/graia/ariadne/app.py | 218 +++++++++--------- src/graia/ariadne/config.py | 32 +++ src/graia/ariadne/connection.py | 130 +++++++++++ src/graia/ariadne/console/__init__.py | 13 +- src/graia/ariadne/context.py | 22 +- src/graia/ariadne/event/mirai.py | 43 ++-- src/graia/ariadne/io/__init__.py | 51 ++++ src/graia/ariadne/message/chain.py | 170 +++++--------- .../ariadne/message/commander/__init__.py | 22 +- src/graia/ariadne/message/element.py | 80 +++---- src/graia/ariadne/message/formatter.py | 10 +- src/graia/ariadne/message/parser/base.py | 31 +-- src/graia/ariadne/message/parser/twilight.py | 71 +++--- src/graia/ariadne/message/parser/util.py | 9 +- src/graia/ariadne/model.py | 63 ++--- src/graia/ariadne/service.py | 164 +++++++++++++ src/graia/ariadne/typing.py | 35 ++- src/graia/ariadne/util/__init__.py | 71 ++++-- src/graia/ariadne/util/async_exec.py | 14 +- src/graia/ariadne/util/cooldown.py | 29 +-- src/graia/ariadne/util/send.py | 30 ++- 25 files changed, 968 insertions(+), 512 deletions(-) create mode 100644 src/graia/ariadne/config.py create mode 100644 src/graia/ariadne/connection.py create mode 100644 src/graia/ariadne/io/__init__.py create mode 100644 src/graia/ariadne/service.py diff --git a/poetry.lock b/poetry.lock index 588d5b16..b7ad374e 100644 --- a/poetry.lock +++ b/poetry.lock @@ -63,7 +63,7 @@ reference = "tuna-tsinghua" [[package]] name = "arclet-alconna" -version = "0.7.4.3" +version = "0.7.5" description = "A Fast Command Analyser based on Dict" category = "main" optional = true @@ -597,7 +597,7 @@ reference = "tuna-tsinghua" [[package]] name = "importlib-metadata" -version = "4.11.2" +version = "4.11.3" description = "Read metadata from Python packages" category = "dev" optional = false @@ -708,7 +708,7 @@ reference = "tuna-tsinghua" [[package]] name = "markupsafe" -version = "2.1.0" +version = "2.1.1" description = "Safely add untrusted strings to HTML/XML markup." category = "dev" optional = false @@ -1073,6 +1073,22 @@ type = "legacy" url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" reference = "tuna-tsinghua" +[[package]] +name = "prompt-toolkit" +version = "3.0.28" +description = "Library for building powerful interactive command lines in Python" +category = "main" +optional = true +python-versions = ">=3.6.2" + +[package.dependencies] +wcwidth = "*" + +[package.source] +type = "legacy" +url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" +reference = "tuna-tsinghua" + [[package]] name = "py" version = "1.11.0" @@ -1525,11 +1541,27 @@ reference = "tuna-tsinghua" [[package]] name = "watchgod" -version = "0.7" +version = "0.8" description = "Simple, modern file watching and code reload in python." category = "main" optional = true -python-versions = ">=3.5" +python-versions = ">=3.7" + +[package.dependencies] +anyio = ">=3.0.0,<4" + +[package.source] +type = "legacy" +url = "https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple" +reference = "tuna-tsinghua" + +[[package]] +name = "wcwidth" +version = "0.2.5" +description = "Measures the displayed width of unicode strings in a terminal" +category = "main" +optional = true +python-versions = "*" [package.source] type = "legacy" @@ -1617,7 +1649,7 @@ reference = "tuna-tsinghua" [extras] alconna = ["arclet-alconna"] -full = ["graia-saya", "graia-scheduler", "arclet-alconna", "uvicorn", "fastapi"] +full = ["graia-saya", "graia-scheduler", "arclet-alconna", "uvicorn", "fastapi", "prompt-toolkit"] graia = ["graia-saya", "graia-scheduler"] server = ["uvicorn", "fastapi"] standard = ["graia-saya", "graia-scheduler", "arclet-alconna"] @@ -1625,7 +1657,7 @@ standard = ["graia-saya", "graia-scheduler", "arclet-alconna"] [metadata] lock-version = "1.1" python-versions = "^3.8" -content-hash = "c491f37604389f0b2322d1807dfa05377ab17b4225342a1021534281d8df5447" +content-hash = "08c26d433213aeb511f77d5ec901677aa27374b65817443bb2e1346f0b456c1d" [metadata.files] aiohttp = [ @@ -1711,7 +1743,7 @@ anyio = [ {file = "anyio-3.5.0.tar.gz", hash = "sha256:a0aeffe2fb1fdf374a8e4b471444f0f3ac4fb9f5a5b542b48824475e0042a5a6"}, ] arclet-alconna = [ - {file = "arclet-alconna-0.7.4.3.tar.gz", hash = "sha256:861876a4547a9a5731e57b756a1f9109b74f4fac8f50d93d98159096d743cfef"}, + {file = "arclet-alconna-0.7.5.tar.gz", hash = "sha256:921c3b4c5c193ad14407c5b182a69178d322c70874756f32b2d9378a7cc4a6e7"}, ] asgiref = [ {file = "asgiref-3.5.0-py3-none-any.whl", hash = "sha256:88d59c13d634dcffe0510be048210188edd79aeccb6a6c9028cdad6f31d730a9"}, @@ -1991,8 +2023,8 @@ idna = [ {file = "idna-3.3.tar.gz", hash = "sha256:9d643ff0a55b762d5cdb124b8eaa99c66322e2157b69160bc32796e824360e6d"}, ] importlib-metadata = [ - {file = "importlib_metadata-4.11.2-py3-none-any.whl", hash = "sha256:d16e8c1deb60de41b8e8ed21c1a7b947b0bc62fab7e1d470bcdf331cea2e6735"}, - {file = "importlib_metadata-4.11.2.tar.gz", hash = "sha256:b36ffa925fe3139b2f6ff11d6925ffd4fa7bc47870165e3ac260ac7b4f91e6ac"}, + {file = "importlib_metadata-4.11.3-py3-none-any.whl", hash = "sha256:1208431ca90a8cca1a6b8af391bb53c1a2db74e5d1cef6ddced95d4b2062edc6"}, + {file = "importlib_metadata-4.11.3.tar.gz", hash = "sha256:ea4c597ebf37142f827b8f39299579e31685c31d3a438b59f469406afd0f2539"}, ] iniconfig = [ {file = "iniconfig-1.1.1-py2.py3-none-any.whl", hash = "sha256:011e24c64b7f47f6ebd835bb12a743f2fbe9a26d4cecaa7f53bc4f35ee9da8b3"}, @@ -2015,46 +2047,46 @@ markdown = [ {file = "Markdown-3.3.6.tar.gz", hash = "sha256:76df8ae32294ec39dcf89340382882dfa12975f87f45c3ed1ecdb1e8cefc7006"}, ] markupsafe = [ - {file = "MarkupSafe-2.1.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3028252424c72b2602a323f70fbf50aa80a5d3aa616ea6add4ba21ae9cc9da4c"}, - {file = "MarkupSafe-2.1.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:290b02bab3c9e216da57c1d11d2ba73a9f73a614bbdcc027d299a60cdfabb11a"}, - {file = "MarkupSafe-2.1.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6e104c0c2b4cd765b4e83909cde7ec61a1e313f8a75775897db321450e928cce"}, - {file = "MarkupSafe-2.1.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:24c3be29abb6b34052fd26fc7a8e0a49b1ee9d282e3665e8ad09a0a68faee5b3"}, - {file = "MarkupSafe-2.1.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:204730fd5fe2fe3b1e9ccadb2bd18ba8712b111dcabce185af0b3b5285a7c989"}, - {file = "MarkupSafe-2.1.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d3b64c65328cb4cd252c94f83e66e3d7acf8891e60ebf588d7b493a55a1dbf26"}, - {file = "MarkupSafe-2.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:96de1932237abe0a13ba68b63e94113678c379dca45afa040a17b6e1ad7ed076"}, - {file = "MarkupSafe-2.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:75bb36f134883fdbe13d8e63b8675f5f12b80bb6627f7714c7d6c5becf22719f"}, - {file = "MarkupSafe-2.1.0-cp310-cp310-win32.whl", hash = "sha256:4056f752015dfa9828dce3140dbadd543b555afb3252507348c493def166d454"}, - {file = "MarkupSafe-2.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:d4e702eea4a2903441f2735799d217f4ac1b55f7d8ad96ab7d4e25417cb0827c"}, - {file = "MarkupSafe-2.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:f0eddfcabd6936558ec020130f932d479930581171368fd728efcfb6ef0dd357"}, - {file = "MarkupSafe-2.1.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5ddea4c352a488b5e1069069f2f501006b1a4362cb906bee9a193ef1245a7a61"}, - {file = "MarkupSafe-2.1.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:09c86c9643cceb1d87ca08cdc30160d1b7ab49a8a21564868921959bd16441b8"}, - {file = "MarkupSafe-2.1.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a0a0abef2ca47b33fb615b491ce31b055ef2430de52c5b3fb19a4042dbc5cadb"}, - {file = "MarkupSafe-2.1.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:736895a020e31b428b3382a7887bfea96102c529530299f426bf2e636aacec9e"}, - {file = "MarkupSafe-2.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:679cbb78914ab212c49c67ba2c7396dc599a8479de51b9a87b174700abd9ea49"}, - {file = "MarkupSafe-2.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:84ad5e29bf8bab3ad70fd707d3c05524862bddc54dc040982b0dbcff36481de7"}, - {file = "MarkupSafe-2.1.0-cp37-cp37m-win32.whl", hash = "sha256:8da5924cb1f9064589767b0f3fc39d03e3d0fb5aa29e0cb21d43106519bd624a"}, - {file = "MarkupSafe-2.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:454ffc1cbb75227d15667c09f164a0099159da0c1f3d2636aa648f12675491ad"}, - {file = "MarkupSafe-2.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:142119fb14a1ef6d758912b25c4e803c3ff66920635c44078666fe7cc3f8f759"}, - {file = "MarkupSafe-2.1.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b2a5a856019d2833c56a3dcac1b80fe795c95f401818ea963594b345929dffa7"}, - {file = "MarkupSafe-2.1.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d1fb9b2eec3c9714dd936860850300b51dbaa37404209c8d4cb66547884b7ed"}, - {file = "MarkupSafe-2.1.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:62c0285e91414f5c8f621a17b69fc0088394ccdaa961ef469e833dbff64bd5ea"}, - {file = "MarkupSafe-2.1.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fc3150f85e2dbcf99e65238c842d1cfe69d3e7649b19864c1cc043213d9cd730"}, - {file = "MarkupSafe-2.1.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:f02cf7221d5cd915d7fa58ab64f7ee6dd0f6cddbb48683debf5d04ae9b1c2cc1"}, - {file = "MarkupSafe-2.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:d5653619b3eb5cbd35bfba3c12d575db2a74d15e0e1c08bf1db788069d410ce8"}, - {file = "MarkupSafe-2.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:7d2f5d97fcbd004c03df8d8fe2b973fe2b14e7bfeb2cfa012eaa8759ce9a762f"}, - {file = "MarkupSafe-2.1.0-cp38-cp38-win32.whl", hash = "sha256:3cace1837bc84e63b3fd2dfce37f08f8c18aeb81ef5cf6bb9b51f625cb4e6cd8"}, - {file = "MarkupSafe-2.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:fabbe18087c3d33c5824cb145ffca52eccd053061df1d79d4b66dafa5ad2a5ea"}, - {file = "MarkupSafe-2.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:023af8c54fe63530545f70dd2a2a7eed18d07a9a77b94e8bf1e2ff7f252db9a3"}, - {file = "MarkupSafe-2.1.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:d66624f04de4af8bbf1c7f21cc06649c1c69a7f84109179add573ce35e46d448"}, - {file = "MarkupSafe-2.1.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c532d5ab79be0199fa2658e24a02fce8542df196e60665dd322409a03db6a52c"}, - {file = "MarkupSafe-2.1.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e67ec74fada3841b8c5f4c4f197bea916025cb9aa3fe5abf7d52b655d042f956"}, - {file = "MarkupSafe-2.1.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:30c653fde75a6e5eb814d2a0a89378f83d1d3f502ab710904ee585c38888816c"}, - {file = "MarkupSafe-2.1.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:961eb86e5be7d0973789f30ebcf6caab60b844203f4396ece27310295a6082c7"}, - {file = "MarkupSafe-2.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:598b65d74615c021423bd45c2bc5e9b59539c875a9bdb7e5f2a6b92dfcfc268d"}, - {file = "MarkupSafe-2.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:599941da468f2cf22bf90a84f6e2a65524e87be2fce844f96f2dd9a6c9d1e635"}, - {file = "MarkupSafe-2.1.0-cp39-cp39-win32.whl", hash = "sha256:e6f7f3f41faffaea6596da86ecc2389672fa949bd035251eab26dc6697451d05"}, - {file = "MarkupSafe-2.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:b8811d48078d1cf2a6863dafb896e68406c5f513048451cd2ded0473133473c7"}, - {file = "MarkupSafe-2.1.0.tar.gz", hash = "sha256:80beaf63ddfbc64a0452b841d8036ca0611e049650e20afcb882f5d3c266d65f"}, + {file = "MarkupSafe-2.1.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:86b1f75c4e7c2ac2ccdaec2b9022845dbb81880ca318bb7a0a01fbf7813e3812"}, + {file = "MarkupSafe-2.1.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f121a1420d4e173a5d96e47e9a0c0dcff965afdf1626d28de1460815f7c4ee7a"}, + {file = "MarkupSafe-2.1.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a49907dd8420c5685cfa064a1335b6754b74541bbb3706c259c02ed65b644b3e"}, + {file = "MarkupSafe-2.1.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10c1bfff05d95783da83491be968e8fe789263689c02724e0c691933c52994f5"}, + {file = "MarkupSafe-2.1.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b7bd98b796e2b6553da7225aeb61f447f80a1ca64f41d83612e6139ca5213aa4"}, + {file = "MarkupSafe-2.1.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:b09bf97215625a311f669476f44b8b318b075847b49316d3e28c08e41a7a573f"}, + {file = "MarkupSafe-2.1.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:694deca8d702d5db21ec83983ce0bb4b26a578e71fbdbd4fdcd387daa90e4d5e"}, + {file = "MarkupSafe-2.1.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:efc1913fd2ca4f334418481c7e595c00aad186563bbc1ec76067848c7ca0a933"}, + {file = "MarkupSafe-2.1.1-cp310-cp310-win32.whl", hash = "sha256:4a33dea2b688b3190ee12bd7cfa29d39c9ed176bda40bfa11099a3ce5d3a7ac6"}, + {file = "MarkupSafe-2.1.1-cp310-cp310-win_amd64.whl", hash = "sha256:dda30ba7e87fbbb7eab1ec9f58678558fd9a6b8b853530e176eabd064da81417"}, + {file = "MarkupSafe-2.1.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:671cd1187ed5e62818414afe79ed29da836dde67166a9fac6d435873c44fdd02"}, + {file = "MarkupSafe-2.1.1-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3799351e2336dc91ea70b034983ee71cf2f9533cdff7c14c90ea126bfd95d65a"}, + {file = "MarkupSafe-2.1.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e72591e9ecd94d7feb70c1cbd7be7b3ebea3f548870aa91e2732960fa4d57a37"}, + {file = "MarkupSafe-2.1.1-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6fbf47b5d3728c6aea2abb0589b5d30459e369baa772e0f37a0320185e87c980"}, + {file = "MarkupSafe-2.1.1-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d5ee4f386140395a2c818d149221149c54849dfcfcb9f1debfe07a8b8bd63f9a"}, + {file = "MarkupSafe-2.1.1-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:bcb3ed405ed3222f9904899563d6fc492ff75cce56cba05e32eff40e6acbeaa3"}, + {file = "MarkupSafe-2.1.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:e1c0b87e09fa55a220f058d1d49d3fb8df88fbfab58558f1198e08c1e1de842a"}, + {file = "MarkupSafe-2.1.1-cp37-cp37m-win32.whl", hash = "sha256:8dc1c72a69aa7e082593c4a203dcf94ddb74bb5c8a731e4e1eb68d031e8498ff"}, + {file = "MarkupSafe-2.1.1-cp37-cp37m-win_amd64.whl", hash = "sha256:97a68e6ada378df82bc9f16b800ab77cbf4b2fada0081794318520138c088e4a"}, + {file = "MarkupSafe-2.1.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:e8c843bbcda3a2f1e3c2ab25913c80a3c5376cd00c6e8c4a86a89a28c8dc5452"}, + {file = "MarkupSafe-2.1.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0212a68688482dc52b2d45013df70d169f542b7394fc744c02a57374a4207003"}, + {file = "MarkupSafe-2.1.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8e576a51ad59e4bfaac456023a78f6b5e6e7651dcd383bcc3e18d06f9b55d6d1"}, + {file = "MarkupSafe-2.1.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4b9fe39a2ccc108a4accc2676e77da025ce383c108593d65cc909add5c3bd601"}, + {file = "MarkupSafe-2.1.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:96e37a3dc86e80bf81758c152fe66dbf60ed5eca3d26305edf01892257049925"}, + {file = "MarkupSafe-2.1.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6d0072fea50feec76a4c418096652f2c3238eaa014b2f94aeb1d56a66b41403f"}, + {file = "MarkupSafe-2.1.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:089cf3dbf0cd6c100f02945abeb18484bd1ee57a079aefd52cffd17fba910b88"}, + {file = "MarkupSafe-2.1.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:6a074d34ee7a5ce3effbc526b7083ec9731bb3cbf921bbe1d3005d4d2bdb3a63"}, + {file = "MarkupSafe-2.1.1-cp38-cp38-win32.whl", hash = "sha256:421be9fbf0ffe9ffd7a378aafebbf6f4602d564d34be190fc19a193232fd12b1"}, + {file = "MarkupSafe-2.1.1-cp38-cp38-win_amd64.whl", hash = "sha256:fc7b548b17d238737688817ab67deebb30e8073c95749d55538ed473130ec0c7"}, + {file = "MarkupSafe-2.1.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e04e26803c9c3851c931eac40c695602c6295b8d432cbe78609649ad9bd2da8a"}, + {file = "MarkupSafe-2.1.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b87db4360013327109564f0e591bd2a3b318547bcef31b468a92ee504d07ae4f"}, + {file = "MarkupSafe-2.1.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:99a2a507ed3ac881b975a2976d59f38c19386d128e7a9a18b7df6fff1fd4c1d6"}, + {file = "MarkupSafe-2.1.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:56442863ed2b06d19c37f94d999035e15ee982988920e12a5b4ba29b62ad1f77"}, + {file = "MarkupSafe-2.1.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3ce11ee3f23f79dbd06fb3d63e2f6af7b12db1d46932fe7bd8afa259a5996603"}, + {file = "MarkupSafe-2.1.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:33b74d289bd2f5e527beadcaa3f401e0df0a89927c1559c8566c066fa4248ab7"}, + {file = "MarkupSafe-2.1.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:43093fb83d8343aac0b1baa75516da6092f58f41200907ef92448ecab8825135"}, + {file = "MarkupSafe-2.1.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8e3dcf21f367459434c18e71b2a9532d96547aef8a871872a5bd69a715c15f96"}, + {file = "MarkupSafe-2.1.1-cp39-cp39-win32.whl", hash = "sha256:d4306c36ca495956b6d568d276ac11fdd9c30a36f1b6eb928070dc5360b22e1c"}, + {file = "MarkupSafe-2.1.1-cp39-cp39-win_amd64.whl", hash = "sha256:46d00d6cfecdde84d40e572d63735ef81423ad31184100411e6e3388d405e247"}, + {file = "MarkupSafe-2.1.1.tar.gz", hash = "sha256:7f91197cc9e48f989d12e4e6fbc46495c446636dfc81b9ccf50bb0ec74b91d4b"}, ] mccabe = [ {file = "mccabe-0.6.1-py2.py3-none-any.whl", hash = "sha256:ab8a6258860da4b6677da4bd2fe5dc2c659cff31b3ee4f7f5d64e79735b80d42"}, @@ -2197,6 +2229,10 @@ pre-commit = [ {file = "pre_commit-2.17.0-py2.py3-none-any.whl", hash = "sha256:725fa7459782d7bec5ead072810e47351de01709be838c2ce1726b9591dad616"}, {file = "pre_commit-2.17.0.tar.gz", hash = "sha256:c1a8040ff15ad3d648c70cc3e55b93e4d2d5b687320955505587fd79bbaed06a"}, ] +prompt-toolkit = [ + {file = "prompt_toolkit-3.0.28-py3-none-any.whl", hash = "sha256:30129d870dcb0b3b6a53efdc9d0a83ea96162ffd28ffe077e94215b233dc670c"}, + {file = "prompt_toolkit-3.0.28.tar.gz", hash = "sha256:9f1cd16b1e86c2968f2519d7fb31dd9d669916f515612c269d14e9ed52b51650"}, +] py = [ {file = "py-1.11.0-py2.py3-none-any.whl", hash = "sha256:607c53218732647dff4acdfcd50cb62615cedf612e72d1724fb1a0cc6405b378"}, {file = "py-1.11.0.tar.gz", hash = "sha256:51c75c4126074b472f746a24399ad32f6053d1b34b68d2fa41e558e6f4a98719"}, @@ -2405,8 +2441,12 @@ watchdog = [ {file = "watchdog-2.1.6.tar.gz", hash = "sha256:a36e75df6c767cbf46f61a91c70b3ba71811dfa0aca4a324d9407a06a8b7a2e7"}, ] watchgod = [ - {file = "watchgod-0.7-py3-none-any.whl", hash = "sha256:d6c1ea21df37847ac0537ca0d6c2f4cdf513562e95f77bb93abbcf05573407b7"}, - {file = "watchgod-0.7.tar.gz", hash = "sha256:48140d62b0ebe9dd9cf8381337f06351e1f2e70b2203fa9c6eff4e572ca84f29"}, + {file = "watchgod-0.8-py3-none-any.whl", hash = "sha256:339c2cfede1ccc1e277bbf5e82e42886f3c80801b01f45ab10d9461c4118b5eb"}, + {file = "watchgod-0.8.tar.gz", hash = "sha256:29a1d8f25e1721ddb73981652ca318c47387ffb12ec4171ddd7b9d01540033b1"}, +] +wcwidth = [ + {file = "wcwidth-0.2.5-py2.py3-none-any.whl", hash = "sha256:beb4802a9cebb9144e99086eff703a642a13d6a0052920003a230f3294bbe784"}, + {file = "wcwidth-0.2.5.tar.gz", hash = "sha256:c4d647b99872929fdb7bdcaa4fbe7f01413ed3d98077df798530e5b04f116c83"}, ] websockets = [ {file = "websockets-10.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:d5396710f86a306cf52f87fd8ea594a0e894ba0cc5a36059eaca3a477dc332aa"}, diff --git a/pyproject.toml b/pyproject.toml index 8f9cb817..e5bf5ca3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,3 +115,11 @@ exclude_lines = [ # Don't complain overload method / functions "@(typing\\.)?overload" ] + +[tool.pyright] +ignore = [ + "docs/**", + "**/site-packages/**/*.py", + "**/test*/**/*.py", + "**/adapter/**" +] diff --git a/src/graia/ariadne/__init__.py b/src/graia/ariadne/__init__.py index f7d70176..2d431fbe 100644 --- a/src/graia/ariadne/__init__.py +++ b/src/graia/ariadne/__init__.py @@ -29,8 +29,10 @@ def get_running(type: Type[T] = Ariadne) -> T: from .context import context_map if type in {Adapter, Ariadne, Broadcast, AbstractEventLoop}: - if val := context_map.get(type.__name__).get(None): - return val + if ctx := context_map.get(type.__name__): + if val := ctx.get(None): + return val for ariadne_inst in Ariadne.running: if type in ariadne_inst.info: - return ariadne_inst.info[type] + return ariadne_inst.info[type] # type: ignore + raise ValueError(f"{type.__name__} is not running") diff --git a/src/graia/ariadne/adapter/__init__.py b/src/graia/ariadne/adapter/__init__.py index ca01b920..cdb11463 100644 --- a/src/graia/ariadne/adapter/__init__.py +++ b/src/graia/ariadne/adapter/__init__.py @@ -1,6 +1,7 @@ """Ariadne 的适配器""" import abc import asyncio +import contextlib from asyncio import Queue, Task from typing import Any, Dict, FrozenSet, Optional, Union @@ -42,7 +43,7 @@ async def call_api( action: str, method: CallMethod, data: Optional[Union[Dict[str, Any], str, FormData]] = None, - ) -> Union[dict, list]: + ) -> Dict[Any, Any]: """调用 API Args: @@ -51,7 +52,7 @@ async def call_api( data (Optional[Union[Dict[str, Any], str, FormData]], optional): 调用数据. Defaults to None. Returns: - Union[dict, list]: API 返回的数据, 为 json 兼容格式 + dict: API 返回的数据, 为 json 兼容格式 """ ... @@ -86,8 +87,7 @@ def build_event(self, data: dict) -> MiraiEvent: logger.error("An event is not recognized! Please report with your log to help us diagnose.") raise ValueError(f"Unable to find event: {event_type}", data) data = {k: v for k, v in data.items() if k != "type"} - event = event_class.parse_obj(data) - return event + return event_class.parse_obj(data) @abc.abstractmethod async def fetch_cycle(self): @@ -118,7 +118,7 @@ def build_event(self, data: dict) -> MiraiEvent: return event -try: +with contextlib.suppress(ImportError): from .reverse import ( # noqa: F401, E402 ComposeReverseWebsocketAdapter as ComposeReverseWebsocketAdapter, ) @@ -129,5 +129,3 @@ def build_event(self, data: dict) -> MiraiEvent: from .reverse import ( # noqa: F401, E402 ReverseWebsocketAdapter as ReverseWebsocketAdapter, ) -except ImportError: - pass diff --git a/src/graia/ariadne/app.py b/src/graia/ariadne/app.py index 46df0b70..74b611da 100644 --- a/src/graia/ariadne/app.py +++ b/src/graia/ariadne/app.py @@ -1,7 +1,9 @@ """Ariadne 实例 """ + import asyncio import base64 +import contextlib import importlib.metadata import inspect import io @@ -14,10 +16,10 @@ from typing import ( TYPE_CHECKING, Any, + AsyncGenerator, ClassVar, Coroutine, Dict, - Generator, Iterable, List, Literal, @@ -61,7 +63,7 @@ Stranger, UploadMethod, ) -from .typing import SendMessageAction, SendMessageDict, SendMessageException +from .typing import SendMessageActionProtocol, SendMessageDict, SendMessageException from .util import ( app_ctx_manager, await_predicate, @@ -73,7 +75,7 @@ if TYPE_CHECKING: from .message.element import Image, Voice - from .typing import R, T + from .typing import T ARIADNE_ASCII_LOGO = "\n".join( ( @@ -105,7 +107,7 @@ def session_key(self) -> Optional[str]: class MessageMixin(AriadneMixin): """用于发送, 撤回, 获取消息的 Mixin 类.""" - default_send_action: SendMessageAction + default_send_action: SendMessageActionProtocol @app_ctx_manager async def getMessageFromId(self, messageId: int) -> MessageEvent: @@ -122,7 +124,7 @@ async def getMessageFromId(self, messageId: int) -> MessageEvent: CallMethod.GET, {"sessionKey": self.session_key, "id": messageId}, ) - return self.adapter.build_event(result) + return cast(MessageEvent, self.adapter.build_event(result)) @app_ctx_manager async def sendFriendMessage( @@ -159,7 +161,7 @@ async def sendFriendMessage( ) event: ActiveFriendMessage = ActiveFriendMessage( messageChain=MessageChain([Source(id=result["messageId"], time=datetime.now())]) + message, - subject=(await RelationshipMixin.getFriend(self, int(target))), + subject=(await RelationshipMixin.getFriend(self, int(target))), # type: ignore ) with enter_context(self, event): self.broadcast.postEvent(event) @@ -203,7 +205,7 @@ async def sendGroupMessage( ) event: ActiveGroupMessage = ActiveGroupMessage( messageChain=MessageChain([Source(id=result["messageId"], time=datetime.now())]) + message, - subject=(await RelationshipMixin.getGroup(self, int(target))), + subject=(await RelationshipMixin.getGroup(self, int(target))), # type: ignore ) with enter_context(self, event): self.broadcast.postEvent(event) @@ -250,7 +252,7 @@ async def sendTempMessage( ) event: ActiveTempMessage = ActiveTempMessage( messageChain=MessageChain([Source(id=result["messageId"], time=datetime.now())]) + message, - subject=(await RelationshipMixin.getMember(self, int(group), int(target))), + subject=(await RelationshipMixin.getMember(self, int(group), int(target))), # type: ignore ) with enter_context(self, event): self.broadcast.postEvent(event) @@ -263,8 +265,8 @@ async def sendMessage( message: MessageChain, *, quote: Union[bool, int, Source, MessageChain] = False, - action: SendMessageAction["T", "R"] = ..., - ) -> Union["T", "R"]: + action: SendMessageActionProtocol["T"] = ..., + ) -> "T": """ 依据传入的 `target` 自动发送消息. 请注意发送给群成员时会自动作为临时消息发送. @@ -282,7 +284,7 @@ async def sendMessage( Union[T, R]: 默认实现为 BotMessage """ action = action if action is not ... else self.default_send_action - data = {"message": message} + data: Dict[Any, Any] = {"message": message} # quote if isinstance(quote, bool) and quote and isinstance(target, MessageEvent): data["quote"] = target.messageChain.getFirst(Source) @@ -299,7 +301,7 @@ async def sendMessage( data["target"] = target send_data: SendMessageDict = SendMessageDict(**data) # send message - data = await action.param(send_data) + data = await action.param(send_data) # type: ignore try: if isinstance(data["target"], Friend): @@ -311,7 +313,7 @@ async def sendMessage( else: raise NotImplementedError(f"Unable to send message with {target} as target.") except Exception as e: - e.send_data = send_data + e.send_data = send_data # type: ignore return await action.exception(cast(SendMessageException, e)) else: return await action.result(val) @@ -815,10 +817,13 @@ async def modifyMemberInfo( Returns: None: 没有返回. """ - if not group and not isinstance(member, Member): - raise TypeError("you should give a Member instance if you cannot give a Group instance to me.") - if isinstance(member, Member) and not group: - group: Group = member.group + if group is None: + if isinstance(member, Member): + group = member.group + else: + raise TypeError( + "you should give a Member instance if you cannot give a Group instance to me." + ) await self.adapter.call_api( "memberInfo", CallMethod.RESTPOST, @@ -852,11 +857,13 @@ async def modifyMemberAdmin( Returns: None: 没有返回. """ - if not group and not isinstance(member, Member): - raise TypeError("you should give a Member instance if you cannot give a Group instance to me.") - if isinstance(member, Member) and not group: - group: Group = member.group - + if group is None: + if isinstance(member, Member): + group = member.group + else: + raise TypeError( + "you should give a Member instance if you cannot give a Group instance to me." + ) await self.adapter.call_api( "memberAdmin", CallMethod.POST, @@ -914,9 +921,9 @@ class AnnouncementMixin(AriadneMixin): async def getAnnouncementIterator( self, target: Union[Group, int], - offset: Optional[int] = 0, - size: Optional[int] = 10, - ) -> Generator[Announcement, None, None]: + offset: int = 0, + size: int = 10, + ) -> AsyncGenerator[Announcement, None]: """ 获取群公告列表. @@ -926,15 +933,14 @@ async def getAnnouncementIterator( size (Optional[int], optional): 列表大小. 默认为 10. Returns: - Generator[Announcement, None, None]: 列出群组下所有的公告. + AsyncGenerator[Announcement, None]: 列出群组下所有的公告. """ target = int(target) current_offset = offset - cache: List[FileInfo] = [] + cache: List[Announcement] = [] while True: - while cache: - yield cache[0] - cache.pop(0) + for announcement in cache: + yield announcement cache = await self.getAnnouncementList(target, current_offset, size) current_offset += len(cache) if not cache: @@ -1046,10 +1052,10 @@ async def getFileIterator( self, target: Union[Group, int], id: str = "", - offset: Optional[int] = 0, - size: Optional[int] = 1, + offset: int = 0, + size: int = 1, with_download_info: bool = False, - ) -> Generator[FileInfo, None, None]: + ) -> AsyncGenerator[FileInfo, None]: """ 以生成器形式列出指定文件夹下的所有文件. @@ -1062,15 +1068,14 @@ async def getFileIterator( with_download_info (bool): 是否携带下载信息, 无必要不要携带 Returns: - Generator[FileInfo, None, None]: 文件信息生成器. + AsyncGenerator[FileInfo, None]: 文件信息生成器. """ target = int(target) current_offset = offset cache: List[FileInfo] = [] while True: - while cache: - yield cache[0] - cache.pop(0) + for file_info in cache: + yield file_info cache = await self.getFileList(target, id, current_offset, size, with_download_info) current_offset += len(cache) if not cache: @@ -1300,7 +1305,7 @@ async def renameFile( async def uploadFile( self, data: Union[bytes, io.IOBase, os.PathLike], - method: Union[str, UploadMethod] = None, + method: Union[str, UploadMethod, None] = None, target: Union[Friend, Group, int] = -1, path: str = "", name: str = "", @@ -1351,7 +1356,7 @@ class MultimediaMixin(AriadneMixin): @app_ctx_manager async def uploadImage( - self, data: Union[bytes, io.IOBase, os.PathLike], method: Union[str, UploadMethod] = None + self, data: Union[bytes, io.IOBase, os.PathLike], method: Union[None, str, UploadMethod] = None ) -> "Image": """上传一张图片到远端服务器, 需要提供: 图片的原始数据(bytes), 图片的上传类型. Args: @@ -1382,7 +1387,7 @@ async def uploadImage( @app_ctx_manager async def uploadVoice( - self, data: Union[bytes, io.IOBase, os.PathLike], method: Union[str, UploadMethod] = None + self, data: Union[bytes, io.IOBase, os.PathLike], method: Union[None, str, UploadMethod] = None ) -> "Voice": """上传语音到远端服务器, 需要提供: 语音的原始数据(bytes), 语音的上传类型. Args: @@ -1518,16 +1523,15 @@ def create(self, cls: Type["T"], reuse: bool = True) -> "T": Returns: T: 创建的类. """ - self.info: Dict[Type["T"], "T"] - if cls in self.info.keys(): - return self.info[cls] + if cls in self.info: + return self.info[cls] # type: ignore call_args: list = [] call_kwargs: Dict[str, Any] = {} init_sig = inspect.signature(cls) for name, param in init_sig.parameters.items(): - if param.annotation in self.info.keys() and param.kind not in ( + if param.annotation in self.info and param.kind not in ( param.VAR_KEYWORD, param.VAR_POSITIONAL, ): @@ -1562,6 +1566,7 @@ async def daemon(self, retry_interval: float = 5.0): await asyncio.wait_for(self.adapter.start(), timeout=retry_interval) logger.success("daemon: adapter started") self.broadcast.postEvent(AdapterLaunched(self)) + assert self.adapter.event_queue is not None, "No event queue found for Adapter" async for event in yield_with_timeout( self.adapter.event_queue.get, lambda: ( @@ -1570,9 +1575,8 @@ async def daemon(self, retry_interval: float = 5.0): ): with enter_context(self, event): sys.audit("AriadnePostRemoteEvent", event) - if isinstance(event, MessageEvent): - if event.messageChain.onlyContains(Source): # Contains unsupported type - event.messageChain.append("") + if isinstance(event, MessageEvent) and event.messageChain.onlyContains(Source): + event.messageChain.append("") if isinstance(event, FriendEvent): with enter_message_send_context(UploadMethod.Friend): self.broadcast.postEvent(event) @@ -1609,7 +1613,7 @@ async def daemon(self, retry_interval: float = 5.0): for t in asyncio.all_tasks(self.loop): if t is asyncio.current_task(self.loop): continue - coro: Coroutine = t.get_coro() + coro: Coroutine = t.get_coro() # type: ignore try: if coro.__qualname__ in ("Broadcast.Executor", "print_track_async..wrapper"): t.cancel() @@ -1635,65 +1639,65 @@ async def daemon(self, retry_interval: float = 5.0): async def launch(self): """启动 Ariadne.""" - if self.status is AriadneStatus.STOP: - self.status = AriadneStatus.LAUNCH - - # Logo - if not self.disable_logo: - logger.opt(colors=True, raw=True).info(f"{ARIADNE_ASCII_LOGO}") - - # Telemetry - if not self.disable_telemetry: - official: List[Tuple[str, str]] = [] - community: List[str] = [] - for dist in importlib.metadata.distributions(): - name: str = dist.metadata["Name"] - version: str = dist.version - if name.startswith("graia-"): - official.append((" ".join(name.split("-")[1:]).title(), version)) - elif name.startswith("graiax-"): - community.append((" ".join(name.split("-")).title(), version)) - - for name, version in official: - logger.opt(colors=True, raw=True).info( - f"{name} version: {version}\n" - ) - for name, version in community: - logger.opt(colors=True, raw=True).info(f"{name} version: {version}\n") - - logger.info("Launching app...") - start_time = time.time() - - if self.chat_log_cfg.enabled: - self.chat_log_cfg.initialize(self) - - if ContextDispatcher not in self.broadcast.finale_dispatchers: - self.broadcast.finale_dispatchers.append(ContextDispatcher) - - self.daemon_task = self.loop.create_task(self.daemon(), name="ariadne_daemon") - await await_predicate(lambda: self.adapter.running, 0.0001) - - self.running.add(self) - - if "reverse" not in self.adapter.tags: - await await_predicate( - lambda: self.adapter.mirai_session.session_key or self.adapter.mirai_session.single_mode, - 0.0001, - ) + if self.status is not AriadneStatus.STOP: + return + self.status = AriadneStatus.LAUNCH + + # Logo + if not self.disable_logo: + logger.opt(colors=True, raw=True).info(f"{ARIADNE_ASCII_LOGO}") + + # Telemetry + if not self.disable_telemetry: + official: List[Tuple[str, str]] = [] + community: List[Tuple[str, str]] = [] + for dist in importlib.metadata.distributions(): + name: str = dist.metadata["Name"] + version: str = dist.version + if name.startswith("graia-"): + official.append((" ".join(name.split("-")[1:]).title(), version)) + elif name.startswith("graiax-"): + community.append((" ".join(name.split("-")).title(), version)) + + for name, version in official: + logger.opt(colors=True, raw=True).info(f"{name} version: {version}\n") + for name, version in community: + logger.opt(colors=True, raw=True).info(f"{name} version: {version}\n") + + logger.info("Launching app...") + start_time = time.time() + + if self.chat_log_cfg.enabled: + self.chat_log_cfg.initialize(self) + + if ContextDispatcher not in self.broadcast.finale_dispatchers: + self.broadcast.finale_dispatchers.append(ContextDispatcher) - self.status = AriadneStatus.RUNNING + self.daemon_task = self.loop.create_task(self.daemon(), name="ariadne_daemon") + await await_predicate(lambda: self.adapter.running, 0.0001) - self.remote_version = await self.getVersion() - logger.success(f"Remote version: {self.remote_version}") - if not self.remote_version.startswith("2"): - raise RuntimeError(f"You are using an unsupported version: {self.remote_version}!") - logger.success(f"Application launched with {time.time() - start_time:.2}s") + self.running.add(self) - await self.broadcast.layered_scheduler( - listener_generator=self.broadcast.default_listener_generator(ApplicationLaunched), - event=ApplicationLaunched(self), + if "reverse" not in self.adapter.tags: + await await_predicate( + lambda: self.adapter.mirai_session.session_key is not None + or self.adapter.mirai_session.single_mode, + 0.0001, ) + self.status = AriadneStatus.RUNNING + + self.remote_version = await self.getVersion() + logger.success(f"Remote version: {self.remote_version}") + if not self.remote_version.startswith("2"): + raise RuntimeError(f"You are using an unsupported version: {self.remote_version}!") + logger.success(f"Application launched with {time.time() - start_time:.2}s") + + await self.broadcast.layered_scheduler( + listener_generator=self.broadcast.default_listener_generator(ApplicationLaunched), + event=ApplicationLaunched(self), + ) + async def stop(self): """请求停止 Ariadne.""" if self.status is AriadneStatus.RUNNING: @@ -1707,7 +1711,8 @@ async def join(self): if self.status in {AriadneStatus.RUNNING, AriadneStatus.LAUNCH}: await self.stop() await await_predicate(lambda: self.status is AriadneStatus.STOP) - await self.daemon_task + if self.daemon_task: + await self.daemon_task async def lifecycle(self): """以 async 阻塞方式启动 Ariadne 并等待其停止.""" @@ -1718,14 +1723,13 @@ def sig_handler(*_): signal_handler(sig_handler) await self.launch() - await self.daemon_task + if self.daemon_task: + await self.daemon_task def launch_blocking(self): """以阻塞方式启动 Ariadne 并等待其停止.""" - try: + with contextlib.suppress(KeyboardInterrupt): self.loop.run_until_complete(self.lifecycle()) - except KeyboardInterrupt: - pass self.loop.run_until_complete(self.join()) @app_ctx_manager diff --git a/src/graia/ariadne/config.py b/src/graia/ariadne/config.py new file mode 100644 index 00000000..6f43c0c7 --- /dev/null +++ b/src/graia/ariadne/config.py @@ -0,0 +1,32 @@ +from typing import Union + +from .model import AriadneBaseModel + + +class ElizabethConnectionConfig(AriadneBaseModel): + account: int + timeout: float = 5.0 + + +class ElizabethHttpClientConfig(ElizabethConnectionConfig): + url: str + + +class ElizabethWebsocketClientConfig(ElizabethConnectionConfig): + url: str + + +class ElizabethHttpServerConfig(ElizabethConnectionConfig): + api_root: str = "/api" + + +class ElizabethWebsocketServerConfig(ElizabethConnectionConfig): + api_root: str = "/api" + + +ConnectionConfig = Union[ + ElizabethHttpClientConfig, + ElizabethHttpServerConfig, + ElizabethWebsocketClientConfig, + ElizabethWebsocketServerConfig, +] diff --git a/src/graia/ariadne/connection.py b/src/graia/ariadne/connection.py new file mode 100644 index 00000000..23237e2e --- /dev/null +++ b/src/graia/ariadne/connection.py @@ -0,0 +1,130 @@ +import abc +import asyncio +from typing import TYPE_CHECKING, Optional, TypedDict + +from .config import ( + ElizabethHttpClientConfig, + ElizabethHttpServerConfig, + ElizabethWebsocketClientConfig, + ElizabethWebsocketServerConfig, +) +from .io import AiohttpClient + +if TYPE_CHECKING: + from graia.amnesia.builtins.starlette import StarletteServer + + from .service import ElizabethService + + +class ConnectionInfo(TypedDict, total=False): + session_key: str + + +class ElizabethConnection: + service: "ElizabethService" + connected: asyncio.Event + info: ConnectionInfo + + def __init__(self, info: ConnectionInfo, service: "ElizabethService") -> None: + self.service = service + self.info = info + self.connected = asyncio.Event() + + @abc.abstractmethod + async def maintask(self): + ... + + @abc.abstractmethod + async def action(self, action: str, data: dict, timeout: Optional[float] = None) -> dict: + await self.connected.wait() + + +# 等一下我发现你写的 HttpClient/Server 怎么都是空的?? +# 我不是说过你可以直接调用 aiohttp.ClientSession 了吗.... +# 不过封装一下, 我觉得你就得用 get_interface 获取了 -- 因为这是 v5 的方式 +# 嗯...职权替代. + + +class HttpClientConnection(ElizabethConnection): + config: ElizabethHttpClientConfig + client: AiohttpClient + + def __init__( + self, + client: AiohttpClient, + config: ElizabethHttpClientConfig, + info: ConnectionInfo, + service: "ElizabethService", + ) -> None: + super().__init__(info, service) + self.client = client + self.config = config + + async def maintask(self): + ... + + +class HttpServerConnection(ElizabethConnection): + config: ElizabethHttpServerConfig + server: "StarletteServer" + + def __init__( + self, + server: "StarletteServer", + config: ElizabethHttpServerConfig, + info: ConnectionInfo, + service: "ElizabethService", + ) -> None: + super().__init__(info, service) + self.server = server + self.config = config + + async def maintask(self): + ... + + async def action(self, action: str, data: dict, timeout: Optional[float] = None) -> dict: + await self.connected.wait() + + +class WebsocketClientConnection(ElizabethConnection): + config: ElizabethWebsocketClientConfig + client: AiohttpClient + + def __init__( + self, + client: AiohttpClient, + config: ElizabethWebsocketClientConfig, + info: ConnectionInfo, + service: "ElizabethService", + ) -> None: + super().__init__(info, service) + self.client = client + self.config = config + + async def maintask(self): + ... + + async def action(self, action: str, data: dict, timeout: Optional[float] = None) -> dict: + await self.connected.wait() + + +class WebsocketServerConnection(ElizabethConnection): + config: ElizabethWebsocketServerConfig + server: "StarletteServer" + + def __init__( + self, + server: "StarletteServer", + config: ElizabethWebsocketServerConfig, + info: ConnectionInfo, + service: "ElizabethService", + ) -> None: + super().__init__(info, service) + self.server = server + self.config = config + + async def maintask(self): + ... + + async def action(self, action: str, data: dict, timeout: Optional[float] = None) -> dict: + await self.connected.wait() diff --git a/src/graia/ariadne/console/__init__.py b/src/graia/ariadne/console/__init__.py index 3ac872e9..537fb948 100644 --- a/src/graia/ariadne/console/__init__.py +++ b/src/graia/ariadne/console/__init__.py @@ -1,5 +1,7 @@ """Ariadne 控制台 注意, 本实现并不 robust, 但是可以使用""" + +import contextlib import importlib.metadata import sys from asyncio.events import AbstractEventLoop @@ -197,11 +199,8 @@ def start(self): self.running = True if self.replace_logger: - try: + with contextlib.suppress(ValueError): logger.remove(0) - except ValueError: - pass - self.handler_id = logger.add(StdoutProxy(raw=True)) # type: ignore self.task = self.broadcast.loop.create_task(self.loop()) @@ -224,7 +223,11 @@ async def join(self): await self.task self.task = None - def register(self, dispatchers: List[BaseDispatcher] = None, decorators: List[Decorator] = None): + def register( + self, + dispatchers: Optional[List[BaseDispatcher]] = None, + decorators: Optional[List[Decorator]] = None, + ): """注册命令处理函数 Args: diff --git a/src/graia/ariadne/context.py b/src/graia/ariadne/context.py index 982a4a14..d0ec4640 100644 --- a/src/graia/ariadne/context.py +++ b/src/graia/ariadne/context.py @@ -1,7 +1,8 @@ """本模块创建了 Ariadne 中的上下文变量""" -from contextlib import contextmanager + +from contextlib import contextmanager, suppress from contextvars import ContextVar -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Dict, Optional if TYPE_CHECKING: from asyncio.events import AbstractEventLoop @@ -10,10 +11,10 @@ from graia.broadcast.entities.event import Dispatchable from .adapter import Adapter - from .app import Ariadne + from .app import AriadneMixin from .model import UploadMethod - ariadne_ctx: ContextVar[Ariadne] = ContextVar("ariadne") + ariadne_ctx: ContextVar[AriadneMixin] = ContextVar("ariadne") adapter_ctx: ContextVar[Adapter] = ContextVar("adapter") event_ctx: ContextVar[Dispatchable] = ContextVar("event") event_loop_ctx: ContextVar[AbstractEventLoop] = ContextVar("event_loop") @@ -50,30 +51,27 @@ def enter_message_send_context(method: "UploadMethod"): @contextmanager -def enter_context(app: "Ariadne" = None, event: "Dispatchable" = None): +def enter_context(app: Optional["AriadneMixin"] = None, event: Optional["Dispatchable"] = None): """进入事件上下文 Args: app (Ariadne, optional): Ariadne 实例. event (Dispatchable, optional): 当前事件 """ - token_app = None - token_event = None token_loop = None token_bcc = None token_adapter = None + token_app = None if app: token_app = ariadne_ctx.set(app) token_loop = event_loop_ctx.set(app.broadcast.loop) token_bcc = broadcast_ctx.set(app.broadcast) token_adapter = adapter_ctx.set(app.adapter) - if event: - token_event = event_ctx.set(event) - + token_event = event_ctx.set(event) if event else None yield - try: + with suppress(ValueError): if token_app: ariadne_ctx.reset(token_app) if token_adapter: @@ -84,5 +82,3 @@ def enter_context(app: "Ariadne" = None, event: "Dispatchable" = None): event_loop_ctx.reset(token_loop) if token_bcc: broadcast_ctx.reset(token_bcc) - except ValueError: - pass diff --git a/src/graia/ariadne/event/mirai.py b/src/graia/ariadne/event/mirai.py index ec2a9a5e..fa724821 100644 --- a/src/graia/ariadne/event/mirai.py +++ b/src/graia/ariadne/event/mirai.py @@ -115,9 +115,10 @@ class FriendInputStatusChangedEvent(FriendEvent): class Dispatcher(BaseDispatcher): @staticmethod async def catch(interface: DispatcherInterface): - if isinstance(interface.event, FriendInputStatusChangedEvent): - if generic_issubclass(Friend, interface.annotation): - return interface.event.friend + if isinstance(interface.event, FriendInputStatusChangedEvent) and generic_issubclass( + Friend, interface.annotation + ): + return interface.event.friend class FriendNickChangedEvent(FriendEvent): @@ -138,9 +139,10 @@ class FriendNickChangedEvent(FriendEvent): class Dispatcher(BaseDispatcher): @staticmethod async def catch(interface: DispatcherInterface): - if isinstance(interface.event, FriendNickChangedEvent): - if generic_issubclass(Friend, interface.annotation): - return interface.event.friend + if isinstance(interface.event, FriendNickChangedEvent) and generic_issubclass( + Friend, interface.annotation + ): + return interface.event.friend class BotGroupPermissionChangeEvent(GroupEvent, BotEvent): @@ -161,9 +163,10 @@ class BotGroupPermissionChangeEvent(GroupEvent, BotEvent): class Dispatcher(BaseDispatcher): @staticmethod async def catch(interface: DispatcherInterface): - if isinstance(interface.event, BotGroupPermissionChangeEvent): - if generic_issubclass(Group, interface.annotation): - return interface.event.group + if isinstance(interface.event, BotGroupPermissionChangeEvent) and generic_issubclass( + Group, interface.annotation + ): + return interface.event.group class BotMuteEvent(GroupEvent, BotEvent): @@ -173,13 +176,13 @@ class BotMuteEvent(GroupEvent, BotEvent): 提供的额外注解支持: Ariadne (annotation): 发布事件的应用实例 - Member (annotation, optional = None): 执行禁言操作的管理员/群主, 若为 None 则为 Bot 账号操作 + Member (annotation): 执行禁言操作的管理员/群主, 若为 None 则为 Bot 账号操作 Group (annotation): 发生该事件的群组 """ type = "BotMuteEvent" durationSeconds: int - operator: Optional[Member] + operator: Member class Dispatcher(BaseDispatcher): @staticmethod @@ -198,12 +201,12 @@ class BotUnmuteEvent(GroupEvent, BotEvent): 提供的额外注解支持: Ariadne (annotation): 发布事件的应用实例 - Member (annotation, optional = None): 执行解除禁言操作的管理员/群主, 若为 None 则为 Bot 账号操作 + Member (annotation): 执行解除禁言操作的管理员/群主, 若为 None 则为 Bot 账号操作 Group (annotation): 发生该事件的群组 """ type = "BotUnmuteEvent" - operator: Optional[Member] + operator: Member class Dispatcher(BaseDispatcher): @staticmethod @@ -256,9 +259,10 @@ class BotLeaveEventActive(GroupEvent, BotEvent): class Dispatcher(BaseDispatcher): @staticmethod async def catch(interface: DispatcherInterface): - if isinstance(interface.event, BotLeaveEventActive): - if generic_issubclass(Group, interface.annotation): - return interface.event.group + if isinstance(interface.event, BotLeaveEventActive) and generic_issubclass( + Group, interface.annotation + ): + return interface.event.group class BotLeaveEventKick(GroupEvent, BotEvent): @@ -277,9 +281,10 @@ class BotLeaveEventKick(GroupEvent, BotEvent): class Dispatcher(BaseDispatcher): @staticmethod async def catch(interface: DispatcherInterface): - if isinstance(interface.event, BotLeaveEventKick): - if generic_issubclass(Group, interface.annotation): - return interface.event.group + if isinstance(interface.event, BotLeaveEventKick) and generic_issubclass( + Group, interface.annotation + ): + return interface.event.group class GroupRecallEvent(GroupEvent): diff --git a/src/graia/ariadne/io/__init__.py b/src/graia/ariadne/io/__init__.py new file mode 100644 index 00000000..c0e1fc27 --- /dev/null +++ b/src/graia/ariadne/io/__init__.py @@ -0,0 +1,51 @@ +# For Ariadne, +# "http.universal_server" stands for "StarletteService" +# "http.universal_client" stands for "AiohttpService" + +import asyncio +from typing import Optional, Type + +from aiohttp import ClientSession +from graia.amnesia.interface import ExportInterface +from graia.amnesia.launch import LaunchComponent +from graia.amnesia.service import Service + + +class AiohttpClient(ExportInterface): + session: ClientSession + + def __init__(self, session: ClientSession, service: "AiohttpService") -> None: + self.session = session + self.service = service + + +class AiohttpService(Service): + supported_interface_types = {AiohttpClient} + + client_session: Optional[ClientSession] + + def __init__(self, client_session: Optional[ClientSession] = None) -> None: + self.client_session = client_session + + def get_interface(self, interface_type: Type[AiohttpClient]) -> AiohttpClient: + if issubclass(interface_type, AiohttpClient): + assert self.client_session + return AiohttpClient(self.client_session, self) + raise ValueError(f"{interface_type} is not supported") + + async def prepare(self, _): + if not self.client_session: + self.client_session = ClientSession(loop=asyncio.get_running_loop()) + + async def cleanup(self, _): + assert self.client_session + await self.client_session.close() + + @property + def launch_component(self) -> LaunchComponent: + return LaunchComponent( + "http.universal_client", + set(), + prepare=self.prepare, + cleanup=self.cleanup, + ) diff --git a/src/graia/ariadne/message/chain.py b/src/graia/ariadne/message/chain.py index ddf7bb03..ab013e5d 100644 --- a/src/graia/ariadne/message/chain.py +++ b/src/graia/ariadne/message/chain.py @@ -8,7 +8,6 @@ Iterable, Iterator, List, - Optional, Tuple, Type, TypeVar, @@ -16,6 +15,8 @@ overload, ) +from typing_extensions import Self + from ..model import AriadneBaseModel from ..util import gen_subclass from .element import ( @@ -32,7 +33,7 @@ ) if TYPE_CHECKING: - from ..typing import MessageIndex, ReprArgs, Slice + from ..typing import ReprArgs Element_T = TypeVar("Element_T", bound=Element) @@ -51,11 +52,11 @@ class MessageChain(AriadneBaseModel): """底层元素列表""" @staticmethod - def build_chain(obj: List[Union[dict, Element, str]]) -> List[Element]: + def build_chain(obj: Iterable[Union[dict, Element, str]]) -> List[Element]: """内部接口, 会自动反序列化对象并生成. Args: - obj (List[T]): 需要反序列化的对象 + obj (Iterable[Union[dict, Element, str]]): 需要反序列化的对象 Returns: List[Element]: 内部承载有尽量有效的消息元素的列表 @@ -83,7 +84,7 @@ def parse_obj(cls: Type["MessageChain"], obj: List[Union[dict, Element]]) -> "Me Returns: MessageChain: 内部承载有尽量有效的消息元素的消息链 """ - return cls(__root__=cls.build_chain(obj)) + return cls(__root__=cls.build_chain(obj)) # type: ignore def __init__( self, @@ -91,9 +92,9 @@ def __init__( inline: bool = False, ) -> None: if not inline: - super().__init__(__root__=self.build_chain(__root__)) + super().__init__(__root__=self.build_chain(__root__)) # type: ignore else: - super().__init__(__root__=__root__) + super().__init__(__root__=__root__) # type: ignore @classmethod def create(cls, *elements: Union[Iterable[Element], Element, str]) -> "MessageChain": @@ -240,16 +241,14 @@ def __getitem__(self, item: Type[Element_T]) -> List[Element_T]: ... @overload - def __getitem__(self, item: int) -> Element_T: + def __getitem__(self, item: int) -> Element: ... @overload def __getitem__(self, item: slice) -> "MessageChain": ... - def __getitem__( - self, item: Union[Type[Element_T], slice, Tuple[Type[Element_T], int], int] - ) -> Union[List[Element_T], "MessageChain", Element]: + def __getitem__(self, item): """ 可通过切片取出子消息链, 或元素. @@ -264,74 +263,15 @@ def __getitem__( Returns: 索引结果. """ - if isinstance(item, slice): - return self.subchain(item) if isinstance(item, type) and issubclass(item, Element): return self.get(item) if isinstance(item, tuple): return self.get(*item) if isinstance(item, int): return self.__root__[item] - raise NotImplementedError(f"{item} is not allowed for item getting") - - def subchain( - self, - item: "Slice[Optional[MessageIndex], Optional[MessageIndex]]", - ignore_text_index: bool = False, - ) -> "MessageChain": - """对消息链执行分片操作 - - Args: - item (slice): 这个分片的 `start` 和 `end` 的 Type Annotation 都是 `Optional[MessageIndex]` - ignore_text_index (bool, optional): 在 TextIndex 取到错误位置时是否引发错误. - - Raises: - ValueError: TextIndex 取到了错误的位置 - - Returns: - MessageChain: 分片后得到的新消息链, 绝对是原消息链的子集. - """ - if isinstance(item.start, int) and isinstance(item.stop, int): + if isinstance(item, slice): return MessageChain(self.__root__[item], inline=True) - result = self.merge(copy=True).__root__[ - item.start[0] if item.start else None : item.stop[0] if item.stop else None - ] - if len(result) == 1: - text_start: Optional[int] = ( - item.start[1] - if item.start and len(item.start) >= 2 and isinstance(item.start[1], int) - else None - ) - text_stop: Optional[int] = ( - item.stop[1] if item.stop and len(item.stop) >= 2 and isinstance(item.stop[1], int) else None - ) - elem = result[0] - if not isinstance(elem, Plain): - if not ignore_text_index: - raise ValueError(f"the sliced chain does not starts with a Plain: {elem}") - else: - elem.text = elem.text[text_start:text_stop] - elif len(result) >= 2: - if item.start: - first_element = result[0] - if len(item.start) >= 2 and item.start[1] is not None: # text slice - if not isinstance(first_element, Plain): - if not ignore_text_index: - raise ValueError( - f"the sliced chain does not starts with a Plain: {first_element}" - ) - else: - first_element.text = first_element.text[item.start[1] :] - if item.stop: - last_element = result[-1] - if len(item.stop) >= 2 and item.stop[1] is not None: # text slice - if not isinstance(last_element, Plain): - if not ignore_text_index: - raise ValueError(f"the sliced chain does not ends with a Plain: {last_element}") - else: - last_element.text = last_element.text[: item.stop[1]] - - return MessageChain(result, inline=True) + raise NotImplementedError(f"{item} is not allowed for item getting") def findSubChain(self, subchain: Union["MessageChain", List[Element]]) -> List[int]: """判断消息链是否含有子链. 使用 KMP 算法. @@ -342,7 +282,11 @@ def findSubChain(self, subchain: Union["MessageChain", List[Element]]) -> List[i Returns: List[int]: 所有找到的下标. """ - pattern: List[Union[str, Element]] = subchain.unzip() + pattern: List[Union[str, Element]] = ( + subchain.unzip() + if isinstance(subchain, MessageChain) + else MessageChain(subchain, inline=True).unzip() + ) match_target: List[Union[str, Element]] = self.unzip() @@ -493,10 +437,10 @@ def merge(self, copy: bool = False) -> "MessageChain": result.append(Plain("".join(plain))) plain.clear() - if copy: - return MessageChain(result, inline=True) - self.__root__ = result - return self + if not copy: + self.__root__ = result + return self + return MessageChain(result, inline=True) def append(self, element: Union[Element, str], copy: bool = False) -> "MessageChain": """ @@ -569,10 +513,10 @@ def index(self, element_type: Type[Element_T]) -> Union[int, None]: Optional[int]: 元素下标, 若未找到则为 None. """ - for i, e in enumerate(self.__root__): - if isinstance(e, element_type): - return i - return None + return next( + (i for i, e in enumerate(self.__root__) if isinstance(e, element_type)), + None, + ) def count(self, element: Union[Type[Element_T], Element_T]) -> int: """ @@ -609,7 +553,7 @@ def __add__(self, content: Union["MessageChain", List[Element], Element, str]) - if isinstance(content, Element): content = [content] if isinstance(content, MessageChain): - content: List[Element] = content.__root__ + content = content.__root__ return MessageChain(self.__root__ + content, inline=True) def __radd__(self, content: Union["MessageChain", List[Element], Element, str]) -> "MessageChain": @@ -618,7 +562,7 @@ def __radd__(self, content: Union["MessageChain", List[Element], Element, str]) if isinstance(content, Element): content = [content] if isinstance(content, MessageChain): - content: List[Element] = content.__root__ + content = content.__root__ return MessageChain(content + self.__root__, inline=True) def __iadd__(self, content: Union["MessageChain", List[Element], Element, str]) -> "MessageChain": @@ -627,7 +571,7 @@ def __iadd__(self, content: Union["MessageChain", List[Element], Element, str]) if isinstance(content, Element): content = [content] if isinstance(content, MessageChain): - content: List[Element] = content.__root__ + content = content.__root__ self.__root__.extend(content) return self @@ -648,15 +592,15 @@ def asPersistentString( self, *, binary: bool = True, - include: Optional[Iterable[Type[Element]]] = (), - exclude: Optional[Iterable[Type[Element]]] = (), + include: Iterable[Type[Element]] = (), + exclude: Iterable[Type[Element]] = (), ) -> str: """转换为持久化字符串. Args: binary (bool, optional): 是否附带图片或声音的二进制. 默认为 True. - include (Optional[Iterable[Type[Element]]], optional): 筛选, 只包含本参数提供的元素类型. - exclude (Optional[Iterable[Type[Element]]], optional): 筛选, 排除本参数提供的元素类型. + include (Iterable[Type[Element]], optional): 筛选, 只包含本参数提供的元素类型. + exclude (Iterable[Type[Element]], optional): 筛选, 排除本参数提供的元素类型. Raises: ValueError: 同时提供 include 与 exclude @@ -683,7 +627,7 @@ def asPersistentString( string_list.append(i.asNoBinaryPersistentString()) return "".join(string_list) - async def download_binary(self) -> None: + async def download_binary(self) -> Self: """下载消息中所有的二进制数据并保存在元素实例内""" for elem in self.__root__: if isinstance(elem, MultimediaElement): @@ -699,8 +643,7 @@ def fromPersistentString(cls, string: str) -> "MessageChain": """ result = [] for match in re.split(r"(\[mirai:.+?\])", string): - mirai = re.fullmatch(r"\[mirai:(.+?)(:(.+?))\]", match) - if mirai: + if mirai := re.fullmatch(r"\[mirai:(.+?)(:(.+?))\]", match): j_string = mirai.group(3) element_cls = ELEMENT_MAPPING[mirai.group(1)] result.append(element_cls.parse_obj(json.loads(j_string))) @@ -725,7 +668,7 @@ def _to_mapping_str( Returns: Tuple[str, Dict[str, Element]]: 生成的映射字符串与映射字典的元组 """ - elem_mapping: Dict[int, Element] = {} + elem_mapping: Dict[str, Element] = {} elem_str_list: List[str] = [] for i, elem in enumerate(self.__root__): if not isinstance(elem, Plain): @@ -735,18 +678,17 @@ def _to_mapping_str( continue elem_mapping[str(i)] = elem elem_str_list.append(f"\x02{i}_{elem.type}\x03") + elif ( + remove_extra_space + and i # not first element + and isinstance( + self.__root__[i - 1], (Quote, At, AtAll) + ) # following elements which have an dumb trailing space + and elem.text.startswith(" ") # extra space (count >= 2) + ): + elem_str_list.append(elem.text[1:]) else: - if ( - remove_extra_space - and i # not first element - and isinstance( - self.__root__[i - 1], (Quote, At, AtAll) - ) # following elements which have an dumb trailing space - and elem.text.startswith(" ") # extra space (count >= 2) - ): - elem_str_list.append(elem.text[1:]) - else: - elem_str_list.append(elem.text) + elem_str_list.append(elem.text) return "".join(elem_str_list), elem_mapping @classmethod @@ -799,7 +741,7 @@ def removeprefix(self, prefix: str, *, copy: bool = True, skip_header: bool = Tr header = deepcopy(header) elements = deepcopy(elements) if not elements or not isinstance(elements[0], Plain): - return self if not copy else self.copy() + return self.copy() if copy else self if elements[0].text.startswith(prefix): elements[0].text = elements[0].text[len(prefix) :] if copy: @@ -817,12 +759,9 @@ def removesuffix(self, suffix: str, *, copy: bool = True) -> "MessageChain": Returns: MessageChain: 修改后的消息链, 若未移除则原样返回. """ - if copy: - elements = deepcopy(self.__root__) - else: - elements = self.__root__ + elements = deepcopy(self.__root__) if copy else self.__root__ if not elements or not isinstance(elements[-1], Plain): - return self if not copy else self.copy() + return self.copy() if copy else self last_elem: Plain = elements[-1] if last_elem.text.endswith(suffix): last_elem.text = last_elem.text[: -len(suffix)] @@ -831,7 +770,9 @@ def removesuffix(self, suffix: str, *, copy: bool = True) -> "MessageChain": self.__root__ = elements return self - def join(self, chains: Iterable["MessageChain"], merge: bool = True) -> "MessageChain": + def join( + self, *chains: Union["MessageChain", Iterable["MessageChain"]], merge: bool = True + ) -> "MessageChain": """将多个消息链连接起来, 并在其中插入自身. Args: @@ -842,11 +783,18 @@ def join(self, chains: Iterable["MessageChain"], merge: bool = True) -> "Message MessageChain: 连接后的消息链. """ result: List[Element] = [] + list_chains: List[MessageChain] = [] for chain in chains: + if isinstance(chain, MessageChain): + list_chains.append(chain) + else: + list_chains.extend(chain) + + for chain in list_chains: if chain is not chains[0]: result.extend(deepcopy(self.__root__)) result.extend(deepcopy(chain.__root__)) - return MessageChain(result, inline=True) if not merge else MessageChain(result, inline=True).merge() + return MessageChain(result, inline=True).merge() if merge else MessageChain(result, inline=True) def replace( self, diff --git a/src/graia/ariadne/message/commander/__init__.py b/src/graia/ariadne/message/commander/__init__.py index f4e6bb68..11a65d4d 100644 --- a/src/graia/ariadne/message/commander/__init__.py +++ b/src/graia/ariadne/message/commander/__init__.py @@ -6,6 +6,7 @@ from contextvars import ContextVar from dataclasses import dataclass from typing import ( + TYPE_CHECKING, Any, Callable, Dict, @@ -112,7 +113,7 @@ def __init__( self.default = default self.default_factory = default_factory self.param_name: str = "" - self.model: Optional[BaseModel] = None + self.model: Optional[Type[BaseModel]] = None def gen_model(self, validators: Iterable[Callable]) -> None: if self.model or self.type is _raw: @@ -250,15 +251,15 @@ def __init__( self, record: CommandPattern, callable: Callable, - dispatchers: Sequence[BaseDispatcher] = None, - decorators: Sequence[Decorator] = None, + dispatchers: Sequence[BaseDispatcher] = (), + decorators: Sequence[Decorator] = (), ): super().__init__( callable, [ ConstantDispatcher(commander_data_ctx), ContextDispatcher(), - *resolve_dispatchers_mixin(dispatchers or []), + *resolve_dispatchers_mixin(dispatchers), ], list(decorators), ) @@ -282,6 +283,9 @@ def get_data( param_result: Dict[str, Any] = {} for arg in set(self.pattern.arg_map.values()): value = arg.default_factory() + if TYPE_CHECKING: + assert arg.param_name + assert arg.model if arg.nargs: for param in arg.match_patterns: if param in arg_data: @@ -307,6 +311,8 @@ def get_data( param_result[arg.param_name] = arg.model(val=value).__dict__["val"] for ind, slot in self.pattern.slot_map.items(): + if TYPE_CHECKING: + assert slot.model if slot.param_name != self.pattern.wildcard: value = slot_data.get(ind, None) or slot.default_factory() param_result[slot.param_name] = slot.model(val=value).__dict__["val"] @@ -357,7 +363,7 @@ def add_type_cast(self, *caster: Callable): def command( self, command: str, - setting: Dict[str, Union[Slot, Arg]] = None, + setting: Optional[Dict[str, Union[Slot, Arg]]] = None, dispatchers: Sequence[BaseDispatcher] = (), decorators: Sequence[Decorator] = (), ) -> Callable[[T_Callable], T_Callable]: @@ -425,7 +431,8 @@ def command( eval(default or "...", *eval_ctx(1)), ) parsed_slot.param_name = name # assuming that param_name is consistent - slot_map[name] = parsed_slot | slot_map.get(name, {}) # parsed slot < provided slot + if name in slot_map: + slot_map[name] = parsed_slot | slot_map[name] # parsed slot < provided slot if wildcard: wildcard_slot_name = name token_list.append([name]) @@ -454,7 +461,8 @@ def __translate_obj(obj): if name in placeholder_set: parsed_slot = Slot(name, annotation, default) parsed_slot.param_name = name # assuming that param_name is consistent - slot_map[name] = parsed_slot | slot_map.get(name, {}) # parsed slot < provided slot + if name in slot_map: + slot_map[name] = parsed_slot | slot_map[name] # parsed slot < provided slot if default is not ...: assert all( [ diff --git a/src/graia/ariadne/message/element.py b/src/graia/ariadne/message/element.py index 4ba785af..38563fd1 100644 --- a/src/graia/ariadne/message/element.py +++ b/src/graia/ariadne/message/element.py @@ -5,7 +5,6 @@ from enum import Enum from io import BytesIO from json import dumps as j_dump -from os import PathLike from pathlib import Path from typing import TYPE_CHECKING, Iterable, List, NoReturn, Optional, Union @@ -34,7 +33,7 @@ class Element(AriadneBaseModel, abc.ABC): type (str): 元素类型 """ - type: str + type: str = "Unknown" """元素类型""" def __hash__(self): @@ -90,7 +89,7 @@ def __add__(self, content: Union["MessageChain", List["Element"], "Element", str if isinstance(content, Element): content = [content] if isinstance(content, MessageChain): - content: List[Element] = content.__root__ + content = content.__root__ return MessageChain(content + [self], inline=True) def __radd__(self, content: Union["MessageChain", List["Element"], "Element", str]) -> "MessageChain": @@ -101,7 +100,7 @@ def __radd__(self, content: Union["MessageChain", List["Element"], "Element", st if isinstance(content, Element): content = [content] if isinstance(content, MessageChain): - content: List[Element] = content.__root__ + content = content.__root__ return MessageChain([self] + content, inline=True) @@ -119,7 +118,7 @@ def __init__(self, text: str, **kwargs) -> None: Args: text (str): 元素所包含的文字 """ - super().__init__(text=text, **kwargs) + super().__init__(text=text, **kwargs) # type: ignore def asDisplay(self) -> str: return self.text @@ -151,9 +150,10 @@ async def fetchOriginal(self) -> "MessageChain": Returns: MessageChain: 原来的消息链. """ - from ..context import ariadne_ctx + from .. import get_running + from ..app import Ariadne - ariadne = ariadne_ctx.get() + ariadne = get_running(Ariadne) return (await ariadne.getMessageFromId(self.id)).messageChain @@ -219,13 +219,10 @@ def __eq__(self, other: "At"): return isinstance(other, At) and self.target == other.target def prepare(self) -> None: - try: - if upload_method_ctx.get() != UploadMethod.Group: - raise InvalidArgument( - f"you cannot use this element in this method: {upload_method_ctx.get().value}" - ) - except LookupError: - pass + if upload_method_ctx.get(None) != UploadMethod.Group: + raise InvalidArgument( + f"you cannot use this element in this method: {upload_method_ctx.get().value}" + ) def asDisplay(self) -> str: return f"@{self.display}" if self.display else f"@{self.target}" @@ -243,13 +240,10 @@ def asDisplay(self) -> str: return "@全体成员" def prepare(self) -> None: - try: - if upload_method_ctx.get() != UploadMethod.Group: - raise InvalidArgument( - f"you cannot use this element in this method: {upload_method_ctx.get().value}" - ) - except LookupError: - pass + if upload_method_ctx.get(None) != UploadMethod.Group: + raise InvalidArgument( + f"you cannot use this element in this method: {upload_method_ctx.get().value}" + ) class Face(Element): @@ -271,7 +265,7 @@ def __init__(self, id: int = ..., name: str = ..., **data) -> None: super().__init__(**data) def asDisplay(self) -> str: - return f"[表情: {self.name if self.name else self.faceId}]" + return f"[表情: {self.name or self.faceId}]" def __eq__(self, other) -> bool: return isinstance(other, Face) and (self.faceId == other.faceId or self.name == other.name) @@ -292,7 +286,7 @@ class MarketFace(Element): """QQ 表情名称""" def asDisplay(self) -> str: - return f"[商城表情: {self.name if self.name else self.faceId}]" + return f"[商城表情: {self.name or self.faceId}]" def __eq__(self, other) -> bool: return isinstance(other, MarketFace) and (self.faceId == other.faceId or self.name == other.name) @@ -307,7 +301,7 @@ class Xml(Element): """XML文本""" def __init__(self, xml: str, **_) -> None: - super().__init__(xml=xml) + super().__init__(xml=xml) # type: ignore def asDisplay(self) -> str: return "[XML消息]" @@ -324,7 +318,7 @@ class Json(Element): def __init__(self, json: Union[dict, str], **kwargs) -> None: if isinstance(json, dict): json = j_dump(json) - super().__init__(json=json, **kwargs) + super().__init__(json=json, **kwargs) # type: ignore def dict(self, *args, **kwargs): return super().dict(*args, **({**kwargs, "by_alias": True})) @@ -342,7 +336,7 @@ class App(Element): """App 内容""" def __init__(self, content: str, **_) -> None: - super().__init__(content=content) + super().__init__(content=content) # type: ignore def asDisplay(self) -> str: return "[APP消息]" @@ -409,7 +403,7 @@ class Poke(Element): """戳一戳使用的方法""" def __init__(self, name: PokeMethods, *_, **__) -> None: - super().__init__(name=name) + super().__init__(name=name) # type: ignore def asDisplay(self) -> str: return f"[戳一戳:{self.name}]" @@ -424,7 +418,7 @@ class Dice(Element): """骰子值""" def __init__(self, value: int, *_, **__) -> None: - super().__init__(value=value) + super().__init__(value=value) # type: ignore def asDisplay(self) -> str: return f"[骰子:{self.value}]" @@ -487,13 +481,13 @@ def __init__( **__, ) -> None: super().__init__( - kind=kind, - title=title, - summary=summary, - jumpUrl=jumpUrl, - pictureUrl=pictureUrl, - musicUrl=musicUrl, - brief=brief, + kind=kind, # type: ignore + title=title, # type: ignore + summary=summary, # type: ignore + jumpUrl=jumpUrl, # type: ignore + pictureUrl=pictureUrl, # type: ignore + musicUrl=musicUrl, # type: ignore + brief=brief, # type: ignore ) def asDisplay(self) -> str: @@ -647,21 +641,17 @@ def __init__( id: Optional[str] = None, url: Optional[str] = None, *, - path: Optional[Union[PathLike, str]] = None, + path: Optional[Union[Path, str]] = None, base64: Optional[str] = None, data_bytes: Union[None, bytes, BytesIO] = None, **kwargs, ) -> None: - data = {} - - for key, value in kwargs.items(): - if key.lower().endswith("id"): - data["id"] = value + data = {"id": value for key, value in kwargs.items() if key.lower().endswith("id")} if sum([bool(url), bool(path), bool(base64)]) > 1: raise ValueError("Too many binary initializers!") # Web initializer - data["id"] = data["id"] if "id" in data else id + data["id"] = data.get("id", id) data["url"] = url # Binary initializer if path: @@ -723,9 +713,7 @@ def asNoBinaryPersistentString(self) -> str: @property def uuid(self): """多媒体元素的 uuid, 即元素在 mirai 内部的标识""" - if self.id: - return self.id.split(".")[0].strip("/{}").lower() - return "" + return self.id.split(".")[0].strip("/{}").lower() if self.id else "" def __eq__(self, other: "MultimediaElement"): if self.__class__ is not other.__class__: @@ -809,7 +797,7 @@ def asDisplay(self) -> str: def _update_forward_refs(): """ - Inner function. + Internal function. Update the forward references. """ from ..model import BotMessage diff --git a/src/graia/ariadne/message/formatter.py b/src/graia/ariadne/message/formatter.py index 3e7f002c..3e0d8c20 100644 --- a/src/graia/ariadne/message/formatter.py +++ b/src/graia/ariadne/message/formatter.py @@ -15,7 +15,7 @@ def __init__(self, format_string: str) -> None: self.format_string = format_string def format( - self, *args: Union[Element, MessageChain, str], **kwargs: Union[Element, MessageChain, str] + self, *o_args: Union[Element, MessageChain, str], **o_kwargs: Union[Element, MessageChain, str] ) -> MessageChain: """通过初始化时传入的格式字符串 格式化消息链 @@ -26,17 +26,17 @@ def format( Returns: MessageChain: 格式化后的消息链 """ - args: List[MessageChain] = [MessageChain.create(e) for e in args] - kwargs: Dict[str, MessageChain] = {k: MessageChain.create(e) for k, e in kwargs.items()} + args: List[MessageChain] = [MessageChain.create(e) for e in o_args] + kwargs: Dict[str, MessageChain] = {k: MessageChain.create(e) for k, e in o_kwargs.items()} args_mapping: Dict[str, MessageChain] = { f"\x02{index}\x02": chain for index, chain in enumerate(args) } kwargs_mapping: Dict[str, MessageChain] = {f"\x03{key}\x03": chain for key, chain in kwargs.items()} - result = self.format_string.format(*args_mapping, **{k: f"\x03{k}\x03" for k in kwargs.keys()}) + result = self.format_string.format(*args_mapping, **{k: f"\x03{k}\x03" for k in kwargs}) - chain_list: List[MessageChain] = [] + chain_list: List[Union[MessageChain, Plain]] = [] for i in re.split("([\x02\x03][\\d\\w]+[\x02\x03])", result): if match := re.fullmatch("(?P
[\x02\x03])(?P\\w+)(?P=header)", i): diff --git a/src/graia/ariadne/message/parser/base.py b/src/graia/ariadne/message/parser/base.py index 6892093c..e5c071a2 100644 --- a/src/graia/ariadne/message/parser/base.py +++ b/src/graia/ariadne/message/parser/base.py @@ -7,6 +7,7 @@ from graia.broadcast.entities.decorator import Decorator from graia.broadcast.exceptions import ExecutionStop from graia.broadcast.interfaces.decorator import DecoratorInterface +from loguru import logger from ... import get_running from ...event.message import GroupMessage @@ -97,19 +98,20 @@ class MentionMe(ChainDecorator): async def decorate(self, chain: MessageChain, interface: DecoratorInterface) -> Optional[MessageChain]: ariadne = get_running() if isinstance(interface.event, GroupMessage): - name = (await ariadne.getMember(ariadne.account, interface.event.sender.group)).name + if not ariadne.account: + logger.warning("Unable to detect Ariadne's name because account is not set") + raise ExecutionStop + name = (await ariadne.getMember(interface.event.sender.group, ariadne.account)).name else: name = (await ariadne.getBotProfile()).nickname header = chain.include(Quote, Source) rest: MessageChain = chain.exclude(Quote, Source) first: Element = rest[0] result: Optional[MessageChain] = None - if rest and isinstance(first, Plain): - if first.asDisplay().startswith(name): - result = header + rest.removeprefix(name).removeprefix(" ") - if rest and isinstance(first, At): - if first.target == ariadne.account: - result = header + MessageChain(rest.__root__[1:], inline=True).removeprefix(" ") + if rest and isinstance(first, Plain) and first.asDisplay().startswith(name): + result = header + rest.removeprefix(name).removeprefix(" ") + if rest and isinstance(first, At) and first.target == ariadne.account: + result = header + MessageChain(rest.__root__[1:], inline=True).removeprefix(" ") if result is None: raise ExecutionStop @@ -130,12 +132,15 @@ async def decorate(self, chain: MessageChain, interface: DecoratorInterface) -> rest: MessageChain = chain.exclude(Quote, Source) first: Element = rest[0] result: Optional[MessageChain] = None - if rest and isinstance(first, Plain): - if isinstance(self.person, str) and first.asDisplay().startswith(self.person): - result = header + rest.removeprefix(self.person).removeprefix(" ") - if rest and isinstance(first, At): - if isinstance(self.person, int) and first.target == self.person: - result = header + MessageChain(rest.__root__[1:], inline=True).removeprefix(" ") + if ( + rest + and isinstance(first, Plain) + and isinstance(self.person, str) + and first.asDisplay().startswith(self.person) + ): + result = header + rest.removeprefix(self.person).removeprefix(" ") + if rest and isinstance(first, At) and isinstance(self.person, int) and first.target == self.person: + result = header + MessageChain(rest.__root__[1:], inline=True).removeprefix(" ") if result is None: raise ExecutionStop diff --git a/src/graia/ariadne/message/parser/twilight.py b/src/graia/ariadne/message/parser/twilight.py index 150d041d..fbb00d8b 100644 --- a/src/graia/ariadne/message/parser/twilight.py +++ b/src/graia/ariadne/message/parser/twilight.py @@ -83,7 +83,7 @@ def help(self, value: str) -> Self: self._help = value return self - def param(self, target: str) -> Self: + def param(self, target: Union[int, str]) -> Self: """设置匹配项的分派位置.""" self.dest = target return self @@ -305,7 +305,7 @@ class ArgumentMatch(Match, Generic[T]): if TYPE_CHECKING: @overload - def __init__( + def __init__( # type: ignore self, *pattern: str, action: Union[str, Type[Action]] = ..., @@ -408,23 +408,24 @@ def __init__(self, match_result: Dict[Union[int, str], MatchResult]): self.res = match_result @overload - def __getitem__(self, key: Union[int, str]) -> MatchResult: + def __getitem__(self, item: Union[int, str]) -> MatchResult: ... @overload - def __getitem__(self, key: Type[int]) -> List[MatchResult]: + def __getitem__(self, item: Type[int]) -> List[MatchResult]: ... @overload - def __getitem__(self, key: Type[str]) -> Dict[str, MatchResult]: + def __getitem__(self, item: Type[str]) -> Dict[str, MatchResult]: ... def __getitem__(self, item: Union[int, str, Type[int], Type[str]]): + if not isinstance(item, type): + return self.get(item) if item is int: return [v for k, v in self.res.items() if isinstance(k, int)] elif item is str: return {k: v for k, v in self.res.items() if isinstance(k, str)} - return self.get(item) def get(self, item: Union[int, str]) -> MatchResult: return self.res[item] @@ -443,7 +444,7 @@ def __init__(self, *root: Union[Iterable[Match], Match]): self._parser = TwilightParser(prog="", add_help=False) self._dest_map: Dict[str, ArgumentMatch] = {} self._group_map: Dict[int, RegexMatch] = {} - self.dispatch_ref: Dict[str, Match] = {} + self.dispatch_ref: Dict[Union[int, str], Match] = {} self.match_ref: DefaultDict[Type[Match], List[Match]] = DefaultDict(list) regex_str_list: List[str] = [] @@ -462,9 +463,12 @@ def __init__(self, *root: Union[Iterable[Match], Match]): elif isinstance(m, ArgumentMatch): self.match_ref[ArgumentMatch].append(m) - if "action" in m.arg_data and "type" in m.arg_data: - if not self._parser.accept_type(m.arg_data["action"]): - del m.arg_data["type"] + if ( + "action" in m.arg_data + and "type" in m.arg_data + and not self._parser.accept_type(m.arg_data["action"]) + ): + del m.arg_data["type"] action = self._parser.add_argument(*m.pattern, **m.arg_data) if m.dest: self._dest_map[action.dest] = m @@ -476,36 +480,39 @@ def __init__(self, *root: Union[Iterable[Match], Match]): self._regex_pattern: re.Pattern = re.compile("".join(regex_str_list)) - def match(self, arguments: List[str], elem_mapping: Dict[str, Element]) -> Dict[str, MatchResult]: + def match( + self, arguments: List[str], elem_mapping: Dict[str, Element] + ) -> Dict[Union[int, str], MatchResult]: """匹配参数 Args: arguments (List[str]): 参数列表 elem_mapping (Dict[str, Element]): 元素映射 Returns: - Dict[str, MatchResult]: 匹配结果 + Dict[Union[int, str], MatchResult]: 匹配结果 """ - result: Dict[str, MatchResult] = {} + result: Dict[Union[int, str], MatchResult] = {} if self._dest_map: namespace, arguments = self._parser.parse_known_args(arguments) nbsp_dict: Dict[str, Any] = namespace.__dict__ for k, v in self._dest_map.items(): res = nbsp_dict.get(k, Unmatched) result[v.dest] = MatchResult(res is not Unmatched, v, res) - if total_match := self._regex_pattern.fullmatch(" ".join(arguments)): - for index, match in self._group_map.items(): - group: Optional[str] = total_match.group(index) - if group is not None: - if isinstance(match, ElementMatch): - res = elem_mapping[group[1:-1].split("_")[0]] - else: - res = MessageChain._from_mapping_string(group, elem_mapping) - else: - res = None - if match.dest: - result[match.dest] = MatchResult(group is not None, match, res) - else: + if not (total_match := self._regex_pattern.fullmatch(" ".join(arguments))): raise ValueError(f"{' '.join(arguments)} not matching {self._regex_pattern.pattern}") + for index, match in self._group_map.items(): + group: Optional[str] = total_match.group(index) + if group is None: + res = None + else: + res = ( + elem_mapping[group[1:-1].split("_")[0]] + if isinstance(match, ElementMatch) + else MessageChain._from_mapping_string(group, elem_mapping) + ) + + if match.dest: + result[match.dest] = MatchResult(group is not None, match, res) return result def get_help( @@ -566,6 +573,7 @@ def __str__(self) -> str: class _TwilightLocalStorage(TypedDict): result: Sparkle + twilight: "Twilight" class Twilight(Generic[T_Sparkle], BaseDispatcher): @@ -620,7 +628,7 @@ def from_command( # ANCHOR: Sparkle: From command Returns: Twilight: 生成的 Twilight. """ - extra_args = extra_args or {} + extra_args = extra_args or [] match: List[RegexMatch] = [] for t_type, token_list in tokenize_command(command): @@ -636,9 +644,7 @@ def from_command( # ANCHOR: Sparkle: From command if match: match[-1].space_policy = SpacePolicy.NOSPACE - if isinstance(extra_args, dict): - return cls(match, extra_args) - return cls(match + extra_args) + return cls(*match, *extra_args) def get_help( self, @@ -703,6 +709,7 @@ class ResultValue(Decorator): pre = True + @staticmethod async def target(i: DecoratorInterface): sparkle: Sparkle = i.local_storage["result"] res = sparkle.res.get(i.name, None) @@ -719,14 +726,14 @@ class Help(Decorator): if TYPE_CHECKING: @overload - def __init__( + def __init__( # type: ignore self, usage: str = "", description: str = "", epilog: str = "", dest: bool = True, sep: str = " -> ", - ) -> str: + ) -> None: ... def __init__(self, *args, **kwargs) -> None: diff --git a/src/graia/ariadne/message/parser/util.py b/src/graia/ariadne/message/parser/util.py index ccf6da51..103bb687 100644 --- a/src/graia/ariadne/message/parser/util.py +++ b/src/graia/ariadne/message/parser/util.py @@ -4,7 +4,7 @@ import inspect import re from contextvars import ContextVar -from typing import Dict, List, Literal, NoReturn, Tuple, Type, Union +from typing import Any, Dict, List, NoReturn, Tuple, Type, Union from graia.ariadne.message.element import Element @@ -63,10 +63,7 @@ class CommandToken(enum.Enum): ANNOTATED = "ANNOTATED" -CommandTokenTuple = Union[ - Tuple[Literal[CommandToken.PARAM], List[Union[int, str]]], - Tuple[Literal[CommandToken.ANNOTATED, CommandToken.CHOICE, CommandToken.PARAM], List[str]], -] +CommandTokenTuple = Tuple[CommandToken, list[Any]] def tokenize_command(string: str) -> List[CommandTokenTuple]: @@ -244,7 +241,7 @@ class ElementType: def __init__(self, pattern: Type[Element_T]): self.regex = re.compile(f"\x02(\\d+)_{pattern.__fields__['type'].default}\x03") - def __call__(self, string: str) -> MessageChain: + def __call__(self, string: str) -> Element: if not self.regex.fullmatch(string): raise ValueError(f"{string} not matching {self.regex.pattern}") return MessageChain._from_mapping_string(string, elem_mapping_ctx.get())[0] diff --git a/src/graia/ariadne/model.py b/src/graia/ariadne/model.py index c660ec38..c7fa832c 100644 --- a/src/graia/ariadne/model.py +++ b/src/graia/ariadne/model.py @@ -8,8 +8,7 @@ from graia.broadcast.entities.listener import Listener from loguru import logger -from pydantic import BaseModel, Field, validator -from pydantic.main import BaseConfig, Extra +from pydantic import BaseConfig, BaseModel, Extra, Field, validator from pydantic.networks import AnyHttpUrl from typing_extensions import Literal from yarl import URL @@ -38,9 +37,7 @@ class DatetimeEncoder(json.JSONEncoder): """可以编码 datetime 的 JSONEncoder""" def default(self, o): - if isinstance(o, datetime): - return int(o.timestamp()) - return super().default(o) + return int(o.timestamp()) if isinstance(o, datetime) else super().default(o) class AriadneBaseModel(BaseModel): @@ -48,23 +45,23 @@ class AriadneBaseModel(BaseModel): Ariadne 一切数据模型的基类. """ - def dict( + def dict( # type: ignore self, *, - include: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, - exclude: Union["AbstractSetIntStr", "MappingIntStrAny"] = None, + include: Union[None, "AbstractSetIntStr", "MappingIntStrAny"] = None, + exclude: Union[None, "AbstractSetIntStr", "MappingIntStrAny"] = None, by_alias: bool = False, - skip_defaults: bool = None, + skip_defaults: bool = False, exclude_unset: bool = False, exclude_defaults: bool = False, exclude_none: bool = False, ) -> "DictStrAny": - _, _ = by_alias, exclude_none + _, *_ = by_alias, exclude_none, skip_defaults return super().dict( - include=include, - exclude=exclude, + include=include, # type: ignore + exclude=exclude, # type: ignore by_alias=True, - skip_defaults=skip_defaults, + skip_defaults=False, exclude_unset=exclude_unset, exclude_defaults=exclude_defaults, exclude_none=True, @@ -77,6 +74,7 @@ class Config(BaseConfig): json_encoders = { datetime: datetime_encoder, } + arbitrary_types_allowed = True @dataclass @@ -245,7 +243,12 @@ def __init__( *, single_mode: bool = False, ) -> None: - super().__init__(host=host, account=account, verify_key=verify_key, single_mode=single_mode) + super().__init__( + host=host, # type: ignore + account=account, # type: ignore + verify_key=verify_key, # type: ignore + single_mode=single_mode, # type: ignore + ) def url_gen(self, route: str) -> str: """生成 route 对应的 API URI @@ -256,6 +259,8 @@ def url_gen(self, route: str) -> str: Returns: str: 对应的 API URI """ + if self.host is None: + raise ValueError("Remote host is unset") return str(URL(self.host) / route) @@ -332,9 +337,10 @@ async def getAvatar(self, cover: Optional[int] = None) -> bytes: from . import get_running cover = (cover or 0) + 1 - return await ( - await get_running().adapter.session.get(f"https://p.qlogo.cn/gh/{self.id}/{self.id}_{cover}/") - ).content.read() + session = get_running().adapter.session + if not session: + raise RuntimeError("No running ClientSession") + return await (await session.get(f"https://p.qlogo.cn/gh/{self.id}/{self.id}_{cover}/")).content.read() class Member(AriadneBaseModel): @@ -427,9 +433,10 @@ async def getAvatar(self, size: Literal[640, 140] = 640) -> bytes: """ from . import get_running - return await ( - await get_running().adapter.session.get(f"https://q.qlogo.cn/g?b=qq&nk={self.id}&s={size}") - ).content.read() + session = get_running().adapter.session + if not session: + raise RuntimeError("No running ClientSession") + return await (await session.get(f"https://q.qlogo.cn/g?b=qq&nk={self.id}&s={size}")).content.read() class Friend(AriadneBaseModel): @@ -471,9 +478,10 @@ async def getAvatar(self, size: Literal[640, 140] = 640) -> bytes: """ from . import get_running - return await ( - await get_running().adapter.session.get(f"https://q.qlogo.cn/g?b=qq&nk={self.id}&s={size}") - ).content.read() + session = get_running().adapter.session + if not session: + raise RuntimeError("No running ClientSession") + return await (await session.get(f"https://q.qlogo.cn/g?b=qq&nk={self.id}&s={size}")).content.read() class Stranger(AriadneBaseModel): @@ -505,9 +513,10 @@ async def getAvatar(self, size: Literal[640, 140] = 640) -> bytes: """ from . import get_running - return await ( - await get_running().adapter.session.get(f"https://q.qlogo.cn/g?b=qq&nk={self.id}&s={size}") - ).content.read() + session = get_running().adapter.session + if not session: + raise RuntimeError("No running ClientSession") + return await (await session.get(f"https://q.qlogo.cn/g?b=qq&nk={self.id}&s={size}")).content.read() class GroupConfig(AriadneBaseModel): @@ -538,7 +547,7 @@ class MemberInfo(AriadneBaseModel): name: str = "" """昵称, 与 nickname不同""" - specialTitle: str = "" + specialTitle: Optional[str] = "" """特殊头衔""" diff --git a/src/graia/ariadne/service.py b/src/graia/ariadne/service.py new file mode 100644 index 00000000..8ff5da3d --- /dev/null +++ b/src/graia/ariadne/service.py @@ -0,0 +1,164 @@ +import asyncio +from typing import TYPE_CHECKING, Dict, List, Optional, Type, TypedDict, overload + +from graia.amnesia.interface import ExportInterface +from graia.amnesia.launch import LaunchComponent, LaunchManager +from graia.amnesia.service import Service +from graia.amnesia.status import Status +from graia.broadcast import Broadcast + +from .config import ( + ConnectionConfig, + ElizabethHttpClientConfig, + ElizabethHttpServerConfig, + ElizabethWebsocketClientConfig, + ElizabethWebsocketServerConfig, +) +from .connection import ( + ConnectionInfo, + ElizabethConnection, + HttpClientConnection, + HttpServerConnection, + WebsocketClientConnection, + WebsocketServerConnection, +) +from .io import AiohttpClient + +if TYPE_CHECKING: + from graia.amnesia.builtins.starlette import StarletteServer +else: + try: + from graia.amnesia.builtins.starlette import StarletteServer + except ImportError: + + class StarletteServer(ExportInterface): + pass + + +class ConnDict(TypedDict, total=False): + action: ElizabethConnection + event: ElizabethConnection + info: ConnectionInfo + + +class ElizabethInterface(ExportInterface): + service: "ElizabethService" + + def __init__(self, service: "ElizabethService") -> None: + self.service = service + + +class ElizabethService(Service): + supported_interface_types = {ElizabethInterface} + + connection_dict: Dict[int, ConnDict] + status: Dict[int, Status] + + def __init__(self, broadcast: Broadcast, conn_configs: List[ConnectionConfig]) -> None: + self.broadcast: Broadcast = broadcast + self.conn_configs: List[ConnectionConfig] = conn_configs + if not conn_configs: + raise ValueError("No accounts configured") + self.connection_dict = {} + + def set_status(self, account: int, available: bool, stage: str) -> None: + status = self.status.setdefault(account, Status(available, stage)) + status.available = available + status.stage = stage + + if TYPE_CHECKING: + + @overload + def get_status(self) -> Dict[int, Status]: + ... + + @overload + def get_status(self, account: int) -> Status: + ... + + def get_status(self, account: Optional[int] = None): + if not account: + return self.status + if account not in self.status: + raise ValueError(f"Account {account} not found") + return self.status[account].frame() + + def get_interface(self, interface_type: Type[ElizabethInterface]) -> ElizabethInterface: + return interface_type(self) + + def retrieve_maintask(self) -> Optional[asyncio.Task]: + for task in asyncio.all_tasks(): + if task.get_name() == "elizabeth.miraijvm-httpapi.service" and not task.done(): + return task + + async def connect(self, connection: ElizabethConnection) -> None: + while self.retrieve_maintask(): + await connection.maintask() + + async def prepare(self, mgr: LaunchManager) -> None: + for conf in self.conn_configs: + self.set_status(conf.account, False, "non-configured") + conn = self.connection_dict.setdefault(conf.account, {}) + if "event" in conn and "action" in conn: + raise ValueError(f"Account {conf.account} already fully configured") + if ( + not isinstance(conf, ElizabethHttpClientConfig) + and "event" in conn + and not isinstance(conn["event"], HttpClientConnection) + ): + raise ValueError( + f"Account {conf.account} already configured event method: {type(conn['event'])}" + ) + conn.setdefault("info", {}) + assert "info" in conn + if isinstance(conf, ElizabethWebsocketClientConfig): # top priority for event + ws_client = mgr.get_interface(AiohttpClient) + ws_client_conn = WebsocketClientConnection(ws_client, conf, conn["info"], self) + conn["event"] = ws_client_conn + if "action" not in conn: + conn["action"] = ws_client_conn + elif isinstance(conf, ElizabethWebsocketServerConfig): # top priority for event + ws_server = mgr.get_interface(StarletteServer) + ws_server_conn = WebsocketServerConnection(ws_server, conf, conn["info"], self) + conn["event"] = ws_server_conn + if "action" not in conn: + conn["action"] = ws_server_conn + elif isinstance(conf, ElizabethHttpClientConfig): # top priority for action + http_client = mgr.get_interface(AiohttpClient) + http_client_conn = HttpClientConnection(http_client, conf, conn["info"], self) + if "action" in conn and isinstance(conn["action"], HttpClientConnection): + raise ValueError( + f"Account {conf.account} already configured action method: {type(conn['action'])}" + ) + conn["event"] = http_client_conn + if "action" not in conn: + conn["action"] = http_client_conn + elif isinstance(conf, ElizabethHttpServerConfig): # top priority for event + http_server = mgr.get_interface(StarletteServer) + http_server_conn = HttpServerConnection(http_server, conf, conn["info"], self) + conn["event"] = http_server_conn + else: + raise ValueError(f"Unsupported connection config type: {type(conf)}") + + async def maintask(self, _) -> None: + conn_tasks: List[asyncio.Task[None]] = [] + for account, conn in self.connection_dict.items(): + if "event" not in conn and "action" not in conn: + raise ValueError(f"Account {account} not configured event and action method") + if "event" not in conn: + raise ValueError(f"Account {account} not configured event method") + if "action" not in conn: + raise ValueError(f"Account {account} not configured action method") + self.set_status(account, False, "not-connected") + conn["action"].connected = conn["event"].connected # Ensure event connection + conn_tasks.append(asyncio.create_task(self.connect(conn["event"]))) + await asyncio.gather(*conn_tasks) + + @property + def launch_component(self) -> LaunchComponent: + return LaunchComponent( + "elizabeth.miraijvm-httpapi.service", + {"http.universal_client"}, + mainline=self.maintask, + prepare=self.prepare, + ) diff --git a/src/graia/ariadne/typing.py b/src/graia/ariadne/typing.py index db74f6d8..04543fcb 100644 --- a/src/graia/ariadne/typing.py +++ b/src/graia/ariadne/typing.py @@ -1,4 +1,6 @@ """Ariadne 的类型标注""" + +import contextlib import typing from typing import ( TYPE_CHECKING, @@ -15,7 +17,7 @@ Union, ) -from typing_extensions import Annotated, ParamSpec, TypeGuard +from typing_extensions import ParamSpec, Protocol, runtime_checkable if TYPE_CHECKING: from .message.chain import MessageChain @@ -24,6 +26,7 @@ P = ParamSpec("P") R = TypeVar("R") T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) T_start = TypeVar("T_start") T_stop = TypeVar("T_stop") @@ -71,7 +74,7 @@ class SendMessageAction(Generic[T, R]): """表示 SendMessage 的 action""" @staticmethod - async def param(item: SendMessageDict, /) -> SendMessageDict: + async def param(item: SendMessageDict) -> SendMessageDict: """传入 SendMessageDict 作为参数, 传出 SendMessageDict 作为结果 Args: @@ -83,7 +86,7 @@ async def param(item: SendMessageDict, /) -> SendMessageDict: return item @staticmethod - async def result(item: "BotMessage", /) -> R: + async def result(item: "BotMessage") -> R: """处理返回结果 Args: @@ -92,10 +95,10 @@ async def result(item: "BotMessage", /) -> R: Returns: R: 要实际由 SendMessage 返回的数据 """ - return item + return item # type: ignore @staticmethod - async def exception(item: SendMessageException, /) -> T: + async def exception(item: SendMessageException) -> Optional[T]: """发生异常时进行处理,可以选择不返回而是直接引发异常 Args: @@ -107,7 +110,19 @@ async def exception(item: SendMessageException, /) -> T: raise item -def generic_issubclass(cls: type, par: Annotated[T, Union[type, Any, Tuple[type, ...]]]) -> TypeGuard[T]: +@runtime_checkable +class SendMessageActionProtocol(Protocol, Generic[T_co]): + async def param(self, item: SendMessageDict) -> SendMessageDict: + ... + + async def result(self, item: "BotMessage") -> T_co: + ... + + async def exception(self, item: SendMessageException) -> Any: + ... + + +def generic_issubclass(cls: type, par: Union[type, Any, Tuple[type, ...]]) -> bool: """检查 cls 是否是 args 中的一个子类, 支持泛型, Any, Union Args: @@ -119,7 +134,7 @@ def generic_issubclass(cls: type, par: Annotated[T, Union[type, Any, Tuple[type, """ if par is Any: return True - try: + with contextlib.suppress(TypeError): if isinstance(par, type): return issubclass(cls, par) if isinstance(par, tuple): @@ -131,8 +146,6 @@ def generic_issubclass(cls: type, par: Annotated[T, Union[type, Any, Tuple[type, return any(generic_issubclass(cls, p) for p in par.__constraints__) if par.__bound__: return generic_issubclass(cls, par.__bound__) - except TypeError: - pass return False @@ -148,7 +161,7 @@ def generic_isinstance(obj: Any, par: Union[type, Any, Tuple[type, ...]]) -> boo """ if par is Any: return True - try: + with contextlib.suppress(TypeError): if isinstance(par, type): return isinstance(obj, par) if isinstance(par, tuple): @@ -160,6 +173,4 @@ def generic_isinstance(obj: Any, par: Union[type, Any, Tuple[type, ...]]) -> boo return any(generic_isinstance(obj, p) for p in par.__constraints__) if par.__bound__: return generic_isinstance(obj, par.__bound__) - except TypeError: - pass return False diff --git a/src/graia/ariadne/util/__init__.py b/src/graia/ariadne/util/__init__.py index 13851e4e..9465c66e 100644 --- a/src/graia/ariadne/util/__init__.py +++ b/src/graia/ariadne/util/__init__.py @@ -3,6 +3,7 @@ # Utility Layout import asyncio +import collections import functools import inspect import logging @@ -18,9 +19,13 @@ AsyncIterator, Callable, Coroutine, + Deque, Dict, Generator, + Generic, + Iterable, List, + Literal, Optional, Tuple, Type, @@ -75,14 +80,14 @@ def emit(self, record): # Find caller from where originated the logged message frame, depth = logging.currentframe(), 2 - while frame.f_code.co_filename == logging.__file__: + while frame and frame.f_code.co_filename == logging.__file__: frame = frame.f_back depth += 1 logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage()) -def inject_loguru_traceback(loop: AbstractEventLoop = None): +def inject_loguru_traceback(loop: Optional[AbstractEventLoop] = None): """使用 loguru 模块替换默认的 traceback.print_exception 与 sys.excepthook""" traceback.print_exception = loguru_excepthook sys.excepthook = loguru_excepthook @@ -105,8 +110,8 @@ def __init__( callable: Callable, namespace: Namespace, listening_events: List[Type[Dispatchable]], - inline_dispatchers: List[T_Dispatcher] = None, - decorators: List[Decorator] = None, + inline_dispatchers: Optional[List[T_Dispatcher]] = None, + decorators: Optional[List[Decorator]] = None, priority: int = 16, ) -> None: events = [] @@ -117,8 +122,8 @@ def __init__( callable, namespace, events, - inline_dispatchers=inline_dispatchers, - decorators=decorators, + inline_dispatchers=inline_dispatchers or [], + decorators=decorators or [], priority=priority, ) @@ -134,6 +139,40 @@ def __init__( pass +class AsyncSignal(Generic[T]): + """模拟 asyncio.Event, 但是支持 sig.wait(Hashable)""" + + def __init__(self, value: T = None) -> None: + self._waiters: Dict[T, Deque[asyncio.Future]] = {} + self._value: T = value + self._loop = asyncio.get_running_loop() + + def __repr__(self) -> str: + waiter_str = f", waiters: {len(self._waiters)}" if self._waiters else "" + return f"" + + def value(self) -> T: + return self._value + + def set(self, value: T) -> None: + self._value = value + + waiter_deque = self._waiters.setdefault(value, collections.deque()) + + for fut in waiter_deque: + if not fut.done(): + fut.set_result(True) + waiter_deque.clear() + + async def wait(self, value: T) -> Literal[True]: + if self._value == value: + return True + + fut = self._loop.create_future() + self._waiters.setdefault(value, collections.deque()).append(fut) + return await fut + + def app_ctx_manager(func: Callable[P, R]) -> Callable[P, R]: """包装声明需要在 Ariadne Context 中执行的函数 @@ -145,13 +184,13 @@ def app_ctx_manager(func: Callable[P, R]) -> Callable[P, R]: """ @functools.wraps(func) - async def wrapper(*args: P.args, **kwargs: P.kwargs): + def wrapper(*args: P.args, **kwargs: P.kwargs): from ..context import enter_context sys.audit("CallAriadneAPI", func.__name__, args, kwargs) - with enter_context(app=args[0]): - return await func(*args, **kwargs) + with enter_context(app=args[0]): # type: ignore + return func(*args, **kwargs) return wrapper @@ -245,8 +284,7 @@ async def yield_with_timeout( if not done: continue for t in done: - res = await t - yield res + yield await t if last_tsk: for tsk in last_tsk: tsk.cancel() @@ -275,11 +313,11 @@ def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: return out_wrapper -def resolve_dispatchers_mixin(dispatchers: List[T_Dispatcher]) -> List[T_Dispatcher]: +def resolve_dispatchers_mixin(dispatchers: Iterable[T_Dispatcher]) -> List[T_Dispatcher]: """解析 dispatcher list 的 mixin Args: - dispatchers (List[T_Dispatcher]): dispatcher 列表 + dispatchers (Iterable[T_Dispatcher]): dispatcher 列表 Returns: List[T_Dispatcher]: 解析后的 dispatcher 列表 @@ -339,7 +377,7 @@ def signal_handler(callback: Callable[[], None], one_time: bool = True) -> None: handler = signal.getsignal(sig) def handler_wrapper(sig_num, frame): - if handler: + if callable(handler): handler(sig_num, frame) callback() if one_time: @@ -351,9 +389,8 @@ def handler_wrapper(sig_num, frame): def get_cls(obj) -> Optional[type]: if cls := typing.get_origin(obj): return cls - else: - if isinstance(obj, type): - return obj + if isinstance(obj, type): + return obj # Import layout diff --git a/src/graia/ariadne/util/async_exec.py b/src/graia/ariadne/util/async_exec.py index c63b937a..dacc8e17 100644 --- a/src/graia/ariadne/util/async_exec.py +++ b/src/graia/ariadne/util/async_exec.py @@ -6,7 +6,7 @@ import multiprocessing from asyncio.events import AbstractEventLoop from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor -from typing import Any, Awaitable, Callable, ClassVar, Dict, Tuple +from typing import Any, Awaitable, Callable, ClassVar, Dict, Optional, Tuple from ..typing import P, R @@ -38,9 +38,9 @@ class ParallelExecutor: def __init__( self, - loop: AbstractEventLoop = None, - max_thread: int = None, - max_process: int = None, + loop: Optional[AbstractEventLoop] = None, + max_thread: Optional[int] = None, + max_process: Optional[int] = None, ): """初始化并行执行器. @@ -58,7 +58,7 @@ def __init__( self.bind_loop(loop or asyncio.get_running_loop()) @classmethod - def get(cls, loop: AbstractEventLoop = None) -> "ParallelExecutor": + def get(cls, loop: Optional[AbstractEventLoop] = None) -> "ParallelExecutor": """获取 ParallelExecutor 实例 Args: @@ -134,7 +134,7 @@ def to_thread(self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Returns: Future[R]: 返回结果. 需要被异步等待. """ - return asyncio.get_running_loop().run_in_executor( + return asyncio.get_running_loop().run_in_executor( # type: ignore self.thread_exec, ParallelExecutor.run_func_static, func, @@ -153,7 +153,7 @@ def to_process(self, func: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Returns: Future[R]: 返回结果. 需要被异步等待. """ - return asyncio.get_running_loop().run_in_executor( + return asyncio.get_running_loop().run_in_executor( # type: ignore self.proc_exec, ParallelExecutor.run_func_static, func, diff --git a/src/graia/ariadne/util/cooldown.py b/src/graia/ariadne/util/cooldown.py index 71ea27e5..49b70256 100644 --- a/src/graia/ariadne/util/cooldown.py +++ b/src/graia/ariadne/util/cooldown.py @@ -6,6 +6,7 @@ from types import TracebackType from typing import ( Any, + AsyncGenerator, Awaitable, Callable, Dict, @@ -67,13 +68,14 @@ async def get(self, target: int, type: Type[T_Time]) -> Tuple[Optional[T_Time], if builtins.type(None) in typing.get_args(type) and delta.total_seconds() <= 0: return None, satisfied if generic_issubclass(datetime, type): - return next_exec_time, satisfied + return next_exec_time, satisfied # type: ignore if generic_issubclass(timedelta, type): - return delta, satisfied + return delta, satisfied # type: ignore if generic_issubclass(float, type): - return delta.total_seconds(), satisfied + return delta.total_seconds(), satisfied # type: ignore if generic_issubclass(int, type): - return int(delta.total_seconds()), satisfied + return int(delta.total_seconds()), satisfied # type: ignore + return None, satisfied async def set(self, target: int) -> None: self.source[target] = datetime.now() + self.interval @@ -85,14 +87,13 @@ async def beforeExecution(self, interface: DispatcherInterface[MessageEvent]): next_exec_time: datetime = self.source.get(sender_id, current_time) delta: timedelta = next_exec_time - current_time satisfied: bool = delta <= timedelta(seconds=0) - if not satisfied: - if self.stop_on_cooldown: - param_dict: Dict[str, Any] = {} - for name, anno, _ in self.override_signature: - param_dict[name] = await interface.lookup_param(name, anno, None) - res = self.override_condition(**param_dict) - if not ((await res) if inspect.isawaitable(res) else res): - raise ExecutionStop + if not satisfied and self.stop_on_cooldown: + param_dict: Dict[str, Any] = {} + for name, anno, _ in self.override_signature: + param_dict[name] = await interface.lookup_param(name, anno, None) + res = self.override_condition(**param_dict) + if not ((await res) if inspect.isawaitable(res) else res): + raise ExecutionStop interface.local_storage["next_exec_time"] = next_exec_time interface.local_storage["delta"] = delta @@ -123,7 +124,9 @@ async def afterDispatch( await self.set(sender_id) @contextlib.asynccontextmanager - async def trigger(self, target: int, type: Type[T_Time] = datetime) -> Tuple[Optional[T_Time], bool]: + async def trigger( + self, target: int, type: Type[T_Time] = datetime + ) -> AsyncGenerator[Tuple[Optional[T_Time], bool], None]: value, satisfied = await self.get(target, type) try: yield value, satisfied diff --git a/src/graia/ariadne/util/send.py b/src/graia/ariadne/util/send.py index 7efcff53..417d3a08 100644 --- a/src/graia/ariadne/util/send.py +++ b/src/graia/ariadne/util/send.py @@ -17,7 +17,7 @@ class Bypass(SendMessageAction): """ @staticmethod - async def exception(item: Exc_T, /) -> Exc_T: + async def exception(item: Exc_T) -> Exc_T: return item @@ -33,8 +33,8 @@ class Ignore(SendMessageAction): """忽略错误的 SendMessage action (发生 Exception 时 返回 None)""" @staticmethod - async def exception(_: Exc_T, /) -> Exc_T: - return None + async def exception(_) -> None: + return # ANCHOR: safe @@ -55,15 +55,15 @@ def __init__(self, ignore: bool = False) -> None: @overload @staticmethod - async def exception(item: Exc_T, /) -> BotMessage: + async def exception(item) -> BotMessage: ... @overload - async def exception(self, item: Exc_T, /) -> BotMessage: + async def exception(self, item) -> BotMessage: ... @staticmethod - async def _handle(item: Exc_T, ignore: bool): + async def _handle(item: SendMessageException, ignore: bool): from ..message.chain import MessageChain from ..message.element import At, AtAll, Forward, MultimediaElement, Plain, Poke @@ -77,14 +77,24 @@ def convert(msg_chain: MessageChain, type) -> None: for type in [AtAll, At, Poke, Forward, MultimediaElement]: convert(chain, type) - val = await ariadne.sendMessage(**item.send_data, action=Ignore) + val = await ariadne.sendMessage(**item.send_data, action=Ignore) # type: ignore if val is not None: return val if not ignore: raise item - async def exception(s: Union["Safe", Exc_T], i: Optional[Exc_T] = None): - if isinstance(s, Safe): + @overload + @staticmethod + async def exception(s, i): + ... + + @overload + async def exception(s, i): + ... + + async def exception(s: Union["Safe", Exc_T], i: Optional[Exc_T] = None): # type: ignore + if not isinstance(s, Safe): + return await Safe._handle(s, True) + if i: return await Safe._handle(i, s.ignore) - return await Safe._handle(s, True)