In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import os
import glob
import json
import shutil

In [2]:
os.chdir('/home/miner/Desktop/work/mine/repository/segment-anything/notebooks')

In [3]:
def show_anns(anns):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    polygons = []
    color = []
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        ax.imshow(np.dstack((img, m*0.35)))

In [4]:
def save_anns_images(anns, image, prefix):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    ax = plt.gca()
    ax.set_autoscale_on(False)
    polygons = []
    color = []

    for idx, ann in enumerate(sorted_anns):
        m = ann['segmentation']
        img = image.copy()
        
        # Apply the segmentation mask
        for i in range(3):
            img[:,:,i] = img[:,:,i] * m
            
        # Extract the crop box from the annotation
        crop_box = ann['crop_box']  # Assuming the crop box is in XYWH format
        x, y, w, h = crop_box
        x1, y1, x2, y2 = x, y, x + w, y + h

        # Crop the image using the crop box
        cropped_img = img[y1:y2, x1:x2]

        # Save the image with the given prefix and index
        output_filename = f"{prefix}_{idx}.png"
        cv2.imwrite(output_filename, cv2.cvtColor(cropped_img, cv2.COLOR_RGB2BGR))
        
        # Save the metadata to a JSON file
        metadata_filename = f"{prefix}_meta_{idx}.json"
        metadata = {
            'bbox': ann['bbox'],
            'area': ann['area'],
            'predicted_iou': ann['predicted_iou'],
            'point_coords': ann['point_coords'],
            'stability_score': ann['stability_score'],
            'crop_box': ann['crop_box']
        }
        with open(metadata_filename, 'w') as json_file:
            json.dump(metadata, json_file)

In [5]:
def save_ocr_results(results, prefix):
    # Save the metadata to a JSON file
    filename = f"{prefix}_ocr.json"
    
    json_result = [
        {
            "bounding_box": coords,
            "text": text,
            "confidence": confidence,
        }
        for coords, text, confidence in results
    ]
    
    def handle_non_serializable(obj):
        if isinstance(obj, np.int64):
            return int(obj)
        raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')
    
    with open(filename, "w") as f:
        json.dump(json_result, f, default=handle_non_serializable, ensure_ascii=False, indent=4)

In [6]:


def get_png_filepaths(folder_path):
    current_folder = os.getcwd()
    os.chdir(folder_path)
    png_files = [os.path.basename(file) for file in glob.glob("*.png")]
    os.chdir(current_folder)
    return png_files
    
def get_png_filenames(folder_path):
    current_folder = os.getcwd()
    os.chdir(folder_path)
    png_files = [(folder_path + '/' + os.path.basename(file)) for file in glob.glob("*.png")]
    os.chdir(current_folder)
    return png_files

In [7]:
def process_image(image_file, prefix):
    #Segment anything
    image = cv2.imread(image_file)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    
    masks = mask_generator.generate(image)
    print("masks len: " + str(len(masks)))
    
    plt.figure(figsize=(20,200))
    plt.imshow(image)
    save_anns_images(masks, image, prefix)
    show_anns(masks)
#     plt.axis('off')
    plt.show() 
    
    #OCR
    result = ocr_reader.readtext(image_file)
    print(result)
    save_ocr_results(result, prefix)



In [8]:
def list_files_in_current_folder():
    current_folder = os.getcwd()
    files = os.listdir(current_folder)
    return files

files_list = list_files_in_current_folder()
print(files_list)





['predictor_example.ipynb', 'sam_onnx_quantized_example.onnx', 'sam_vit_l_0b3195.pth', 'onnx_model_example.ipynb', '.ipynb_checkpoints', 'sam_vit_h_4b8939.pth', 'webDemoInput', 'images', 'webDemoOutput', 'sam_onnx_example.onnx', 'sam_vit_b_01ec64.pth', 'automatic_mask_generator_example.ipynb', 'web_demo.ipynb']


In [9]:
def clear_folder(folder_path):
    for filename in os.listdir(folder_path):
        file_path = os.path.join(folder_path, filename)
        if os.path.isfile(file_path):
            os.remove(file_path)
            
