<a href="https://colab.research.google.com/github/Savage-Soccer/miniproj/blob/main/projjjj.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install torch torchvision opencv-python matplotlib pillow numpy --quiet
!pip install opencv-python-headless>=4.1.1 --quiet
!pip install diffusers transformers accelerate --quiet
!pip install easyocr --quiet
!pip install wget --quiet
!pip install huggingface_hub

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m38.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m33.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m33.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m5.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m14.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m127.9/127.9 MB[0m [31m8.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
%%writefile app.py
import streamlit as st
import os
import cv2
import numpy as np
import torch
import torchvision
from torchvision.models.detection import maskrcnn_resnet50_fpn
from PIL import Image, ImageDraw
import matplotlib.pyplot as plt
import time
import tempfile
import easyocr
from io import BytesIO
import random
import shutil
from diffusers import StableDiffusionInpaintPipeline, DDIMScheduler
from sklearn.cluster import KMeans
import gc
import threading
import traceback
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Set page configuration
st.set_page_config(
    page_title="Advanced Image Inpainting Tool",
    page_icon="🖌️",
    layout="wide",
    initial_sidebar_state="expanded"
)

# Define MODEL_CACHE_DIR for model storage
MODEL_CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "huggingface")

# Create necessary directories
os.makedirs('temp', exist_ok=True)
os.makedirs('output', exist_ok=True)

# Title and intro
st.title("Advanced Object Detection and Inpainting Tool")
st.markdown("""
This tool helps you remove objects or text from images using deep learning techniques.
Upload an image, select what you want to remove, and get a clean result!
""")

# Memory management utility
def clear_gpu_memory():
    """Clear GPU memory if using CUDA"""
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        gc.collect()

# Part 1: Object Detection Class
class ObjectDetector:
    def __init__(self, detection_threshold=0.5):
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        st.write(f"Using device: {self.device}")
        self.detection_threshold = detection_threshold
        self.model = None
        self.classes = [
            '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
            'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
            'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
            'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
            'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
            'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
            'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
            'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
            'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
            'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
            'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
            'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
        ]

    def load_model(self):
        """Lazy load the model only when needed"""
        if self.model is None:
            with st.spinner("Loading object detection model..."):
                try:
                    self.model = maskrcnn_resnet50_fpn(pretrained=True)
                    self.model.to(self.device)
                    self.model.eval()
                    st.success("Object detection model loaded successfully!")
                except Exception as e:
                    st.error(f"Error loading model: {str(e)}")
                    logger.error(f"Model loading error: {str(e)}")
                    logger.error(traceback.format_exc())
                    raise

    def detect_objects(self, image):
        """Detect objects in an image using Mask R-CNN"""
        try:
            # Ensure model is loaded
            self.load_model()

            # Process the image
            image_tensor = torchvision.transforms.functional.to_tensor(image)
            with torch.no_grad():
                prediction = self.model([image_tensor.to(self.device)])

            # Extract predictions
            image_np = np.array(image)
            boxes = prediction[0]['boxes'].cpu().numpy().astype(np.int32)
            labels = prediction[0]['labels'].cpu().numpy()
            scores = prediction[0]['scores'].cpu().numpy()
            masks = prediction[0]['masks'].cpu().numpy()

            # Filter by detection threshold
            keep_indices = np.where(scores > self.detection_threshold)[0]

            # Return filtered results
            filtered_results = {
                'image': image_np,
                'boxes': boxes[keep_indices],
                'labels': labels[keep_indices],
                'scores': scores[keep_indices],
                'masks': masks[keep_indices],
                'class_names': [self.classes[label] for label in labels[keep_indices]]
            }
            return filtered_results
        except Exception as e:
            st.error(f"Object detection failed: {str(e)}")
            logger.error(f"Object detection error: {str(e)}")
            logger.error(traceback.format_exc())
            # Return empty results
            return {
                'image': np.array(image),
                'boxes': np.array([]),
                'labels': np.array([]),
                'scores': np.array([]),
                'masks': np.array([]),
                'class_names': []
            }
        finally:
            # Clean up memory
            clear_gpu_memory()

    def visualize_detection(self, results):
        """Visualize detected objects with bounding boxes and labels"""
        try:
            image = results['image'].copy()

            # Create a figure for visualization
            fig, ax = plt.subplots(figsize=(12, 8))
            ax.imshow(image)

            # Draw bounding boxes and labels
            for i, box in enumerate(results['boxes']):
                x1, y1, x2, y2 = box
                label = results['labels'][i]
                score = results['scores'][i]
                class_name = self.classes[label]

                # Add rectangle
                rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, fill=False, edgecolor='green', linewidth=2)
                ax.add_patch(rect)

                # Add text label
                plt.text(x1, y1-10, f"{class_name}: {score:.2f}", color='green',
                         fontsize=10, backgroundcolor='white')

            # Finalize and return visualization
            plt.axis('off')
            plt.tight_layout()
            buf = BytesIO()
            fig.savefig(buf, format="png", dpi=150)
            buf.seek(0)
            plt.close(fig)
            return buf
        except Exception as e:
            st.error(f"Visualization failed: {str(e)}")
            logger.error(f"Visualization error: {str(e)}")
            return None

    def create_mask_for_object(self, results, object_class=None, instance_idx=None):
        """Create a mask for the selected object(s)"""
        try:
            image = results['image']
            height, width = image.shape[:2]
            combined_mask = np.zeros((height, width), dtype=np.uint8)
            individual_masks = []

            # Process each object of interest
            for i, class_name in enumerate(results['class_names']):
                # Skip if not the requested class or instance
                if object_class and class_name != object_class:
                    continue
                if instance_idx is not None and i != instance_idx:
                    continue

                # Process the mask
                mask = results['masks'][i][0]
                mask = (mask > 0.5).astype(np.uint8) * 255
                individual_masks.append(mask)
                combined_mask = np.maximum(combined_mask, mask)

                # Break if we only need one specific instance
                if instance_idx is not None:
                    break

            return combined_mask, individual_masks
        except Exception as e:
            st.error(f"Mask creation failed: {str(e)}")
            logger.error(f"Mask creation error: {str(e)}")
            logger.error(traceback.format_exc())
            # Return empty masks
            return np.zeros((results['image'].shape[0], results['image'].shape[1]), dtype=np.uint8), []

    def visualize_masks(self, individual_masks, combined_mask):
        """Visualize individual and combined masks"""
        try:
            # Handle empty masks
            if not individual_masks:
                fig, ax = plt.subplots(figsize=(12, 8))
                ax.imshow(combined_mask, cmap='gray')
                ax.set_title("No objects selected")
                ax.axis('off')
                plt.tight_layout()
                buf = BytesIO()
                fig.savefig(buf, format="png")
                buf.seek(0)
                plt.close(fig)
                return buf

            # Create figure with subplots for each mask
            fig, axes = plt.subplots(1, len(individual_masks) + 1, figsize=(4 * (len(individual_masks) + 1), 4))
            if len(individual_masks) == 1:
                axes = [axes]

            # Plot individual masks
            for i, mask in enumerate(individual_masks):
                axes[i].imshow(mask, cmap='gray')
                axes[i].set_title(f"Mask {i+1}")
                axes[i].axis('off')

            # Plot combined mask
            axes[-1].imshow(combined_mask, cmap='gray')
            axes[-1].set_title("Combined Mask")
            axes[-1].axis('off')

            plt.tight_layout()
            buf = BytesIO()
            fig.savefig(buf, format="png", dpi=150)
            buf.seek(0)
            plt.close(fig)
            return buf
        except Exception as e:
            st.error(f"Mask visualization failed: {str(e)}")
            logger.error(f"Mask visualization error: {str(e)}")
            return None

