diff --git a/poetry.lock b/poetry.lock index 8d7cb30..48ee5b9 100644 --- a/poetry.lock +++ b/poetry.lock @@ -440,10 +440,6 @@ files = [ {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a37b8f0391212d29b3a91a799c8e4a2855e0576911cdfb2515487e30e322253d"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e84799f09591700a4154154cab9787452925578841a94321d5ee8fb9a9a328f0"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f66b5337fa213f1da0d9000bc8dc0cb5b896b726eefd9c6046f699b169c41b9e"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5dab0844f2cf82be357a0eb11a9087f70c5430b2c241493fc122bb6f2bb0917c"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e4fe605b917c70283db7dfe5ada75e04561479075761a0b3866c081d035b01c1"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1e9a65b5736232e7a7f91ff3d02277f11d339bf34099a56cdab6a8b3410a02b2"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:58d4b711689366d4a03ac7957ab8c28890415e267f9b6589969e74b6e42225ec"}, {file = "Brotli-1.1.0-cp310-cp310-win32.whl", hash = "sha256:be36e3d172dc816333f33520154d708a2657ea63762ec16b62ece02ab5e4daf2"}, {file = "Brotli-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0c6244521dda65ea562d5a69b9a26120769b7a9fb3db2fe9545935ed6735b128"}, {file = "Brotli-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a3daabb76a78f829cafc365531c972016e4aa8d5b4bf60660ad8ecee19df7ccc"}, @@ -456,14 +452,8 @@ files = [ {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:19c116e796420b0cee3da1ccec3b764ed2952ccfcc298b55a10e5610ad7885f9"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:510b5b1bfbe20e1a7b3baf5fed9e9451873559a976c1a78eebaa3b86c57b4265"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a1fd8a29719ccce974d523580987b7f8229aeace506952fa9ce1d53a033873c8"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c247dd99d39e0338a604f8c2b3bc7061d5c2e9e2ac7ba9cc1be5a69cb6cd832f"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1b2c248cd517c222d89e74669a4adfa5577e06ab68771a529060cf5a156e9757"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:2a24c50840d89ded6c9a8fdc7b6ed3692ed4e86f1c4a4a938e1e92def92933e0"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f31859074d57b4639318523d6ffdca586ace54271a73ad23ad021acd807eb14b"}, {file = "Brotli-1.1.0-cp311-cp311-win32.whl", hash = "sha256:39da8adedf6942d76dc3e46653e52df937a3c4d6d18fdc94a7c29d263b1f5b50"}, {file = "Brotli-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:aac0411d20e345dc0920bdec5548e438e999ff68d77564d5e9463a7ca9d3e7b1"}, - {file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:32d95b80260d79926f5fab3c41701dbb818fde1c9da590e77e571eefd14abe28"}, - {file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b760c65308ff1e462f65d69c12e4ae085cff3b332d894637f6273a12a482d09f"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2"}, {file = "Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451"}, @@ -474,24 +464,8 @@ files = [ {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:87a3044c3a35055527ac75e419dfa9f4f3667a1e887ee80360589eb8c90aabb9"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c5529b34c1c9d937168297f2c1fde7ebe9ebdd5e121297ff9c043bdb2ae3d6fb"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ca63e1890ede90b2e4454f9a65135a4d387a4585ff8282bb72964fab893f2111"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e79e6520141d792237c70bcd7a3b122d00f2613769ae0cb61c52e89fd3443839"}, {file = "Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0"}, {file = "Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951"}, - {file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8bf32b98b75c13ec7cf774164172683d6e7891088f6316e54425fde1efc276d5"}, - {file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7bc37c4d6b87fb1017ea28c9508b36bbcb0c3d18b4260fcdf08b200c74a6aee8"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c0ef38c7a7014ffac184db9e04debe495d317cc9c6fb10071f7fefd93100a4f"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91d7cc2a76b5567591d12c01f019dd7afce6ba8cba6571187e21e2fc418ae648"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a93dde851926f4f2678e704fadeb39e16c35d8baebd5252c9fd94ce8ce68c4a0"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0db75f47be8b8abc8d9e31bc7aad0547ca26f24a54e6fd10231d623f183d089"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6967ced6730aed543b8673008b5a391c3b1076d834ca438bbd70635c73775368"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7eedaa5d036d9336c95915035fb57422054014ebdeb6f3b42eac809928e40d0c"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d487f5432bf35b60ed625d7e1b448e2dc855422e87469e3f450aa5552b0eb284"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:832436e59afb93e1836081a20f324cb185836c617659b07b129141a8426973c7"}, - {file = "Brotli-1.1.0-cp313-cp313-win32.whl", hash = "sha256:43395e90523f9c23a3d5bdf004733246fba087f2948f87ab28015f12359ca6a0"}, - {file = "Brotli-1.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:9011560a466d2eb3f5a6e4929cf4a09be405c64154e12df0dd72713f6500e32b"}, {file = "Brotli-1.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a090ca607cbb6a34b0391776f0cb48062081f5f60ddcce5d11838e67a01928d1"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de9d02f5bda03d27ede52e8cfe7b865b066fa49258cbab568720aa5be80a47d"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2333e30a5e00fe0fe55903c8832e08ee9c3b1382aacf4db26664a16528d51b4b"}, @@ -501,10 +475,6 @@ files = [ {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:fd5f17ff8f14003595ab414e45fce13d073e0762394f957182e69035c9f3d7c2"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:069a121ac97412d1fe506da790b3e69f52254b9df4eb665cd42460c837193354"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e93dfc1a1165e385cc8239fab7c036fb2cd8093728cbd85097b284d7b99249a2"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:aea440a510e14e818e67bfc4027880e2fb500c2ccb20ab21c7a7c8b5b4703d75"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_i686.whl", hash = "sha256:6974f52a02321b36847cd19d1b8e381bf39939c21efd6ee2fc13a28b0d99348c"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_ppc64le.whl", hash = "sha256:a7e53012d2853a07a4a79c00643832161a910674a893d296c9f1259859a289d2"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:d7702622a8b40c49bffb46e1e3ba2e81268d5c04a34f460978c6b5517a34dd52"}, {file = "Brotli-1.1.0-cp36-cp36m-win32.whl", hash = "sha256:a599669fd7c47233438a56936988a2478685e74854088ef5293802123b5b2460"}, {file = "Brotli-1.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:d143fd47fad1db3d7c27a1b1d66162e855b5d50a89666af46e1679c496e8e579"}, {file = "Brotli-1.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:11d00ed0a83fa22d29bc6b64ef636c4552ebafcef57154b4ddd132f5638fbd1c"}, @@ -516,10 +486,6 @@ files = [ {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:919e32f147ae93a09fe064d77d5ebf4e35502a8df75c29fb05788528e330fe74"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:23032ae55523cc7bccb4f6a0bf368cd25ad9bcdcc1990b64a647e7bbcce9cb5b"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:224e57f6eac61cc449f498cc5f0e1725ba2071a3d4f48d5d9dffba42db196438"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:cb1dac1770878ade83f2ccdf7d25e494f05c9165f5246b46a621cc849341dc01"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:3ee8a80d67a4334482d9712b8e83ca6b1d9bc7e351931252ebef5d8f7335a547"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:5e55da2c8724191e5b557f8e18943b1b4839b8efc3ef60d65985bcf6f587dd38"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:d342778ef319e1026af243ed0a07c97acf3bad33b9f29e7ae6a1f68fd083e90c"}, {file = "Brotli-1.1.0-cp37-cp37m-win32.whl", hash = "sha256:587ca6d3cef6e4e868102672d3bd9dc9698c309ba56d41c2b9c85bbb903cdb95"}, {file = "Brotli-1.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2954c1c23f81c2eaf0b0717d9380bd348578a94161a65b3a2afc62c86467dd68"}, {file = "Brotli-1.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:efa8b278894b14d6da122a72fefcebc28445f2d3f880ac59d46c90f4c13be9a3"}, @@ -532,10 +498,6 @@ files = [ {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ab4fbee0b2d9098c74f3057b2bc055a8bd92ccf02f65944a241b4349229185a"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:141bd4d93984070e097521ed07e2575b46f817d08f9fa42b16b9b5f27b5ac088"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fce1473f3ccc4187f75b4690cfc922628aed4d3dd013d047f95a9b3919a86596"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d2b35ca2c7f81d173d2fadc2f4f31e88cc5f7a39ae5b6db5513cf3383b0e0ec7"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:af6fa6817889314555aede9a919612b23739395ce767fe7fcbea9a80bf140fe5"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:2feb1d960f760a575dbc5ab3b1c00504b24caaf6986e2dc2b01c09c87866a943"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4410f84b33374409552ac9b6903507cdb31cd30d2501fc5ca13d18f73548444a"}, {file = "Brotli-1.1.0-cp38-cp38-win32.whl", hash = "sha256:db85ecf4e609a48f4b29055f1e144231b90edc90af7481aa731ba2d059226b1b"}, {file = "Brotli-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3d7954194c36e304e1523f55d7042c59dc53ec20dd4e9ea9d151f1b62b4415c0"}, {file = "Brotli-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5fb2ce4b8045c78ebbc7b8f3c15062e435d47e7393cc57c25115cfd49883747a"}, @@ -548,10 +510,6 @@ files = [ {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:949f3b7c29912693cee0afcf09acd6ebc04c57af949d9bf77d6101ebb61e388c"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:89f4988c7203739d48c6f806f1e87a1d96e0806d44f0fba61dba81392c9e474d"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:de6551e370ef19f8de1807d0a9aa2cdfdce2e85ce88b122fe9f6b2b076837e59"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0737ddb3068957cf1b054899b0883830bb1fec522ec76b1098f9b6e0f02d9419"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4f3607b129417e111e30637af1b56f24f7a49e64763253bbc275c75fa887d4b2"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:6c6e0c425f22c1c719c42670d561ad682f7bfeeef918edea971a79ac5252437f"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:494994f807ba0b92092a163a0a283961369a65f6cbe01e8891132b7a320e61eb"}, {file = "Brotli-1.1.0-cp39-cp39-win32.whl", hash = "sha256:f0d8a7a6b5983c2496e364b969f0e526647a06b075d034f3297dc66f3b360c64"}, {file = "Brotli-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdad5b9014d83ca68c25d2e9444e28e967ef16e80f6b436918c700c117a85467"}, {file = "Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724"}, @@ -1050,19 +1008,19 @@ test = ["pytest (>=6)"] [[package]] name = "fastapi" -version = "0.115.8" +version = "0.115.12" description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" optional = false python-versions = ">=3.8" groups = ["main"] files = [ - {file = "fastapi-0.115.8-py3-none-any.whl", hash = "sha256:753a96dd7e036b34eeef8babdfcfe3f28ff79648f86551eb36bfc1b0bf4a8cbf"}, - {file = "fastapi-0.115.8.tar.gz", hash = "sha256:0ce9111231720190473e222cdf0f07f7206ad7e53ea02beb1d2dc36e2f0741e9"}, + {file = "fastapi-0.115.12-py3-none-any.whl", hash = "sha256:e94613d6c05e27be7ffebdd6ea5f388112e5e430c8f7d6494a9d1d88d43e814d"}, + {file = "fastapi-0.115.12.tar.gz", hash = "sha256:1e2c2a2646905f9e83d32f04a3f86aff4a286669c6c950ca95b5fd68c2602681"}, ] [package.dependencies] pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.1.0 || >2.1.0,<3.0.0" -starlette = ">=0.40.0,<0.46.0" +starlette = ">=0.40.0,<0.47.0" typing-extensions = ">=4.8.0" [package.extras] @@ -3469,7 +3427,6 @@ files = [ {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:bb89f0a835bcfc1d42ccd5f41f04870c1b936d8507c6df12b7737febc40f0909"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:f0c2d907a1e102526dd2986df638343388b94c33860ff3bbe1384130828714b1"}, {file = "psycopg2_binary-2.9.10-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:f8157bed2f51db683f31306aa497311b560f2265998122abe1dce6428bd86567"}, - {file = "psycopg2_binary-2.9.10-cp313-cp313-win_amd64.whl", hash = "sha256:27422aa5f11fbcd9b18da48373eb67081243662f9b46e6fd07c3eb46e4535142"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-macosx_12_0_x86_64.whl", hash = "sha256:eb09aa7f9cecb45027683bb55aebaaf45a0df8bf6de68801a6afdc7947bb09d4"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b73d6d7f0ccdad7bc43e6d34273f70d587ef62f824d7261c4ae9b8b1b6af90e8"}, {file = "psycopg2_binary-2.9.10-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ce5ab4bf46a211a8e924d307c1b1fcda82368586a19d0a24f8ae166f5c784864"}, @@ -4895,14 +4852,14 @@ sqlcipher = ["sqlcipher3_binary"] [[package]] name = "starlette" -version = "0.45.3" +version = "0.46.2" description = "The little ASGI library that shines." optional = false python-versions = ">=3.9" groups = ["main"] files = [ - {file = "starlette-0.45.3-py3-none-any.whl", hash = "sha256:dfb6d332576f136ec740296c7e8bb8c8a7125044e7c6da30744718880cdd059d"}, - {file = "starlette-0.45.3.tar.gz", hash = "sha256:2cbcba2a75806f8a41c722141486f37c28e30a0921c5f6fe4346cb0dcee1302f"}, + {file = "starlette-0.46.2-py3-none-any.whl", hash = "sha256:595633ce89f8ffa71a015caed34a5b2dc1c0cdb3f0f1fbd1e69339cf2abeec35"}, + {file = "starlette-0.46.2.tar.gz", hash = "sha256:7f7361f34eed179294600af672f565727419830b54b7b084efe44bb82d2fccd5"}, ] [package.dependencies] diff --git a/src/app/api/api_v1/endpoints/search.py b/src/app/api/api_v1/endpoints/search.py index dd5a085..fe13d74 100644 --- a/src/app/api/api_v1/endpoints/search.py +++ b/src/app/api/api_v1/endpoints/search.py @@ -1,18 +1,22 @@ -from typing import List, Optional, Union - from fastapi import APIRouter, Depends, Response +from qdrant_client.models import ScoredPoint from sqlalchemy.sql import select from src.app.models.db_models import CorpusEmbedding -from src.app.models.documents import Collection_schema, Document -from src.app.models.search import EnhancedSearchQuery, SearchFilter, SearchQuery -from src.app.services.exceptions import EmptyQueryError, bad_request -from src.app.services.search import SearchService -from src.app.services.search_helpers import ( - search_all_base, - search_items_base, - search_multi_inputs, +from src.app.models.documents import Collection_schema +from src.app.models.search import ( + EnhancedSearchQuery, + SDGFilter, + SearchMethods, + SearchQuery, +) +from src.app.services.exceptions import ( + CollectionNotFoundError, + EmptyQueryError, + bad_request, ) +from src.app.services.search import SearchService +from src.app.services.search_helpers import search_multi_inputs from src.app.services.sql_db import session_maker from src.app.utils.logger import logger as logger_utils @@ -25,7 +29,7 @@ def get_params( body: SearchQuery, nb_results: int = 30, - subject: Optional[str] = None, + subject: str | None = None, influence_factor: float = 2, relevance_factor: float = 1, concatenate: bool = True, @@ -52,7 +56,7 @@ def get_params( "/collections", summary="get all collections", description="Get all collections available in the database", - response_model=List[Collection_schema], + response_model=list[Collection_schema], ) async def get_corpus(): statement = select( @@ -73,59 +77,72 @@ async def get_corpus(): @router.post( - "/collections/{collection_query}", - summary="search items in a specific collection", - description="Search items in a specific collection", - response_model=Union[List[Document], None], + "/collections/{collection}", + summary="search documents in a specific collection", + description="Search documents in a specific collection", + response_model=list[ScoredPoint] | str | None, ) -async def search_items( - query: Optional[str] = None, - collection_query: str = "conversation", +async def search_doc_by_collection( + response: Response, + query: str, + collection: str = "conversation", nb_results: int = 10, - sdg_filter: Optional[SearchFilter] = None, + sdg_filter: SDGFilter | None = None, ): if not query: e = EmptyQueryError() return bad_request(message=e.message, msg_code=e.msg_code) - return await search_items_base( + qp = EnhancedSearchQuery( query=query, - collection_query=collection_query, nb_results=nb_results, - sdg_filter=sdg_filter, - search_func=sp.search_group_by_document, + corpora=(collection,), + sdg_filter=sdg_filter.sdg_filter if sdg_filter else None, ) + try: + res = await sp.search_handler(qp=qp, method=SearchMethods.BY_DOCUMENT) + + if not res: + response.status_code = 206 + return [] + + return res + except CollectionNotFoundError as e: + response.status_code = 404 + return e.message + @router.post( "/by_slices", summary="search all slices", description="Search slices in all collections or in collections specified", - response_model=Union[List[Document], None], + response_model=list[ScoredPoint] | None | str, ) async def search_all_slices_by_lang( response: Response, qp: EnhancedSearchQuery = Depends(get_params), ): - res = await search_all_base( - response=response, - qp=qp, - search_func=sp.search, - ) + try: - if not res: - logger.error("No results found") - response.status_code = 404 - return None + res = await sp.search_handler(qp=qp, method=SearchMethods.BY_SLICES) - return res + if not res: + logger.debug("No results found") + response.status_code = 404 + return [] + + return res + except CollectionNotFoundError as e: + response.status_code = 404 + return e.message @router.post( "/multiple_by_slices", summary="search all slices", description="Search slices in all collections or in collections specified", - response_model=Union[List[Document], None], + response_model=list[ScoredPoint] | None, ) async def multi_search_all_slices_by_lang( response: Response, @@ -135,17 +152,14 @@ async def multi_search_all_slices_by_lang( qp.query = [qp.query] results = await search_multi_inputs( - response=response, - nb_results=qp.nb_results, - sdg_filter=qp.sdg_filter, - collections=qp.corpora, - inputs=qp.query, - callback_function=sp.search, + qp=qp, + callback_function=sp.search_handler, ) if not results: logger.error("No results found") + # todo switch to 204 no content response.status_code = 404 - return None + return [] return results @@ -153,22 +167,22 @@ async def multi_search_all_slices_by_lang( @router.post( "/by_document", summary="search all documents", - description="Search documents in all collections or in collections specified", - response_model=Union[List[Document], None], + description="Search by documents, returns only one result by document id", + response_model=list[ScoredPoint] | None | str, ) async def search_all( response: Response, qp: EnhancedSearchQuery = Depends(get_params), ): - res = await search_all_base( - response=response, - qp=qp, - search_func=sp.search_group_by_document, - ) - - if not res: - logger.error("No results found") + try: + res = await sp.search_handler(qp=qp, method=SearchMethods.BY_DOCUMENT) + + if not res: + logger.error("No results found") + response.status_code = 404 + return [] + except CollectionNotFoundError as e: response.status_code = 404 - return None + return e.message return res diff --git a/src/app/api/api_v1/endpoints/tutor.py b/src/app/api/api_v1/endpoints/tutor.py index a2dd8a5..7779e63 100644 --- a/src/app/api/api_v1/endpoints/tutor.py +++ b/src/app/api/api_v1/endpoints/tutor.py @@ -3,6 +3,7 @@ from fastapi import APIRouter, File, HTTPException, Response, UploadFile from src.app.api.dependencies import get_settings +from src.app.models.search import EnhancedSearchQuery from src.app.services.abst_chat import AbstractChat, ChatFactory from src.app.services.exceptions import NoResultsError from src.app.services.search import SearchService @@ -90,13 +91,17 @@ async def tutor_search( inputs = [doc.summary for doc in themes_extracted.extracts] # type: ignore try: - search_results = await search_multi_inputs( - response=response, - inputs=inputs, + qp = EnhancedSearchQuery( + query=inputs, nb_results=5, sdg_filter=None, - collections=None, - callback_function=sp.search, + corpora=None, + ) + + search_results = await search_multi_inputs( + response=response, + qp=qp, + callback_function=sp.search_handler, ) except NoResultsError as e: response.status_code = 404 @@ -120,8 +125,6 @@ async def tutor_search( documents=search_results, ) - # TODO: handle duplicates - return resp diff --git a/src/app/models/collections.py b/src/app/models/collections.py index 2ec66f2..bd3ea4c 100644 --- a/src/app/models/collections.py +++ b/src/app/models/collections.py @@ -11,7 +11,6 @@ class Collection_schema(BaseModel): class Collection(NamedTuple): - name: str lang: str model: str - alias: str + name: str diff --git a/src/app/models/search.py b/src/app/models/search.py index 2a1e546..2bef81d 100644 --- a/src/app/models/search.py +++ b/src/app/models/search.py @@ -1,7 +1,14 @@ +from enum import StrEnum + from pydantic import BaseModel, Field +from qdrant_client.models import FieldCondition, Filter, MatchAny + +from src.app.utils.logger import logger as logger_utils +logger = logger_utils(__name__) -class SearchFilter(BaseModel): + +class SDGFilter(BaseModel): sdg_filter: list[int] | None = Field( None, max_length=17, @@ -11,12 +18,12 @@ class SearchFilter(BaseModel): ) -class SearchQuery(SearchFilter): +class SearchQuery(SDGFilter): query: str | list[str] | None corpora: list[str] | None = None -class EnhancedSearchQuery(SearchFilter): +class EnhancedSearchQuery(SDGFilter): query: str | list[str] corpora: tuple[str, ...] | None = None nb_results: int = 30 @@ -24,3 +31,38 @@ class EnhancedSearchQuery(SearchFilter): influence_factor: float = 2 relevance_factor: float = 1 concatenate: bool = True + + +class SearchFilters(BaseModel): + slice_sdg: list[int] | None + document_corpus: tuple[str, ...] | list[str] | None + + def build_filters(self) -> Filter | None: + if not self.slice_sdg and not self.document_corpus: + return None + + filters = { + "slice_sdg": self.slice_sdg, + "document_corpus": self.document_corpus, + } + + qdrant_filter = [] + for key, values in filters.items(): + if not values: + continue + + qdrant_filter.append( + FieldCondition( + key=key, + match=MatchAny(any=values), + ) + ) + + logger.debug("build_filters=%s", qdrant_filter) + + return Filter(must=qdrant_filter) + + +class SearchMethods(StrEnum): + BY_SLICES = "by_slices" + BY_DOCUMENT = "by_document" diff --git a/src/app/services/exceptions.py b/src/app/services/exceptions.py index 459db34..ce10ffb 100644 --- a/src/app/services/exceptions.py +++ b/src/app/services/exceptions.py @@ -1,5 +1,3 @@ -from typing import Optional - from fastapi import HTTPException, Response, status from src.app.utils.logger import logger as logger_utils @@ -152,7 +150,7 @@ def __init__( super().__init__(self.message, self.msg_code) -def handle_error(response: Optional[Response], exc: Exception) -> None: +def handle_error(exc: Exception, response: Response | None = None) -> None: if isinstance(exc, PartialResponseResultError): if response: response.status_code = status.HTTP_206_PARTIAL_CONTENT diff --git a/src/app/services/search.py b/src/app/services/search.py index c32e9e6..2071ab5 100644 --- a/src/app/services/search.py +++ b/src/app/services/search.py @@ -1,8 +1,7 @@ -import asyncio import json import time from functools import cache -from typing import Any, Callable, Dict, List, Optional, Tuple, cast +from typing import Tuple, cast import numpy as np from qdrant_client import AsyncQdrantClient @@ -14,11 +13,9 @@ from src.app.api.dependencies import get_settings from src.app.models.collections import Collection -from src.app.services.exceptions import ( - CollectionNotFoundError, - ModelNotFoundError, - PartialResponseResultError, -) +from src.app.models.search import EnhancedSearchQuery, SearchFilters, SearchMethods +from src.app.services.exceptions import CollectionNotFoundError, ModelNotFoundError +from src.app.services.helpers import detect_language_from_entry from src.app.utils.decorators import log_time_and_error from src.app.utils.logger import logger as logger_utils @@ -65,98 +62,67 @@ def __init__(self): "document_details.readability", "document_details.source", ] + self.col_prefix = "collection_welearn_" @log_time_and_error async def get_collections(self) -> Tuple[str, ...]: - aliases = await self.client.get_aliases() - self.collections = tuple(alias.alias_name for alias in aliases.aliases) + collections = await self.client.get_collections() + self.collections = tuple( + collection.name for collection in collections.collections + ) logger.info("collections=%s", self.collections) return self.collections @log_time_and_error - async def get_collections_aliases_by_language( - self, lang: str, collections: Optional[Tuple[str, ...]] = None - ) -> List[str]: - col_to_iter = self.collections or await self.get_collections() - if collections is None: - return [ - collection for collection in col_to_iter if f"_{lang}_" in collection - ] - - cols = [ - collection - for cur_col in collections - for collection in col_to_iter - if collection.startswith(cur_col) and f"_{lang}_" in collection - ] - if not cols: - raise CollectionNotFoundError( - message=f"No collection found for this language {lang} and collections {collections}" - ) - return cols - - @log_time_and_error - async def get_collection_alias(self, collection_name: str, lang: str) -> str: - col_to_iter = self.collections or await self.get_collections() - if len(collection_name.split("_")) == 1: - collection_name = f"{collection_name}_{lang}" + async def get_collection_by_language(self, lang: str) -> Collection: + collections = self.collections or await self.get_collections() collection = next( ( - c - for c in col_to_iter - if c.startswith(collection_name) or c == collection_name + collection + for collection in collections + if collection.startswith(f"{self.col_prefix}{lang}") ), None, ) + if not collection: raise CollectionNotFoundError( - message=f"Collection {collection_name} not found" + message=f"No collection found for this language {lang}" ) - logger.debug( - "method=get_collection_alias collection_name=%s collection_alias=%s", - collection_name, - collection, - ) - return collection + return self._get_info_from_collection_name(collection) - def _get_info_from_collection_alias(self, collection_alias: str) -> Collection: - name, lang, model = collection_alias.split("_") - corpus = Collection(name=name, lang=lang, model=model, alias=collection_alias) - logger.debug( - "info_from_collection collection=%s name=%s lang=%s model=%s", - collection_alias, - name, - lang, - model, - ) - return corpus + def _get_info_from_collection_name(self, collection_name: str) -> Collection: + lang, model = collection_name.replace(self.col_prefix, "").split("_") + return Collection(lang=lang, model=model, name=collection_name) - def get_collection_dict_with_embed( + def get_query_embed( self, - collection_alias: str, + model: str, query: str, - subject_vector: Optional[List[float]] = None, + subject_vector: list[float] | None = None, subject_influence_factor: float = 1.0, - ) -> Dict[str, Any]: - col_info = self._get_info_from_collection_alias(collection_alias)._asdict() - col_info["embed"] = self.embed_query(query, col_info["model"]) + ) -> np.ndarray: + embedding = self._embed_query(query, model) + if subject_vector: + embedding = embedding + [ + subject_influence_factor * vec for vec in subject_vector + ] + logger.debug( - "Adding subject vector collection=%s influence_factor=%s", - collection_alias, + "Adding subject vector influence_factor=%s", subject_influence_factor, ) - col_info["embed"] = col_info["embed"] + [ - subject_influence_factor * vec for vec in subject_vector - ] - return col_info + + return embedding @cache - def get_model(self, curr_model: str) -> SentenceTransformer: + def _get_model(self, curr_model: str) -> SentenceTransformer: try: time_start = time.time() + # TODO: path should be an env variable model = SentenceTransformer(f"../models/embedding/{curr_model}/") time_end = time.time() logger.info( @@ -170,10 +136,10 @@ def get_model(self, curr_model: str) -> SentenceTransformer: return model @cache - def embed_query(self, search_input: str, curr_model: str) -> np.ndarray: + def _embed_query(self, search_input: str, curr_model: str) -> np.ndarray: logger.debug("Creating embeddings model=%s", curr_model) time_start = time.time() - model = self.get_model(curr_model) + model = self._get_model(curr_model) try: embeddings = model.encode(sentences=search_input) except Exception as ex: @@ -188,32 +154,61 @@ def embed_query(self, search_input: str, curr_model: str) -> np.ndarray: ) return cast(np.ndarray, embeddings) - def build_filters( - self, filters: Optional[List[int]] = None - ) -> Optional[qdrant_models.Filter]: - if filters is None: - return None + async def search_handler( + self, qp: EnhancedSearchQuery, method: SearchMethods = SearchMethods.BY_SLICES + ) -> list[http_models.ScoredPoint]: + assert isinstance(qp.query, str) + + lang = detect_language_from_entry(qp.query) + collection = await self.get_collection_by_language(lang) + subject_vector = get_subject_vector(qp.subject) + embedding = self.get_query_embed( + model=collection.model, + subject_vector=subject_vector, + query=qp.query, + subject_influence_factor=qp.influence_factor, + ) - qdrant_filter: List[qdrant_models.Condition] = [ - qdrant_models.FieldCondition( - key="document_sdg", match=qdrant_models.MatchValue(value=filt) + filters = SearchFilters( + slice_sdg=qp.sdg_filter, document_corpus=qp.corpora + ).build_filters() + data = [] + if method == "by_slices": + data = await self.search( + collection_info=collection.name, + embedding=embedding, + filters=filters, + nb_results=qp.nb_results, ) - for filt in filters - ] - return qdrant_models.Filter(should=qdrant_filter) + elif method == "by_document": + data = await self.search_group_by_document( + collection_info=collection.name, + embedding=embedding, + filters=filters, + nb_results=qp.nb_results, + ) + else: + raise ValueError(f"Unknown search method: {method}") + + sorted_data = sort_slices_using_mmr(data, theta=qp.relevance_factor) + + if qp.concatenate: + sorted_data = concatenate_same_doc_id_slices(sorted_data) + + return sorted_data @log_time_and_error async def search_group_by_document( self, collection_info: str, embedding: np.ndarray, - filters: Optional[List[int]] = None, + filters: qdrant_models.Filter | None = None, nb_results: int = 100, - ) -> Optional[List[http_models.ScoredPoint]]: + ) -> list[http_models.ScoredPoint]: logger.debug("search_group_by_document collection=%s", collection_info) try: resp = await self.client.search_groups( - query_filter=self.build_filters(filters), + query_filter=filters, collection_name=collection_info, query_vector=embedding, limit=nb_results, @@ -234,12 +229,12 @@ async def search( self, collection_info: str, embedding: np.ndarray, - filters: Optional[List[int]] = None, + filters: qdrant_models.Filter | None = None, nb_results: int = 100, - ) -> List[http_models.ScoredPoint]: + ) -> list[http_models.ScoredPoint]: try: resp = await self.client.search( - query_filter=self.build_filters(filters), + query_filter=filters, collection_name=collection_info, query_vector=embedding, limit=nb_results, @@ -255,47 +250,13 @@ async def search( return resp -def search_items_method( - callback_function: Callable, - nb_results: int, - embedding: np.ndarray, - collection: str, - filters: Optional[List[int]] = None, -) -> Optional[List[http_models.ScoredPoint]]: - return callback_function( - collection_info=collection, - embedding=embedding, - filters=filters, - nb_results=nb_results, - ) - - -@log_time_and_error -async def parallel_search( - callback_function: Callable, - nb_results: int, - collections: List[Dict[str, Any]], - sdg_filter: Optional[List[int]] = None, -) -> List[http_models.ScoredPoint]: - tasks = [ - callback_function( - collection_info=col["alias"], - embedding=col["embed"], - nb_results=nb_results, - filters=sdg_filter, - ) - for col in collections - ] - data = await asyncio.gather(*tasks) - if len(data) < len(collections): - raise PartialResponseResultError() - return [doc for source in data for doc in source] - - def sort_slices_using_mmr( - qdrant_results: List[http_models.ScoredPoint], + qdrant_results: list[http_models.ScoredPoint], theta: float = 1.0, -) -> List[http_models.ScoredPoint]: +) -> list[http_models.ScoredPoint]: + if not qdrant_results: + return [] + logger.debug("sort_slices_using_mmr=start") reward = [r.score for r in qdrant_results] sim = cosine_similarity(np.array([r.vector for r in qdrant_results])) @@ -315,8 +276,8 @@ def sort_slices_using_mmr( def concatenate_same_doc_id_slices( - qdrant_results: List[http_models.ScoredPoint], -) -> List[http_models.ScoredPoint]: + qdrant_results: list[http_models.ScoredPoint], +) -> list[http_models.ScoredPoint]: """ Concatenate slices on the same document ID and remove duplicates. @@ -356,7 +317,7 @@ def concatenate_same_doc_id_slices( return new_results -def get_subject_vector(subject: str | None) -> List[float] | None: +def get_subject_vector(subject: str | None) -> list[float] | None: if not subject: return None with open("src/app/services/subject_vectors.json") as f: diff --git a/src/app/services/search_helpers.py b/src/app/services/search_helpers.py index 51759ed..cc5563b 100644 --- a/src/app/services/search_helpers.py +++ b/src/app/services/search_helpers.py @@ -1,166 +1,41 @@ import asyncio -from typing import Awaitable, Callable, List, Optional +from typing import Awaitable, Callable -from fastapi import Response from qdrant_client.http.models import ScoredPoint -from src.app.models.documents import Document -from src.app.models.search import EnhancedSearchQuery, SearchFilter -from src.app.services.exceptions import ( - CollectionNotFoundError, - NoResultsError, - handle_error, -) -from src.app.services.helpers import detect_language_from_entry -from src.app.services.search import ( - SearchService, - concatenate_same_doc_id_slices, - get_subject_vector, - parallel_search, - sort_slices_using_mmr, -) +from src.app.models.search import EnhancedSearchQuery, SearchMethods +from src.app.services.exceptions import handle_error from src.app.utils.logger import logger as logger_utils logger = logger_utils(__name__) -sp = SearchService() -async def search_items_base( - query: str, - collection_query: str, - nb_results: int, - sdg_filter: Optional[SearchFilter], - search_func: Callable[..., Awaitable[List[Document]]], -) -> Optional[List[Document]]: - logger.info("search_query=%s searched_collection=%s", query, collection_query) - - try: - lang = detect_language_from_entry(query) - collection_alias = await sp.get_collection_alias( - collection_name=collection_query, lang=lang - ) - col = sp._get_info_from_collection_alias(collection_alias=collection_alias) - model_embedding = sp.embed_query(search_input=query, curr_model=col.model) - - data = await search_func( - collection_info=col.alias, - embedding=model_embedding, - filters=sdg_filter.sdg_filter if sdg_filter else None, - nb_results=nb_results, - ) - - if not data: - raise NoResultsError() - - return data - except Exception as e: - return handle_error(response=None, exc=e) - - -async def search_all_base( - response: Response, +async def search_multi_inputs( qp: EnhancedSearchQuery, - search_func: Callable[..., Awaitable[List[ScoredPoint]]], -) -> Optional[List[ScoredPoint]]: + callback_function: Callable[..., Awaitable[list[ScoredPoint]]], +) -> list[ScoredPoint] | None: try: - lang = detect_language_from_entry(qp.query) - subject_vector = get_subject_vector(qp.subject) - - try: - collections = await sp.get_collections_aliases_by_language( - lang=lang, collections=qp.corpora - ) - except CollectionNotFoundError as e: - logger.error(e.message) - raise CollectionNotFoundError() - - collections_to_search = [ - sp.get_collection_dict_with_embed( - collection_alias=col, - query=qp.query, - subject_vector=subject_vector, - subject_influence_factor=qp.influence_factor, - ) - for col in collections - ] - - logger.info( - "Found %s collections to search: %s", - len(collections_to_search), - collections, - ) - - data = await parallel_search( - callback_function=search_func, - nb_results=qp.nb_results, - collections=collections_to_search, - sdg_filter=qp.sdg_filter, - ) + qps: list[EnhancedSearchQuery] = [] + for query in qp.query: + temp_qp = qp.model_copy() + temp_qp.query = query + qps.append(temp_qp) - if not data: - return [] - - sorted_data = sorted(data, key=lambda x: x.score, reverse=True) - sorted_data = sort_slices_using_mmr(sorted_data, theta=qp.relevance_factor) - - if qp.concatenate: - sorted_data = concatenate_same_doc_id_slices(sorted_data) - - return sorted_data - except Exception as e: - handle_error(response=response, exc=e) - return None - - -async def search_multi_inputs( - response: Response, - inputs: List[str], - nb_results: int, - sdg_filter: list[int] | None, - callback_function: Callable[..., Awaitable[List[ScoredPoint]]], - collections: tuple[str, ...] | None, -): - try: - qps: list[EnhancedSearchQuery] = [ - EnhancedSearchQuery( - nb_results=nb_results, - sdg_filter=sdg_filter, - corpora=collections, - query=input, - ) - for input in inputs - ] tasks = [ - search_all_base( - response=response, - search_func=callback_function, + callback_function( qp=qp, + method=SearchMethods.BY_SLICES, ) for qp in qps ] + all_data: list[ScoredPoint] = [] for coroutine in asyncio.as_completed(tasks): - try: - all_data = await coroutine - except CollectionNotFoundError as e: - logger.error(e.message) - response.status_code = 206 - - if not all_data: - response.status_code = 404 - raise NoResultsError() - - doc: list[Document] = [ - Document( - score=d.score, - payload=d.payload, - ) - for d in all_data - ] - - sorted_data = sorted(doc, key=lambda x: x.score, reverse=True) + data = await coroutine + if data: + all_data.extend(data) - return sorted_data + return all_data except Exception as e: - handle_error(response=response, exc=e) - return None + handle_error(exc=e) + return None diff --git a/src/app/services/tutor/agents.py b/src/app/services/tutor/agents.py index 87f806b..46cd64e 100644 --- a/src/app/services/tutor/agents.py +++ b/src/app/services/tutor/agents.py @@ -173,7 +173,6 @@ async def handle_syllabus( ctx.cancellation_token, ) except Exception as e: - print("Error in SDGExpertAgent:", e) raise e end_time = time.time() response = llm_result.chat_message.content diff --git a/src/app/services/tutor/models.py b/src/app/services/tutor/models.py index ebc12d3..dd8dc58 100644 --- a/src/app/services/tutor/models.py +++ b/src/app/services/tutor/models.py @@ -2,8 +2,7 @@ from typing import Dict, List from pydantic import BaseModel - -from src.app.models.documents import Document +from qdrant_client.models import ScoredPoint class ExtractorOutput(BaseModel): @@ -19,7 +18,7 @@ class ExtractorOuputList(BaseModel): class TutorSearchResponse(BaseModel): extracts: list[ExtractorOutput] nb_results: int - documents: list[Document] + documents: list[ScoredPoint] class SyllabusResponseAgent(BaseModel): @@ -29,7 +28,7 @@ class SyllabusResponseAgent(BaseModel): class SyllabusResponse(BaseModel): syllabus: list[SyllabusResponseAgent] - documents: list[Document] + documents: list[ScoredPoint] class MessageWithAnalysis(BaseModel): diff --git a/src/app/services/tutor/utils.py b/src/app/services/tutor/utils.py index abdf3f4..89fc310 100644 --- a/src/app/services/tutor/utils.py +++ b/src/app/services/tutor/utils.py @@ -1,7 +1,6 @@ from fastapi import UploadFile from pypdf import PdfReader - -from src.app.models.documents import Document +from qdrant_client.models import ScoredPoint def build_system_message( @@ -21,7 +20,7 @@ def build_system_message( return message -def extract_doc_info(documents: list[Document]) -> list[dict]: +def extract_doc_info(documents: list[ScoredPoint]) -> list[dict]: """ Extracts the document information from a list of documents. Args: @@ -31,11 +30,12 @@ def extract_doc_info(documents: list[Document]) -> list[dict]: """ return [ { - "title": doc.payload.document_title, - "url": doc.payload.document_url, - "content": doc.payload.slice_content, + "title": doc.payload.document_title, # type: ignore + "url": doc.payload.document_url, # type: ignore + "content": doc.payload.slice_content, # type: ignore } for doc in documents + if doc.payload is not None ] diff --git a/src/app/tests/api/api_v1/test_search.py b/src/app/tests/api/api_v1/test_search.py index e70173c..6a112cf 100644 --- a/src/app/tests/api/api_v1/test_search.py +++ b/src/app/tests/api/api_v1/test_search.py @@ -5,25 +5,24 @@ from qdrant_client.http import models from src.app.core.config import settings -from src.app.models import collections, documents +from src.app.models import collections +from src.app.models.search import EnhancedSearchQuery from src.app.services.exceptions import ( CollectionNotFoundError, LanguageNotSupportedError, ModelNotFoundError, ) -from src.app.services.search import sort_slices_using_mmr +from src.app.services.search import SearchService, sort_slices_using_mmr from src.main import app client = TestClient(app) search_pipeline_path = "src.app.services.search.SearchService" -parallel_search_path = "src.app.api.api_v1.endpoints.search.parallel_search" mocked_collection = collections.Collection( - name="collection", lang="fr", model="model", - alias="collection_fr_model", + name="collection_welearn_fr_model", ) mocked_scored_points = [ models.ScoredPoint( @@ -31,131 +30,83 @@ version=1, score=0.9, vector=[0.1, 0.2], + payload={ + "document_corpus": "corpus", + "document_desc": "desc", + "document_details": {}, + "document_id": "1", + "document_lang": "fr", + "document_sdg": [1], + "document_title": "title", + "document_url": "url", + "slice_content": "content", + "slice_sdg": 1, + }, ), models.ScoredPoint( id="2", version=1, score=0.89, vector=[0.11, 0.21], + payload={ + "document_corpus": "corpus", + "document_desc": "desc", + "document_details": {}, + "document_id": "1", + "document_lang": "fr", + "document_sdg": [1], + "document_title": "title", + "document_url": "url", + "slice_content": "content", + "slice_sdg": 1, + }, ), models.ScoredPoint( id="3", version=1, score=0.88, vector=[0.3, 0.4], + payload={ + "document_corpus": "corpus", + "document_desc": "desc", + "document_details": {}, + "document_id": "1", + "document_lang": "fr", + "document_sdg": [1], + "document_title": "title", + "document_url": "url", + "slice_content": "content", + "slice_sdg": 1, + }, ), ] -mocked_document = documents.Document( - score=0.9, - payload=documents.DocumentPayloadModel( - document_corpus="corpus", - document_desc="desc", - document_details={}, - document_id="1", - document_lang="fr", - document_sdg=[1], - document_title="title", - document_url="url", - slice_content="content", - slice_sdg=1, - ), -) +long_query = "français with a very long sentence to test what you are saying and if the issue is the size of the string" # noqa: E501 @patch("src.app.services.sql_db.session_maker") @patch("src.app.services.security.check_api_key", new=mock.MagicMock(return_value=True)) -@patch(search_pipeline_path, new=mock.MagicMock()) @patch( f"{search_pipeline_path}.get_collections", - new=mock.AsyncMock(return_value=("collection_fr_model", "collection_en_model")), -) -@patch( - f"{search_pipeline_path}._get_info_from_collection_alias", - new=mock.MagicMock(return_value=mocked_collection), + new=mock.AsyncMock( + return_value=("collection_welearn_fr_model", "collection_en_model") + ), ) class SearchTests(IsolatedAsyncioTestCase): - @patch( - "src.app.api.api_v1.endpoints.search.search_items_base", - new=mock.AsyncMock(return_value=[mocked_document]), - ) - @patch( - "src.app.services.search.SearchService.search_group_by_document", - new=mock.AsyncMock(return_value=["document_1", "document_2"]), - ) - def test_search_items_success(self, *mocks): - """Test successful search_items response""" - - response = client.post( - f"{settings.API_V1_STR}/search/collections/collection_fr_model?query=français&nb_results=10", - headers={"X-API-Key": "test"}, # noqa: E501 - ) - - # Assert: Check that the response is as expected - self.assertEqual(response.status_code, 200) - self.assertEqual( - response.json(), - [ - { - "score": 0.9, - "payload": { - "document_corpus": "corpus", - "document_desc": "desc", - "document_details": {}, - "document_id": "1", - "document_lang": "fr", - "document_sdg": [1], - "document_title": "title", - "document_url": "url", - "slice_content": "content", - "slice_sdg": 1, - }, - } - ], - ) - - @patch( - "src.app.api.api_v1.endpoints.search.search_items_base", - new=mock.AsyncMock(return_value=[mocked_document]), - ) - @patch(f"{search_pipeline_path}.search_group_by_document") def test_search_items_no_query(self, *mocks): """Test search_items when no query is provided""" - # Act: Make a test request to the /collections/{collection_query} endpoint without query response = client.post( - f"{settings.API_V1_STR}/search/collections/collection_fr_model", # noqa: E501 + f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model", # noqa: E501 json={"nb_results": 10}, headers={"X-API-Key": "test"}, ) - # Assert: Check that the response is as expected - self.assertEqual(response.status_code, 400) - self.assertEqual( - response.json(), - {"detail": {"code": "EMPTY_QUERY", "message": "Empty query"}}, - ) - - @patch( - f"{search_pipeline_path}.get_collections", - new=mock.AsyncMock(return_value=("NOTTOTO_fr_model")), - ) - async def test_search_collection_not_found(self, *mocks): - with self.assertRaises(CollectionNotFoundError): - response = client.post( - f"{settings.API_V1_STR}/search/collections/toto?query=français&nb_results=10", - headers={"X-API-Key": "test"}, - ) - - self.assertEqual(response.status_code, 404) - self.assertEqual( - response.json().get("detail")["code"], - "COLL_NOT_FOUND", - ) + self.assertEqual(response.status_code, 422) @patch( - f"{search_pipeline_path}.get_model", + f"{search_pipeline_path}._get_model", new=mock.MagicMock( side_effect=ModelNotFoundError("Model not found", "MODEL_NOT_FOUND") ), @@ -163,7 +114,7 @@ async def test_search_collection_not_found(self, *mocks): async def test_search_model_not_found(self, *mocks): with self.assertRaises(ModelNotFoundError): response = client.post( - f"{settings.API_V1_STR}/search/collections/collection_fr_model?query=français&nb_results=10", # noqa: E501 + f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model?query=français&nb_results=10", # noqa: E501 headers={"X-API-Key": "test"}, ) @@ -178,6 +129,52 @@ async def test_search_model_not_found(self, *mocks): }, ) + @patch( + f"{search_pipeline_path}.search_handler", + new=mock.AsyncMock(return_value=mocked_scored_points), + ) + async def test_search_items_success(self, *mocks): + """Test successful search_items response""" + + response = client.post( + f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model?query={long_query}&nb_results=10", + headers={"X-API-Key": "test"}, # noqa: E501 + ) + + self.assertEqual(response.status_code, 200) + + @patch( + f"{search_pipeline_path}.search_handler", + new=mock.AsyncMock(return_value=[]), + ) + async def test_search_items_no_result(self, *mocks): + response = client.post( + f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model?query={long_query}&nb_results=10", + headers={"X-API-Key": "test"}, # noqa: E501 + ) + + self.assertEqual(response.status_code, 206) + self.assertEqual(response.json(), []) + + @patch( + f"{search_pipeline_path}.get_collection_by_language", + new=mock.AsyncMock( + side_effect=CollectionNotFoundError( + "Collection not found", "COLL_NOT_FOUND" + ) + ), + ) + async def test_search_all_slices_no_collections(self, *mocks): + response = client.post( + f"{settings.API_V1_STR}/search/collections/collection_welearn_fr_model?query={long_query}&nb_results=10", + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 404) + self.assertEqual( + response.json(), + "Collection not found", + ) + @patch("src.app.services.sql_db.session_maker") @patch("src.app.services.security.check_api_key", new=mock.MagicMock(return_value=True)) @@ -202,9 +199,8 @@ async def test_search_all_slices_lang_not_supported(self, *mocks): }, ) - # patch should raise @patch( - f"{search_pipeline_path}.get_collections_aliases_by_language", + f"{search_pipeline_path}.get_collection_by_language", new=mock.AsyncMock( side_effect=CollectionNotFoundError( "Collection not found", "COLL_NOT_FOUND" @@ -212,56 +208,26 @@ async def test_search_all_slices_lang_not_supported(self, *mocks): ), ) async def test_search_all_slices_no_collections(self, *mocks): - with self.assertRaises(CollectionNotFoundError): - response = client.post( - f"{settings.API_V1_STR}/search/by_slices?nb_results=10", - json={ - "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" - }, # noqa: E501 - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 404) - self.assertEqual( - response.json().get("detail")["code"], - "COLL_NOT_FOUND", - ) + response = client.post( + f"{settings.API_V1_STR}/search/by_slices?nb_results=10", + json={ + "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" + }, # noqa: E501 + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 404) + self.assertEqual( + response.json(), + "Collection not found", + ) - @patch( - "src.app.api.api_v1.endpoints.search.search_all_base", - return_value=[ - { - "score": 0.7513504, - "payload": { - "document_corpus": "conversation", - "document_desc": "More and more evidence has accumulated which shows that changes in global and regional climate over the last 50 years are almost entirely due to human influence.", - "document_details": { - "authors": [ - {"misc": "", "name": "Mark New"}, - {"misc": "", "name": "University of Cape Town"}, - ], - "duration": "316", - "readability": "56.49", - "source": "africa", - }, - "document_id": "af2ca7b2-1011-4b4d-828e-9678597fd255", - "document_lang": "en", - "document_sdg": [13], - "document_title": "Climate explained: how much of climate change is natural? How much is man-made?", - "document_url": "https://theconversation.com/climate-explained-how-much-of-climate-change-is-natural-how-much-is-man-made-123604", - "slice_content": "The Intergovernmental Panel on Climate Change defines climate change as: a change in the state of the climate that can be identified by changes in the mean and/or the variability of its properties and that persists for an extended period, typically decades or longer. The causes of climate change can be any combination of: Internal variability in the climate system, when various components of the climate system – like the atmosphere and ocean – vary on their own to cause fluctuations in climatic conditions, such as temperature or rainfall. These internally-driven changes generally happen over decades or longer; shorter variations such as those related to El Niño fall in the bracket of climate variability, not climate change. Natural external causes such as increases or decreases in volcanic activity or solar radiation. For example, every 11 years or so, the Sun’s magnetic field completely flips and this can cause small fluctuations in global temperature, up to about 0.2 degrees. On longer time scales – tens to hundreds of millions of years – geological processes can drive changes in the climate, due to shifting continents and mountain building. Human influence through greenhouse gases (gases that trap heat in the atmosphere such as carbon dioxide and methane), other particles released into the air (which absorb or reflect sunlight such as soot and aerosols) and land-use change (which affects how much sunlight is absorbed on land surfaces and also how much carbon dioxide and methane is absorbed and released by vegetation and soils). What changes have been detected?\n\nThe Intergovernmental Panel on Climate Change’s recent report showed that, on average, the global surface air temperature has risen by 1°C since the beginning of significant industrialisation (which roughly started in the 1850s). And it is increasing at ever faster rates, currently 0.2°C per decade, because the concentrations of greenhouse gases in the atmosphere have themselves been increasing ever faster. The oceans are warming as well. In fact, about 90% of the extra heat trapped in the atmosphere by greenhouse gases is being absorbed by the oceans. A warmer atmosphere and oceans are causing dramatic changes, including steep decreases in Arctic summer sea ice which is profoundly impacting arctic marine ecosystems, increasing sea level rise which is inundating low lying coastal areas such as Pacific island atolls, and an increasing frequency of many climate extremes such as drought and heavy rain, as well as disasters where climate is an important driver, such as wildfire, flooding and landslides. Multiple lines of evidence, using different methods, show that human influence is the only plausible explanation for the patterns and magnitude of changes that have been detected. This human influence is largely due to our activities that release greenhouse gases, such as carbon dioxide and methane, as well sunlight absorbing soot. The main sources of these warming gases and particles are fossil fuel burning, cement production, land cover change (especially deforestation) and agriculture. Weather attribution Most of us will struggle to pick up slow changes in the climate.", - "slice_sdg": 13, - }, - } - ], - ) + @patch(f"{search_pipeline_path}.search_handler", return_value=mocked_scored_points) async def test_search_all_slices_ok(self, *mocks): response = client.post( f"{settings.API_V1_STR}/search/by_slices", json={ "query": "Comment est-ce que les gouvernements font pour suivre ces conseils et les mettre en place ?", "relevance_factor": 0.75, - "sdg_filter": [], - "corpora": [], }, headers={"X-API-Key": "test"}, ) @@ -281,12 +247,7 @@ async def test_search_all_slices_no_query(self, *mocks): ) @patch( - f"{search_pipeline_path}.get_collections_aliases_by_language", - return_value=("collection_fr_model"), - ) - @patch(f"{search_pipeline_path}.get_collection_dict_with_embed") - @patch( - "src.app.services.search_helpers.parallel_search", + f"{search_pipeline_path}.search_handler", return_value=[], ) async def test_search_all_slices_no_result(self, *mocks): @@ -324,7 +285,7 @@ async def test_search_all_lang_not_supported(self, *mocks): ) @patch( - f"{search_pipeline_path}.get_collections_aliases_by_language", + f"{search_pipeline_path}.get_collection_by_language", new=mock.AsyncMock( side_effect=CollectionNotFoundError( "Collection not found", "COLL_NOT_FOUND" @@ -332,37 +293,24 @@ async def test_search_all_lang_not_supported(self, *mocks): ), ) @patch( - f"{search_pipeline_path}._get_info_from_collection_alias", + f"{search_pipeline_path}._get_info_from_collection_name", new=mock.MagicMock(return_value=mocked_collection), ) async def test_search_all_no_collections(self, *mocks): - with self.assertRaises(CollectionNotFoundError): - response = client.post( - f"{settings.API_V1_STR}/search/by_document?nb_results=10", - json={ - "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" - }, - headers={"X-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 404) - self.assertEqual( - response.json().get("detail")["code"], - "COLL_NOT_FOUND", - ) + response = client.post( + f"{settings.API_V1_STR}/search/by_document?nb_results=10", + json={ + "query": "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne" + }, + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 404) + self.assertEqual( + response.json(), + "Collection not found", + ) - @patch( - f"{search_pipeline_path}.get_collections_aliases_by_language", - return_value=["collection_fr_model"], - ) - @patch( - f"{search_pipeline_path}._get_info_from_collection_alias", - new=mock.MagicMock(return_value=mocked_collection), - ) - @patch(f"{search_pipeline_path}.embed_query") - @patch( - "src.app.services.search_helpers.parallel_search", - return_value=[], - ) + @patch(f"{search_pipeline_path}.search_handler", return_value=[]) async def test_search_all_no_result(self, *mocks): response = client.post( f"{settings.API_V1_STR}/search/by_document?nb_results=10", @@ -425,3 +373,46 @@ async def test_search_multi_lang_not_supported(self, *mocks): } }, ) + + @patch( + f"{search_pipeline_path}.search_handler", + return_value=[], + ) + async def test_search_multi_no_result(self, *mocks): + response = client.post( + f"{settings.API_V1_STR}/search/multiple_by_slices?nb_results=10", + json={ + "query": [ + "une phrase plus longue pour tester la recherche en français. et voir ce que cela donne", + "another long sentence to test the search in english and see what happens", + ] + }, + headers={"X-API-Key": "test"}, + ) + self.assertEqual(response.status_code, 404) + + async def test_search_multi_single_query(self, *mocks): + with mock.patch( + "src.app.api.api_v1.endpoints.search.search_multi_inputs", + ) as search_multi, mock.patch.object( + SearchService, "search_handler", return_value=mocked_scored_points + ) as search_handler: + client.post( + f"{settings.API_V1_STR}/search/multiple_by_slices?nb_results=10", + json={ + "query": long_query, + }, + headers={"X-API-Key": "test"}, + ) + search_multi.assert_called_once_with( + qp=EnhancedSearchQuery( + query=[long_query], + sdg_filter=None, + corpora=None, + subject=None, + nb_results=10, + influence_factor=2.0, + relevance_factor=1.0, + ), + callback_function=search_handler, # noqa: E501 + ) diff --git a/src/app/tests/services/test_search.py b/src/app/tests/services/test_search.py index 76f8689..ebf69bb 100644 --- a/src/app/tests/services/test_search.py +++ b/src/app/tests/services/test_search.py @@ -1,31 +1,20 @@ import os +from typing import List from unittest import IsolatedAsyncioTestCase, mock from qdrant_client import AsyncQdrantClient +from qdrant_client.models import CollectionDescription, CollectionsResponse, ScoredPoint +from src.app.models.collections import Collection from src.app.services.search import ( DBClientSingleton, SearchService, concatenate_same_doc_id_slices, - search_items_method, ) os.environ["USE_CACHED_SETTINGS"] = "False" -class AliasItem: - def __init__(self, alias_name): - self.alias_name = alias_name - - -class Aliases: - def __init__(self): - self.aliases = [ - AliasItem(alias_name="collection_fr_exists"), - AliasItem(alias_name="collection_en_exists"), - ] - - class alternate_mock_method(object): def __init__(self, url: str, timeout: int, port: int, **kwargs): return @@ -36,13 +25,20 @@ async def get_aliases(self, *args, **kwargs): USER_QUERY = "query1" +collections = CollectionsResponse( + collections=[ + CollectionDescription(name="collection_welearn_fr_exists"), + CollectionDescription(name="collection_welearn_en_exists"), + ] +) + def fake_callback_function(embedding, nb_results, filters, collection_info): return f"{embedding}, {nb_results}, {filters}, {collection_info}" @mock.patch("qdrant_client.AsyncQdrantClient", alternate_mock_method) -@mock.patch.object(AsyncQdrantClient, "get_aliases", return_value=Aliases()) +@mock.patch.object(AsyncQdrantClient, "get_collections", return_value=collections) class SearchServiceTests(IsolatedAsyncioTestCase): def setUp(self): self.sp = SearchService() @@ -53,78 +49,76 @@ def test_db_singleton(self, *mocks): self.assertEqual(db_sing1, db_sing2) - async def test_search_pipeline_collection(self, *mocks): - collection = await self.sp.get_collection_alias("collection", "fr") - - self.assertEqual(collection, "collection_fr_exists") - - async def test_get_collections_aliases_by_language(self, *mocks): - collections = await self.sp.get_collections_aliases_by_language("fr") - - self.assertEqual(collections, ["collection_fr_exists"]) - - async def test_get_collections_aliases_by_language_without_sel_collection( - self, *mocks - ): - await self.sp.get_collections_aliases_by_language("fr") - assert mocks[0].called - - async def test_get_collection_alias(self, *mocks): - collection = await self.sp.get_collection_alias("collection", "fr") - self.assertEqual(collection, "collection_fr_exists") + async def test_get_collection_by_language(self, *mocks): + collections = await self.sp.get_collection_by_language("fr") - async def test_get_collection_alias_without_sel_collection(self, *mocks): - await self.sp.get_collection_alias("collection", "fr") - assert mocks[0].called + self.assertEqual(collections.name, "collection_welearn_fr_exists") - def test_get_info_from_collection_alias(self, *mocks): - collection = self.sp._get_info_from_collection_alias("collection_fr_exists") + def test_get_info_from_collection_name(self, *mocks): + collection = self.sp._get_info_from_collection_name( + "collection_welearn_fr_exists" + ) - self.assertEqual(collection.alias, "collection_fr_exists") + self.assertEqual(collection.name, "collection_welearn_fr_exists") self.assertEqual(collection.lang, "fr") self.assertEqual(collection.model, "exists") - self.assertEqual(collection.name, "collection") - async def test_get_collections_aliases_by_language_with_collection(self, *mocks): + async def test_get_collection_by_language_with_collection(self, *mocks): with mock.patch.object( SearchService, "get_collections", return_value=( - "conversation_en_exists", - "conversation_fr_exists", + "collection_welearn_en_exists", + "collection_welearn_fr_exists", "wiki_fr_exists", ), ): - collections = await self.sp.get_collections_aliases_by_language( - "fr", ("wiki") + collection = await self.sp.get_collection_by_language("fr") + exp_collection = Collection( + name="collection_welearn_fr_exists", + lang="fr", + model="exists", ) - self.assertEqual(collections, ["wiki_fr_exists"]) + self.assertEqual(collection.name, "collection_welearn_fr_exists") + self.assertEqual(collection, exp_collection) def test_concatenate_same_doc_id_slices(self, *mocks): - class FakeQdrantDoc: - def __init__(self, id, payload) -> None: - self.id = id - self.payload = payload - - qdrant_docs = [ - FakeQdrantDoc(1, {"document_id": "1", "slice_content": "content1"}), - FakeQdrantDoc(2, {"document_id": "1", "slice_content": "content2"}), - FakeQdrantDoc(3, {"document_id": "2", "slice_content": "content3"}), + + qdrant_docs: List[ScoredPoint] = [ + ScoredPoint( + id=1, + version=1, + score=0.5, + payload={"document_id": "1", "slice_content": "content1"}, + ), + ScoredPoint( + id=1, + version=1, + score=0.8, + payload={"document_id": "1", "slice_content": "content2"}, + ), + ScoredPoint( + id=2, + version=1, + score=0.8, + payload={"document_id": "2", "slice_content": "content3"}, + ), ] results = concatenate_same_doc_id_slices(qdrant_results=qdrant_docs) + expected_result = [ + ScoredPoint( + id=1, + version=1, + score=0.5, + payload={"document_id": "1", "slice_content": "content1\n\ncontent2"}, + ), + ScoredPoint( + id=2, + version=1, + score=0.8, + payload={"document_id": "2", "slice_content": "content3"}, + ), + ] self.assertEqual(len(results), 2) - self.assertEqual( - results[0].payload.get("slice_content"), "content1\n\ncontent2" - ) - self.assertEqual(results[1].payload.get("slice_content"), "content3") - - def test_search_itmes_method(self, *mocks): - results = search_items_method( - callback_function=fake_callback_function, - embedding=USER_QUERY, - nb_results=1, - collection="collection", - filters=None, - ) - - self.assertEqual(results, "query1, 1, None, collection") + self.assertEqual(results[0].payload, expected_result[0].payload) + self.assertEqual(results[1].payload, expected_result[1].payload)