In [1]:
import sys
import os


# Add the 'zero123' directory to the Python path
zero123_dir = os.path.join("/home/pat", 'diffusion_augmentation', 'zero123')
sys.path.append(zero123_dir)

controlnet_dir = os.path.join("/home/pat", 'diffusion_augmentation', 'controlnet')
sys.path.append(controlnet_dir)

controlnet_dir = os.path.join("/home/pat", 'diffusion_augmentation', 'color_controlnet')
sys.path.append(controlnet_dir)


In [2]:
control_net_device = 'cuda:0'
zero123_device = 'cuda:1'
color_control_device = 'cuda:1'

In [3]:
print(sys.path)

['/home/pat/miniconda3/envs/auto_aug/lib/python310.zip', '/home/pat/miniconda3/envs/auto_aug/lib/python3.10', '/home/pat/miniconda3/envs/auto_aug/lib/python3.10/lib-dynload', '', '/home/pat/miniconda3/envs/auto_aug/lib/python3.10/site-packages', '/home/pat/diffusion_augmentation/zero123/src/taming-transformers', '/home/pat/diffusion_augmentation/zero123/src/clip', '/home/pat/miniconda3/envs/auto_aug/lib/python3.10/site-packages/setuptools/_vendor', '/home/pat/diffusion_augmentation/zero123', '/home/pat/diffusion_augmentation/controlnet', '/home/pat/diffusion_augmentation/color_controlnet']


In [4]:
import torch
import numpy as np
import cv2
from PIL import Image
import random
import einops
from torchvision.transforms.functional import to_pil_image
from transformers import AutoProcessor, LlavaForConditionalGeneration, SamModel, AutoImageProcessor, DPTForDepthEstimation
from transformers import pipeline
from controlnet.annotator.util import resize_image, HWC3
from controlnet.annotator.canny import CannyDetector
from controlnet.annotator.uniformer import UniformerDetector
from controlnet.annotator.midas import MidasDetector
from omegaconf import OmegaConf
from controlnet.cldm.model import create_model, load_state_dict
from controlnet.cldm.ddim_hacked import DDIMSampler
from pytorch_lightning import seed_everything

# Zero123 imports
from zero123.nerf import load_model_from_config, generate_angles
from zero123.ldm.util import create_carvekit_interface

# Color Control imports
from color_controlnet.diffusers import ControlNetModel, LineartDetector, StableDiffusionImg2ImgControlNetPalettePipeline
from color_controlnet.diffusers import UniPCMultistepScheduler
from color_controlnet.infer_palette import get_cond_color, show_anns, image_grid, HWC3, resize_in_buckets, SAMImageAnnotator
from color_controlnet.infer_palette_img2img import control_color_augment
import sys
import os

  from .autonotebook import tqdm as notebook_tqdm
  @torch.library.impl_abstract("xformers_flash::flash_fwd")
  @torch.library.impl_abstract("xformers_flash::flash_bwd")


In [5]:
def initialize_models():
    """
    Initialize all models and detectors used in the pipeline
    Returns:
        control_net: dict, containing all ControlNet models and detectors
        llava: dict, containing the LLAVA model and processor
        zero123: dict, containing the Zero123 model and Carvekit interface
        color_control: dict, containing the Color Control model and SAM annotator

    """

    # Initialize ControlNet models
    print('Loading ControlNet models...')
    control_net = {}
    model_names = ['control_v11p_sd15_canny', 'control_v11f1p_sd15_depth', 'control_v11p_sd15_seg']
    models = {}
    for name in model_names:
        model = create_model(f'./models/{name}.yaml').cpu()
        model.load_state_dict(load_state_dict('./models/v1-5-pruned.ckpt', location=control_net_device), strict=False)
        model.load_state_dict(load_state_dict(f'./models/{name}.pth', location=control_net_device), strict=False)
        models[name] = model.to(control_net_device)

    # Initialize Control Netdetectors
    apply_canny = CannyDetector()
    apply_depth = MidasDetector()
    apply_seg = UniformerDetector()
    detectors = {'Canny': apply_canny, 'Depth': apply_depth, 'Segmentation': apply_seg}
    
    control_net['models'] = models
    control_net['detectors'] = detectors

    # Initialize LLAVA model
    print('Loading LLAVA model...')
    llava = {}
    model_id = "llava-hf/llava-1.5-7b-hf"
    processor = AutoProcessor.from_pretrained(model_id)
    llava_model = LlavaForConditionalGeneration.from_pretrained(model_id)
    llava['processor'] = processor
    llava['model'] = llava_model

    # Zero123 models
    print('Loading Zero123 models...')
    zero123 = {}
    config_path = './models/sd-objaverse-finetune-c_concat-256.yaml'
    config = OmegaConf.load(config_path)

    model_path = "./models/105000.ckpt"
    model = load_model_from_config(config, model_path, zero123_device)
    model = model.to(zero123_device)

    # print('Creating Carvekit interface...')
    carvekit_interface = create_carvekit_interface()

    zero123['model'] = model
    zero123['carvekit_interface'] = carvekit_interface 


    # Color Control model
    print('Loading Color Control model...')
    color_control = {}

    controlnet = ControlNetModel.from_config("./model_configs/controlnet_config.json").half()
    adapter = ControlNetModel.from_config("./model_configs/controlnet_config.json").half()

    sketch_method = "skmodel"
    sam_annotator = SAMImageAnnotator()

    model_ckpt = f"./model_configs/color_img2img_palette.pt"
    model_sd = torch.load(model_ckpt, map_location="cpu")["module"]

    # assign the weights of the controlnet and adapter separately
    controlnet_sd = {}
    adapter_sd = {}
    for k in model_sd.keys():
        if k.startswith("controlnet"):
            controlnet_sd[k.replace("controlnet.", "")] = model_sd[k]
        if k.startswith("adapter"):
            adapter_sd[k.replace("adapter.", "")] = model_sd[k]

    msg_control = controlnet.load_state_dict(controlnet_sd, strict=True)
    print(f"msg_control: {msg_control} ")
    if adapter is not None:
        msg_adapter = adapter.load_state_dict(adapter_sd, strict=False)
        print(f"msg_adapter: {msg_adapter} ")

    # define the inference pipline
    # sdv15_path = "/home/pat/diffusion_augmentation/color_controlnet/model_configs/sd15_config.json"
    pipe = StableDiffusionImg2ImgControlNetPalettePipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        controlnet=controlnet,
        adapter=adapter,
        torch_dtype=torch.float16,
        safety_checker=None,
    ).to(color_control_device)
    pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)

    color_control['pipe'] = pipe
    color_control['sam_annotator'] = sam_annotator
    color_control['adapter'] = adapter 

    return control_net, llava, zero123, color_control