# Part 2: Text Detection Class
class TextDetector:
    def __init__(self):
        self.reader = None

    def load_model(self):
        """Lazy load the EasyOCR model only when needed"""
        if self.reader is None:
            with st.spinner("Setting up text detection..."):
                try:
                    self.reader = easyocr.Reader(['en'])
                    st.success("Text detection model loaded successfully!")
                except Exception as e:
                    st.error(f"Error loading text detection model: {str(e)}")
                    logger.error(f"Text detection model loading error: {str(e)}")
                    raise

    def detect_text(self, image):
        """Detect text in an image using EasyOCR"""
        try:
            # Ensure model is loaded
            self.load_model()

            # Convert image to the right format
            if isinstance(image, Image.Image):
                image_np = np.array(image)
            else:
                image_np = image

            # Handle different color spaces
            if len(image_np.shape) == 2:
                image_rgb = cv2.cvtColor(image_np, cv2.COLOR_GRAY2RGB)
            elif image_np.shape[2] == 4:
                image_rgb = cv2.cvtColor(image_np, cv2.COLOR_RGBA2RGB)
            elif image_np.shape[2] == 3:
                image_rgb = image_np
            else:
                raise ValueError("Unsupported image format")

            # Save to temp file for EasyOCR processing
            with tempfile.NamedTemporaryFile(suffix='.png', delete=False) as temp_file:
                temp_path = temp_file.name
                cv2.imwrite(temp_path, cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR))

            # Run OCR
            with st.spinner("Detecting text in the image..."):
                results = self.reader.readtext(temp_path)

            # Cleanup temp file
            os.unlink(temp_path)

            # Process results
            boxes, texts, scores = [], [], []
            for (bbox, text, prob) in results:
                # Get coordinates
                (tl, tr, br, bl) = bbox
                tl = (int(tl[0]), int(tl[1]))
                br = (int(br[0]), int(br[1]))
                x1, y1 = min(tl[0], bl[0]), min(tl[1], tr[1])
                x2, y2 = max(tr[0], br[0]), max(bl[1], br[1])

                # Store results
                boxes.append([x1, y1, x2, y2])
                texts.append(text)
                scores.append(prob)

            return {
                'image': image_rgb,
                'boxes': np.array(boxes) if boxes else np.array([]),
                'texts': texts,
                'scores': np.array(scores) if scores else np.array([])
            }
        except Exception as e:
            st.error(f"Text detection failed: {str(e)}")
            logger.error(f"Text detection error: {str(e)}")
            logger.error(traceback.format_exc())
            # Return empty results
            if isinstance(image, Image.Image):
                image_np = np.array(image)
            else:
                image_np = image
            return {
                'image': image_np,
                'boxes': np.array([]),
                'texts': [],
                'scores': np.array([])
            }
        finally:
            # Clean up memory
            clear_gpu_memory()

    def visualize_detection(self, results):
        """Visualize detected text with bounding boxes"""
        try:
            image = results['image'].copy()

            # Create a figure for visualization
            fig, ax = plt.subplots(figsize=(12, 8))
            ax.imshow(image)

            # Draw bounding boxes and labels for each detected text
            for i, box in enumerate(results['boxes']):
                x1, y1, x2, y2 = box
                text = results['texts'][i]
                score = results['scores'][i]

                # Add rectangle
                rect = plt.Rectangle((x1, y1), x2-x1, y2-y1, fill=False, edgecolor='blue', linewidth=2)
                ax.add_patch(rect)

                # Add text label (truncate if too long)
                display_text = text[:20] + "..." if len(text) > 20 else text
                plt.text(x1, y1-10, f"{display_text}: {score:.2f}", color='blue',
                         fontsize=10, backgroundcolor='white')

            plt.axis('off')
            plt.tight_layout()
            buf = BytesIO()
            fig.savefig(buf, format="png", dpi=150)
            buf.seek(0)
            plt.close(fig)
            return buf
        except Exception as e:
            st.error(f"Text visualization failed: {str(e)}")
            logger.error(f"Text visualization error: {str(e)}")
            return None

    def create_mask_for_text(self, results, text_indices=None):
        """Create a mask for the selected text regions"""
        try:
            image = results['image']
            height, width = image.shape[:2]
            combined_mask = np.zeros((height, width), dtype=np.uint8)
            individual_masks = []

            if len(results['boxes']) == 0:
                return combined_mask, individual_masks

            # Process selected text boxes or all if none specified
            indices = text_indices if text_indices is not None else range(len(results['boxes']))

            for idx in indices:
                if idx >= len(results['boxes']):
                    continue

                # Get box coordinates
                box = results['boxes'][idx]
                x1, y1, x2, y2 = [int(coord) for coord in box]

                # Add padding around text
                padding = 10
                x1, y1 = max(0, x1 - padding), max(0, y1 - padding)
                x2, y2 = min(width, x2 + padding), min(height, y2 + padding)

                # Create mask
                mask = np.zeros((height, width), dtype=np.uint8)
                cv2.rectangle(mask, (x1, y1), (x2, y2), 255, -1)
                individual_masks.append(mask)
                combined_mask = np.maximum(combined_mask, mask)

            return combined_mask, individual_masks
        except Exception as e:
            st.error(f"Text mask creation failed: {str(e)}")
            logger.error(f"Text mask creation error: {str(e)}")
            # Return empty masks
            return np.zeros((results['image'].shape[0], results['image'].shape[1]), dtype=np.uint8), []

    def visualize_masks(self, individual_masks, combined_mask):
        """Visualize individual and combined text masks"""
        try:
            # Handle case with no masks
            if not individual_masks:
                fig, ax = plt.subplots(figsize=(12, 8))
                ax.imshow(combined_mask, cmap='gray')
                ax.set_title("Combined Mask (No Text Detected)")
                ax.axis('off')
                plt.tight_layout()
                buf = BytesIO()
                fig.savefig(buf, format="png")
                buf.seek(0)
                plt.close(fig)
                return buf

            # Create figure with subplots for each mask
            fig, axes = plt.subplots(1, len(individual_masks) + 1, figsize=(4 * (len(individual_masks) + 1), 4))
            if len(individual_masks) == 1:
                axes = [axes]

            # Plot individual masks
            for i, mask in enumerate(individual_masks):
                axes[i].imshow(mask, cmap='gray')
                axes[i].set_title(f"Text Mask {i+1}")
                axes[i].axis('off')

            # Plot combined mask
            axes[-1].imshow(combined_mask, cmap='gray')
            axes[-1].set_title("Combined Mask")
            axes[-1].axis('off')

            plt.tight_layout()
            buf = BytesIO()
            fig.savefig(buf, format="png", dpi=150)
            buf.seek(0)
            plt.close(fig)
            return buf
        except Exception as e:
            st.error(f"Mask visualization failed: {str(e)}")
            logger.error(f"Mask visualization error: {str(e)}")
            return None

