diff --git a/install.bash b/install.bash index a2f5d66..c6da2d2 100644 --- a/install.bash +++ b/install.bash @@ -34,9 +34,15 @@ if (( cuda_major_version >= 12 )) || (( cuda_major_version == 11 && cuda_minor_v echo "install torch 2.0.1+cu118" pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --extra-index-url https://download.pytorch.org/whl/cu118 pip install xformers==0.0.21 -elif (( cuda_major_version == 11 && cuda_minor_version == 6 )); then +elif (( cuda_major_version == 11 && cuda_minor_version >= 6 )); then echo "install torch 1.12.1+cu116" pip install torch==1.12.1+cu116 torchvision==0.13.1+cu116 --extra-index-url https://download.pytorch.org/whl/cu116 + # for RTX3090+cu113/cu116 xformers, we need to install this version from source. You can also try xformers==0.0.18 + pip install --upgrade git+https://github.com/facebookresearch/xformers.git@0bad001ddd56c080524d37c84ff58d9cd030ebfd + pip install triton==2.0.0.dev20221202 +elif (( cuda_major_version == 11 && cuda_minor_version >= 2 )); then + echo "install torch 1.12.1+cu113" + pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu116 pip install --upgrade git+https://github.com/facebookresearch/xformers.git@0bad001ddd56c080524d37c84ff58d9cd030ebfd pip install triton==2.0.0.dev20221202 else