# Install dependencies

In [None]:
!git clone https://github.com/aliaksandr960/segment-anything-eo.git
import os
os.chdir('/content/segment-anything-eo')
!pip install rasterio
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

Cloning into 'segment-anything-eo'...
remote: Enumerating objects: 180, done.[K
remote: Counting objects: 100% (180/180), done.[K
remote: Compressing objects: 100% (143/143), done.[K
remote: Total 180 (delta 43), reused 133 (delta 13), pack-reused 0[K
Receiving objects: 100% (180/180), 27.95 MiB | 35.17 MiB/s, done.
Resolving deltas: 100% (43/43), done.
Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting rasterio
  Downloading rasterio-1.3.6-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (20.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m20.0/20.0 MB[0m [31m81.4 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting affine (from rasterio)
  Downloading affine-2.4.0-py3-none-any.whl (15 kB)
Collecting cligj>=0.5 (from rasterio)
  Downloading cligj-0.7.2-py3-none-any.whl (7.1 kB)
Collecting snuggs>=1.4.1 (from rasterio)
  Downloading snuggs-1.4.7-py3-none-any.whl (5.4 kB)
Collecting click-plugins (

# Unzip images to SAM predict

In [None]:
!unzip pics.zip

Archive:  /content/drive/MyDrive/datasets/u256_1.zip
   creating: u256_1/
  inflating: u256_1/1046.tif         
  inflating: u256_1/1054.tif         
  inflating: u256_1/1058.tif         
  inflating: u256_1/1059.tif         
  inflating: u256_1/1079.tif         
  inflating: u256_1/1095.tif         
  inflating: u256_1/1101.tif         
  inflating: u256_1/1171.tif         
  inflating: u256_1/1172.tif         
  inflating: u256_1/1179.tif         
  inflating: u256_1/1184.tif         
  inflating: u256_1/1201.tif         
  inflating: u256_1/1254.tif         
  inflating: u256_1/1258.tif         
  inflating: u256_1/1265.tif         
  inflating: u256_1/1284.tif         
  inflating: u256_1/1288.tif         
  inflating: u256_1/1324.tif         
  inflating: u256_1/1345.tif         
  inflating: u256_1/1348.tif         
  inflating: u256_1/1388.tif         
  inflating: u256_1/1398.tif         
  inflating: u256_1/1407.tif         
  inflating: u256_1/1430.tif         
  inflating: u

In [None]:
def list_items(path, ext=None):
    result = []
    for name in os.listdir(path):
        full_path = os.path.join(path, name)
        if ext is not None:
            if not full_path.endswith(ext):
                continue
        result.append(full_path)
    return sorted(result)

In [None]:
input_path = 'pics'
output_path = 'pics_inf'

In [None]:
input_ext = '.tif'
input_path_list = list_items(input_path, ext=input_ext)
input_name_list = [os.path.basename(i.split('.')[0]) for i in input_path_list]
print(len(input_path_list), len(input_name_list))

256 256


# Make prediction functionality

In [None]:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator


class SamEO:
    def __init__(self, checkpoint="sam_vit_h_4b8939.pth",
                 model_type='vit_h',
                 device='cpu',
                 sam_kwargs=None,
                 mask_preprocessor=None):
        
        self.checkpoint = checkpoint
        self.model_type = model_type
        self.device = device
        self.sam_kwargs = sam_kwargs
        self.reinit_sam()
            
    def reinit_sam(self):
        self.sam = sam_model_registry[self.model_type](checkpoint=self.checkpoint)
        self.sam.to(device=self.device)
        sam_kwargs = self.sam_kwargs if self.sam_kwargs is not None else {}
        self.mask_generator = SamAutomaticMaskGenerator(self.sam, **sam_kwargs)
    
    def __call__(self, image):
        masks = self.mask_generator.generate(image)
        return [m['segmentation'] for m in masks]

device = 'cuda'

sam_eo = SamEO(checkpoint="sam_vit_h_4b8939.pth",
               model_type='vit_h',
               device=device,
               sam_kwargs=None)

# Run prediction

In [None]:
import rasterio
import cv2
from tqdm import tqdm, trange
import numpy as np

output_ext = '.tif'
mask_multiplier = 255

for path in tqdm(input_path_list):
    name = os.path.basename(path.split('.')[0])
    output_subpath = os.path.join(output_path, name)
    if not os.path.exists(output_subpath):
        os.makedirs(output_subpath)
    image = cv2.imread(path)
    masks_list = sam_eo(image)

    with rasterio.open(path) as src:
        profile = src.profile
        profile['count'] = 1
        profile['dtype'] = 'uint8'
        for n, m in enumerate(masks_list):
            dst_fp = os.path.join(output_subpath, f'{str(n+1)}{output_ext}')
            with rasterio.open(dst_fp, 'w', **profile) as dst:
                m = (m > 0).astype(np.uint8)
                dst.write(m * mask_multiplier, 1)


100%|██████████| 256/256 [33:16<00:00,  7.80s/it]


# Zip results

In [None]:
!zip -r pics_inf.zip ./pics_inf

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  adding: u256_1_inf/3213/73.tif (deflated 59%)
  adding: u256_1_inf/3213/111.tif (deflated 71%)
  adding: u256_1_inf/3213/157.tif (deflated 64%)
  adding: u256_1_inf/3213/172.tif (deflated 68%)
  adding: u256_1_inf/3213/102.tif (deflated 70%)
  adding: u256_1_inf/3213/21.tif (deflated 59%)
  adding: u256_1_inf/3213/65.tif (deflated 67%)
  adding: u256_1_inf/3213/26.tif (deflated 50%)
  adding: u256_1_inf/3213/190.tif (deflated 72%)
  adding: u256_1_inf/3213/90.tif (deflated 73%)
  adding: u256_1_inf/3213/68.tif (deflated 68%)
  adding: u256_1_inf/3213/69.tif (deflated 67%)
  adding: u256_1_inf/3213/32.tif (deflated 65%)
  adding: u256_1_inf/3213/128.tif (deflated 71%)
  adding: u256_1_inf/3213/166.tif (deflated 70%)
  adding: u256_1_inf/3213/154.tif (deflated 73%)
  adding: u256_1_inf/3213/180.tif (deflated 72%)
  adding: u256_1_inf/3213/170.tif (deflated 72%)
  adding: u256_1_inf/3213/167.tif (deflated 71%)
  adding: u2