# Part 3: Manual Mask Creation Function
def create_manual_mask(image, box_coords=None):
    """Create a mask manually by specifying coordinates"""
    try:
        # Convert image to appropriate format
        if isinstance(image, np.ndarray):
            height, width = image.shape[:2]
            image_pil = Image.fromarray(image)
        else:
            width, height = image.size
            image_pil = image

        # Create blank mask
        mask = Image.new('L', (width, height), 0)
        draw = ImageDraw.Draw(mask)

        # Draw rectangle on mask
        if box_coords:
            draw.rectangle(box_coords, fill=255)
        else:
            # Default mask covers left third of image
            x_split = width // 3
            draw.rectangle([(0, 0), (x_split, height)], fill=255)

        # Convert to numpy array
        mask_np = np.array(mask)
        return mask_np, [mask_np]
    except Exception as e:
        st.error(f"Manual mask creation failed: {str(e)}")
        logger.error(f"Manual mask creation error: {str(e)}")
        # Return empty mask
        if isinstance(image, np.ndarray):
            height, width = image.shape[:2]
        else:
            width, height = image.size
        empty_mask = np.zeros((height, width), dtype=np.uint8)
        return empty_mask, [empty_mask]

# Part 4: Background Analysis
def analyze_background(image, mask):
    """Analyze the background of an image to generate better prompts for inpainting"""
    try:
        # Create inverse mask to isolate background
        inv_mask = 255 - mask
        inv_mask_bool = inv_mask > 128
        background = image.copy()
        background[~inv_mask_bool] = 0

        # Handle edge case of empty mask
        if np.sum(inv_mask_bool) == 0:
            return "neutral background"

        # Calculate statistics of background pixels
        bg_mean = np.mean(background[inv_mask_bool], axis=0)
        bg_std = np.std(background[inv_mask_bool], axis=0)

        # Determine brightness level
        brightness = "dark" if np.sum(bg_mean) < 200 else "bright" if np.sum(bg_mean) > 600 else "medium-toned"

        # Determine texture
        texture = "smooth" if np.mean(bg_std) < 30 else "textured" if np.mean(bg_std) > 60 else "moderately textured"

        # Extract color components
        r, g, b = bg_mean

        # Determine dominant color
        if r > 200 and g > 180 and b < 100:
            color = "yellow"
        elif r > g + 20 and r > b + 20:
            color = "reddish"
        elif g > r + 20 and g > b + 20:
            color = "greenish"
        elif b > r + 20 and b > g + 20:
            color = "bluish"
        elif r > 200 and g > 200 and b > 200:
            color = "white"
        elif r < 50 and g < 50 and b < 50:
            color = "black"
        elif abs(r - g) < 20 and abs(r - b) < 20 and abs(g - b) < 20:
            color = "gray"
        else:
            color = "neutral"

        return f"{brightness}, {texture}, {color} background"
    except Exception as e:
        st.error(f"Background analysis failed: {str(e)}")
        logger.error(f"Background analysis error: {str(e)}")
        return "neutral background"

