In [10]:
## Experiment prompts from VGG to sam
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import pandas as pd
import json
from skimage.measure import label, regionprops
import glob
import os
from PIL import Image
import torch

import sys

parent_dir = os.path.abspath(os.path.join(os.getcwd(),'..'))

sys.path.append(parent_dir)



In [12]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    


def prepare_image(image, transform):
    image = transform.apply_image(image)
    image = torch.as_tensor(image) 
    return image.permute(2, 0, 1).contiguous()



In [13]:

import random

def get_background_points(image, bbox, binary_mask):
    x_min, y_min, x_max, y_max = bbox
    

    # Extract foreground and background coordinates
    fg_coords = np.argwhere(binary_mask == 255)
    bg_coords = np.argwhere(binary_mask == 0)

    # Points inside the bounding box but outside the foreground
    inside_bbox_outside_fg = [
        (i, j) for i, j in bg_coords if x_min <= j <= x_max and y_min <= i <= y_max
    ]

    # Points just outside the bounding box
    outside_bbox = [
        (i, j) for i, j in bg_coords if
        (x_min - 1 <= j <= x_max + 1 and (i < y_min or i > y_max)) or
        (y_min - 1 <= i <= y_max + 1 and (j < x_min or j > x_max))
    ]

    # Convert image to grayscale to find zero-pixel points
    grayscale_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    zero_pixel_coords = np.argwhere(grayscale_image == 0)

    # Randomly sample the points
    random_points = []

    if inside_bbox_outside_fg:
        random_points.append(random.choice(inside_bbox_outside_fg))

    if outside_bbox:
        random_points.append(random.choice(outside_bbox))

    if zero_pixel_coords.any():
        random_points.append(tuple(random.choice(zero_pixel_coords)))

    # Randomly add remaining points from the background
    random_points_set = set(random_points)
    bg_coords_set = set(tuple(map(tuple, bg_coords)))
    remaining_bg_points = list(bg_coords_set - random_points_set)

    while len(random_points) < 5 and remaining_bg_points:
        random_points.append(random.choice(remaining_bg_points))

    # Limit to a maximum of 5 points
    return random_points[:5]

def reduce_bbox(bbox, percentage, center_point):
    """
    Reduces the size of the bounding box according to the given percentage, centered at the specified point.
    
    :param bbox: The bounding box in the format (xmin, ymin, xmax, ymax).
    :param percentage: The percentage by which to reduce the size of the bounding box.
    :param center_point: The point (x, y) around which to center the reduced bounding box.
    :return: The reduced bounding box in the format (xmin, ymin, xmax, ymax) as integers.
    """
    x_min, y_min, x_max, y_max = bbox
    center_x, center_y = center_point

    # Calculate the current width and height of the bbox
    current_width = x_max - x_min
    current_height = y_max - y_min

    # Calculate the new width and height based on the percentage
    new_area = (current_width * current_height) * (percentage)
    new_width = (new_area * current_width / current_height) ** 0.5
    new_height = (new_area * current_height / current_width) ** 0.5

    # Ensure the new bbox is centered around the given center_point
    new_x_min = center_x - new_width / 2
    new_x_max = center_x + new_width / 2
    new_y_min = center_y - new_height / 2
    new_y_max = center_y + new_height / 2

    return (int(new_x_min), int(new_y_min), int(new_x_max), int(new_y_max))

In [135]:
#image = cv2.imread('C:\\Users\\d42684\\Documents\\STAGE\\CODES\\ACtoolbox-main\\Dataset\\Small_ARIS_Mauzac\\TEST\\All_Originals\\2014-11-16_002000_t8_Obj_frame3065.jpg')
#image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
#2014-11-05_184000_t0_Obj_frame315.jpg
#2014-11-16_002000_t0_Obj_frame508.jpg

In [None]:
# plt.figure(figsize=(10,10))
# plt.imshow(image)
# plt.axis('on')
# plt.show()


In [14]:

sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = r"C:\Users\chapi\Documents\STAGE\CODE\segment-anything-main\sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
#sam.to(device=device)

predictor = SamPredictor(sam)



In [25]:
ImageList = glob.glob(os.path.join(r'C:\Users\chapi\Documents\STAGE\CODE\bgslibrary\SELECTED ORIGINAL IMAGES\*.jpg'))
MaskImageList = glob.glob(os.path.join(r'C:\Users\chapi\Documents\STAGE\CODE\bgslibrary\SELECTED ORIGINAL IMAGES\*.png'))
#C:\Users\d42684\Documents\STAGE\CODES\ACtoolbox-main\Dataset\Small_ARIS_SELUNE\2019-05-02_005000.avi\False Negatives\SIL_1
#ImageList = glob.glob(os.path.join('C:\\Users\\d42684\\Documents\\STAGE\\CODES\\ACtoolbox-main\\Dataset\\Small_ARIS_Mauzac\\TEST\\All_Originals\\*.jpg'))

df = pd.read_csv(r"C:\Users\chapi\Documents\STAGE\CODE\bgslibrary\ACtoolbox-main\SecondVersion_Annotations_Filled.csv").iloc[11761:] ## Ddoesn't take into account videos with masks already generated

#df = pd.read_csv(r'C:\Users\chapi\Documents\STAGE\CODE\segment-anything-main\notebooks\SecondVersion_Annotations_Filled.csv')
#C:\Users\chapi\Documents\STAGE\CODE\bgslibrary\updateddf.csv
filtered_df0 = df[df['file_attributes'] != '{"Object_Count":"0"}']

