<a href="https://colab.research.google.com/github/Rajeshj4all/roominterior/blob/feature%2Fdev/RoomTransferServer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install diffusers transformers torch torchvision torchaudio
!pip install scikit-image
!pip install gradio

In [None]:
import torch
import gc
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, StableDiffusionImg2ImgPipeline
import gradio as gr
from PIL import Image, ImageEnhance
import numpy as np
from skimage import feature, color
import os

# Create a debug directory if it doesn't exist
os.makedirs("debug", exist_ok=True)

def preprocess_image(room_image):
    """Process the input room image to prepare it for the model"""
    # Save original for debugging
    room_image.save("debug/reference_original.png")

    # Convert to RGB
    image = room_image.convert("RGB")
    image.save("debug/reference_rgb.png")

    # Check pixel data to detect potential issues with the input image
    pixels = list(image.getdata())  # Fixed: Changed rgb_image to image
    unique_pixels = len(set(pixels))
    print(f"Unique pixel count: {unique_pixels}")

    if unique_pixels < 100:
        print("WARNING: Reference image has very few unique colors")

    # Resize to 768x768 for better quality
    image = image.resize((768, 768), Image.LANCZOS)
    image.save("debug/reference_sample.png")  # Fixed: Added quotes

    # Convert to numpy array for edge detection
    image_np = np.array(image)

    # Edge detection with optimized parameters
    edges = feature.canny(
        color.rgb2gray(image_np),
        sigma=1.5,
        low_threshold=0.1,
        high_threshold=0.2
    ).astype(np.uint8) * 255

    edges_image = Image.fromarray(edges)
    edges_image.save("debug/edges.png")

    return image, edges_image

def analyze_reference_image(reference_image):
    """Analyze reference image to extract key features for style transfer"""
    reference_image = reference_image.convert("RGB")
    reference_image.save("debug/reference_image.png")

    # Resize reference image to match our working size
    reference_image = reference_image.resize((768, 768), Image.LANCZOS)

    # We could add more sophisticated analysis here:
    # - Color palette extraction
    # - Texture analysis
    # - Style features

    return reference_image

def enhance_output(image):
    """Post-process the generated image for better quality"""
    # Apply sharpening
    enhancer = ImageEnhance.Sharpness(image)
    image = enhancer.enhance(1.2)

    # Improve contrast
    enhancer = ImageEnhance.Contrast(image)
    image = enhancer.enhance(1.1)

    # Enhance colors
    enhancer = ImageEnhance.Color(image)
    image = enhancer.enhance(1.1)

    return image