# Part 5: Stable Diffusion Inpainting
class StableDiffusionInpainter:
    def __init__(self):
        self.setup_successful = False
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.pipe = None
        st.write(f"Using device: {self.device}")

    def load_model(self):
        """Lazy load the inpainting model only when needed"""
        if self.pipe is not None:
            return True

        try:
            # Clear existing cache if necessary
            cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
            model_cache = os.path.join(cache_dir, "models--runwayml--stable-diffusion-inpainting")
            if os.path.exists(model_cache):
                shutil.rmtree(model_cache)
                st.info("Cleared existing model cache.")
        except Exception as e:
            st.warning(f"Cache management error: {str(e)}")
            logger.warning(f"Cache management error: {str(e)}")

        try:
            with st.spinner("Loading Stable Diffusion Inpainting model... This may take a few minutes."):
                import huggingface_hub
                # Set up model parameters
                model_id = "runwayml/stable-diffusion-inpainting"
                revision = "fp16" if torch.cuda.is_available() else "main"

                # Download model
                huggingface_hub.snapshot_download(
                    repo_id=model_id,
                    revision=revision,
                    force_download=False,  # Changed to False to allow using cached version if available
                    cache_dir=MODEL_CACHE_DIR,
                    resume_download=True
                )

                # Set up dtype based on hardware
                dtype = torch.float16 if torch.cuda.is_available() else torch.float32

                # Load model
                self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
                    model_id,
                    torch_dtype=dtype,
                    cache_dir=MODEL_CACHE_DIR,
                    safety_checker=None,
                    variant="fp16" if torch.cuda.is_available() else None
                ).to(self.device)

                # Optimize scheduler for better quality
                self.pipe.scheduler = DDIMScheduler.from_config(self.pipe.scheduler.config)

                # Enable optimizations if on GPU
                if torch.cuda.is_available():
                    self.pipe.enable_model_cpu_offload()
                    self.pipe.enable_attention_slicing()

                st.success("Stable Diffusion Inpainting model loaded successfully!")
                self.setup_successful = True
                return True
        except Exception as e:
            st.error(f"Failed to load Stable Diffusion Inpainting model: {str(e)}")
            logger.error(f"Model loading error: {str(e)}")
            logger.error(traceback.format_exc())

            # Try alternative model
            try:
                with st.spinner("Trying alternative Stable Diffusion model..."):
                    alternative_model_id = "stabilityai/stable-diffusion-2-inpainting"
                    self.pipe = StableDiffusionInpaintPipeline.from_pretrained(
                        alternative_model_id,
                        torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                        safety_checker=None
                    ).to(self.device)

                    # Enable optimizations if on GPU
                    if torch.cuda.is_available():
                        self.pipe.enable_model_cpu_offload()
                        self.pipe.enable_attention_slicing()

                    st.success("Alternative Stable Diffusion Inpainting model loaded successfully!")
                    self.setup_successful = True
                    return True
            except Exception as alt_e:
                st.error(f"Failed to load alternative model: {str(alt_e)}")
                logger.error(f"Alternative model loading error: {str(alt_e)}")
                self.setup_successful = False
                return False

    def preprocess_mask(self, mask, blur_radius=15, dilate_iterations=2):
        """Preprocess the mask for better inpainting results"""
        try:
            # Convert mask to appropriate format
            if isinstance(mask, Image.Image):
                mask_np = np.array(mask)
            else:
                mask_np = mask.copy()

            # Dilate mask to expand coverage
            kernel = np.ones((5, 5), np.uint8)
            mask_dilated = cv2.dilate(mask_np, kernel, iterations=dilate_iterations)

            # Apply blur for smoother edges
            mask_blurred = cv2.GaussianBlur(mask_dilated, (blur_radius, blur_radius), 0)

            # Apply gradual fade-out near edges
            height, width = mask_blurred.shape[:2]
            y, x = np.ogrid[:height, :width]
            mask_distance = np.minimum(x, width - x) / (width/10)
            edge_weight = np.clip(mask_distance, 0, 1)

            return (mask_blurred * edge_weight).astype(np.uint8)
        except Exception as e:
            st.error(f"Mask preprocessing failed: {str(e)}")
            logger.error(f"Mask preprocessing error: {str(e)}")
            # Return original mask
            return mask

    def analyze_image_for_prompting(self, image, mask):
        """Analyze the image context around the mask to generate better prompts"""
        try:
            # Convert inputs to appropriate format
            if isinstance(image, Image.Image):
                image_np = np.array(image)
            else:
                image_np = image.copy()

            if isinstance(mask, Image.Image):
                mask_np = np.array(mask)
            else:
                mask_np = mask.copy()

            # Ensure proper types
            mask_np = mask_np.astype(np.uint8)
            image_np = image_np.astype(np.uint8)

            # Convert to HSV for better color analysis
            hsv = cv2.cvtColor(image_np, cv2.COLOR_RGB2HSV)
            _, s, _ = cv2.split(hsv)
            saturation = np.mean(s)

            # Create a border around the mask
            kernel = np.ones((5, 5), np.uint8)
            mask_dilated = cv2.dilate(mask_np, kernel, iterations=2)
            mask_border = mask_dilated - mask_np
            mask_border_bool = mask_border > 0

            # Handle edge case of empty border
            if np.sum(mask_border_bool) == 0:
                return {
                    "color": "neutral",
                    "saturation": "normal",
                    "lighting": "normal",
                    "texture": "smooth",
                    "direction": "uniform"
                }

            # Analyze pixels at the border of the masked region
            border_pixels = image_np[mask_border_bool]
            border_mean = np.mean(border_pixels, axis=0)
            border_std = np.std(border_pixels, axis=0)
            high_detail = np.mean(border_std) > 40

            # Detect edges for structural analysis
            mask_border_uint8 = mask_border.astype(np.uint8)
            masked_image = image_np * np.expand_dims(mask_border_uint8/255, axis=2)
            gray_border = cv2.cvtColor(masked_image.astype(np.uint8), cv2.COLOR_RGB2GRAY)
            edges = cv2.Canny(gray_border, 50, 150)
            edge_density = np.sum(edges > 0) / np.sum(mask_border_bool)

            # Analyze color
            r, g, b = border_mean
            if max(r, g, b) < 50:
                color = "dark"
            elif min(r, g, b) > 200:
                color = "light"
            elif r > g + 20 and r > b + 20:
                color = "reddish"
            elif g > r + 20 and g > b + 20:
                color = "greenish"
            elif b > r + 20 and b > g + 20:
                color = "bluish"
            else:
                color = "neutral"

            # Analyze saturation
            saturation_desc = "vibrant" if saturation > 100 else "muted" if saturation < 50 else "normal"

            # Analyze lighting
            avg_brightness = np.mean(border_pixels)
            lighting = "bright" if avg_brightness > 180 else "dark" if avg_brightness < 80 else "normal"

            # Analyze texture
            texture = "textured" if high_detail else "smooth"

            # Analyze direction
            sobel_x = cv2.Sobel(gray_border, cv2.CV_64F, 1, 0, ksize=5)
            sobel_y = cv2.Sobel(gray_border, cv2.CV_64F, 0, 1, ksize=5)
            abs_sobel_x = np.abs(sobel_x)
            abs_sobel_y = np.abs(sobel_y)
            x_to_y_ratio = np.sum(abs_sobel_x) / (np.sum(abs_sobel_y) + 1e-10)

            if x_to_y_ratio > 1.5:
                direction = "horizontal"
            elif x_to_y_ratio < 0.67:
                direction = "vertical"
            else:
                direction = "uniform"

            return {
                "color": color,
                "saturation": saturation_desc,
                "lighting": lighting,
                "texture": texture,
                "direction": direction,
                "edge_density": edge_density
            }
        except Exception as e:
            st.error(f"Image analysis failed: {str(e)}")
            logger.error(f"Image analysis error: {str(e)}")
            # Return default analysis
            return {
                "color": "neutral",
                "saturation": "normal",
                "lighting": "normal",
                "texture": "smooth",
                "direction": "uniform",
                "edge_density": 0.1
            }

    def generate_smart_prompt(self, image, mask, user_prompt=""):
        """Generate a contextually aware prompt based on image analysis"""
        try:
            # Analyze background
            bg_description = analyze_background(image, mask)

            # Analyze border context
            context = self.analyze_image_for_prompting(image, mask)

            # Base prompt enhancers
            enhancers = [
                f"{context['lighting']} lighting",
                f"{context['texture']} texture",
                f"{bg_description}"
            ]

            if context['edge_density'] > 0.3:
                enhancers.append("seamless integration")

            if context['direction'] != "uniform":
                enhancers.append(f"{context['direction']} pattern")

            # Create final prompt
            if user_prompt:
                final_prompt = f"{user_prompt}, {', '.join(enhancers)}"
            else:
                final_prompt = f"Clean {bg_description}, {', '.join(enhancers)}"

            return final_prompt
        except Exception as e:
            logger.error(f"Smart prompt generation error: {str(e)}")
            # Return basic prompt
            if user_prompt:
                return user_prompt
            else:
                return "clean background, seamless integration"

    def inpaint_image(self, image, mask, prompt="", negative_prompt="", strength=1.0, guidance_scale=7.5, num_inference_steps=50):
        """Use Stable Diffusion to inpaint the masked region"""
        try:
            # Ensure model is loaded
            if not self.load_model():
                st.error("Failed to load inpainting model. Cannot proceed.")
                return None

            # Process inputs
            if isinstance(image, np.ndarray):
                image_pil = Image.fromarray(image.astype('uint8'))
            else:
                image_pil = image

            if isinstance(mask, np.ndarray):
                # Ensure mask is grayscale
                if len(mask.shape) > 2:
                    mask_pil = Image.fromarray(mask[:, :, 0])
                else:
                    mask_pil = Image.fromarray(mask)
            else:
                mask_pil = mask

            # Resize if too large (SD has a limit)
            orig_size = image_pil.size
            max_size = 1024
            if max(orig_size) > max_size:
                ratio = max_size / max(orig_size)
                new_size = (int(orig_size[0] * ratio), int(orig_size[1] * ratio))
                image_pil = image_pil.resize(new_size, Image.Resampling.LANCZOS)
                mask_pil = mask_pil.resize(new_size, Image.Resampling.NEAREST)
                st.info(f"Image resized from {orig_size} to {new_size} for inpainting.")
                resized = True
            else:
                resized = False

            # Process mask - make sure it's black and white
            mask_pil = mask_pil.convert("L")

            # Generate smart prompt if not provided
            if not prompt:
                image_np = np.array(image_pil)
                mask_np = np.array(mask_pil)
                prompt = self.generate_smart_prompt(image_np, mask_np)
                st.info(f"Generated prompt: {prompt}")

            # Default negative prompts to help quality
            if not negative_prompt:
                negative_prompt = "poor quality, blurry, distorted, deformed, disfigured, cropped, lowres, low resolution, ugly, duplicate, mutilated, mutation, mutated, out of frame, tiling, poorly drawn, multiple, extra, cross-eyed"

            # Run inpainting with progress bar
            with st.spinner(f"Inpainting image... (steps: {num_inference_steps})"):
                # Print model parameters
                logger.info(f"Inpainting with: prompt='{prompt}', negative_prompt='{negative_prompt}', strength={strength}, guidance_scale={guidance_scale}, steps={num_inference_steps}")

                # Generate result
                result = self.pipe(
                    prompt=prompt,
                    negative_prompt=negative_prompt,
                    image=image_pil,
                    mask_image=mask_pil,
                    strength=strength,
                    guidance_scale=guidance_scale,
                    num_inference_steps=num_inference_steps
                ).images[0]

                # Resize back to original if needed
                if resized:
                    result = result.resize(orig_size, Image.Resampling.LANCZOS)

                st.success("Inpainting completed successfully!")
                return result
        except Exception as e:
            st.error(f"Inpainting failed: {str(e)}")
            logger.error(f"Inpainting error: {str(e)}")
            logger.error(traceback.format_exc())
            return None
        finally:
            # Clean up memory
            clear_gpu_memory()

