In [None]:
from pathlib import Path

import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch

from PIL import Image

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

from utils import plot_mask_and_label

In [None]:
# Input and output directory, suffix of the input files to process:

input_dir = Path("../../data/20240813_data/VID_05_GP_50/")

print(input_dir.exists())

image_suffix = ".jpg"  # Change this to match your image suffix
annotation_suffix = ".txt"  # Change this to match your annotation suffix
annotation_folder = "Annotations"

output_dir = Path("../../data/20240813_data_output-sam2/VID_05_GP_50/")

#annotation_file = Path(
#    "../../data/training_data/1/Annotations/14717-training-images-1.csv"
#)
#print(annotation_file.exists())

Download the large SAM model's weights from here:
https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

For all available models see here: https://github.com/facebookresearch/segment-anything?tab=readme-ov-file#model-checkpoints

In [None]:
from hydra import initialize, core

core.global_hydra.GlobalHydra.instance().clear()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# large sam2: works on gpu > 8g
sam2_checkpoint = "../models/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"
config_dir = "../models/"

# base sam2: smaller version
#sam2_checkpoint = "../../SAM2_models/checkpoints/sam2_hiera_base_plus.pt"
#model_cfg = "sam2_hiera_b+.yaml"

with initialize(version_base=None, config_path=config_dir):
    sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)

predictor = SAM2ImagePredictor(sam2_model)

In [None]:
def process_file(dir_path, filename):
    
    in_file = os.path.join(input_dir, dir_path, filename)
    out_folder = os.path.join(output_dir, dir_path)

    # Take the name of the image without the suffix 
    basename = os.path.basename(in_file).split(image_suffix)[0]
    annotation_file = os.path.join(input_dir, dir_path, annotation_folder, basename + "_3840x2160" + annotation_suffix)
    print(annotation_file)
    print(in_file)

    # If output folder does not exist, create one
    os.makedirs(out_folder, exist_ok=True)
 
    # Load the annotation file related to the image
    df_annotations = pd.read_csv(annotation_file, delimiter="\t", header=None, names=["x", "y", "label_name"])
    print(df_annotations)
    df_points = df_annotations
    
    # Load the image 
    test_image = np.array(Image.open(in_file))
    
    out_files = []
    point_prompts = []
    species = []
    for row in np.arange(0, len(df_annotations), 1):
        out_files.append(os.path.join(output_dir, dir_path, basename + "_" +str(row) + image_suffix))
        
        # Get point coordinates; df_annotations.loc[row, "x"]; float(y), float(x)
        point_prompt = [float(df_annotations.loc[row, "x"]),float(df_annotations.loc[row, "y"])] #2704 - x
            #df_points.loc[row]["points"].apply(get_point_coord).to_list())
        point_prompts.append(point_prompt)

        # Get label names
        species.append(df_points.loc[row]["label_name"])
    

    point_prompts = np.array(point_prompts, dtype=np.float32)
    prompt_labels = np.ones(len(point_prompts))  # positive prompt
    species = np.array(species)

    # get the SAM predictions for each point
    print(f"getting predictions for {len(point_prompts)} point prompts...")
    all_masks = []
    all_scores = []
    all_species = []
    predictor.set_image(test_image)
    for i in range(len(point_prompts)):
        masks, scores, logits = predictor.predict(
            point_coords=point_prompts[i : i + 1],
            point_labels=prompt_labels[i : i + 1],
            multimask_output=False,
        )
        all_masks.append(masks)
        all_scores.append(scores)
        all_species.append(species[i : i + 1])
        
    
    for i in range(len(point_prompts)):
        print(i)
        plot_mask_and_label(all_masks[i], all_species[i], out_files[i], input_point=point_prompts[i], input_label=prompt_labels[i], image=test_image, saveit = True)

    all_masks = None
    all_scores = None
    all_species = None

Recursively pass through the subfolders - start with input_folder. If you find a file ending with ".jpg" (which is specified in "image_suffix") process the file

In [None]:
def process_folder_aux():
    process_folder("")

def process_folder(dir_path):
    full_dir = os.path.join(input_dir, dir_path)
    full_out_dir = os.path.join(output_dir, dir_path)
    
    for item in os.listdir(full_dir):
        item_path = os.path.join(full_dir, item)
        out_item_path = os.path.join(full_out_dir, item)
        
        if os.path.isdir(item_path):
            os.makedirs(out_item_path, exist_ok=True)
            process_folder(os.path.join(dir_path, item))
        elif item.endswith(image_suffix):
            process_file(dir_path, item)


# Run the process
process_folder_aux()