In [1]:
from typing import Tuple, Union, List
import os
import cv2
import numpy as np
from PIL import Image
import torch
from diffusers import ControlNetModel
from diffusers.pipelines.controlnet import StableDiffusionControlNetInpaintPipeline, StableDiffusionControlNetPipeline
from diffusers import ControlNetModel, UniPCMultistepScheduler
from transformers import AutoImageProcessor, UperNetForSemanticSegmentation, CLIPTextModel, CLIPTokenizer, DataCollatorWithPadding, pipeline
from diffusers import (
    AutoencoderKL,
    DDPMScheduler,
    StableDiffusionInpaintPipeline,
    UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor

from models.colors import ade_palette
from models.utils import map_colors_rgb
from diffusers.utils import load_image
from diffusers import StableDiffusionDepth2ImgPipeline

  torch.utils._pytree._register_pytree_node(
  torch.utils._pytree._register_pytree_node(


CLIP MODELS

In [2]:
text_encoder = CLIPTextModel.from_pretrained(
    "runwayml/stable-diffusion-inpainting",
    subfolder="text_encoder") 

tokenizer = CLIPTokenizer.from_pretrained(
    "runwayml/stable-diffusion-inpainting",
    subfolder="tokenizer")

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

FUNCTIONS

In [3]:
def filter_items(
    colors_list: Union[List, np.ndarray],
    items_list: Union[List, np.ndarray],
    items_to_retain: Union[List, np.ndarray]
) -> Tuple[Union[List, np.ndarray], Union[List, np.ndarray]]:
    """
    Filters items and their corresponding colors from given lists, excluding
    specified items.

    Args:
        colors_list: A list or numpy array of colors corresponding to items.
        items_list: A list or numpy array of items.
        items_to_remove: A list or numpy array of items to be removed.

    Returns:
        A tuple of two lists or numpy arrays: filtered colors and filtered
        items.
    """
    filtered_colors = []
    filtered_items = []
    for color, item in zip(colors_list, items_list):
        if item in items_to_retain:
            filtered_colors.append(color)
            filtered_items.append(item)
    return filtered_colors, filtered_items

def filter_items_mask(colors_list,items_list,items_to_mask):
    """
    Filters items and their corresponding colors from given lists, excluding
    specified items.

    Args:
        colors_list: A list or numpy array of colors corresponding to items.
        items_list: A list or numpy array of items.
        items_to_remove: A list or numpy array of items to be removed.

    Returns:
        A tuple of two lists or numpy arrays: filtered colors and filtered
        items.
    """
    filtered_colors = []
    filtered_items = []
    for color, item in zip(colors_list, items_list):
        if item not in items_to_mask:
            filtered_colors.append(color)
            filtered_items.append(item)
    return filtered_colors, filtered_items

def filter_items_retain(colors_list,items_list,items_to_retain):
    """
    Filters items and their corresponding colors from given lists, excluding
    specified items.

    Args:
        colors_list: A list or numpy array of colors corresponding to items.
        items_list: A list or numpy array of items.
        items_to_remove: A list or numpy array of items to be removed.

    Returns:
        A tuple of two lists or numpy arrays: filtered colors and filtered
        items.
    """
    filtered_colors = []
    filtered_items = []
    for color, item in zip(colors_list, items_list):
        if item in items_to_retain:
            filtered_colors.append(color)
            filtered_items.append(item)
    return filtered_colors, filtered_items

def get_segmentation_pipeline(
) -> Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]:
    """Method to load the segmentation pipeline
    Returns:
        Tuple[AutoImageProcessor, UperNetForSemanticSegmentation]: segmentation pipeline
    """
    image_processor = AutoImageProcessor.from_pretrained(
         "openmmlab/upernet-convnext-small"
    )
    image_segmentor = UperNetForSemanticSegmentation.from_pretrained(
        "openmmlab/upernet-convnext-small"
    )
    return image_processor, image_segmentor


@torch.inference_mode()
@torch.autocast('cuda')
def segment_image(
        image: Image,
        image_processor: AutoImageProcessor,
        image_segmentor: UperNetForSemanticSegmentation
) -> Image:
    """
    Segments an image using a semantic segmentation model.

    Args:
        image (Image): The input image to be segmented.
        image_processor (AutoImageProcessor): The processor to prepare the
            image for segmentation.
        image_segmentor (UperNetForSemanticSegmentation): The semantic
            segmentation model used to identify different segments in the image.

    Returns:
        Image: The segmented image with each segment colored differently based
            on its identified class.
    """
    # image_processor, image_segmentor = get_segmentation_pipeline()
    pixel_values = image_processor(image, return_tensors="pt").pixel_values
    with torch.no_grad():
        outputs = image_segmentor(pixel_values)

    seg = image_processor.post_process_semantic_segmentation(
        outputs, target_sizes=[image.size[::-1]])[0]
    color_seg = np.zeros((seg.shape[0], seg.shape[1], 3), dtype=np.uint8)
    palette = np.array(ade_palette())
    for label, color in enumerate(palette):
        color_seg[seg == label, :] = color
    color_seg = color_seg.astype(np.uint8)
    seg_image = Image.fromarray(color_seg).convert('RGB')
    return seg_image

def resize_dimensions(dimensions, target_size):
    """ 
    Resize PIL to target size while maintaining aspect ratio 
    If smaller than target size leave it as is
    """
    width, height = dimensions

    # Check if both dimensions are smaller than the target size
    if width < target_size and height < target_size:
        return dimensions

    # Determine the larger side
    if width > height:
        # Calculate the aspect ratio
        aspect_ratio = height / width
        # Resize dimensions
        return (target_size, int(target_size * aspect_ratio))
    else:
        # Calculate the aspect ratio
        aspect_ratio = width / height
        # Resize dimensions
        return (int(target_size * aspect_ratio), target_size)

def tokenize_function(caption):
    return tokenizer(caption, truncation=False)

def do_encode(inputs, text_encoder, device, max_seq_len=75):
    embeddings = []
    tokens = inputs['input_ids']
    attention_mask = inputs['attention_mask']
    num_chunks = (tokens.size(1) + max_seq_len - 1) // max_seq_len

    text_encoder = text_encoder.to(device)
    tokens = tokens.to(device)
    attention_mask = attention_mask.to(device)
    
    for i in range(num_chunks):
        start_idx = i * max_seq_len
        end_idx = start_idx + max_seq_len
        chunk_tokens = tokens[:, start_idx:end_idx]
        # chunk_attention_mask = attention_mask[:, start_idx:end_idx]

        chunk_embeddings = text_encoder.text_model.embeddings.token_embedding(chunk_tokens)

        chunk_size = chunk_tokens.size(1)
        position_ids = torch.arange(start_idx, start_idx + chunk_size, dtype=torch.long)
        position_ids = position_ids.unsqueeze(0).expand(chunk_tokens.size(0), chunk_size)

        position_ids = torch.clamp(position_ids.to(device), max=text_encoder.text_model.embeddings.position_embedding.num_embeddings - 1)
        position_embeddings = text_encoder.text_model.embeddings.position_embedding(position_ids)
        chunk_embeddings += position_embeddings

        embeddings.append(chunk_embeddings)

    concatenated_embeddings = torch.cat(embeddings, dim=1)
    attention_mask_expanded = attention_mask.unsqueeze(1).unsqueeze(2).repeat(1, 1, attention_mask.shape[1], 1)
    encoder_outputs = text_encoder.text_model.encoder(concatenated_embeddings, attention_mask=attention_mask_expanded)
    return (encoder_outputs.last_hidden_state)
    # return encoder_outputs[0]

def get_pipeline_embeds_mod(input_ids, negative_ids):
    max_length = tokenizer.model_max_length
    shape_max_length = max(input_ids.shape[-1], negative_ids.shape[-1])                                 
    concat_embeds = []
    neg_embeds = []
    for i in range(0, shape_max_length, max_length):
        concat_embeds.append(text_encoder(input_ids[:, i: i + max_length])[0])
        neg_embeds.append(text_encoder(negative_ids[:, i: i + max_length])[0])
    return torch.cat(concat_embeds, dim=1), torch.cat(neg_embeds, dim=1)

MAIN FUNCTION

In [7]:
class ControlNetDesignModel_wall_mask:
    """ Produces random noise images """
    def __init__(self):
        """ Initialize your model(s) here """

        os.environ['HF_HUB_OFFLINE'] = "False"

        unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-inpainting", subfolder="unet")
                    
        unet.requires_grad_(False)
        weight_dtype = torch.float32
        unet.to('cuda', dtype=weight_dtype)

        lora_attn_procs = {}
        for name in unet.attn_processors.keys():
            # print(f'name in unet : {name}')
            cross_attention_dim = (
                None
                if name.endswith("attn1.processor")
                else unet.config.cross_attention_dim
            )
            if name.startswith("mid_block"):
                hidden_size = unet.config.block_out_channels[-1]
            elif name.startswith("up_blocks"):
                block_id = int(name[len("up_blocks.")])
                hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
            elif name.startswith("down_blocks"):
                block_id = int(name[len("down_blocks.")])
                hidden_size = unet.config.block_out_channels[block_id]

            lora_attn_procs[name] = LoRAAttnProcessor(
                hidden_size=hidden_size,
                cross_attention_dim=cross_attention_dim,
                rank=64,
            )
            
        unet.set_attn_processor(lora_attn_procs)
        lora_layers = AttnProcsLayers(unet.attn_processors)    

        ### NEW CODE - DEPTH2IMG ###
        self.depth2img_pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
            "stabilityai/stable-diffusion-2-depth",
            torch_dtype=torch.float16,
        )
        self.depth2img_pipe.to("cuda")
        #######

        controlnet_seg = ControlNetModel.from_pretrained(
            "BertChristiaens/controlnet-seg-room", torch_dtype=torch.float32)

        self.pipe = StableDiffusionControlNetInpaintPipeline.from_pretrained(
            "runwayml/stable-diffusion-inpainting",
            controlnet=controlnet_seg,
            safety_checker=None,
            torch_dtype=torch.float32
        )

        self.pipe.scheduler = UniPCMultistepScheduler.from_config(self.pipe.scheduler.config)
        self.pipe.enable_xformers_memory_efficient_attention()
        self.pipe = self.pipe.to("cuda")

        # customised unet
        # /models/unet_fine_tuned_weights/root/stable/output/pytorch_lora_weights.safetensors
        unet_weight_path = "/models/unet_fine_tuned_weights/root/stable/output/pytorch_lora_weights.safetensors"
        #unet_weight_path = "models/unet_fine_tuned_weights/pytorch_lora_weights_run1403.safetensors"
        self.pipe.unet.load_attn_procs(unet_weight_path, use_safetensors=True)

        self.seg_image_processor, self.image_segmentor = get_segmentation_pipeline()

        self.seed = 2
        self.neg_prompt = "lowres, watermark, banner, logo, contactinfo, text, deformed, blurry, blur, \
        out of focus, out of frame, surreal, ugly, distortion, low-res, poor quality, "
        self.additional_quality_suffix = "interior design, 4K, high resolution"        
        self.control_items = ["floor;flooring", "rug;carpet;carpeting", "wall", "ceiling"]
        self.control_items_mask = ["stairs;steps", "step;stair", "stairway;staircase", "radiator", "screen;door;screen", "windowpane;window", "door;double;door", "countertop", "fireplace;hearth;open;fireplace","column;pillar"]
        self.control_items_retain = ["floor;flooring", "rug;carpet;carpeting", "wall", "ceiling"]

    def generate_design(self, empty_room_image: Image, prompt: str) -> Image:
        """
        Given an image of an empty room and a prompt
        generate the designed room according to the prompt
        Inputs - 
            empty_room_image - An RGB PIL Image of the empty room
            prompt - Text describing the target design elements of the room
        Returns - 
            design_image - PIL Image of the same size as the empty room image
                           If the size is not the same the submission will fail.
        """            


        ## prompt - embeddings
        pos_prompt = prompt + f', {self.additional_quality_suffix},'
        prompt_lst = [pos_prompt, self.neg_prompt]
        prompt_token_lst = []
        for prompt in prompt_lst:
            prompt_dict = tokenize_function(prompt)
            prompt_token_lst.append(prompt_dict)
        prompt_tensors = data_collator(prompt_token_lst)
        prompt_ids = prompt_tensors['input_ids']
        pos_prompt_ids = prompt_ids[0, :].unsqueeze(0)
        neg_prompt_ids = prompt_ids[1, :].unsqueeze(0)
        pos_prompt_embed, neg_prompt_embed = get_pipeline_embeds_mod(pos_prompt_ids, neg_prompt_ids) 

        
        ### NEW CODE - depth to image ####
        depth2img = self.depth2img_pipe(prompt=pos_prompt,
            negative_prompt=self.neg_prompt, image=empty_room_image, strength=0.75).images[0]
        
        depth2img_np = np.array(depth2img)
        input_image = Image.fromarray(depth2img_np).convert("RGB")

        # image resizing ( after depth)
        orig_w, orig_h = input_image.size
        new_width, new_height = resize_dimensions(input_image.size, 768)
        input_image = input_image.resize((new_width, new_height))
        image_np = np.array(input_image)
        image = Image.fromarray(image_np).convert("RGB")
        print(f"Image after depth2img")
        image.show()
        #################

        # segment image - first segmentation model
        real_seg = np.array(segment_image(input_image,
                                          self.seg_image_processor,
                                          self.image_segmentor))
        unique_colors = np.unique(real_seg.reshape(-1, real_seg.shape[2]), axis=0)
        unique_colors = [tuple(color) for color in unique_colors]
        segment_items = [map_colors_rgb(i) for i in unique_colors]
        chosen_colors, segment_items_1 = filter_items_mask(
            colors_list=unique_colors,
            items_list=segment_items,
            items_to_mask=self.control_items_mask
        )
        mask = np.zeros_like(real_seg)
        for color in chosen_colors:
            color_matches = (real_seg == color).all(axis=2)
            mask[color_matches] = 1

        # segmented image
        segmentation_cond_image = Image.fromarray(real_seg).convert("RGB")
        mask_image = Image.fromarray((mask * 255).astype(np.uint8)).convert("RGB")

        mask_0_array = (mask * 255).astype(np.uint8)
        mask_1_image = Image.fromarray(mask_0_array).convert("L")
        mask_1_array = np.array(mask_1_image)

        object_items_2 = ["wall"]
        chosen_colors_2, segment_items_2 = filter_items_mask(
            colors_list=unique_colors,
            items_list=segment_items,
            items_to_mask=object_items_2,
        )                
        mask_2 = np.zeros_like(real_seg)
        for color in chosen_colors_2:
            color_matches = (real_seg == color).all(axis=2)
            mask_2[color_matches] = 1   
            
        mask_2_array = (mask_2 * 255).astype(np.uint8)
        mask_2_image = Image.fromarray(mask_2_array).convert("L")

        # Find the wall height for each column of the image
        mask_3_array = np.array(mask_2_image)
        wall_heights = []
        for col in range(mask_3_array.shape[1]):
            # Find the black pixelsfrom the top of the column
            black_indices = np.nonzero(mask_3_array[:, col] == 0)[0]
            if black_indices.size == 0:
                min_ = 0
                max_ = 6
            else:
                max_ = max(black_indices)
                min_ = min(black_indices)            
            tup = (min_, max_)
            wall_heights.append(tup)
    
        height, width = mask_3_array.shape
        white_image_array = np.full((height, width), 255, dtype=np.uint8)
    
        for col_idx, coords in enumerate(wall_heights):
            min_, max_ = coords
            wall_ht = max_ - min_
            mask_wall_ht = int(0.15 * (wall_ht)) 
            new_max_ = min_ + mask_wall_ht
            for col in range(white_image_array.shape[1]):
                white_image_array[min_: new_max_, col_idx] = 0    
        
        combined_mask_array = cv2.bitwise_and(mask_1_array, white_image_array)  
        final_mask_image = Image.fromarray((combined_mask_array).astype(np.uint8)).convert("RGB")


        # pipeline
        generated_image = self.pipe(
            prompt_embeds=pos_prompt_embed,
            negative_prompt_embeds=neg_prompt_embed,
            num_inference_steps=50,
            strength=1.0,
            guidance_scale=7.0,
            generator=[torch.Generator(device="cuda").manual_seed(self.seed)],
            image=image,
            mask_image=final_mask_image,
            control_image=segmentation_cond_image,
        ).images[0]

        design_image = generated_image.resize(
            (orig_w, orig_h), Image.Resampling.LANCZOS
        )
        
        return design_image

In [None]:
#inistalise
ctl = ControlNetDesignModel_wall_mask()

Loading pipeline components...:   0%|          | 0/7 [00:00<?, ?it/s]

In [None]:
# test 1 image
img = load_image('image_0.jpg')
img.show()
prompt = "A Bauhaus-inspired living room with a sleek black leather sofa, a tubular steel coffee table exemplifying modernist design, and a geometric patterned rug adding a touch of artistic flair."
ctl.generate_design(img,prompt)

In [None]:
# lots of images

captions_csv = '/development_data-v0.2/all_input_list.csv'
file1 = open(captions_csv, 'r')
Lines = file1.readlines()
count = 0
# Strips the newline character
for line in Lines:
    count += 1
    if count == 1:
        continue
    sep_line = line.strip().split('.jpg')
    img_name = '/development_data-v0.2/' + sep_line[0] + '.jpg'
    prompt = sep_line[1].strip()
    print(f"Processing Image:: {img_name}")
    print(f"Prompt: {prompt}")
    img = load_image(img_name)
    img.show()
    designed_img = ctl.generate_design(img,prompt)
    designed_img.show()
    print(f'Processed ALL images.')