# Part 6: Custom Mask Editor
class MaskEditor:
    def __init__(self, image, initial_mask=None):
        """Initialize the mask editor with an image and optional initial mask"""
        self.image = image
        if isinstance(image, Image.Image):
            self.image = np.array(image)

        # Initialize mask
        self.height, self.width = self.image.shape[:2]
        if initial_mask is not None:
            if isinstance(initial_mask, Image.Image):
                self.mask = np.array(initial_mask)
            else:
                self.mask = initial_mask.copy()
        else:
            self.mask = np.zeros((self.height, self.width), dtype=np.uint8)

        # Ensure mask is the right size
        if self.mask.shape[:2] != (self.height, self.width):
            self.mask = cv2.resize(self.mask, (self.width, self.height), interpolation=cv2.INTER_NEAREST)

        # Initialize brush parameters
        self.brush_size = 20
        self.brush_mode = "add"  # can be "add" or "subtract"

    def update_mask(self, clicked_point, brush_size=None, mode=None):
        """Update the mask with a new brush stroke"""
        try:
            # Update parameters if provided
            if brush_size is not None:
                self.brush_size = brush_size
            if mode is not None:
                self.brush_mode = mode

            # Get coordinates
            x, y = clicked_point

            # Create brush mask
            brush_mask = np.zeros((self.height, self.width), dtype=np.uint8)
            cv2.circle(brush_mask, (int(x), int(y)), self.brush_size, 255, -1)

            # Apply brush to mask
            if self.brush_mode == "add":
                self.mask = np.maximum(self.mask, brush_mask)
            else:  # subtract
                self.mask = np.minimum(self.mask, 255 - brush_mask)

            return self.mask
        except Exception as e:
            st.error(f"Mask update failed: {str(e)}")
            logger.error(f"Mask update error: {str(e)}")
            return self.mask

    def visualize_mask_overlay(self):
        """Create a visualization of the mask overlaid on the image"""
        try:
            # Create RGB mask overlay
            mask_rgb = np.zeros_like(self.image)
            mask_rgb[:, :, 0] = self.mask  # Red channel

            # Create alpha overlay
            alpha = 0.5
            overlay = cv2.addWeighted(self.image, 1, mask_rgb, alpha, 0)

            # Convert to PIL for Streamlit
            overlay_pil = Image.fromarray(overlay)
            return overlay_pil
        except Exception as e:
            st.error(f"Mask visualization failed: {str(e)}")
            logger.error(f"Mask visualization error: {str(e)}")
            # Return original image as fallback
            return Image.fromarray(self.image)

    def get_mask(self):
        """Return the current mask"""
        return self.mask

