Skip to content
This repository has been archived by the owner on Oct 3, 2023. It is now read-only.

Feat/custom model not supported in stable diffusion deploy #265

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .github/workflows/ci-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ jobs:
SIGNING_SECRET: ${{secrets.SIGNING_SECRET}}
SLACK_CLIENT_ID: ${{secrets.SLACK_CLIENT_ID}}
CLIENT_SECRET: ${{secrets.CLIENT_SECRET}}
run: coverage run --source muse -m pytest muse tests -v
run: |
coverage run --source muse -m pytest muse tests/test_custom_model_logic.py tests/test_app.py -v -s

- name: Statistics
if: success()
Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@ data/
lightning_logs/
wandb/
.idea/
.vscode/
flagged/
.python-version
github/
Expand Down
11 changes: 11 additions & 0 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def __init__(
parallel=True,
start_with_flow=False,
)
print(f"Starting work {i}...")
self.add_work(work)

self.slack_bot = MuseSlackCommandBot(command="/muse")
Expand Down Expand Up @@ -148,6 +149,16 @@ def get_work(self, index: int):
def run(self): # noqa: C901
if os.environ.get("TESTING_LAI"):
print("⚡ Lightning Dream App! ⚡")
if os.environ.get("CUSTOM_MODEL_TEST"):
work = StableDiffusionServe(
safety_embeddings_drive=self.safety_embeddings_drive,
safety_embeddings_filename=self.safety_checker_embedding_work.safety_embeddings_filename,
cloud_compute=L.CloudCompute(self.gpu_type, disk_size=30),
cache_calls=True,
parallel=True,
start_with_flow=False,
)
work.build_pipeline()

# provision these works early
if not self.load_balancer.is_running:
Expand Down
1 change: 0 additions & 1 deletion dev_install.sh
100644 → 100755
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
pip install -r requirements.txt
pip install -r requirements/dev.txt
pip install https://github.com/aniketmaurya/stable_diffusion_inference/archive/refs/tags/v0.0.2.tar.gz
pip install taming-transformers-rom1504 -q
16 changes: 15 additions & 1 deletion docs/configurations.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Muse Configurations

- `SD_VARIANT`: You can select the stable diffusion model version.
List of Environmnet variables

- `SD_VARIANT`: You can select the stable diffusion model name.
- If the value is undefined, default value is "sd1".
- Possilbe value
- Stable diffusion v1.4: sd1.4
- Stable diffusion v1.5: sd1.5, sd1, sd
- Stable diffusion v2-base: sd2, sd2_base
- Stable diffusion v2-high: sd2_high
- Any other url for downloading chkpt file
- Any other local path
- `SD_VERSION`: Specific version of stable diffusion model.
- If the value is undefined, default value is 1.
- Possible value: 1 or 2
- This varaible can be affects only configuration of stable diffusion.
47 changes: 42 additions & 5 deletions muse/components/stable_diffusion_serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,11 @@
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path
from typing import List, Optional
from typing import List, Optional, Any, Dict
from os.path import dirname

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import stable_diffusion_inference
from stable_diffusion_inference.model import SDInference

import lightning as L # noqa: E402
import torch # noqa: E402
Expand All @@ -20,6 +22,8 @@
from muse.CONST import IMAGE_SIZE, INFERENCE_REQUEST_TIMEOUT, KEEP_ALIVE_TIMEOUT # noqa: E402
from muse.utility.data_io import Data, DataBatch, TimeoutException # noqa: E402

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"


class SafetyChecker:
def __init__(self, embeddings_path):
Expand Down Expand Up @@ -73,14 +77,47 @@ def download_weights(url: str, target_folder: Path):
# extracting file
file.extractall(target_folder)

def is_custom_model(self) -> Dict[str, Any]:
sd_variant = os.environ.get("SD_VARIANT", "sd1")
sd_version = os.environ.get("SD_VERSION", 1)
supported_model_of_stable_diffusion_inference = ("sd1.4", "sd1.5", "sd1", "sd", "sd2_high", "sd2", "sd2_base")

if sd_variant in supported_model_of_stable_diffusion_inference:
return {"is_custom_model": False, "sd_variant": sd_variant, "sd_version": sd_version}
else:
return {"is_custom_model": True, "sd_variant": sd_variant, "sd_version": sd_version}

def build_pipeline(self):
"""The `build_pipeline(...)` method builds a model and trainer."""
from stable_diffusion_inference import create_text2image

print("loading model...")
# model url is loaded from stable_diffusion_inference library
# url: https://pl-public-data.s3.amazonaws.com/dream_stable_diffusion/v1-5-pruned-emaonly.ckpt
self._model = create_text2image(sd_variant=os.environ.get("SD_VARIANT", "sd1"))

model_info = self.is_custom_model()
if not model_info["is_custom_model"]:
# model url is loaded from stable_diffusion_inference library
# url: https://pl-public-data.s3.amazonaws.com/dream_stable_diffusion/v1-5-pruned-emaonly.ckpt
self._model = create_text2image(sd_variant=model_info["sd_variant"])
else:
# If the stable_diffusion_inference library can't support SD_VARIANT, build custom pipeline.
# SD_VERSION: configuration version of stable diffusion, default=1
_ROOT_DIR = dirname(stable_diffusion_inference.__file__)
config_path = f"{_ROOT_DIR}/configs/stable-diffusion/v{model_info['sd_version']}-inference.yaml"
checkpoint_path = model_info["sd_variant"]

