In [136]:
# pip install segments-ai
from segments import SegmentsClient, SegmentsDataset
from segments.utils import get_semantic_bitmap
import cv2
import numpy as np
from segment_anything import sam_model_registry, SamPredictor, SamAutomaticMaskGenerator
from PIL import Image
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
from matplotlib.backends.backend_agg import FigureCanvasAgg as FigureCanvas
from matplotlib.figure import Figure

# Initialize a SegmentsDataset from the release file
client = SegmentsClient('06bcb58a22ed6e6b10f075fc2bf8016ffcfda3b6')
sam = sam_model_registry["vit_h"](checkpoint='checkpoints/sam_vit_h_4b8939.pth').to(device='cuda')
predictor = SamPredictor(sam)

### Utility Functions

In [137]:
def bitmap2image(semantic_bitmap) -> Image:
    A = np.asarray(semantic_bitmap, dtype=np.uint8)
    image = np.zeros((A.shape[0], A.shape[1], 3), dtype=np.uint8)
    image[:,:,0] = A*255
    image[:,:,1] = A*255
    image[:,:,2] = A*255
    image = Image.fromarray(image, 'RGB')
    return image

def image2bitmap(image : Image, dtype=np.uint8):
    image = np.asarray(image, dtype=dtype)
    bitmap = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
    bitmap = image[:,:,0]
    return bitmap

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    return mask_image

def combine(image : Image, semantic_mask : Image) -> Image:
    fig = Figure()
    canvas = FigureCanvas(fig)
    ax = fig.gca()
    ax.imshow(image)
    mask = show_mask(image2bitmap(semantic_mask, dtype=np.bool_), plt)
    ax.imshow(mask)
    ax.axis('off')
    canvas.draw()
    width, height = fig.get_size_inches() * fig.get_dpi() 
    img = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8).reshape(int(height), int(width), 3)
    img = Image.fromarray(img, 'RGB')
    return img


### Get ground truth for dirt paths

In [138]:
release2 = client.get_release('mcummins/Hike2', 'v0.2') 
dataset2 = SegmentsDataset(release2, labelset='ground-truth', filter_by=['reviewed'])
road_dir = []

for sample in tqdm(dataset2):
    semantic_bitmap = get_semantic_bitmap(sample['segmentation_bitmap'], sample['annotations'])
    dir = 'GT_masks/' + sample['name']
    road_dir.append(sample['name'])
    mask = bitmap2image(semantic_bitmap)
    mask.save(dir, format='png')
    illustration = combine(sample['image'], mask)
    illustration.save('GT/'+sample['name'], format='png')


Initializing dataset...
Preloading all samples. This may take a while...


100%|[38;2;255;153;0m██████████[0m| 82/82 [00:00<00:00, 325.62it/s]


Initialized dataset with 82 images.


100%|██████████| 82/82 [00:32<00:00,  2.50it/s]


### Get ground truth for forest floors and concrete roads using SAM

In [140]:
gt = os.listdir('GT_masks')

release = client.get_release('mcummins/Hike', 'V0.5') 
dataset = SegmentsDataset(release, labelset='ground-truth', filter_by=['reviewed'])

for sample in tqdm(dataset):
    
    if sample['name'] not in road_dir: 

        points = []
        name = sample['name']
        image = sample['image']
        for ann in sample['annotations']:
            points.append(ann['points'])
        points = np.array(points)
        points = points.reshape((points.shape[0], points.shape[2]))

        predictor.set_image(np.asarray(image))
        input_point = points
        input_label = np.ones(points.shape[0])

        masks, _, _ = predictor.predict(
            point_coords=input_point,
            point_labels=input_label,
            multimask_output=False,
        )
        
        mask = bitmap2image(masks[0])
        dir = 'GT_masks/' + sample['name']
        mask.save(dir, format='png')
        illustration = combine(image, mask)
        illustration.save('GT/'+sample['name'], format='png')


Initializing dataset...
Preloading all samples. This may take a while...


100%|[38;2;255;153;0m██████████[0m| 205/205 [00:11<00:00, 17.58it/s]


Initialized dataset with 205 images.


100%|██████████| 205/205 [01:40<00:00,  2.03it/s]
