In [None]:
import os
import yaml
import torch
import torchvision
from tqdm import tqdm

os.chdir('..')
from inference.utils import *
from core.utils import load_or_fail
from train import ControlNetCore, WurstCoreB

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(device)

### Choose your ControlNet

#### Inpainting / Outpainting

In [None]:
!wget https://huggingface.co/stabilityai/stable-cascade/resolve/main/controlnet/inpainting.safetensors -P models -q --show-progress

In [None]:
config_file = 'configs/inference/controlnet_c_3b_inpainting.yaml'

#### Face Identity (Not available yet)

In [None]:
# config_file = 'configs/inference/controlnet_c_3b_identity.yaml'

#### Canny

In [None]:
!wget https://huggingface.co/stabilityai/stable-cascade/resolve/main/controlnet/canny.safetensors -P models -q --show-progress

In [None]:
config_file = 'configs/inference/controlnet_c_3b_canny.yaml'

#### Super Resolution

In [None]:
!wget https://huggingface.co/stabilityai/stable-cascade/resolve/main/controlnet/super_resolution.safetensors -P models -q --show-progress

In [None]:
config_file = 'configs/inference/controlnet_c_3b_sr.yaml'

## Load Config & Models

### Load Config

In [None]:
# SETUP STAGE C
with open(config_file, "r", encoding="utf-8") as file:
    loaded_config = yaml.safe_load(file)

core = ControlNetCore(config_dict=loaded_config, device=device, training=False)

# SETUP STAGE B
config_file_b = 'configs/inference/stage_b_3b.yaml'
with open(config_file_b, "r", encoding="utf-8") as file:
    config_file_b = yaml.safe_load(file)
    
core_b = WurstCoreB(config_dict=config_file_b, device=device, training=False)

### Load Extras & Models

In [None]:
extras = core.setup_extras_pre()
models = core.setup_models(extras)
models.generator.eval().requires_grad_(False)
print("CONTROLNET READY")

extras_b = core_b.setup_extras_pre()
models_b = core_b.setup_models(extras_b, skip_clip=True)
models_b = WurstCoreB.Models(
   **{**models_b.to_dict(), 'tokenizer': models.tokenizer, 'text_model': models.text_model}
)
models_b.generator.eval().requires_grad_(False)
print("STAGE B READY")

### Inpainting / Outpainting

**Note**: You are able to define your own masks with the `mask` parameter or for demonstration purposes, use what we do during training to generate masks: use a tiny saliency model to predict the area of "interesting content", like an animal, a person, an object etc. This results in masks that closely mimic how humans actually inpaint, can be calculated extremely fast and with just a few lines of code. You have two parameters to control the masks `threshold` and `outpaint`. The former determines how much area will be masked and `outpaint` would just flip the predicted mask. Just play around with the parameters and you will get a feeling for it (`theshold` should be between 0.0 and 0.4). If you do wish, to load your own masks, just uncomment the `mask` parameter and replace it with your own.

In [None]:
batch_size = 4
url = "https://cdn.discordapp.com/attachments/1121232062708457508/1204787053892603914/cat_dog.png?ex=65d60061&is=65c38b61&hm=37c3d179a39b1eca4b8894e3c239930cedcbb965da00ae2209cca45f883f86f4&"
images = resize_image(download_image(url)).unsqueeze(0).expand(batch_size, -1, -1, -1)

batch = {'images': images}

mask = None
# mask = torch.ones(batch_size, 1, images.size(2), images.size(3)).bool()

outpaint = False
threshold = 0.2

with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
    cnet, cnet_input = core.get_cnet(batch, models, extras, mask=mask, outpaint=outpaint, threshold=threshold)
    cnet_uncond = cnet
    
show_images(batch['images'])
show_images(cnet_input)

### Face Identity

**Note**: This ControlNet lets you generate images based on faces in a given image. Simply load an image or enter the `url`.

In [None]:
batch_size = 4
url = "https://cdn.discordapp.com/attachments/1039261364935462942/1200109692978999317/three_people.png?ex=65c4fc3f&is=65b2873f&hm=064a8cebea5560b74e7088be9d1399a5fe48863d1581e65ea9d6734725f4c8d3&"
images = resize_image(download_image(url)).unsqueeze(0).expand(batch_size, -1, -1, -1)

batch = {'images': images}

with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
    cnet, cnet_input = core.get_cnet(batch, models, extras)
    cnet_uncond = core.get_cnet({'images': torch.zeros_like(batch['images'])}, models, extras)[0]
    
show_images(batch['images'])
show_images(cnet_input)

### Canny

**Note**: This here is a typical ControlNet for Canny Edge Detection. You can also use it for doing *sketch-to-image*. You can enable that, by setting `sketch = True` and providing a sketch as the image.

In [None]:
batch_size = 4
url = "https://media.discordapp.net/attachments/1177378292765036716/1205484279405219861/image.png?ex=65d889b9&is=65c614b9&hm=0722ab9707b48d677316c0b4de5e51702b43eac1e27b76c268a069ec67ff6d15&=&format=webp&quality=lossless&width=861&height=859"
images = resize_image(download_image(url)).unsqueeze(0).expand(batch_size, -1, -1, -1)
sketch = False

batch = {'images': images}