In [6]:
control_net, llava, zero123, color_control = initialize_models()

Loading ControlNet models...
ControlLDM: Running in eps-prediction mode
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 768 and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 768 and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 768 and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 768 and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 8 heads.
Setting up MemoryEfficientCrossAttention



Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 768 and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 768 and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 768 and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 768 and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is None and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1280, context_dim is 768 and using 8 heads.
Setting up Me

  state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))


Loaded state_dict from [./models/v1-5-pruned.ckpt]
Loaded state_dict from [./models/control_v11p_sd15_canny.pth]
ControlLDM: Running in eps-prediction mode
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 768 and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is None and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 320, context_dim is 768 and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 768 and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is None and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 640, context_dim is 768 and using 8 heads.
Setting up MemoryEfficientCrossAttention. Query dim is 1

  model = create_fn(
  parameters = torch.load(path, map_location=torch.device('cpu'))


Use Checkpoint: False
Checkpoint Number: [0, 0, 0, 0]
Use global window for all blocks in stage3
load checkpoint from local path: /home/pat/diffusion_augmentation/controlnet/annotator/ckpts/upernet_global_small.pth
Loading LLAVA model...


  checkpoint = torch.load(filename, map_location=map_location)
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Loading checkpoint shards: 100%|██████████| 3/3 [00:03<00:00,  1.06s/it]
  pl_sd = torch.load(ckpt, map_location='cpu')


Loading Zero123 models...
Loading model from ./models/105000.ckpt
LatentDiffusion: Running in eps-prediction mode
DiffusionWrapper has 859.53 M params.
Keeping EMAs of 688.
making attention of type 'vanilla' with 512 in_channels
Working with z of shape (1, 4, 32, 32) = 4096 dimensions.
making attention of type 'vanilla' with 512 in_channels


  torch.load(model_path, map_location=self.device), strict=False
  self.load_state_dict(torch.load(fba_pretrained(), map_location=self.device))
  deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)


Loading Color Control model...


  state_dict = torch.load(f)
  model_sd = torch.load(model_ckpt, map_location="cpu")["module"]


msg_control: <All keys matched successfully> 
msg_adapter: <All keys matched successfully> 


Fetching 15 files: 100%|██████████| 15/15 [00:00<00:00, 180270.95it/s]
You have disabled the safety checker for <class 'color_controlnet.diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img_controlnet_palette.StableDiffusionImg2ImgControlNetPalettePipeline'> by passing `safety_checker=None`. Ensure that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered results in services or applications open to the public. Both the diffusers team and Hugging Face strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling it only for use-cases that involve analyzing network behavior or auditing its results. For more information, please have a look at https://github.com/huggingface/diffusers/pull/254 .
The config attributes {'skip_prk_steps': True, 'set_alpha_to_one': False, 'steps_offset': 1, 'clip_sample': False} were passed to UniPCMultistepScheduler, but are not expected and will be ignored. Please verif

In [9]:
example_image = Image.open("./test_images/original.png")
caption = "a painting of a tent with a forest in the background"

color_augmented = control_color_augment(example_image, color_control['adapter'], color_control['pipe'], caption, color_control['sam_annotator'], 1, color_control_device)

color_augmented[0].save("./test_images/color_augmented.png")

  model.load_state_dict(torch.load(modelpath, map_location=torch.device('cpu')))
100%|██████████| 22/22 [00:02<00:00,  9.03it/s]