def transform_room(room_image, reference_image, theme, room_type, specific_items, include_reference=False):
    """Main function to transform a room based on inputs

    Args:
        room_image: The input room image to transform
        reference_image: The reference image with the desired style
        theme: Text description of the theme (e.g., "modern", "vintage")
        room_type: Type of room (e.g., "living room", "bedroom")
        specific_items: Description of specific items to include
        include_reference: Whether to include the reference image in the output

    Returns:
        PIL Image or tuple of PIL Images if include_reference is True
    """
    # Clear GPU memory before starting
    torch.cuda.empty_cache()
    gc.collect()

    device = "cuda" if torch.cuda.is_available() else "cpu"
    dtype = torch.float16 if device == "cuda" else torch.float32  # Use float16 on GPU to reduce memory

    print(f"Using device: {device} with dtype: {dtype}")

    # Memory optimization settings
    torch.backends.cudnn.benchmark = True
    if torch.cuda.is_available():
        torch.cuda.set_per_process_memory_fraction(0.7)  # Use only 70% of GPU memory

    # Load ControlNet model
    controlnet = ControlNetModel.from_pretrained(
        "lllyasviel/sd-controlnet-canny",
        torch_dtype=dtype
    ).to(device)

    # Load Stable Diffusion pipeline with ControlNet
    pipeline = StableDiffusionControlNetPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        controlnet=controlnet,
        torch_dtype=dtype
    ).to(device)

    # Load img2img pipeline for refinement
    img2img_pipeline = StableDiffusionImg2ImgPipeline.from_pretrained(
        "runwayml/stable-diffusion-v1-5",
        torch_dtype=dtype,
        safety_checker=None
    ).to(device)

    # Enable optimizations for CUDA
    if device == "cuda":
        pipeline.enable_attention_slicing()
        pipeline.enable_vae_tiling()
        img2img_pipeline.enable_attention_slicing()
        img2img_pipeline.enable_vae_tiling()

    # Preprocess images
    room_image, edge_image = preprocess_image(room_image)
    reference_image = analyze_reference_image(reference_image)

    # No need for redundant edge detection here since we already have edge_image

    # Craft detailed prompt
    prompt = f"""a {theme.lower()} {room_type.lower()},
    8k photorealistic, incorporating {specific_items},
    perfect symmetry, masterful interior design,
    professional photography, ultra-detailed"""

    negative_prompt = """poor quality, low resolution, blurry,
    bad composition, noisy, grainy, artifacts"""

    print("Generating initial image with ControlNet...")

    # Generate initial image with ControlNet
    output = pipeline(
        prompt=prompt,
        negative_prompt=negative_prompt,
        image=edge_image,
        num_inference_steps=30,
        guidance_scale=7.5,
        controlnet_conditioning_scale=1.0
    )
    initial_result = output.images[0]
    initial_result.save("debug/initial_result.png")

    print("Refining image with reference style...")

    # Extract reference image features to incorporate in the refinement
    # This is a more direct way to utilize the reference image

    # Use the reference image to influence the refinement
    refine_prompt = prompt + ", style transfer from reference image"
    refined_output = img2img_pipeline(
        prompt=refine_prompt,
        negative_prompt=negative_prompt,
        image=initial_result,
        strength=0.55,
        guidance_scale=8.5,
        num_inference_steps=50
    )

    # Get refined result
    result = refined_output.images[0]
    result.save("debug/refined_result.png")

    # Clear memory
    torch.cuda.empty_cache()
    gc.collect()

    # Apply post-processing enhancements
    result = enhance_output(result)
    result.save("debug/final_enhanced_result.png")

    # If include_reference is True, return both the reference image and the result
    if include_reference:
        # Create a copy of the reference image at the same size as the result
        reference_copy = reference_image.resize(result.size, Image.LANCZOS)
        reference_copy.save("debug/reference_resized.png")

        # Return both images
        return [result, reference_copy]

    # Otherwise return just the result
    return result

# Create Gradio interface with improved error handling
def process_images(room_image, reference_image, theme, room_type, specific_items, include_reference):
    """Process function for the Gradio interface with error handling"""
    if room_image is None:
        return "Please upload a room image" if not include_reference else ["Please upload a room image", None]
    if reference_image is None:
        return "Please upload a reference image" if not include_reference else ["Please upload a reference image", None]

    try:
        print(f"Processing with theme: {theme}, room type: {room_type}, include reference: {include_reference}")
        result = transform_room(room_image, reference_image, theme, room_type, specific_items, include_reference)
        return result
    except RuntimeError as e:
        error_msg = f"GPU memory error: Please try with a smaller image or wait a moment before trying again." if "out of memory" in str(e) else f"An error occurred: {str(e)}"
        torch.cuda.empty_cache()
        gc.collect()
        print(f"Error: {str(e)}")
        return error_msg if not include_reference else [error_msg, None]
    except Exception as e:
        error_msg = f"An unexpected error occurred: {str(e)}"
        print(f"Unexpected error: {str(e)}")
        return error_msg if not include_reference else [error_msg, None]

# Create better Gradio interface with more information and reference image toggle
interface = gr.Interface(
    fn=process_images,
    inputs=[
        gr.Image(type="pil", label="Room Image (Your current room)"),
        gr.Image(type="pil", label="Reference Image (Style inspiration)"),
        gr.Textbox(label="Theme (e.g., modern, vintage, minimalist)", value="modern"),
        gr.Textbox(label="Room Type (e.g., living room, bedroom)", value="living room"),
        gr.Textbox(label="Specific Items Description", value="two bamboo plants, a coffee table"),
        gr.Checkbox(label="Include Reference Image in Results", value=False,
                    info="Display the original reference image alongside the generated result")
    ],
    outputs=[
        gr.Image(type="pil", label="Generated Room"),
        gr.Image(type="pil", label="Reference Image (Original Style)", visible=True)
    ],
    title="AI Room Transformer",
    description="""Transform your room with AI-powered interior design.
    Upload a photo of your room and a reference image for style inspiration.
    The AI will generate a new design based on your inputs.
    Toggle 'Include Reference Image' to see your style inspiration alongside the result.
    Check the 'debug' folder for intermediate results."""
)

# Set Gradio queue for better memory management
interface.queue()
interface.launch(share=True, debug=True)