diff --git a/CHANGELOG.md b/CHANGELOG.md index c4643b24..ccfc4bb6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 0.7.5 + +* Improved packaging + ## 0.7.4 * Dynamic beam search size has been implemented for Chipper, the decoding process starts with a size = 1 and changes to size = 3 if repetitions appear. diff --git a/MANIFEST.in b/MANIFEST.in index a69668c6..117ef8ae 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1 +1,2 @@ include requirements/base.in +include requirements/sg.in diff --git a/Makefile b/Makefile index 463eedb3..db158c38 100644 --- a/Makefile +++ b/Makefile @@ -19,10 +19,10 @@ install-base: install-base-pip-packages ## install: installs all test, dev, and experimental requirements .PHONY: install -install: install-base-pip-packages install-dev install-detectron2 +install: install-base-pip-packages install-dev install-sg install-detectron2 .PHONY: install-ci -install-ci: install-base-pip-packages install-test install-paddleocr +install-ci: install-base-pip-packages install-test install-sg install-paddleocr .PHONY: install-base-pip-packages install-base-pip-packages: @@ -46,6 +46,10 @@ install-test: install-base install-dev: install-test pip install -r requirements/dev.txt +.PHONY: install-sg +install-sg: install-base + pip install -r requirements/sg.txt + ## pip-compile: compiles all base/dev/test requirements .PHONY: pip-compile pip-compile: @@ -56,6 +60,7 @@ pip-compile: sed 's/^detectron2 @/# detectron2 @/g' requirements/base.txt pip-compile --upgrade requirements/test.in pip-compile --upgrade requirements/dev.in + pip-compile --upgrade requirements/sg.in ################# diff --git a/requirements/base.in b/requirements/base.in index b301d982..c06e7ff6 100644 --- a/requirements/base.in +++ b/requirements/base.in @@ -1,9 +1,8 @@ +-c constraints.in layoutparser[layoutmodels,tesseract] python-multipart huggingface-hub opencv-python!=4.7.0.68 -super-gradients -supervision # NOTE(benjamin): Pinned because onnxruntime changed the way quantization is done, and we need to update our code to support it onnxruntime<1.16 # NOTE(alan): Pinned because this is when the most recent module we import appeared diff --git a/requirements/base.txt b/requirements/base.txt index eb7f7819..deac0afb 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -4,30 +4,8 @@ # # pip-compile requirements/base.in # -absl-py==2.0.0 - # via tensorboard -alabaster==0.7.13 - # via sphinx antlr4-python3-runtime==4.9.3 - # via - # hydra-core - # omegaconf -attrs==23.1.0 - # via - # jsonschema - # referencing -babel==2.13.0 - # via sphinx -boto3==1.28.62 - # via super-gradients -botocore==1.31.62 - # via - # boto3 - # s3transfer -build==1.0.3 - # via pip-tools -cachetools==5.3.1 - # via google-auth + # via omegaconf certifi==2023.7.22 # via requests cffi==1.16.0 @@ -36,28 +14,16 @@ charset-normalizer==3.3.0 # via # pdfminer-six # requests -click==8.1.7 - # via pip-tools coloredlogs==15.0.1 # via onnxruntime contourpy==1.1.1 # via matplotlib -coverage==5.3.1 - # via super-gradients cryptography==41.0.4 # via pdfminer-six cycler==0.12.1 # via matplotlib -deprecated==1.2.14 - # via super-gradients -docutils==0.17.1 - # via - # sphinx - # sphinx-rtd-theme effdet==0.4.1 # via layoutparser -einops==0.3.2 - # via super-gradients filelock==3.12.4 # via # huggingface-hub @@ -71,16 +37,6 @@ fsspec==2023.9.2 # via # huggingface-hub # torch -future==0.18.3 - # via treelib -google-auth==2.23.3 - # via - # google-auth-oauthlib - # tensorboard -google-auth-oauthlib==1.0.0 - # via tensorboard -grpcio==1.59.0 - # via tensorboard huggingface-hub==0.17.3 # via # -r requirements/base.in @@ -89,114 +45,53 @@ huggingface-hub==0.17.3 # transformers humanfriendly==10.0 # via coloredlogs -hydra-core==1.3.2 - # via super-gradients idna==3.4 # via requests -imagesize==1.4.1 - # via sphinx -importlib-metadata==6.8.0 - # via - # build - # markdown importlib-resources==6.1.0 - # via - # hydra-core - # jsonschema - # jsonschema-specifications - # matplotlib + # via matplotlib iopath==0.1.10 # via layoutparser jinja2==3.1.2 - # via - # sphinx - # torch -jmespath==1.0.1 - # via - # boto3 - # botocore -json-tricks==3.16.1 - # via super-gradients -jsonschema==4.19.1 - # via super-gradients -jsonschema-specifications==2023.7.1 - # via jsonschema + # via torch kiwisolver==1.4.5 # via matplotlib layoutparser[layoutmodels,tesseract]==0.3.4 # via -r requirements/base.in -markdown==3.5 - # via tensorboard -markdown-it-py==3.0.0 - # via rich markupsafe==2.1.3 - # via - # jinja2 - # werkzeug + # via jinja2 matplotlib==3.7.3 - # via - # pycocotools - # super-gradients - # supervision -mdurl==0.1.2 - # via markdown-it-py + # via pycocotools mpmath==1.3.0 # via sympy networkx==3.1 # via torch numpy==1.23.0 # via + # -c requirements/constraints.in # contourpy # layoutparser # matplotlib - # onnx # onnxruntime # opencv-python - # opencv-python-headless # pandas # pycocotools # scipy - # super-gradients - # supervision - # tensorboard - # torchmetrics # torchvision # transformers -oauthlib==3.2.2 - # via requests-oauthlib omegaconf==2.3.0 - # via - # effdet - # hydra-core - # super-gradients -onnx==1.13.0 - # via - # onnx-simplifier - # super-gradients -onnx-simplifier==0.4.33 - # via super-gradients -onnxruntime==1.13.1 - # via - # -r requirements/base.in - # super-gradients + # via effdet +onnxruntime==1.15.1 + # via -r requirements/base.in opencv-python==4.8.1.78 # via # -r requirements/base.in # layoutparser - # super-gradients -opencv-python-headless==4.8.1.78 - # via supervision packaging==23.2 # via - # build # huggingface-hub - # hydra-core # matplotlib # onnxruntime # pytesseract - # sphinx - # super-gradients - # torchmetrics # transformers pandas==2.0.3 # via layoutparser @@ -213,213 +108,87 @@ pillow==10.0.1 # pdf2image # pdfplumber # pytesseract - # super-gradients - # supervision # torchvision -pip-tools==7.3.0 - # via super-gradients -pkgutil-resolve-name==1.3.10 - # via jsonschema portalocker==2.8.2 # via iopath -protobuf==3.20.3 - # via - # onnx - # onnxruntime - # super-gradients - # tensorboard -psutil==5.9.5 - # via super-gradients -pyasn1==0.5.0 - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 - # via google-auth -pycocotools==2.0.6 - # via - # effdet - # super-gradients +protobuf==4.24.4 + # via onnxruntime +pycocotools==2.0.7 + # via effdet pycparser==2.21 # via cffi -pydeprecate==0.3.2 - # via torchmetrics -pygments==2.16.1 - # via - # rich - # sphinx - # super-gradients -pyparsing==2.4.5 - # via - # matplotlib - # super-gradients -pypdfium2==4.20.0 +pyparsing==3.1.1 + # via matplotlib +pypdfium2==4.21.0 # via pdfplumber -pyproject-hooks==1.0.0 - # via build pytesseract==0.3.10 # via layoutparser python-dateutil==2.8.2 # via - # botocore # matplotlib # pandas python-multipart==0.0.6 # via -r requirements/base.in pytz==2023.3.post1 - # via - # babel - # pandas + # via pandas pyyaml==6.0.1 # via # huggingface-hub # layoutparser # omegaconf - # supervision # timm # transformers rapidfuzz==3.4.0 - # via - # -r requirements/base.in - # super-gradients -referencing==0.30.2 - # via - # jsonschema - # jsonschema-specifications + # via -r requirements/base.in regex==2023.10.3 # via transformers requests==2.31.0 # via # huggingface-hub - # requests-oauthlib - # sphinx - # tensorboard # torchvision # transformers -requests-oauthlib==1.3.1 - # via google-auth-oauthlib -rich==13.6.0 - # via onnx-simplifier -rpds-py==0.10.4 - # via - # jsonschema - # referencing -rsa==4.9 - # via google-auth -s3transfer==0.7.0 - # via boto3 safetensors==0.4.0 # via # timm # transformers scipy==1.10.1 - # via - # layoutparser - # super-gradients - # supervision + # via layoutparser six==1.16.0 # via python-dateutil -snowballstemmer==2.2.0 - # via sphinx -sphinx==4.0.3 - # via - # sphinx-rtd-theme - # sphinxcontrib-jquery - # super-gradients -sphinx-rtd-theme==1.3.0 - # via super-gradients -sphinxcontrib-applehelp==1.0.4 - # via sphinx -sphinxcontrib-devhelp==1.0.2 - # via sphinx -sphinxcontrib-htmlhelp==2.0.1 - # via sphinx -sphinxcontrib-jquery==4.1 - # via sphinx-rtd-theme -sphinxcontrib-jsmath==1.0.1 - # via sphinx -sphinxcontrib-qthelp==1.0.3 - # via sphinx -sphinxcontrib-serializinghtml==1.1.5 - # via sphinx -stringcase==1.2.0 - # via super-gradients -super-gradients==3.2.1 - # via -r requirements/base.in -supervision==0.15.0 - # via -r requirements/base.in sympy==1.12 # via # onnxruntime # torch -tensorboard==2.14.0 - # via super-gradients -tensorboard-data-server==0.7.1 - # via tensorboard -termcolor==1.1.0 - # via super-gradients timm==0.9.7 # via effdet tokenizers==0.14.1 # via transformers -tomli==2.0.1 - # via - # build - # pip-tools - # pyproject-hooks torch==2.1.0 # via # effdet # layoutparser - # super-gradients # timm - # torchmetrics # torchvision -torchmetrics==0.8.0 - # via super-gradients torchvision==0.16.0 # via # effdet # layoutparser - # super-gradients # timm tqdm==4.66.1 # via # huggingface-hub # iopath - # super-gradients # transformers transformers==4.34.0 # via -r requirements/base.in -treelib==1.6.1 - # via super-gradients typing-extensions==4.8.0 # via # huggingface-hub # iopath - # onnx - # rich # torch tzdata==2023.3 # via pandas -urllib3==1.26.17 - # via - # botocore - # requests -werkzeug==3.0.0 - # via tensorboard -wheel==0.41.2 - # via - # pip-tools - # super-gradients - # tensorboard -wrapt==1.15.0 - # via deprecated +urllib3==2.0.6 + # via requests zipp==3.17.0 - # via - # importlib-metadata - # importlib-resources - -# The following packages are considered to be unsafe in a requirements file: -# pip -# setuptools + # via importlib-resources diff --git a/requirements/constraints.in b/requirements/constraints.in new file mode 100644 index 00000000..a819fb2e --- /dev/null +++ b/requirements/constraints.in @@ -0,0 +1 @@ +numpy<=1.23 diff --git a/requirements/dev.in b/requirements/dev.in index f40c982b..9fdba403 100644 --- a/requirements/dev.in +++ b/requirements/dev.in @@ -1,3 +1,4 @@ +-c constraints.in -c base.txt -c test.txt jupyter diff --git a/requirements/dev.txt b/requirements/dev.txt index 8abd3139..bac4c5a0 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -24,13 +24,10 @@ async-lru==2.0.4 # via jupyterlab attrs==23.1.0 # via - # -c requirements/base.txt # jsonschema # referencing babel==2.13.0 - # via - # -c requirements/base.txt - # jupyterlab-server + # via jupyterlab-server backcall==0.2.0 # via ipython beautifulsoup4==4.12.2 @@ -38,9 +35,7 @@ beautifulsoup4==4.12.2 bleach==6.1.0 # via nbconvert build==1.0.3 - # via - # -c requirements/base.txt - # pip-tools + # via pip-tools certifi==2023.7.22 # via # -c requirements/base.txt @@ -57,7 +52,6 @@ charset-normalizer==3.3.0 # requests click==8.1.7 # via - # -c requirements/base.txt # -c requirements/test.txt # pip-tools comm==0.1.4 @@ -101,7 +95,6 @@ idna==3.4 # requests importlib-metadata==6.8.0 # via - # -c requirements/base.txt # build # jupyter-client # jupyter-lsp @@ -148,17 +141,14 @@ jsonpointer==2.4 # via jsonschema jsonschema[format-nongpl]==4.19.1 # via - # -c requirements/base.txt # jupyter-events # jupyterlab-server # nbformat jsonschema-specifications==2023.7.1 - # via - # -c requirements/base.txt - # jsonschema + # via jsonschema jupyter==1.0.0 # via -r requirements/dev.in -jupyter-client==8.3.1 +jupyter-client==8.4.0 # via # ipykernel # jupyter-console @@ -191,7 +181,7 @@ jupyter-server==2.7.3 # notebook-shim jupyter-server-terminals==0.4.4 # via jupyter-server -jupyterlab==4.0.6 +jupyterlab==4.0.7 # via notebook jupyterlab-pygments==0.2.2 # via nbconvert @@ -233,7 +223,7 @@ nbformat==5.9.2 # nbconvert nest-asyncio==1.5.8 # via ipykernel -notebook==7.0.4 +notebook==7.0.5 # via jupyter notebook-shim==0.2.3 # via @@ -242,6 +232,7 @@ notebook-shim==0.2.3 numpy==1.23.0 # via # -c requirements/base.txt + # -c requirements/constraints.in # contourpy # matplotlib overrides==7.4.0 @@ -273,13 +264,9 @@ pillow==10.0.1 # -c requirements/test.txt # matplotlib pip-tools==7.3.0 - # via - # -c requirements/base.txt - # -r requirements/dev.in + # via -r requirements/dev.in pkgutil-resolve-name==1.3.10 - # via - # -c requirements/base.txt - # jsonschema + # via jsonschema platformdirs==3.11.0 # via # -c requirements/test.txt @@ -291,9 +278,7 @@ prompt-toolkit==3.0.39 # ipython # jupyter-console psutil==5.9.5 - # via - # -c requirements/base.txt - # ipykernel + # via ipykernel ptyprocess==0.7.0 # via # pexpect @@ -306,19 +291,16 @@ pycparser==2.21 # cffi pygments==2.16.1 # via - # -c requirements/base.txt # ipython # jupyter-console # nbconvert # qtconsole -pyparsing==2.4.5 +pyparsing==3.1.1 # via # -c requirements/base.txt # matplotlib pyproject-hooks==1.0.0 - # via - # -c requirements/base.txt - # build + # via build python-dateutil==2.8.2 # via # -c requirements/base.txt @@ -349,7 +331,6 @@ qtpy==2.4.0 # via qtconsole referencing==0.30.2 # via - # -c requirements/base.txt # jsonschema # jsonschema-specifications # jupyter-events @@ -366,9 +347,8 @@ rfc3986-validator==0.1.1 # via # jsonschema # jupyter-events -rpds-py==0.10.4 +rpds-py==0.10.6 # via - # -c requirements/base.txt # jsonschema # referencing send2trash==1.8.2 @@ -396,7 +376,6 @@ tinycss2==1.2.1 # via nbconvert tomli==2.0.1 # via - # -c requirements/base.txt # -c requirements/test.txt # build # jupyterlab @@ -437,7 +416,7 @@ typing-extensions==4.8.0 # ipython uri-template==1.3.0 # via jsonschema -urllib3==1.26.17 +urllib3==2.0.6 # via # -c requirements/base.txt # -c requirements/test.txt @@ -453,9 +432,7 @@ webencodings==0.5.1 websocket-client==1.6.4 # via jupyter-server wheel==0.41.2 - # via - # -c requirements/base.txt - # pip-tools + # via pip-tools widgetsnbextension==4.0.9 # via ipywidgets zipp==3.17.0 diff --git a/requirements/sg.in b/requirements/sg.in new file mode 100644 index 00000000..b166b40a --- /dev/null +++ b/requirements/sg.in @@ -0,0 +1,4 @@ +-c constraints.in +-c base.in +super-gradients +supervision diff --git a/requirements/sg.txt b/requirements/sg.txt new file mode 100644 index 00000000..17dec764 --- /dev/null +++ b/requirements/sg.txt @@ -0,0 +1,334 @@ +# +# This file is autogenerated by pip-compile with Python 3.8 +# by the following command: +# +# pip-compile requirements/sg.in +# +absl-py==2.0.0 + # via tensorboard +alabaster==0.7.13 + # via sphinx +antlr4-python3-runtime==4.9.3 + # via + # hydra-core + # omegaconf +attrs==23.1.0 + # via + # jsonschema + # referencing +babel==2.13.0 + # via sphinx +boto3==1.28.63 + # via super-gradients +botocore==1.31.63 + # via + # boto3 + # s3transfer +build==1.0.3 + # via pip-tools +cachetools==5.3.1 + # via google-auth +certifi==2023.7.22 + # via requests +charset-normalizer==3.3.0 + # via requests +click==8.1.7 + # via pip-tools +coloredlogs==15.0.1 + # via onnxruntime +contourpy==1.1.1 + # via matplotlib +coverage==5.3.1 + # via super-gradients +cycler==0.12.1 + # via matplotlib +deprecated==1.2.14 + # via super-gradients +docutils==0.17.1 + # via + # sphinx + # sphinx-rtd-theme +einops==0.3.2 + # via super-gradients +filelock==3.12.4 + # via torch +flatbuffers==23.5.26 + # via onnxruntime +fonttools==4.43.1 + # via matplotlib +fsspec==2023.9.2 + # via torch +future==0.18.3 + # via treelib +google-auth==2.23.3 + # via + # google-auth-oauthlib + # tensorboard +google-auth-oauthlib==1.0.0 + # via tensorboard +grpcio==1.59.0 + # via tensorboard +humanfriendly==10.0 + # via coloredlogs +hydra-core==1.3.2 + # via super-gradients +idna==3.4 + # via requests +imagesize==1.4.1 + # via sphinx +importlib-metadata==6.8.0 + # via + # build + # markdown +importlib-resources==6.1.0 + # via + # hydra-core + # jsonschema + # jsonschema-specifications + # matplotlib +jinja2==3.1.2 + # via + # sphinx + # torch +jmespath==1.0.1 + # via + # boto3 + # botocore +json-tricks==3.16.1 + # via super-gradients +jsonschema==4.19.1 + # via super-gradients +jsonschema-specifications==2023.7.1 + # via jsonschema +kiwisolver==1.4.5 + # via matplotlib +markdown==3.5 + # via tensorboard +markdown-it-py==3.0.0 + # via rich +markupsafe==2.1.3 + # via + # jinja2 + # werkzeug +matplotlib==3.7.3 + # via + # pycocotools + # super-gradients + # supervision +mdurl==0.1.2 + # via markdown-it-py +mpmath==1.3.0 + # via sympy +networkx==3.1 + # via torch +numpy==1.23.0 + # via + # -c requirements/constraints.in + # contourpy + # matplotlib + # onnx + # onnxruntime + # opencv-python + # opencv-python-headless + # pycocotools + # scipy + # super-gradients + # supervision + # tensorboard + # torchmetrics + # torchvision +oauthlib==3.2.2 + # via requests-oauthlib +omegaconf==2.3.0 + # via + # hydra-core + # super-gradients +onnx==1.13.0 + # via + # onnx-simplifier + # super-gradients +onnx-simplifier==0.4.33 + # via super-gradients +onnxruntime==1.13.1 + # via + # -c requirements/base.in + # super-gradients +opencv-python==4.8.1.78 + # via + # -c requirements/base.in + # super-gradients +opencv-python-headless==4.8.1.78 + # via supervision +packaging==23.2 + # via + # build + # hydra-core + # matplotlib + # onnxruntime + # sphinx + # super-gradients + # torchmetrics +pillow==10.0.1 + # via + # matplotlib + # super-gradients + # supervision + # torchvision +pip-tools==7.3.0 + # via super-gradients +pkgutil-resolve-name==1.3.10 + # via jsonschema +protobuf==3.20.3 + # via + # onnx + # onnxruntime + # super-gradients + # tensorboard +psutil==5.9.5 + # via super-gradients +pyasn1==0.5.0 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth +pycocotools==2.0.6 + # via super-gradients +pydeprecate==0.3.2 + # via torchmetrics +pygments==2.16.1 + # via + # rich + # sphinx + # super-gradients +pyparsing==2.4.5 + # via + # matplotlib + # super-gradients +pyproject-hooks==1.0.0 + # via build +python-dateutil==2.8.2 + # via + # botocore + # matplotlib +pytz==2023.3.post1 + # via babel +pyyaml==6.0.1 + # via + # omegaconf + # supervision +rapidfuzz==3.4.0 + # via + # -c requirements/base.in + # super-gradients +referencing==0.30.2 + # via + # jsonschema + # jsonschema-specifications +requests==2.31.0 + # via + # requests-oauthlib + # sphinx + # tensorboard + # torchvision +requests-oauthlib==1.3.1 + # via google-auth-oauthlib +rich==13.6.0 + # via onnx-simplifier +rpds-py==0.10.6 + # via + # jsonschema + # referencing +rsa==4.9 + # via google-auth +s3transfer==0.7.0 + # via boto3 +scipy==1.10.1 + # via + # super-gradients + # supervision +six==1.16.0 + # via python-dateutil +snowballstemmer==2.2.0 + # via sphinx +sphinx==4.0.3 + # via + # sphinx-rtd-theme + # sphinxcontrib-jquery + # super-gradients +sphinx-rtd-theme==1.3.0 + # via super-gradients +sphinxcontrib-applehelp==1.0.4 + # via sphinx +sphinxcontrib-devhelp==1.0.2 + # via sphinx +sphinxcontrib-htmlhelp==2.0.1 + # via sphinx +sphinxcontrib-jquery==4.1 + # via sphinx-rtd-theme +sphinxcontrib-jsmath==1.0.1 + # via sphinx +sphinxcontrib-qthelp==1.0.3 + # via sphinx +sphinxcontrib-serializinghtml==1.1.5 + # via sphinx +stringcase==1.2.0 + # via super-gradients +super-gradients==3.2.1 + # via -r requirements/sg.in +supervision==0.15.0 + # via -r requirements/sg.in +sympy==1.12 + # via + # onnxruntime + # torch +tensorboard==2.14.0 + # via super-gradients +tensorboard-data-server==0.7.1 + # via tensorboard +termcolor==1.1.0 + # via super-gradients +tomli==2.0.1 + # via + # build + # pip-tools + # pyproject-hooks +torch==2.1.0 + # via + # super-gradients + # torchmetrics + # torchvision +torchmetrics==0.8.0 + # via super-gradients +torchvision==0.16.0 + # via super-gradients +tqdm==4.66.1 + # via super-gradients +treelib==1.6.1 + # via super-gradients +typing-extensions==4.8.0 + # via + # onnx + # rich + # torch +urllib3==1.26.17 + # via + # botocore + # requests +werkzeug==3.0.0 + # via tensorboard +wheel==0.41.2 + # via + # pip-tools + # super-gradients + # tensorboard +wrapt==1.15.0 + # via deprecated +zipp==3.17.0 + # via + # importlib-metadata + # importlib-resources + +# The following packages are considered to be unsafe in a requirements file: +# pip +# setuptools diff --git a/requirements/test.in b/requirements/test.in index bcabec95..26329781 100644 --- a/requirements/test.in +++ b/requirements/test.in @@ -1,3 +1,4 @@ +-c constraints.in -c base.txt black>=22.3.0 coverage diff --git a/requirements/test.txt b/requirements/test.txt index 956285e3..a51f610b 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -20,12 +20,10 @@ charset-normalizer==3.3.0 # requests click==8.1.7 # via - # -c requirements/base.txt # -r requirements/test.in # black -coverage[toml]==5.3.1 +coverage[toml]==7.3.2 # via - # -c requirements/base.txt # -r requirements/test.in # pytest-cov exceptiongroup==1.1.3 @@ -92,7 +90,7 @@ platformdirs==3.11.0 # via black pluggy==1.3.0 # via pytest -pycodestyle==2.11.0 +pycodestyle==2.11.1 # via flake8 pydocstyle==6.3.0 # via flake8-docstrings @@ -122,15 +120,11 @@ sniffio==1.3.0 # httpcore # httpx snowballstemmer==2.2.0 - # via - # -c requirements/base.txt - # pydocstyle -toml==0.10.2 - # via coverage + # via pydocstyle tomli==2.0.1 # via - # -c requirements/base.txt # black + # coverage # mypy # pytest tqdm==4.66.1 @@ -145,7 +139,7 @@ typing-extensions==4.8.0 # black # huggingface-hub # mypy -urllib3==1.26.17 +urllib3==2.0.6 # via # -c requirements/base.txt # requests diff --git a/setup.py b/setup.py index 968316b1..8a4864bf 100644 --- a/setup.py +++ b/setup.py @@ -74,4 +74,5 @@ def load_text_from_file(filename: str): version=__version__, entry_points={}, install_requires=load_requirements(), + extras_require={"supergradients": load_requirements("requirements/sg.in")}, ) diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index 63ee408a..66a9fc3b 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.7.4" # pragma: no cover +__version__ = "0.7.5" # pragma: no cover diff --git a/unstructured_inference/models/super_gradients.py b/unstructured_inference/models/super_gradients.py index 99e8bad6..bdb939a2 100644 --- a/unstructured_inference/models/super_gradients.py +++ b/unstructured_inference/models/super_gradients.py @@ -1,19 +1,32 @@ import os -from typing import Callable, List, cast +from types import ModuleType +from typing import TYPE_CHECKING, Callable, List, cast import numpy as np -import supervision as sv import yaml from PIL import Image -from super_gradients.training import models from unstructured_inference.constants import Source from unstructured_inference.inference.layoutelement import LayoutElement from unstructured_inference.logger import logger from unstructured_inference.models.unstructuredmodel import UnstructuredObjectDetectionModel +if TYPE_CHECKING: + import supervision as sv + from super_gradients.training import models as _sgmodels + +sgmodels = None + class UnstructuredSuperGradients(UnstructuredObjectDetectionModel): + def __init__(self): + super().__init__() + global sgmodels + if sgmodels is None: + from super_gradients.training import models as _sgmodels + + sgmodels = _sgmodels + def predict(self, x: Image): """Predict using Super-Gradients model.""" super().predict(x) @@ -24,9 +37,10 @@ def initialize( model_arch: str, model_path: str, dataset_yaml_path: str, - callback: Callable[[np.ndarray, models.sg_module.SgModule], sv.Detections], + callback: Callable[[np.ndarray, "_sgmodels.sg_module.SgModule"], "sv.Detections"], ): """Start inference session for SuperGradients model.""" + if not os.path.exists(model_path): logger.info("Super Gradients Model Path Does Not Exist!") self.model_path = model_path @@ -34,7 +48,7 @@ def initialize( with open(dataset_yaml_path) as file: dataset_yaml = yaml.safe_load(file) - self.model = models.get( + self.model = cast(ModuleType, sgmodels).get( model_name=model_arch, num_classes=len(dataset_yaml["names"]), checkpoint_path=model_path,