In [1]:
!pip install opencv-python



In [2]:
import os
import numpy as np
import cv2
import matplotlib.pyplot as plt
import torch
import ipywidgets as widgets
from IPython.display import display

In [3]:
!pip install git+https://github.com/facebookresearch/segment-anything.git
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth


Collecting git+https://github.com/facebookresearch/segment-anything.git
  Cloning https://github.com/facebookresearch/segment-anything.git to /tmp/pip-req-build-8g90aena
  Running command git clone --filter=blob:none --quiet https://github.com/facebookresearch/segment-anything.git /tmp/pip-req-build-8g90aena
  Resolved https://github.com/facebookresearch/segment-anything.git to commit dca509fe793f601edb92606367a655c15ac00fdf
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: segment_anything
  Building wheel for segment_anything (setup.py) ... [?25l[?25hdone
  Created wheel for segment_anything: filename=segment_anything-1.0-py3-none-any.whl size=36592 sha256=c636fcddd103c5943a2a42999ff492b077102a807003c39c5ee7c03a93ce87fd
  Stored in directory: /tmp/pip-ephem-wheel-cache-axtkw7en/wheels/15/d7/bd/05f5f23b7dcbe70cbc6783b06f12143b0cf1a5da5c7b52dcc5
Successfully built segment_anything
Installing collected packages: segment_anything
Successfully 

In [4]:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

In [5]:
class WallRecoloringTool:
    def __init__(self):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        print(f"Using device: {self.device}")

        model_type = "vit_h"
        checkpoint = "sam_vit_h_4b8939.pth"

        self.sam = sam_model_registry[model_type](checkpoint=checkpoint)
        self.sam.to(device=self.device)

        self.mask_generator = SamAutomaticMaskGenerator(
            model=self.sam,
            points_per_side=32,
            pred_iou_thresh=0.9,
            stability_score_thresh=0.92,
            crop_n_layers=1,
            crop_n_points_downscale_factor=2,
            min_mask_region_area=100
        )

        self.predictor = SamPredictor(self.sam)

        self.selected_segments = []
        self.wall_mask = None

    def load_image(self, image_path):
        self.original_image = cv2.imread(image_path)
        if self.original_image is None:
            raise ValueError(f"Could not load image from {image_path}")
        self.original_image = cv2.cvtColor(self.original_image, cv2.COLOR_BGR2RGB)
        self.height, self.width = self.original_image.shape[:2]

        self.predictor.set_image(self.original_image)

        return self.original_image

    def generate_segments(self):
        self.segments = self.mask_generator.generate(self.original_image)
        print(f"Generated {len(self.segments)} segments")
        return self.segments

    def select_wall_segments(self):
        if not hasattr(self, 'segments'):
            self.generate_segments()

        self.selected_segments = []

        wall_mask = np.zeros((self.height, self.width), dtype=np.uint8)


        sorted_segments = sorted(self.segments, key=lambda x: x['area'], reverse=True)

        # Take the top 30% of segments by area as candidates
        num_candidates = max(1, int(len(sorted_segments) * 0.3))
        candidates = sorted_segments[:num_candidates]

        for segment in candidates:
            mask = segment['segmentation']

            # boundary bata in/out
            touches_boundary = (
                np.any(mask[0, :]) or
                np.any(mask[-1, :]) or
                np.any(mask[:, 0]) or
                np.any(mask[:, -1])
            )

            # Calculating solidity (area / convex hull area) as a measure of simplicity
            contours, _ = cv2.findContours(mask.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
            if len(contours) > 0:
                hull = cv2.convexHull(contours[0])
                hull_area = cv2.contourArea(hull)
                solidity = segment['area'] / hull_area if hull_area > 0 else 0
            else:
                solidity = 0

            # Check color - walls are typically light colored,haha
            segment_pixels = self.original_image[mask]
            avg_color = np.mean(segment_pixels, axis=0)
            is_light = np.mean(avg_color) > 150  # Threshold for "lightness"

            # Combine all criteria
            if touches_boundary and solidity > 0.7 and is_light:
                self.selected_segments.append(segment)
                wall_mask = np.logical_or(wall_mask, mask)

        # If no segments were selected, use fallback: just take the largest segment
        if len(self.selected_segments) == 0 and len(sorted_segments) > 0:
            self.selected_segments.append(sorted_segments[0])
            wall_mask = np.logical_or(wall_mask, sorted_segments[0]['segmentation'])

        self.wall_mask = wall_mask.astype(np.uint8)
        return self.wall_mask

    def detect_walls(self):
        return self.select_wall_segments()

    def apply_color(self, color_rgb):
        if not hasattr(self, 'wall_mask') or self.wall_mask is None:
            self.detect_walls()

        result_image = self.original_image.copy()

        lab_image = cv2.cvtColor(self.original_image, cv2.COLOR_RGB2LAB)

        color_lab = cv2.cvtColor(np.uint8([[color_rgb]]), cv2.COLOR_RGB2LAB)[0][0]

        wall_pixels = np.where(self.wall_mask == 1)

        lab_image[wall_pixels[0], wall_pixels[1], 0] = lab_image[wall_pixels[0], wall_pixels[1], 0]  # Keep original luminance
        lab_image[wall_pixels[0], wall_pixels[1], 1] = color_lab[1]  # a channel (green-red)
        lab_image[wall_pixels[0], wall_pixels[1], 2] = color_lab[2]  # b channel (blue-yellow)

        result_image = cv2.cvtColor(lab_image, cv2.COLOR_LAB2RGB)

        return result_image

    def display_results(self, original, mask, result):
        plt.figure(figsize=(15, 5))

        plt.subplot(1, 3, 1)
        plt.imshow(original)
        plt.title('Original Image')
        plt.axis('off')

        plt.subplot(1, 3, 2)
        plt.imshow(mask, cmap='gray')
        plt.title('Wall Mask')
        plt.axis('off')

        plt.subplot(1, 3, 3)
        plt.imshow(result)
        plt.title('Recolored Walls')
        plt.axis('off')

        plt.tight_layout()
        plt.show()

    def save_result(self, result_image, output_path):
        result_rgb = cv2.cvtColor(result_image, cv2.COLOR_RGB2BGR)
        cv2.imwrite(output_path, result_rgb)
        print(f"Result saved to {output_path}")

In [6]:
def create_interactive_ui(tool):
    default_image_path = "/content/room5.png"

    color_picker = widgets.ColorPicker(
        description='Wall Color',
        value='#94C2E5',
        concise=False
    )

    process_button = widgets.Button(description='Recolor Walls')
    output = widgets.Output()

    def on_button_click(b):
        with output:
            output.clear_output()

            try:
                original_image = tool.load_image(default_image_path)
                print("Image loaded successfully!")
                print("Detecting walls...")
                wall_mask = tool.detect_walls()

                color_hex = color_picker.value.lstrip('#')
                color_rgb = tuple(int(color_hex[i:i+2], 16) for i in (0, 2, 4))
                print("Applying color...")
                result_image = tool.apply_color(color_rgb)
                tool.display_results(original_image, wall_mask, result_image)
                output_path = "/content/recolored_room.jpg"
                tool.save_result(result_image, output_path)

            except Exception as e:
                print(f"Error: {e}")
                import traceback
                traceback.print_exc()

    process_button.on_click(on_button_click)

    display(widgets.HTML("<h3>Room Wall Recoloring Tool with SAM</h3>"))
    display(widgets.HTML("<p>Using image at: {}</p>".format(default_image_path)))
    display(widgets.HTML("<p>Select a color for the walls:</p>"))
    display(color_picker)
    display(process_button)
    display(output)

In [7]:
def main():
    wall_tool = WallRecoloringTool()
    create_interactive_ui(wall_tool)


In [8]:
if __name__ == "__main__":
    main()

Using device: cuda


HTML(value='<h3>Room Wall Recoloring Tool with SAM</h3>')

HTML(value='<p>Using image at: /content/room5.png</p>')

HTML(value='<p>Select a color for the walls:</p>')

ColorPicker(value='#94C2E5', description='Wall Color')

Button(description='Recolor Walls', style=ButtonStyle())

Output()