Skip to content

Commit

Permalink
Merge pull request #2225 from SCIInstitute/2224-light-the-torch
Browse files Browse the repository at this point in the history
Resolve #2224 - Switch to light-the-torch for pytorch installation.
  • Loading branch information
akenmorris committed Mar 30, 2024
2 parents 81e32c9 + 48173f7 commit 6dec614
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 37 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/gha_conda.sh
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ else

# run install
source ./install_shapeworks.sh --developer
conda clean -p -t
conda clean -p -t -y

echo "Create and store cache"
cd /
Expand Down
6 changes: 4 additions & 2 deletions install_shapeworks.bat
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ REM reactivate shapeworks environment
call conda activate base
call conda activate %CONDAENV%

call pip install torch==1.11.0+cpu torchvision==0.12.0+cpu torchaudio===0.11.0 -f https://download.pytorch.org/whl/torch_stable.html

call pip install -r python_requirements.txt

REM install ptorch using light-the-torch
call ltt install torch==1.11.0

REM different versions of open3d for different OSes, so we install it manually here
call pip install open3d==0.17.0

call pip install Python/DatasetUtilsPackage.tar.gz
Expand Down
36 changes: 5 additions & 31 deletions install_shapeworks.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,34 +52,6 @@ fi
echo "Creating new conda environment for ShapeWorks called \"$CONDAENV\"..."


# PyTorch installation
function install_pytorch() {
echo "installing pytorch"
if [[ "$(uname)" == "Darwin" ]]; then
pip install torch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0
elif ! [ -x "$(command -v nvidia-smi)" ]; then
echo 'Could not find nvidia-smi, using cpu-only PyTorch'
pip install torch==1.11.0+cpu torchvision==0.12.0+cpu torchaudio==0.11.0 -f https://download.pytorch.org/whl/torch_stable.html
else
CUDA=`nvidia-smi | grep CUDA | sed -e "s/.*CUDA Version: //" -e "s/ .*//"`
echo "Found CUDA Version: ${CUDA}"
if [[ "$CUDA" == "9.2" ]]; then
pip install torch==1.11.0+cu92 torchvision==0.12.0+cu92 torchaudio==0.11.0 -f https://download.pytorch.org/whl/torch_stable.html
elif [[ "$CUDA" == "10.1" ]]; then
pip install torch==1.11.0+cu101 torchvision==0.12.0+cu101 torchaudio==0.11.0 -f https://download.pytorch.org/whl/torch_stable.html
elif [[ "$CUDA" == "10.2" ]]; then
pip install torch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 -f https://download.pytorch.org/whl/torch_stable.html
elif [[ "$CUDA" == "11.0" || "$CUDA" == "11.1" || "$CUDA" == "11.2" ]]; then
pip install torch==1.11.0+cu110 torchvision==0.12.0+cu110 torchaudio==0.11.0 -f https://download.pytorch.org/whl/torch_stable.html
elif [[ "$CUDA" == "11.7" || "$CUDA" == "11.8" || "$CUDA" == "12.0" || "$CUDA" == "12.1" ]]; then
pip install torch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 -f https://download.pytorch.org/whl/cu118
else
echo "CUDA version not compatible, using cpu-only"
pip install torch==1.11.0+cpu torchvision==0.12.0+cpu torchaudio==0.11.0 -f https://download.pytorch.org/whl/torch_stable.html
fi
fi
}

function install_conda() {
if ! command -v conda 2>/dev/null 1>&2; then
echo "Installing Miniconda..."
Expand Down Expand Up @@ -158,10 +130,13 @@ function install_conda() {
# install conda into the shell
conda init

if ! pip install -r python_requirements.txt; then return 1; fi

if ! python -m pip install -r python_requirements.txt; then return 1; fi

# install pytorch using light-the-torch
if ! ltt install torch==1.13.1 torchaudio==0.13.1 torchvision==0.14.1; then return 1; fi

# for network analysis
# open3d needs to be installed differently on each platform so it's not part of python_requirements.txt
if [[ "$(uname)" == "Linux" ]]; then
if ! pip install open3d-cpu==0.17.0; then return 1; fi
elif [[ "$(uname)" == "Darwin" ]]; then
Expand Down Expand Up @@ -233,7 +208,6 @@ function install_conda() {
}

if install_conda; then
install_pytorch

echo "Conda info:"
conda info
Expand Down
4 changes: 1 addition & 3 deletions python_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.8
jupytext==1.14.7
kiwisolver==1.4.4
light-the-torch==0.7.5
lxml==4.9.3
Markdown==3.3.7
markdown-it-py==3.0.0
Expand Down Expand Up @@ -170,9 +171,6 @@ terminado==0.17.1
threadpoolctl==3.1.0
tinycss2==1.2.1
toml==0.10.2
torch==1.13.1
torchaudio==0.13.1
torchvision==0.14.1
tornado==6.3.3
tqdm==4.65.0
traitlets==5.9.0
Expand Down

0 comments on commit 6dec614

Please sign in to comment.