Skip to content

Commit

Permalink
Fix bug.
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed May 6, 2024
1 parent 9514d91 commit e2cbdab
Showing 1 changed file with 29 additions and 26 deletions.
55 changes: 29 additions & 26 deletions modules/launch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,31 +462,6 @@ def prepare_environment():
"TORCH_COMMAND",
f"pip install torch==2.3.0 torchvision --index-url {torch_index_url}",
)
error = None
from modules import zluda_installer
try:
if args.use_zluda_dnn:
if zluda_installer.check_dnn_dependency():
zluda_installer.enable_dnn()
else:
print("Couldn't find the required dependency of ZLUDA DNN.")
zluda_installer.install()
zluda_path = zluda_installer.find()
zluda_installer.make_copy(zluda_path)
except Exception as e:
error = e
print(f'Failed to install ZLUDA: {e}')
if error is None:
try:
zluda_installer.load(zluda_path)
torch_command = os.environ.get('TORCH_COMMAND', 'pip install torch==2.3.0 torchvision --index-url https://download.pytorch.org/whl/cu118')
print(f'Using ZLUDA in {zluda_path}')
except Exception as e:
error = e
print(f'Failed to load ZLUDA: {e}')
if error is not None:
print('Using CPU-only torch')
torch_command = os.environ.get('TORCH_COMMAND', 'pip install torch torchvision')
elif args.use_ipex:
backend = "ipex"
if system == "Windows":
Expand Down Expand Up @@ -523,6 +498,7 @@ def prepare_environment():
f"pip install torch==2.3.0 torchvision --extra-index-url {torch_index_url}",
)
elif system == "Windows" and hip_found: # ZLUDA
args.use_zluda = True
print("ROCm Toolkit was found.")
backend = "cuda"
torch_index_url = os.environ.get(
Expand Down Expand Up @@ -586,7 +562,34 @@ def prepare_environment():
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
startup_timer.record("install torch")

if args.use_ipex or args.use_directml or args.use_cpu_torch:
if args.use_zluda:
error = None
from modules import zluda_installer
try:
if args.use_zluda_dnn:
if zluda_installer.check_dnn_dependency():
zluda_installer.enable_dnn()
else:
print("Couldn't find the required dependency of ZLUDA DNN.")
zluda_installer.install()
zluda_path = zluda_installer.find()
zluda_installer.make_copy(zluda_path)
except Exception as e:
error = e
print(f'Failed to install ZLUDA: {e}')
if error is None:
try:
zluda_installer.load(zluda_path)
torch_command = os.environ.get('TORCH_COMMAND', 'pip install torch==2.3.0 torchvision --index-url https://download.pytorch.org/whl/cu118')
print(f'Using ZLUDA in {zluda_path}')
except Exception as e:
error = e
print(f'Failed to load ZLUDA: {e}')
if error is not None:
print('Using CPU-only torch')
torch_command = os.environ.get('TORCH_COMMAND', 'pip install torch torchvision')

if args.use_ipex or args.use_directml or args.use_zluda or args.use_cpu_torch:
args.skip_torch_cuda_test = True
if not args.skip_torch_cuda_test and not check_run_python("import torch; assert torch.cuda.is_available()"):
raise RuntimeError(
Expand Down

0 comments on commit e2cbdab

Please sign in to comment.