diff --git a/pyproject.toml b/pyproject.toml index b23646af3..480b944fd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,6 +53,7 @@ dependencies = [ "pyarrow>=21.0", "dill>=0.3.8", # datasets requirements "pypcre>=0.2.4", + "torchao>=0.14.0", # fix bad transformers 4.57.1 breaking torchao compat # "cython>=3.1.4", # required by hf-xet/hf-transfer # "flash-attn>=2.8.3", <-- install for lower vram usage ] diff --git a/requirements.txt b/requirements.txt index c2cd10489..697b032cb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -19,4 +19,4 @@ datasets>=3.6.0 pyarrow>=21.0 dill>=0.3.8 pypcre>=0.2.4 - +torchao>=0.14.0