if sketch:
    cnet_input = 1-images.mean(dim=1, keepdim=True)
else:
    cnet_input = None

with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
    cnet, cnet_input = core.get_cnet(batch, models, extras, cnet_input=cnet_input)
    cnet_uncond = cnet
    
show_images(batch['images'])
show_images(cnet_input)

### Super Resolution

In [None]:
batch_size = 4
cnet_override = None
# url = "https://media.discordapp.net/attachments/1121232062708457508/1205134173053132810/image.png?ex=65d743a9&is=65c4cea9&hm=48dc4901514caada29271f48d76431f3a648940f2fda9e643a6bb693c906cc09&=&format=webp&quality=lossless&width=862&height=857"
# url = "https://cdn.discordapp.com/attachments/1121232062708457508/1204787053892603914/cat_dog.png?ex=65d60061&is=65c38b61&hm=37c3d179a39b1eca4b8894e3c239930cedcbb965da00ae2209cca45f883f86f4&"
url = "https://cdn.discordapp.com/attachments/1121232062708457508/1205110687538479145/A_photograph_of_a_sunflower_with_sunglasses_on_in__3.jpg?ex=65d72dc9&is=65c4b8c9&hm=72172e774ce6cda618503b3778b844de05cd1208b61e185d8418db512fb2858a&"
images = resize_image(download_image(url)).unsqueeze(0).expand(batch_size, -1, -1, -1)

batch = {'images': images}

with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
    effnet_latents = core.encode_latents(batch, models, extras)
    effnet_latents_up = torch.nn.functional.interpolate(effnet_latents, scale_factor=2, mode="nearest")
    cnet = models.controlnet(effnet_latents_up)
    cnet_uncond = cnet
    cnet_input = torch.nn.functional.interpolate(images, scale_factor=2, mode="nearest")
    # cnet, cnet_input = core.get_cnet(batch, models, extras)
    # cnet_uncond = cnet
    
show_images(batch['images'])
show_images(cnet_input)

### Optional: Compile Stage C and Stage B

**Note**: This will increase speed inference by about 2x, but will initially take a few minutes to compile. Moreover, currently using `torch.compile` only works for a single image resolution, e.g. 1024 x 1024. If you use a different size, it will recompile. See more [here](https://pytorch.org/tutorials/intermediate/torch_compile_tutorial.html).

In [None]:
models = ControlNetCore.Models(
   **{**models.to_dict(), 'generator': torch.compile(models.generator, mode="reduce-overhead", fullgraph=True)}
)

models_b = WurstCoreB.Models(
   **{**models_b.to_dict(), 'generator': torch.compile(models_b.generator, mode="reduce-overhead", fullgraph=True)}
)

## ControlNet Generation

In [None]:
caption = "An oil painting"
cnet_multiplier = 1.0 # 0.8 # 0.3

if "controlnet_c_3b_sr" in config_file:
    height, width = int(cnet[0].size(-2)*32*4/3), int(cnet[0].size(-1)*32*4/3)
else:
    height, width = 1024, 1024
stage_c_latent_shape, stage_b_latent_shape = calculate_latent_sizes(height, width, batch_size=batch_size)

# Stage C Parameters
extras.sampling_configs['cfg'] = 1
extras.sampling_configs['shift'] = 2
extras.sampling_configs['timesteps'] = 20
extras.sampling_configs['t_start'] = 1.0

# Stage B Parameters
extras_b.sampling_configs['cfg'] = 1.1
extras_b.sampling_configs['shift'] = 1
extras_b.sampling_configs['timesteps'] = 10
extras_b.sampling_configs['t_start'] = 1.0

# PREPARE CONDITIONS
batch['captions'] = [caption] * batch_size
conditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=False, eval_image_embeds=False)
unconditions = core.get_conditions(batch, models, extras, is_eval=True, is_unconditional=True, eval_image_embeds=False)    
conditions['cnet'] = [c.clone() * cnet_multiplier if c is not None else c for c in cnet]
unconditions['cnet'] = [c.clone() * cnet_multiplier if c is not None else c for c in cnet_uncond]
conditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=False)
unconditions_b = core_b.get_conditions(batch, models_b, extras_b, is_eval=True, is_unconditional=True)

with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.bfloat16):
    # torch.manual_seed(42)

    sampling_c = extras.gdf.sample(
        models.generator, conditions, stage_c_latent_shape,
        unconditions, device=device, **extras.sampling_configs,
    )
    for (sampled_c, _, _) in tqdm(sampling_c, total=extras.sampling_configs['timesteps']):
        sampled_c = sampled_c
        
    # preview_c = models.previewer(sampled_c).float()
    # show_images(preview_c)

    conditions_b['effnet'] = sampled_c
    unconditions_b['effnet'] = torch.zeros_like(sampled_c)

    sampling_b = extras_b.gdf.sample(
        models_b.generator, conditions_b, stage_b_latent_shape,
        unconditions_b, device=device, **extras_b.sampling_configs
    )
    for (sampled_b, _, _) in tqdm(sampling_b, total=extras_b.sampling_configs['timesteps']):
        sampled_b = sampled_b
    sampled = models_b.stage_a.decode(sampled_b).float()

show_images(cnet_input)
show_images(sampled)