# Part 7: Main Application
def main():
    st.sidebar.header("Settings")

    # Initialize session state for storing data between reruns
    if 'original_image' not in st.session_state:
        st.session_state.original_image = None
    if 'current_mask' not in st.session_state:
        st.session_state.current_mask = None
    if 'object_detection_results' not in st.session_state:
        st.session_state.object_detection_results = None
    if 'text_detection_results' not in st.session_state:
        st.session_state.text_detection_results = None
    if 'mask_editor' not in st.session_state:
        st.session_state.mask_editor = None
    if 'inpainted_result' not in st.session_state:
        st.session_state.inpainted_result = None
    if 'processing_done' not in st.session_state:
        st.session_state.processing_done = False

    # File uploader for image
    uploaded_file = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])

    # Process uploaded image
    if uploaded_file is not None:
        try:
            # Read image
            image = Image.open(uploaded_file).convert('RGB')

            # Reset session state if new image
            if st.session_state.original_image is None or not np.array_equal(np.array(st.session_state.original_image), np.array(image)):
                st.session_state.original_image = image
                st.session_state.current_mask = None
                st.session_state.object_detection_results = None
                st.session_state.text_detection_results = None
                st.session_state.mask_editor = None
                st.session_state.inpainted_result = None
                st.session_state.processing_done = False

            st.sidebar.image(image, caption="Original Image", use_column_width=True)

            # Display image
            st.image(image, caption="Uploaded Image", use_column_width=True)

            # Create columns for detection options
            col1, col2, col3 = st.columns(3)
            with col1:
                detect_objects = st.button("Detect Objects")
            with col2:
                detect_text = st.button("Detect Text")
            with col3:
                manual_mode = st.button("Manual Selection")

            # Object Detection
            if detect_objects:
                with st.spinner("Detecting objects..."):
                    detector = ObjectDetector(detection_threshold=0.5)
                    st.session_state.object_detection_results = detector.detect_objects(image)
                    st.session_state.text_detection_results = None  # Reset text detection

                # Show detection results if any objects found
                if len(st.session_state.object_detection_results['boxes']) > 0:
                    # Visualize detections
                    detection_vis = detector.visualize_detection(st.session_state.object_detection_results)
                    st.image(detection_vis, caption="Detected Objects", use_column_width=True)

                    # Create object selection
                    object_classes = sorted(set(st.session_state.object_detection_results['class_names']))
                    selected_class = st.selectbox("Select object type to remove:",
                                                ["All"] + object_classes)

                    # Create instance selection if multiple of same class
                    instance_idx = None
                    if selected_class != "All":
                        instances = [i for i, c in enumerate(st.session_state.object_detection_results['class_names'])
                                    if c == selected_class]
                        if len(instances) > 1:
                            instance_options = [f"{selected_class} {i+1}" for i in range(len(instances))]
                            selected_instance = st.selectbox("Select specific instance:",
                                                        ["All Instances"] + instance_options)
                            if selected_instance != "All Instances":
                                instance_num = int(selected_instance.split()[-1]) - 1
                                instance_idx = instances[instance_num]

                    # Generate mask for selected object(s)
                    if selected_class == "All":
                        # Create mask for all objects
                        combined_mask = np.zeros((image.height, image.width), dtype=np.uint8)
                        for i in range(len(st.session_state.object_detection_results['boxes'])):
                            mask = st.session_state.object_detection_results['masks'][i][0]
                            mask = (mask > 0.5).astype(np.uint8) * 255
                            combined_mask = np.maximum(combined_mask, mask)
                        st.session_state.current_mask = combined_mask
                    else:
                        object_class = selected_class
                        combined_mask, individual_masks = detector.create_mask_for_object(
                            st.session_state.object_detection_results,
                            object_class=object_class,
                            instance_idx=instance_idx
                        )
                        st.session_state.current_mask = combined_mask

                    # Show mask
                    if st.session_state.current_mask is not None:
                        mask_vis = detector.visualize_masks(individual_masks, combined_mask)
                        st.image(mask_vis, caption="Object Mask", use_column_width=True)

                        # Initialize mask editor
                        st.session_state.mask_editor = MaskEditor(image, st.session_state.current_mask)
                else:
                    st.warning("No objects detected in the image.")

            # Text Detection
            if detect_text:
                with st.spinner("Detecting text..."):
                    text_detector = TextDetector()
                    st.session_state.text_detection_results = text_detector.detect_text(image)
                    st.session_state.object_detection_results = None  # Reset object detection

                # Show detection results if any text found
                if len(st.session_state.text_detection_results['boxes']) > 0:
                    # Visualize detections
                    text_vis = text_detector.visualize_detection(st.session_state.text_detection_results)
                    st.image(text_vis, caption="Detected Text", use_column_width=True)

                    # Create text selection
                    text_options = [f"{text[:20]}..." if len(text) > 20 else text
                                  for text in st.session_state.text_detection_results['texts']]

                    # Add All option
                    selected_texts = st.multiselect("Select text to remove:",
                                               options=text_options,
                                               default=text_options)

                    # Generate mask for selected text
                    if selected_texts:
                        selected_indices = [text_options.index(text) for text in selected_texts]
                        combined_mask, individual_masks = text_detector.create_mask_for_text(
                            st.session_state.text_detection_results,
                            text_indices=selected_indices
                        )
                        st.session_state.current_mask = combined_mask

                        # Show mask
                        mask_vis = text_detector.visualize_masks(individual_masks, combined_mask)
                        st.image(mask_vis, caption="Text Mask", use_column_width=True)

                        # Initialize mask editor
                        st.session_state.mask_editor = MaskEditor(image, st.session_state.current_mask)
                else:
                    st.warning("No text detected in the image.")

            # Manual Selection
            if manual_mode:
                st.info("Draw a rectangle to select the area to remove.")

                # Reset detection results
                st.session_state.object_detection_results = None
                st.session_state.text_detection_results = None

                # Create selection options
                selection_method = st.radio("Selection Method:", ["Rectangle", "Free Draw"])

                if selection_method == "Rectangle":
                    # Create sliders for rectangle coordinates
                    col1, col2 = st.columns(2)
                    with col1:
                        x1 = st.slider("Left (X1)", 0, image.width, int(image.width * 0.25))
                        y1 = st.slider("Top (Y1)", 0, image.height, int(image.height * 0.25))
                    with col2:
                        x2 = st.slider("Right (X2)", 0, image.width, int(image.width * 0.75))
                        y2 = st.slider("Bottom (Y2)", 0, image.height, int(image.height * 0.75))

                    # Create mask
                    mask, _ = create_manual_mask(image, [x1, y1, x2, y2])
                    st.session_state.current_mask = mask

                    # Initialize mask editor
                    st.session_state.mask_editor = MaskEditor(image, mask)

                    # Show mask overlay
                    if st.session_state.mask_editor:
                        overlay = st.session_state.mask_editor.visualize_mask_overlay()
                        st.image(overlay, caption="Selection Overlay", use_column_width=True)

                elif selection_method == "Free Draw":
                    # Initialize mask editor if not exists
                    if st.session_state.mask_editor is None:
                        empty_mask = np.zeros((image.height, image.width), dtype=np.uint8)
                        st.session_state.mask_editor = MaskEditor(image, empty_mask)
                        st.session_state.current_mask = empty_mask

                    # Add drawing controls
                    col1, col2 = st.columns(2)
                    with col1:
                        brush_size = st.slider("Brush Size", 5, 100, 20)
                    with col2:
                        brush_mode = st.radio("Brush Mode", ["Add", "Subtract"])

                    # Show current mask
                    overlay = st.session_state.mask_editor.visualize_mask_overlay()
                    st.image(overlay, caption="Drawing Overlay", use_column_width=True)

                    # Add click handling (simplified for Streamlit)
                    st.write("Click on areas to add to or subtract from mask:")
                    col1, col2, col3 = st.columns(3)
                    with col1:
                        if st.button("Top Left"):
                            point = (image.width * 0.25, image.height * 0.25)
                            st.session_state.current_mask = st.session_state.mask_editor.update_mask(
                                point, brush_size, brush_mode.lower())
                    with col2:
                        if st.button("Center"):
                            point = (image.width * 0.5, image.height * 0.5)
                            st.session_state.current_mask = st.session_state.mask_editor.update_mask(
                                point, brush_size, brush_mode.lower())
                    with col3:
                        if st.button("Bottom Right"):
                            point = (image.width * 0.75, image.height * 0.75)
                            st.session_state.current_mask = st.session_state.mask_editor.update_mask(
                                point, brush_size, brush_mode.lower())

                    # Update overlay after changes
                    overlay = st.session_state.mask_editor.visualize_mask_overlay()
                    st.image(overlay, caption="Updated Drawing", use_column_width=True)

            # Inpainting options if mask is selected
            if st.session_state.current_mask is not None:
                st.subheader("Inpainting Options")

                # Create inpainting parameters
                prompt = st.text_area("Inpainting Prompt (leave empty for automatic):",
                                    "")
                negative_prompt = st.text_area("Negative Prompt (what to avoid):",
                                            "poor quality, distorted, deformed, blurry")

                col1, col2, col3 = st.columns(3)
                with col1:
                    steps = st.slider("Inference Steps", 10, 80, 30)
                with col2:
                    guidance_scale = st.slider("Guidance Scale", 1.0, 20.0, 7.5)
                with col3:
                    strength = st.slider("Strength", 0.1, 1.0, 1.0)

                # Run inpainting
                if st.button("Generate Inpainted Result"):
                    with st.spinner("Preparing inpainting model..."):
                        inpainter = StableDiffusionInpainter()

                    # Process image
                    st.session_state.inpainted_result = inpainter.inpaint_image(
                        image,
                        st.session_state.current_mask,
                        prompt=prompt,
                        negative_prompt=negative_prompt,
                        strength=strength,
                        guidance_scale=guidance_scale,
                        num_inference_steps=steps
                    )

                    st.session_state.processing_done = True

            # Show results
            if st.session_state.processing_done and st.session_state.inpainted_result is not None:
                st.subheader("Inpainting Result")

                # Create side-by-side comparison
                col1, col2 = st.columns(2)
                with col1:
                    st.image(image, caption="Original Image", use_column_width=True)
                with col2:
                    st.image(st.session_state.inpainted_result, caption="Inpainted Result", use_column_width=True)

                # Save button
                output_path = f"output/inpainted_{int(time.time())}.png"
                st.session_state.inpainted_result.save(output_path)
                with open(output_path, "rb") as file:
                    btn = st.download_button(
                        label="Download Result",
                        data=file,
                        file_name=os.path.basename(output_path),
                        mime="image/png"
                    )
                st.success(f"Result saved to {output_path}")

        except Exception as e:
            st.error(f"Error processing image: {str(e)}")
            logger.error(f"Image processing error: {str(e)}")
            logger.error(traceback.format_exc())

