Skip to content

Commit

Permalink
Fix ONNX/Olive.
Browse files Browse the repository at this point in the history
  • Loading branch information
lshqqytiger committed May 5, 2024
1 parent 7fd77f2 commit 9514d91
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 7 deletions.
8 changes: 4 additions & 4 deletions modules/onnx_impl/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,17 +379,17 @@ def preprocess(self, p: StableDiffusionProcessing):
}
in_dir = out_dir

if shared.opts.cuda_compile_backend == "olive-ai":
if shared.opts.olive_enable:
submodels_for_olive = []

if "Text Encoder" in shared.opts.cuda_compile:
if "Text Encoder" in shared.opts.olive_submodels:
if not self.is_refiner:
submodels_for_olive.append("text_encoder")
if self._is_sdxl:
submodels_for_olive.append("text_encoder_2")
if "Model" in shared.opts.cuda_compile:
if "Model" in shared.opts.olive_submodels:
submodels_for_olive.append("unet")
if "VAE" in shared.opts.cuda_compile:
if "VAE" in shared.opts.olive_submodels:
submodels_for_olive.append("vae_encoder")
submodels_for_olive.append("vae_decoder")

Expand Down
2 changes: 1 addition & 1 deletion modules/onnx_impl/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def create_ui():
ep_log = gr.HTML("")
ep_install.click(fn=install_execution_provider, inputs=[ep_checkbox], outputs=[ep_log])

if opts.cuda_compile_backend == "olive-ai":
if opts.olive_enable:
import olive.passes as olive_passes
from olive.hardware.accelerator import AcceleratorSpec, Device
accelerator = AcceleratorSpec(accelerator_type=Device.GPU, execution_provider=opts.onnx_execution_provider)
Expand Down
3 changes: 1 addition & 2 deletions modules/onnx_impl/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,7 @@ def move_inference_session(session: ort.InferenceSession, device: torch.device):


def check_diffusers_cache(path: os.PathLike):
from modules.shared import opts
return opts.diffusers_dir in os.path.abspath(path)
return False


def check_pipeline_sdxl(cls: Type[diffusers.DiffusionPipeline]) -> bool:
Expand Down

0 comments on commit 9514d91

Please sign in to comment.