<a href="https://colab.research.google.com/github/NAGADEEP92/AI-ML-DS/blob/main/KERASCV_Keras_Core_KerasCV.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# KerasCV Stable Diffusion Multi-backend demo

### Author: Suvaditya Mukherjee

## `pip` Package installations

In [None]:
!pip install keras_core keras_cv wandb -qqq

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m728.0/728.0 kB[0m [31m20.2 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m756.2/756.2 kB[0m [31m64.4 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.1/2.1 MB[0m [31m82.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m188.5/188.5 kB[0m [31m18.0 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m214.7/214.7 kB[0m [31m25.5 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m62.7/62.7 kB[0m [31m7.9 MB/s[0m eta [36m0:00:00[0m
[?25h  Building wheel for pathtools (setup.py) ... [?25l[?25hdone


## Stable Diffusion with `torch` backend

### Import statements

In [None]:
import os

os.environ['KERAS_BACKEND'] = 'torch'

import keras_core as keras
import keras_cv
import wandb
import time

Using PyTorch backend.


### Define and warm up the model

In [None]:
keras.mixed_precision.set_global_policy("mixed_float16")

wandb.init(
    project="io-extended-sd-keras_core"
)

wandb.config = {
    "img_height":512,
    "img_width":512,
}


model = keras_cv.models.StableDiffusion(
    img_width=512,
    img_height=512,
    jit_compile=True
)

# Warming up the model
model.text_to_image("warm-up the model", batch_size=3)
print("Warm up done")

### Perform generation

In [None]:
start = time.time()
images = model.text_to_image(
    "wolverine trying to pick a lock with his claws",
    batch_size = 3
)
end = time.time()

images = [wandb.Image(
    image,
    caption=f"XLA, Mixed-Precision, in PyTorch - Example {idx+1}"
) for idx, image in enumerate(images)]

wandb.log({"Examples": images, "Generation Time":end-start})



### Release memory

In [None]:
del model
keras.backend.clear_session()

wandb.finish()

VBox(children=(Label(value='1.327 MB of 1.327 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
Generation Time,▁

0,1
Generation Time,105.79374


## Stable Diffusion with `JAX` backend

In [None]:
import os

os.environ['KERAS_BACKEND'] = 'jax'

import keras_core as keras
import keras_cv
import wandb
import time

Using JAX backend.


### Define and warm up the model

In [None]:
keras.mixed_precision.set_global_policy("mixed_float16")

wandb.init(
    project="io-extended-sd-keras_core"
)

wandb.config = {
    "img_height":512,
    "img_width":512,
}


model = keras_cv.models.StableDiffusion(
    img_width=512,
    img_height=512,
    jit_compile=True
)

# Warming up the model
model.text_to_image("warm-up the model", batch_size=3)

[34m[1mwandb[0m: Currently logged in as: [33msuvadityamuk[0m. Use [1m`wandb login --relogin`[0m to force relogin


By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE


array([[[[ 62,  64,  63],
         [ 42,  43,  43],
         [ 61,  57,  56],
         ...,
         [ 47,  47,  50],
         [ 54,  54,  50],
         [ 49,  47,  44]],

        [[ 53,  56,  56],
         [ 23,  17,  23],
         [ 43,  41,  44],
         ...,
         [ 53,  51,  53],
         [ 30,  29,  25],
         [ 44,  38,  38]],

        [[ 59,  62,  59],
         [ 49,  41,  45],
         [ 60,  56,  59],
         ...,
         [ 49,  46,  48],
         [ 33,  25,  29],
         [ 46,  43,  44]],

        ...,

        [[ 30,  31,  30],
         [ 52,  46,  48],
         [ 68,  66,  67],
         ...,
         [ 58,  58,  57],
         [ 54,  49,  51],
         [ 54,  50,  49]],

        [[ 33,  34,  34],
         [ 47,  40,  46],
         [ 54,  53,  55],
         ...,
         [ 53,  52,  53],
         [ 49,  45,  49],
         [ 62,  61,  59]],

        [[ 51,  49,  51],
         [ 59,  58,  58],
         [ 51,  50,  51],
         ...,
         [ 50,  49,  54],
        

### Perform generation

In [None]:
start = time.time()
images = model.text_to_image(
    "wolverine trying to pick a lock with his claws",
    batch_size = 3
)
end = time.time()

images = [wandb.Image(
    image,
    caption=f"XLA, Mixed-Precision, in JAX - Example {idx+1}"
) for idx, image in enumerate(images)]

wandb.log({"Examples": images, "Generation Time":end-start})



### Release memory

In [None]:
del model
keras.backend.clear_session()

wandb.finish()

VBox(children=(Label(value='1.208 MB of 1.216 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.992704…

0,1
Generation Time,▁

0,1
Generation Time,106.40664


## Stable Diffusion with `TensorFlow` backend

In [None]:
import os

os.environ['KERAS_BACKEND'] = 'tensorflow'

import keras_core as keras
import keras_cv
import wandb
import time

Using JAX backend.


### Define and warm up the model

In [None]:
keras.mixed_precision.set_global_policy("mixed_float16")

wandb.init(
    project="io-extended-sd-keras_core"
)

wandb.config = {
    "img_height":512,
    "img_width":512,
}


model = keras_cv.models.StableDiffusion(
    img_width=512,
    img_height=512,
    jit_compile=True
)

# Warming up the model
model.text_to_image("warm-up the model", batch_size=3)

### Perform generation

In [None]:
start = time.time()
images_mp = model.text_to_image(
    "wolverine trying to pick a lock with his claws",
    batch_size = 3
)
end = time.time()

images_mp = [wandb.Image(
    image,
    caption=f"XLA, Mixed-Precision, in TensorFlow - Example {idx+1}"
) for idx, image in enumerate(images_mp)]

wandb.log({"Examples": images_mp, "Generation Time":end-start})

### Release memory

In [None]:
del model
keras.backend.clear_session()

wandb.finish()

[34m[1mwandb[0m: Currently logged in as: [33msuvadityamuk[0m. Use [1m`wandb login --relogin`[0m to force relogin


By using this model checkpoint, you acknowledge that its usage is subject to the terms of the CreativeML Open RAIL-M license at https://raw.githubusercontent.com/CompVis/stable-diffusion/main/LICENSE


0,1
Generation Time,▁

0,1
Generation Time,99.67967


## That's all!

## धन्यवाद, नई दिल्ली!

Suvaditya Mukherjee  
[GitHub](http://github.com/suvadityamuk) | [Twitter](https://www.twitter.com/halcyonrayes)