# Run the main application
if __name__ == "__main__":
    main()

Writing app.py


In [None]:
# Install required packages
!pip install pyngrok
!pip install streamlit

# Set up ngrok authentication
import os
NGROK_AUTH_TOKEN = "2w2o8loCYzgd4N9nLhVRzG4cyeb_3oJQM8pZ9ztuAaeSjrAUb"  # Replace this with your actual token from ngrok dashboard

# Download and install ngrok binary if needed
!wget https://bin.equinox.io/c/bNyj1mQVY4c/ngrok-v3-stable-linux-amd64.tgz
!tar -xvf ngrok-v3-stable-linux-amd64.tgz
!./ngrok authtoken $NGROK_AUTH_TOKEN

# Set up the Streamlit tunnel
from pyngrok import ngrok
import subprocess

# Kill any existing Streamlit processes
!pkill -f streamlit || true

# Start Streamlit in the background
!nohup streamlit run app.py &

# Wait a moment for Streamlit to start
import time
time.sleep(5)

# Create a tunnel to port 8501 where Streamlit runs
public_url = ngrok.connect(8501)
print(f"\n\n🚀 Your Streamlit app is available at: {public_url}\n\n")

Collecting pyngrok
  Downloading pyngrok-7.2.5-py3-none-any.whl.metadata (8.9 kB)
Downloading pyngrok-7.2.5-py3-none-any.whl (23 kB)
Installing collected packages: pyngrok
Successfully installed pyngrok-7.2.5
Collecting streamlit
  Downloading streamlit-1.45.0-py3-none-any.whl.metadata (8.9 kB)
Collecting watchdog<7,>=2.1.5 (from streamlit)
  Downloading watchdog-6.0.0-py3-none-manylinux2014_x86_64.whl.metadata (44 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.3/44.3 kB[0m [31m3.0 MB/s[0m eta [36m0:00:00[0m
Collecting pydeck<1,>=0.8.0b4 (from streamlit)
  Downloading pydeck-0.9.1-py2.py3-none-any.whl.metadata (4.1 kB)
Downloading streamlit-1.45.0-py3-none-any.whl (9.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.9/9.9 MB[0m [31m92.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pydeck-0.9.1-py2.py3-none-any.whl (6.9 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.9/6.9 MB[0m [31m107.1 MB/s[0m eta [36m0:00: