In [2]:
import os

ESCA_dataset = {
    'esca': {
        'folder': os.path.join(*['ESCA_dataset', 'esca']),
        'esca_foliage_over_healthy_bg': os.path.join(*['ESCA_dataset', 'esca', 'esca_foliage_over_healthy_bg']),
        'masks': os.path.join(*['ESCA_dataset', 'esca', 'masks']),
        'pictures': os.path.join(*['ESCA_dataset', 'esca', 'pictures']),
        'SAM_masks': os.path.join(*['ESCA_dataset', 'esca', 'SAM_masks'])
    },
    'healthy': {
        'folder': os.path.join(*['ESCA_dataset', 'healthy']),
        'healthy_foliage_over_esca_bg': os.path.join(*['ESCA_dataset', 'esca', 'healthy_foliage_over_esca_bg']),
        'masks': os.path.join(*['ESCA_dataset', 'healthy', 'masks']),
        'pictures': os.path.join(*['ESCA_dataset', 'healthy', 'pictures']),
        'SAM_masks': os.path.join(*['ESCA_dataset', 'healthy', 'SAM_masks'])
    }
}

CWFID_dataset = {
    'annotations': os.path.join(*['CWFID_dataset', 'annotations']),
    'images': os.path.join(*['CWFID_dataset', 'images']),
    'masks': os.path.join(*['CWFID_dataset', 'masks']),
    'SAM_masks': os.path.join(*['CWFID_dataset', 'SAM_masks'])
}

In [7]:
# first experiment - run on cpu machine

prms = {
    'checkpoint': 'sam_vit_h_4b8939.pth',
    'model-type': 'vit_h',
    'input': os.path.join(*['batch_segmentation_experiment', 'images']),
    'output': os.path.join(*['batch_segmentation_experiment', 'generated_masks']),
    'device': 'cpu'
}
arg_string = f"--checkpoint {prms['checkpoint']} --model-type {prms['model-type']} --input {prms['input']} --output {prms['output']} --device {prms['device']}"
!python batch_segmentation.py {arg_string}

Loading model...
Processing 'batch_segmentation_experiment/images/esca_000_cam1.jpg'...
Processing 'batch_segmentation_experiment/images/esca_001_cam1.jpg'...
Done!


In [None]:
prms = {
    'checkpoint': 'sam_vit_h_4b8939.pth',
    'model-type': 'vit_h',
    'input': ESCA_dataset['esca']['pictures'],
    'output': ESCA_dataset['esca']['SAM_masks'],
    'device': 'cuda'
}
arg_string = f"--checkpoint {prms['checkpoint']} --model-type {prms['model-type']} --input {prms['input']} --output {prms['output']} --device {prms['device']}"
!python batch_segmentation.py {arg_string}


In [None]:
import time

prms = {
    'checkpoint': 'sam_vit_h_4b8939.pth',
    'model-type': 'vit_h',
    'input': CWFID_dataset['images'],
    'output': CWFID_dataset['SAM_masks'],
    'device': 'cpu'
}
arg_string = f"--checkpoint {prms['checkpoint']} --model-type {prms['model-type']} --input {prms['input']} --output {prms['output']} --device {prms['device']}"

st = time.time() # get the start time

!python batch_segmentation.py {arg_string}

et = time.time() # get the end time

elapsed_time = et - st # get the execution time
print('Execution time:', elapsed_time, 'seconds')

Loading model...
Processing 'CWFID_dataset/images/059_image.png'...
Processing 'CWFID_dataset/images/023_image.png'...
Processing 'CWFID_dataset/images/037_image.png'...
Processing 'CWFID_dataset/images/016_image.png'...
Processing 'CWFID_dataset/images/002_image.png'...
Processing 'CWFID_dataset/images/054_image.png'...
Processing 'CWFID_dataset/images/040_image.png'...
Processing 'CWFID_dataset/images/010_image.png'...
Processing 'CWFID_dataset/images/004_image.png'...
Processing 'CWFID_dataset/images/025_image.png'...
Processing 'CWFID_dataset/images/031_image.png'...


In [7]:
import json
import numpy as np
import matplotlib.pyplot as plt
import torch
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import os
# import sys
# sys.path.append("..")

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device = "cpu"  # cpu, cuda, ipu, xpu, mkldnn, opengl, opencl, ideep, hip, ve, fpga, ort, xla, lazy, vulkan, mps, meta, hpu, mtia, privateuseone

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

mask_generator = SamAutomaticMaskGenerator(sam)
img_files = os.listdir(ESCA_dataset['esca']['pictures'])
img_file = prms['input']
img = plt.imread(img_file)
masks = mask_generator.generate(img)

print(f"SAM detected {len(masks)} masks on image {img_file}")
# dict_keys(['segmentation', 'area', 'bbox', 'predicted_iou', 'point_coords', 'stability_score', 'crop_box'])
MAX = 3
stable_masks = 0
for count, mask in enumerate(masks):
    print(f" -------- mask {count} ---------")
    mask_relevant_data = {
        'area': mask['area'],
        'bbox': mask['bbox'],
        'predicted_iou': mask['predicted_iou'],
        'point_coords': mask['point_coords'],
        'stability_score': mask['stability_score'],
        'crop_box': mask['crop_box']
    }
    if count < MAX:
        json_object = json.dumps(mask_relevant_data, indent=2)
        print(json_object)
    if mask['stability_score'] > 0.8:
        stable_masks += 1

print(f"SAM detected {stable_masks} stable masks on image {img_file}")

SAM detected 173 masks on image batch_segmentation_experiment/images/esca_000_cam1.jpg
 -------- mask 0 ---------
{
  "area": 2738,
  "bbox": [
    0,
    0,
    102,
    36
  ],
  "predicted_iou": 1.0017647743225098,
  "point_coords": [
    [
      100.0,
      11.25
    ]
  ],
  "stability_score": 0.984425961971283,
  "crop_box": [
    0,
    0,
    1280,
    720
  ]
}
 -------- mask 1 ---------
{
  "area": 25221,
  "bbox": [
    967,
    591,
    290,
    127
  ],
  "predicted_iou": 0.9952616691589355,
  "point_coords": [
    [
      1220.0,
      663.75
    ]
  ],
  "stability_score": 0.9756488800048828,
  "crop_box": [
    0,
    0,
    1280,
    720
  ]
}
 -------- mask 2 ---------
{
  "area": 22142,
  "bbox": [
    779,
    101,
    219,
    142
  ],
  "predicted_iou": 0.990822970867157,
  "point_coords": [
    [
      860.0,
      101.25
    ]
  ],
  "stability_score": 0.9734296202659607,
  "crop_box": [
    0,
    0,
    1280,
    720
  ]
}
 -------- mask 3 ---------
 --------