self._model = SDInference(
config_path=config_path,
checkpoint_path=checkpoint_path,
version="1.5",
cache_dir=None,
force_download=None,
ckpt_filename=checkpoint_path,
)
print(f"{model_info['sd_variant']} is loaded.")

if os.environ.get("CUSTOM_MODEL_TEST"):
return

self.safety_embeddings_drive.get(self.safety_embeddings_filename)
self._safety_checker = SafetyChecker(self.safety_embeddings_filename)
print("model loaded")
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ Pillow==9.5.0
python-dotenv==1.0.0
uvloop>=0.16.0, <=0.17.0
lightning@ git+https://github.com/Lightning-AI/lightning@release/stable
sd_inference@ git+https://github.com/aniketmaurya/stable_diffusion_inference
9 changes: 9 additions & 0 deletions tests/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
You can test custom model downloading in your local machine.

```sh
pip install pytest coverage
cd ..

coverage run --source muse -m pytest muse tests/test_custom_model_using_local_path.py -v -s
coverage run --source muse -m pytest muse tests/test_custom_model_using_http_url.py -v -s
```
39 changes: 39 additions & 0 deletions tests/local_tests/test_custom_model_using_http_url.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import io
import os
from contextlib import redirect_stdout

from lightning.app.testing.testing import LightningTestApp, application_testing


class LightningAppCustomModelTest(LightningTestApp):
def run_once(self) -> bool:
f = io.StringIO()
with redirect_stdout(f):
super().run_once()
out = f.getvalue()
assert "is loaded.\n" == out[-11:]
return True


def test_custom_model_via_url_is_available():
cwd = os.getcwd()
cwd = os.path.join(cwd, "app.py")
envs = {
"SD_VARIANT": "https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.ckpt",
"SD_VERSION": "2",
"CUSTOM_MODEL_TEST": "True",
}
for env in envs:
os.environ[env] = envs[env]

command_line = [
cwd,
"--blocking",
"False",
"--open-ui",
"False",
]
result = application_testing(LightningAppCustomModelTest, command_line)
for env in envs:
os.environ[env] = ""
assert result.exit_code == 0
44 changes: 44 additions & 0 deletions tests/local_tests/test_custom_model_using_local_path.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import io
import os
from contextlib import redirect_stdout

from lightning.app.testing.testing import LightningTestApp, application_testing


class LightningAppCustomModelTest(LightningTestApp):
def run_once(self) -> bool:
f = io.StringIO()
with redirect_stdout(f):
super().run_once()
out = f.getvalue()
assert "is loaded.\n" == out[-11:]
return True


def test_custom_model_via_local_path_is_available():
import urllib.request

cwd = os.getcwd()
cwd = os.path.join(cwd, "app.py")

ckpt_url = "https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.ckpt"
dest = "chkpt_dir"
os.makedirs(dest, exist_ok=True)

envs = {"SD_VARIANT": "./chkpt_dir/v2-1_768-ema-pruned.ckpt", "SD_VERSION": "2", "CUSTOM_MODEL_TEST": "True"}
for env in envs:
os.environ[env] = envs[env]

urllib.request.urlretrieve(ckpt_url, os.environ["SD_VARIANT"])

command_line = [
cwd,
"--blocking",
"False",
"--open-ui",
"False",
]
result = application_testing(LightningAppCustomModelTest, command_line)
for env in envs:
os.environ[env] = ""
assert result.exit_code == 0
60 changes: 60 additions & 0 deletions tests/test_custom_model_logic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import os

from muse.components.stable_diffusion_serve import StableDiffusionServe


def test_custom_logic_check_bool():
sd_variant = "https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.ckpt"
sd_version = "2"
custom_model_test = "True"

envs = {
"SD_VARIANT": sd_variant,
"SD_VERSION": sd_version,
"CUSTOM_MODEL_TEST": custom_model_test,
}
for env in envs:
os.environ[env] = envs[env]

model_info = StableDiffusionServe().is_custom_model()
for env in envs:
os.environ[env] = ""
assert model_info["is_custom_model"] is True


def test_custom_logic_check_variant():
sd_variant = "https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.ckpt"
sd_version = "2"
custom_model_test = "True"

envs = {
"SD_VARIANT": sd_variant,
"SD_VERSION": sd_version,
"CUSTOM_MODEL_TEST": custom_model_test,
}
for env in envs:
os.environ[env] = envs[env]

model_info = StableDiffusionServe().is_custom_model()
for env in envs:
os.environ[env] = ""
assert model_info["sd_variant"] == sd_variant


def test_custom_logic_check_version():
sd_variant = "https://huggingface.co/stabilityai/stable-diffusion-2-1/resolve/main/v2-1_768-ema-pruned.ckpt"
sd_version = "2"
custom_model_test = "True"

envs = {
"SD_VARIANT": sd_variant,
"SD_VERSION": sd_version,
"CUSTOM_MODEL_TEST": custom_model_test,
}
for env in envs:
os.environ[env] = envs[env]

model_info = StableDiffusionServe().is_custom_model()
for env in envs:
os.environ[env] = ""
assert model_info["sd_version"] == sd_version
Loading