def clear_folder_recursively(folder_path):
    # Check if the folder exists before trying to clear it
    if os.path.exists(folder_path):
        shutil.rmtree(folder_path)
        # Recreate the folder after removing it
        os.makedirs(folder_path)

In [10]:
#Load Segment Anything AI module

import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

#default model
mask_generator = SamAutomaticMaskGenerator(sam)

#custom model
mask_generator_2 = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.86,
    stability_score_thresh=0.92,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=100,  # Requires open-cv to run post-processing
)

In [11]:
#Load EasyOCR module

import easyocr

ocr_reader = easyocr.Reader(['en']) # this needs to run only once to load the model into memory

    There is an imbalance between your GPUs. You may want to exclude GPU 1 which
    has less than 75% of the memory or cores of GPU 0. You can do so by setting
    the device_ids argument to DataParallel, or by setting the CUDA_VISIBLE_DEVICES
    environment variable.


In [12]:
input_path = "webDemoInput"
png_files_array = get_png_filenames(input_path)
print(png_files_array)

#os.chdir('/home/miner/Desktop/work/mine/repository/segment-anything/notebooks')

# Clear output folder
output_path = "webDemoOutput"
clear_folder_recursively(output_path)

['webDemoInput/www.babiesrus.ca_en_Graco-Pack--n-Play-Day2Dream-Bassinet-Playard--Rainier_9405689D.html(Pixel 5).png', 'webDemoInput/www.babiesrus.ca_en_babiesrus_Category_Baby-Gear_Bassinets(Pixel 5) (1).png', 'webDemoInput/www.babiesrus.ca_en_home(Pixel 5) (1).png', 'webDemoInput/www.babiesrus.ca_en_cart(Pixel 5).png', 'webDemoInput/www.babiesrus.ca_en_home(Pixel 5).png', 'webDemoInput/www.babiesrus.ca_en_babiesrus_Category_Baby-Gear_Bassinets(Pixel 5).png', 'webDemoInput/www.babiesrus.ca_en_Safety-1st-Amherst-Bassinet_54095B36.html(Pixel 5).png']


In [13]:
#Main logic
output_prefix = output_path + '/' + input_path + '/' + png_files_array[1] + '/anns'
os.makedirs(output_path + '/' + input_path + '/' + png_files_array[1])
process_image(png_files_array[1], output_prefix)

# for i, image in enumerate(png_files_array):
#     output_prefix = output_path + '/' + input_path + '/' + png_files_array[i] + '/anns'
#     os.makedirs(output_path + '/' + input_path + '/' + png_files_array[i])
#     process_image(image, output_prefix)

[([[941, 250], [1017, 250], [1017, 338], [941, 338]], 'P', 0.246849482314361), ([[276, 558], [542, 558], [542, 619], [276, 619]], 'something fur', 0.9812687630681373), ([[334, 833], [715, 833], [715, 971], [334, 971]], 'FREE', 0.9492940306663513), ([[254, 946], [871, 946], [871, 1114], [254, 1114]], 'HIPPING', 0.4682771073035096), ([[262, 1120], [786, 1120], [786, 1196], [262, 1196]], 'ONBABY GEAR', 0.9910269048260235), ([[419, 1225], [643, 1225], [643, 1270], [419, 1270]], '1 LEARN MORE', 0.45099549830079916), ([[31, 1593], [403, 1593], [403, 1683], [31, 1683]], 'Bassinets', 0.8037108668533746), ([[33, 1710], [581, 1710], [581, 1771], [33, 1771]], 'Showing 24 of 36 products', 0.8971317783250488), ([[187, 1940], [325, 1940], [325, 2001], [187, 2001]], 'Selling', 0.8492228303310392), ([[457, 2109], [617, 2109], [617, 2153], [457, 2153]], 'FILTER', 0.9999835018602554), ([[61, 2902], [195, 2902], [195, 2947], [61, 2947]], 'Graco', 0.9999827268895326), ([[765, 2902], [937, 2902], [937, 294