for i in range(len(ImageList)): # Puedo optimizar al directamente eliminar todos los que tienen region 0

    imageName = ImageList[i].split('\\')[-1]


    filtered_df = filtered_df0[filtered_df0['filename'] == imageName]

    # Check if filtered_df is empty
    if filtered_df.empty:
        continue  # Skip this iteration if no matches found



    print("File attributes (regions):")
    print(json.loads(filtered_df.iloc[0]['file_attributes'])['Object_Count'])


    MaskImage = cv2.imread(MaskImageList[i])
    MaskImage = cv2.cvtColor(MaskImage, cv2.COLOR_BGR2GRAY)
    print(np.unique(MaskImage))
    image = cv2.imread(ImageList[i])
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

    predictor.set_image(image)


    
    print("image atributes: ")
    print(imageName)
    print(filtered_df["region_shape_attributes"])
    print('-----------------------------------------')

    cx = []
    cy = []

    x0 = []
    y0 = []
    y1 = []
    x1 = []
    classname = []
    for i in range(len(filtered_df)):

        info = json.loads(filtered_df.iloc[i]['region_shape_attributes'])
        class_info = json.loads(filtered_df.iloc[i]['region_attributes'])
        if info['name'] == 'point':

            cx.append(info['cx'])
            cy.append(info['cy'])
            #print(f"Point - cx: {cx}, cy: {cy}")

        elif info['name'] == 'rect':
            # Process rectangle information
            x0.append(info['x'])
            y0.append(info['y'])
            x1.append(int(info['x']) + int(info['width']))
            y1.append(int(info['y']) + int(info['height']))
            classname.append(class_info["Object"])
            print(class_info["Object"])
            #print(f"Rectangle - x: {x0}, y: {y0}, x1: {x1}, y1: {y1}")
            #print('wait')
    
    
    Points_array = np.column_stack((np.array(cx), np.array(cy)))
    Bbox_array = np.column_stack((x0, y0, x1, y1))

    print(np.column_stack((x0, y0, x1, y1)))


    ### Set up the loop for several objects using bbox to separate points
    Objects = 0
    mixed = np.zeros((image.shape[0],image.shape[1]))
    for bbox in Bbox_array:

        # Se hacen predicciones individuales para casos en los que haya mas de un objeto  de interes en la imagen, mayormente para los casos smallfish
        SingleObject_Points = []
        for point in Points_array:
             # para los small fish, como la manera en la que se guarda el bbox y su medoid, no dice a cual corresponde, toca limitarlos para podeer asignarlos correctamente
            x, y = point
            x_min, y_min, x_max, y_max = bbox

            if x_min <= x <= x_max and y_min <= y <= y_max:
                SingleObject_Points.append((x,y))

        
        bg_points = get_background_points(image, reduce_bbox(tuple(bbox),0.2, SingleObject_Points[0]), MaskImage)

        input_points = np.array([np.array(SingleObject_Points[0]),np.array(bg_points[0][::-1]),np.array(bg_points[1][::-1]),np.array(bg_points[2][::-1]),
                                 np.array(bg_points[3][::-1]),np.array(bg_points[4][::-1])])

        fg_label = np.ones(len(SingleObject_Points))
        bg_label = np.zeros(len(bg_points))

        input_label = np.concatenate((fg_label,bg_label))

        print(imageName)

        masks, scores, logits = predictor.predict(
        point_coords=input_points,
        point_labels=input_label,
        box=bbox,
        multimask_output=True,)
        
        
        mask_input = logits[np.argmax(scores), :, :]  # Choose the model's best mask

        mask, score, _ = predictor.predict(
        point_coords=input_points,
        point_labels=input_label,
        box=bbox,
        mask_input=mask_input[None, :, :],
        multimask_output=False,
        )








        mixed = squeezed_array = np.squeeze(mask.astype(int), axis=0)*255 + mixed
        print(np.unique(squeezed_array))
        #_, squeezed_array = cv2.threshold(squeezed_array, 127, 255, cv2.THRESH_BINARY)
        mixed =  np.where(mixed != 0, 255, 0)

    plt.imshow(mixed, cmap="gray")
    plt.show()
            
        #mask_image = Image.fromarray(mask_image)

    os.makedirs(os.path.join(r'C:\Users\chapi\Documents\STAGE\CODE\bgslibrary\ACtoolbox-main\Dataset\Small_ARIS_Mauzac\TEST\All_Originals\NewMasks3' ),exist_ok = True)
    
    cv2.imwrite(os.path.join(r'C:\Users\chapi\Documents\STAGE\CODE\bgslibrary\ACtoolbox-main\Dataset\Small_ARIS_Mauzac\TEST\All_Originals\NewMasks3', str('m_'+str(classname[Objects])+"_"+imageName)), mixed)

        #cv2.imwrite(os.path.join(r'C:\Users\d42684\Documents\STAGE\CODES\ACtoolbox-main\Dataset\Small_ARIS_SELUNE\2019-05-02_005000.avi\False Negatives\SIL_1\masks_bbox',str('m_'+imageName)), masks*255)
        #r'C:\Users\d42684\Documents\STAGE\CODES\ACtoolbox-main\Dataset\Small_ARIS_SELUNE\2019-05-02_005000.avi\False Negatives\SIL_1\

        #mask_image.save(os.path.join(r'C:\Users\d42684\Documents\STAGE\CODES\ACtoolbox-main\Dataset\Small_ARIS_Mauzac\TEST\All_Originals\masks_bbox',str('m_'+imageName)))
        # for box in Bbox_array:
        #     show_box(box.numpy(), plt.gca())
        #plt.axis('off')


        #plt.tight_layout()
        #plt.show()
        # masks, scores, logits = predictor.predict(
        #     point_coords=input_point,
        #     point_labels=input_label,
        #     multimask_output=True,   
        # )
    #Objects =  Objects + 1
    #print("wait")