# ZERO SHOT SEGMENTATION with CLIPSEG

------------------------------------------

## Load Model

In [None]:
! python.exe -m pip install --upgrade pip

In [None]:
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation

# LOAD PRETRAINED MODEL
processor = CLIPSegProcessor.from_pretrained("CIDAS/clipseg-rd64-refined")
model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")

## Common Utils

In [None]:
import sys
from common import *

sys.path.append("data")
from processing import *

## Model Function

Create function to use model



In [None]:
import os
import re
import torch
from torchvision import transforms
import matplotlib.pyplot as plt
from PIL import Image


def do_zero_shot_segmentation(image_path, prompts, save_results:bool=True):
    image_name = re.search(r'\d+_\d+', image_path).group(0)
    image = Image.open(image_path)
    # Prepare inputs for prediction
    inputs = processor(text=prompts, images=[image] * len(prompts), padding="max_length", return_tensors="pt")

    is_total_black = is_image_total_black(image)

    if not is_total_black:
      # Perform prediction
      with torch.no_grad():
          outputs = model(**inputs)
      preds = outputs.logits.unsqueeze(1)

      # Plotting predictions
      num_plots = len(prompts) + 1
      fig, ax = plt.subplots(1, num_plots, figsize=(3 * num_plots, 4))
      [a.axis('off') for a in ax.flatten()]

      # Display the original image
      ax[0].imshow(image)

      # Display each prediction and add text
      for i in range(len(prompts)):
          ax[i + 1].imshow(torch.sigmoid(preds[i][0]))
          ax[i + 1].text(0, -15, prompts[i])
    else:
      # Total black image need to be resized
      image = image.resize((352, 352))


    if save_results:

      directory_name = "predictions"
      create_directory(directory_name)

      if is_total_black:
        print('Total Black Image')
        convert_tensor = transforms.ToTensor()
        tensor_image = convert_tensor(image)
        preds = [tensor_image] * 2

      for i, pred in enumerate(preds):
          plt.imsave(f'{directory_name}/{image_name}_{i + 1}.png', torch.sigmoid(pred[0]), format='png')

      plt.show()

## Image Pre Processing

In [None]:
import matplotlib.pyplot as plt
from PIL import Image, ImageFilter


def image_pre_process(image_path=None,
                      number_splits=2,
                      do_sharp_pil=True,
                      sharpen_strength = 2,
                      do_enlarge=True,
                      save_result=True):

    create_directory('pre_process')
    image = open_image(image_path)
    images = [image]

    image_name = extract_image_name(image_path)

    if number_splits > 1:
        images = split_image(image, number_splits, image_name, save_result)

    for index, img in enumerate(images):
        current_image_name = f'{image_name}_{index}'
        if do_enlarge:
            images[index] = enlarge_image(img, current_image_name, save_result)
        if do_sharp_pil:
            images[index] = sharp_image_pil(images[index], current_image_name, save_result, sharpen_strength)
    return images

## Image Post Processing

In [None]:
import os


def image_post_process(image_path,
                       threshold_bw = 128,
                       save_results:bool = True):
  create_directory('post_process')
  image = open_image(image_path)
  file_name = extract_patch_name(image_path)
  return convert_to_black_and_white(image,
                                    file_name,
                                    save_results,
                                    threshold_bw)

## Experiments

In [None]:
prompts = {
    0: ['black', 'white'],
    1: ['land', 'plants'],
    2: ['green', 'brown'],
    3: ['sugar beet', 'soil'],
    4: ['soil', 'rows'],
    5: ['soil', 'row of plants'],
    6: ['land', 'coltivations'],
    7: ['ground', 'vegetation'],
    8: ['background', 'vegetation'],
    9: ['soil', 'plants'],
    10: ['soil', 'plantations'],
    11: ['soil', 'root'],
    12: ['soil', 'stems'],
    13: ['soil', 'radisch'],
    14: ['soil', 'photosynthetic organism'],          #
    15: ['flora growth', 'bare ground'],
    16: ['plant matter', 'uncovered soil'],
    17: ['foliage density', 'land surface'],
    18: ['ground', 'plants'],
    19: ['ground', 'trees'],
    20: ['substrate', 'plants'],
    21: ['ground', 'photosynthetic organism'],
    22: ['substrate', 'photosynthetic organism']
}

In [None]:
# PRE PROCESS

create_directory('frames')
directory_path = 'frames'
files = os.listdir(directory_path)
black_list = ['.ipynb_checkpoints']

for file_name in files:
  if file_name not in get_blacklist_files():
    print(f'PROCESSING: {file_name}')
    image_path = f'{directory_path}/{file_name}'
    image_pre_process(image_path = image_path,
                            number_splits = 2,
                            do_sharp_pil = True,
                            sharpen_strength = 4,
                            do_enlarge = True,
                            save_result = True)
  print("\n\n")

In [None]:
# PREDICTION

directory_path = 'pre_process/sharp'
prompt = list(prompts.values())[9]
files = os.listdir(directory_path)

for file_name in files:
  if file_name not in get_blacklist_files():
    print(f'Processing: {file_name}')
    image_path = f'{directory_path}/{file_name}'
    do_zero_shot_segmentation(image_path, prompt, save_results=True)

In [None]:
# POST PROCESS

directory_path = 'predictions'
files = os.listdir(directory_path)

for file_name in files:
  if file_name not in get_blacklist_files() and  file_name.endswith('_2.png'):
    print(f'PROCESSING: {file_name}')
    image_path = f'{directory_path}/{file_name}'
    image_post_process(image_path)
    print("\n")

In [None]:
assemble_patch("post_process/black_and_white")

In [None]:
reassemble_orthomosaic('/content/frames_predictions')