### Description:
Testing Environment for removing trees from a panoramic image

### Functions:

In [4]:
from diffusers.utils import load_image, make_image_grid
import torch
from diffusers import AutoPipelineForInpainting



def remove_trees(path_to_image, path_to_mask):
    if not torch.cuda.is_available():
        print("Please ensure cuda is available before running.")
        return
    
    
    
    pipeline = AutoPipelineForInpainting.from_pretrained(
    "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16, variant="fp16"
    )
    pipeline.enable_model_cpu_offload()
    
    init_image = load_image(path_to_image)
    image_width, image_height = init_image.size
    mask_image = load_image(path_to_mask)
    
    prompt = "photoralistic buildings and sky, highly detailed, 8k"
    print("create generator")
    generator = torch.Generator("cuda").manual_seed(92)
    print("Starting inpainting")
    image = pipeline(prompt=prompt, image=init_image, mask_image=mask_image, generator=generator).images[0]
    print(" Impainting done, resizing image")
    image = image.resize((image_width, image_height))
    return image

### Main:

In [5]:

from transformers import SegformerFeatureExtractor, SegformerForSemanticSegmentation
import torch
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import numpy as np



def get_remove_trees_panoramic(image_path):
    """
        image_path: path to original panoramic image
    """

    # -W-u6oxxcZfhXSxvCHtomQ
    # 2gzizC8uRTlcEU-9GuZhmQ
    # image_path = "data/panoramic_imgs/-W-u6oxxcZfhXSxvCHtomQ.jpg" 
    mask_path = 'tree_mask.jpg'
    print("[get_remove_trees_panoramic] starting...")

    image = Image.open(image_path)

    feature_extractor = SegformerFeatureExtractor.from_pretrained("nvidia/segformer-b5-finetuned-cityscapes-1024-1024")
    model = SegformerForSemanticSegmentation.from_pretrained("nvidia/segformer-b5-finetuned-cityscapes-1024-1024")
    print("Feature and model loaded")
    inputs = feature_extractor(images=image, return_tensors="pt")
    outputs = model(**inputs)
    logits = outputs.logits

    print("about to convert logits to class predictions")
    # convert logits to class predictions
    predicted_class = torch.argmax(logits, dim=1)  # shape (batch_size, height/4, width/4)

    # upsample to match input image size
    predicted_class = torch.nn.functional.interpolate(
        predicted_class.unsqueeze(1).float(),  # Add channel dimension
        size=image.size[::-1],  # Match input image dimensions (height, width)
        mode="nearest"
    ).squeeze(1).to(torch.int32)

    print("generating segmentation map...")
    # visualize
    segmentation_map = predicted_class[0].cpu().numpy()

    
    
    tree_mask = (segmentation_map == 8).astype(np.uint8) * 255
    im = Image.fromarray(tree_mask)
    im.save('tree_mask.jpg')
    print('saved mask image')
    
    print('inpainting image... this may take a minute')
    inpainted_img = remove_trees(image_path, mask_path)
    inpainted_img.save('inpainted_img.jpg')
    
    
    fig, ax = plt.subplots(1, 3, figsize=(20, 5))
    
    # Original image
    ax[0].imshow(image)
    ax[0].set_title("Original Image")
    ax[0].axis("off")
    
    # Mask
    ax[1].imshow(tree_mask, cmap='gray')
    ax[1].set_title("Tree Mask")
    ax[1].axis("off")

    # Inpainted image
    ax[2].imshow(inpainted_img)
    ax[2].set_title("Inpainted Image")
    ax[2].axis("off")
    
    plt.tight_layout()
    plt.show()


