Skip to content

feat(integrations): Add support for diffusion pipelines not in the list of supported pipelines #7450

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from

Conversation

soumik12345
Copy link
Contributor

Description

The diffusers integration currently only supports pipelines that are mentioned in the SUPPORTED_MULTIMODAL_PIPELINES. This PR updates the integration to track pipelines that are not in this list, irrespective of whether the pipeline is even part of the diffusers library or not.

Example 1: Tracking the Pixart Sigma pipeline from the official codebase

  1. First clone the library using https://github.com/PixArt-alpha/PixArt-sigma.
  2. Next, rename the repository to PixArt_sigma in order to treat it as a python module.
  3. Install diffusers from source using pip install git+https://github.com/huggingface/diffusers.
  4. Run the code using the autologger:
import torch
from diffusers import Transformer2DModel
from PixArt_sigma.scripts.diffusers_patches import (
    pixart_sigma_init_patched_inputs,
    PixArtSigmaPipeline,
)
from wandb.integration.diffusers import autolog

# We tell the autolog exactly which pipeline to track
pipeline_log_config = (
    dict(
        api_module="PixArt_sigma.scripts.diffusers_patches",
        pipeline=PixArtSigmaPipeline,
        kwarg_logging=["prompt", "negative_prompt"],
    )
    if not autolog.check_pipeline_support(PixArtSigmaPipeline)
    else dict()
)
autolog(init=dict(project="diffusers_logging", job_type="test"), **pipeline_log_config)

assert getattr(
    Transformer2DModel, "_init_patched_inputs", False
), "Need to Upgrade diffusers: pip install git+https://github.com/huggingface/diffusers"
setattr(Transformer2DModel, "_init_patched_inputs", pixart_sigma_init_patched_inputs)
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
weight_dtype = torch.float16

transformer = Transformer2DModel.from_pretrained(
    "PixArt-alpha/PixArt-Sigma-XL-2-1024-MS",
    subfolder="transformer",
    torch_dtype=weight_dtype,
    use_safetensors=True,
)
pipe = PixArtSigmaPipeline.from_pretrained(
    "PixArt-alpha/pixart_sigma_sdxlvae_T5_diffusers",
    transformer=transformer,
    torch_dtype=weight_dtype,
    use_safetensors=True,
)
pipe.to(device)

prompt = "A small cactus with a happy face in the Sahara desert."
image = pipe(prompt).images[0]

Sample Run: https://wandb.ai/geekyrakshit/diffusers_logging/runs/31f4mpf1

Example 2: Tracking the StableCascade Pipeline

The StableCascadeCombinedPipeline is currently not part of the SUPPORTED_MULTIMODAL_PIPELINES.

import torch
from diffusers import StableCascadeCombinedPipeline

from wandb.integration.diffusers import autolog


# We tell the autolog exactly which pipeline to track
pipeline_log_config = (
    dict(
        api_module="diffusers",
        pipeline=StableCascadeCombinedPipeline,
        kwarg_logging=["prompt", "negative_prompt"],
    )
    if not autolog.check_pipeline_support(StableCascadeCombinedPipeline)
    else dict()
)
autolog(init=dict(project="diffusers_logging", job_type="test"), **pipeline_log_config)

pipe = StableCascadeCombinedPipeline.from_pretrained(
    "stabilityai/stable-cascade", torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
prompt = "an image of a shiba inu, donning a spacesuit and helmet"
images = pipe(prompt=prompt)

Sample Run: https://wandb.ai/geekyrakshit/diffusers_logging/runs/31f4mpf1

  • I updated CHANGELOG.md, or it's not applicable

Copy link

codecov bot commented Apr 22, 2024

Codecov Report

Attention: Patch coverage is 0% with 152 lines in your changes missing coverage. Please review.

Project coverage is 74.50%. Comparing base (cfe348d) to head (8826b65).
Report is 1501 commits behind head on main.

Files with missing lines Patch % Lines
wandb/integration/diffusers/pipeline_resolver.py 0.00% 102 Missing ⚠️
...ndb/integration/diffusers/diffusers_autolog_api.py 0.00% 44 Missing ⚠️
wandb/integration/diffusers/autologger.py 0.00% 4 Missing ⚠️
wandb/integration/diffusers/__init__.py 0.00% 2 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7450      +/-   ##
==========================================
+ Coverage   72.53%   74.50%   +1.97%     
==========================================
  Files         502      501       -1     
  Lines       54197    52242    -1955     
==========================================
- Hits        39310    38922     -388     
+ Misses      14389    12821    -1568     
- Partials      498      499       +1     
Flag Coverage Δ
func 44.66% <0.00%> (+0.01%) ⬆️
system 61.96% <0.00%> (-0.05%) ⬇️
unit 56.69% <0.00%> (+1.31%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

Files with missing lines Coverage Δ
wandb/integration/diffusers/utils.py 0.00% <ø> (ø)
wandb/integration/diffusers/__init__.py 0.00% <0.00%> (ø)
wandb/integration/diffusers/autologger.py 0.00% <0.00%> (ø)
...ndb/integration/diffusers/diffusers_autolog_api.py 0.00% <0.00%> (ø)
wandb/integration/diffusers/pipeline_resolver.py 0.00% <0.00%> (ø)

... and 125 files with indirect coverage changes

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@soumik12345 soumik12345 requested a review from a team June 10, 2024 13:02
plugin:
- wandb
tag:
shard: standalone-gpu
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we do not run this shard in CI

@kptkin kptkin removed the cc-feat label Sep 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants