### Description:
Testing Environment for removing trees from a panoramic image

### Functions:

### Main:

In [None]:
from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
import torch
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
from PIL import Image
import cv2
import numpy as np
import torch
from diffusers import AutoPipelineForInpainting
from diffusers.utils import load_image, make_image_grid


# testing:
def main():
    
    pipeline = AutoPipelineForInpainting.from_pretrained(
    "kandinsky-community/kandinsky-2-2-decoder-inpaint", torch_dtype=torch.float16
    )
    pipeline.enable_model_cpu_offload()
    
    ##### test out gen ai method
    
    # init_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint.png")
    # mask_image = load_image("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/inpaint_mask.png")
    
    # prompt = "a black cat with glowing eyes, cute, adorable, disney, pixar, highly detailed, 8k"
    # negative_prompt = "bad anatomy, deformed, ugly, disfigured"
    # image = pipeline(prompt=prompt, negative_prompt=negative_prompt, image=init_image, mask_image=mask_image).images[0]
    # make_image_grid([init_image, mask_image, image], rows=1, cols=3)
    
    ################
    
    image_path = "data/panoramic_imgs/_HveufZbNlDXqHIEDRNFzg.jpg" 
    image = Image.open(image_path)

    feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b5-finetuned-cityscapes-1024-1024")
    model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b5-finetuned-cityscapes-1024-1024")

    inputs = feature_extractor(images=image, return_tensors="pt")
    outputs = model(**inputs)
    logits = outputs.logits

    print("about to convert logits to class predictions")
    # convert logits to class predictions
    predicted_class = torch.argmax(logits, dim=1)  # shape (batch_size, height/4, width/4)

    # upsample to match input image size
    predicted_class = torch.nn.functional.interpolate(
        predicted_class.unsqueeze(1).float(),  # Add channel dimension
        size=image.size[::-1],  # Match input image dimensions (height, width)
        mode="nearest"
    ).squeeze(1).to(torch.int32)

    print("about to visualize")
    # visualize
    segmentation_map = predicted_class[0].cpu().numpy()
    
    cv_image = cv2.imread("data/panoramic_imgs/_HveufZbNlDXqHIEDRNFzg.jpg")
    class_8_mask = (segmentation_map == 8).astype(np.uint8) * 255
    cv2.imwrite('tree_mask.jpg', class_8_mask)
    cv_mask = cv2.imread('tree_mask.jpg', cv2.IMREAD_GRAYSCALE)
    
    inpainted_image = cv2.inpaint(cv_image, cv_mask, inpaintRadius=3, flags=cv2.INPAINT_TELEA)
    
    inpainted_image[class_8_mask == 255] = sky_resized[class_8_mask == 255]
    inpainted_image = cv2.cvtColor(inpainted_image, cv2.COLOR_BGR2RGB)
    
    
    fig, ax = plt.subplots(1, 3, figsize=(20, 5))
    
    # Original image
    ax[0].imshow(image)
    ax[0].set_title("Original Image")
    ax[0].axis("off")
    
    # Mask
    ax[1].imshow(class_8_mask, cmap='gray')
    ax[1].set_title("Tree Mask")
    ax[1].axis("off")

    # Inpainted image
    ax[2].imshow(inpainted_image)
    ax[2].set_title("Inpainted Image")
    ax[2].axis("off")
    
    plt.tight_layout()
    plt.show()

    # # save the image
    # plt.imsave("data\image_processing\\test_segmented_image.jpg", segmentation_map, cmap=cmap, vmin=0, vmax=num_classes - 1)

main()


Loading pipeline components...: 100%|██████████| 3/3 [06:43<00:00, 134.34s/it]
Loading pipeline components...: 100%|██████████| 6/6 [06:12<00:00, 62.14s/it] 


### TODO Remove the trees (class 8) and replace them with either sky (class 10) or buildings (class 2)

In [None]:
# Panoramic image examples can be found in the data/panoramic_imgs folder

# the idea is we call this function and pass in an image, and it will return the same image, but with the trees removed
# we can either remove trees from the original image, or remove them from the segmentation map

