In [None]:
import cv2
import numpy as np
import torch
from transformers import AutoImageProcessor, YolosForObjectDetection, CLIPTokenizer
from diffusers import StableDiffusionInpaintPipeline
from PIL import Image
import os
import psutil

device = "mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

# Load YOLO model for object detection
processor = AutoImageProcessor.from_pretrained("hustvl/yolos-tiny")
model = YolosForObjectDetection.from_pretrained("hustvl/yolos-tiny").to(device)

# Load Stable Diffusion Inpainting pipeline
pipe = StableDiffusionInpaintPipeline.from_pretrained(
    "stabilityai/stable-diffusion-2-inpainting",
    torch_dtype=torch.float16
)
pipe.to(device)

# Load tokenizer manually
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14", use_fast=True)
pipe.tokenizer = tokenizer

# Ensure height and width are divisible by 8
def adjust_size(width, height):
    return (width // 8) * 8, (height // 8) * 8

# Load image function
def load_image(image_path):
    image = cv2.imread(image_path)
    if image is None:
        raise FileNotFoundError("Image not found! Check the file path.")

    print("Image loaded successfully.")
    image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image_rgb, image

selected_boxes = []
object_boxes = []
image_for_display = None

# Mouse click event handler
def select_object(event, x, y, flags, param):
    global selected_boxes, object_boxes, image_for_display

    if event == cv2.EVENT_LBUTTONDOWN:
        for i, (x1, y1, x2, y2) in enumerate(object_boxes):
            if x1 <= x <= x2 and y1 <= y <= y2:
                selected_boxes.append(i)
                print(f"Selected object {i+1} for removal.")

        temp_image = image_for_display.copy()
        for i in selected_boxes:
            x1, y1, x2, y2 = object_boxes[i]
            cv2.rectangle(temp_image, (x1, y1), (x2, y2), (0, 0, 255), 3)

        cv2.imshow("Select Objects to Remove", temp_image)
        cv2.startWindowThread()

# Object detection and selection with YOLO
def yolo_object_detection(image_rgb, image, padding=4):
    global object_boxes, image_for_display

    inputs = processor(images=image_rgb, return_tensors="pt").to(device)
    outputs = model(**inputs)

    target_sizes = torch.tensor([image_rgb.shape[:2]], device=device)
    results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.5)[0]

    object_boxes = [box.int().tolist() for box in results["boxes"]]

    # Draw bounding boxes on a copy of the image
    image_for_display = image.copy()
    for x1, y1, x2, y2 in object_boxes:
        cv2.rectangle(image_for_display, (x1, y1), (x2, y2), (255, 0, 0), 2)

    # Open window for selection
    cv2.imshow("Select Objects to Remove", image_for_display)
    cv2.setMouseCallback("Select Objects to Remove", select_object)
    print("\nClick on the objects you want to remove. Press 'Enter' when done.")

    # Wait until Enter key is pressed
    while True:
        key = cv2.waitKey(1)
        if key == 13:
            detected_image_path = os.path.join("detected_objects/", f"{image_name}")
            cv2.imwrite(detected_image_path, cv2.cvtColor(image_for_display, cv2.COLOR_RGBA2RGB))
            cv2.destroyAllWindows()
            cv2.waitKey(1)
            break

    mask = np.zeros(image.shape[:2], dtype=np.uint8)

    for i in selected_boxes:
        x1, y1, x2, y2 = object_boxes[i]

        x1 = max(0, x1 - padding)
        y1 = max(0, y1 - padding)
        x2 = min(image.shape[1], x2 + padding)
        y2 = min(image.shape[0], y2 + padding)

        mask[y1:y2, x1:x2] = 255

    return mask

# Inpainting using Stable Diffusion
def stable_diffusion_inpainting(image_rgb, mask):
    # Ensure dimensions are divisible by 8
    width, height = adjust_size(image_rgb.shape[1], image_rgb.shape[0])

    # Resize images while maintaining aspect ratio
    image_pil = Image.fromarray(image_rgb).resize((width, height))
    mask_pil = Image.fromarray(mask).resize((width, height))

    # Run Stable Diffusion inpainting
    inpainted_image = pipe(
        prompt=(
            "Fill the missing background of the image naturally. "
            "Maintain proper background lighting."
            "Do not introduce any new objects, patterns, or artificial elements."
        ),
        image=image_pil,
        mask_image=mask_pil,
        width=width,
        height=height,
        num_inference_steps=100
    ).images[0]

    # Convert result to NumPy array
    inpainted_array = np.array(inpainted_image)

    # Resize inpainted image & mask back to original size
    inpainted_array_resized = cv2.resize(inpainted_array, (image_rgb.shape[1], image_rgb.shape[0]), interpolation=cv2.INTER_LANCZOS4)
    mask_resized = cv2.resize(mask, (image_rgb.shape[1], image_rgb.shape[0]), interpolation=cv2.INTER_NEAREST)

    # Blend the inpainted result only in masked areas
    final_result = image_rgb.copy()
    final_result[mask_resized > 0] = inpainted_array_resized[mask_resized > 0]

    return final_result

def blend_edges(original, inpainted, mask):
    blurred_mask = cv2.GaussianBlur(mask.astype(np.float32), (15, 15), 0)
    blended = original * (1 - blurred_mask[:, :, None] / 255) + inpainted * (blurred_mask[:, :, None] / 255)
    return blended.astype(np.uint8)

def save_results(mask, output_image, output_image_blended):
    mask_pil = Image.fromarray(mask)
    output_pil = Image.fromarray(output_image)
    output_blended_pil = Image.fromarray(output_image_blended)

    mask_pil.save("masks/" + image_name)
    output_pil.save("outputs/" + image_name)
    output_blended_pil.save("outputs_blended/" + image_name)
    print("Mask and output image saved successfully.")

image_name = "field.jpg"

try:
    image_rgb, image = load_image("../input_images/" + image_name)
    mask = yolo_object_detection(image_rgb.copy(), image.copy())
    stable_diffused_result = stable_diffusion_inpainting(image_rgb, mask)
    blended_result = blend_edges(image_rgb, stable_diffused_result, mask)
    save_results(mask, stable_diffused_result, blended_result)


    current_pid = os.getpid()
    for proc in psutil.process_iter(attrs=['pid', 'name']):
        try:
            if "python" in proc.info['name'].lower() and proc.info['pid'] != current_pid:
                os.kill(proc.info['pid'], 9)
        except (psutil.NoSuchProcess, psutil.AccessDenied):
            continue

except Exception as e:
    print(f"Error: {e}")