# SAM Postprocessing for DFU2024 Challenge

If running locally using jupyter, first install `segment-anything-2` in your environment using the [installation instructions](https://github.com/facebookresearch/segment-anything-2#installation) in the repository.

## Set-up

Necessary imports and helper functions for displaying points, boxes, and masks.

In [2]:
import torch
import os
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

In [3]:
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

if torch.cuda.get_device_properties(0).major >= 8:
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

## Data preparation for bounding box provided by skimage.region props saved in .csv file

In [None]:
# Custom converter for the tuple
def tuple_converter(s):
    # Convert the bytes to a string and strip the parentheses
    s = s.decode("utf-8").strip('"()"')
    # Split by commas and convert to integers
    split_values = s.split(',')
    # Reorder the columns: 2nd, 1st, 4th, 3rd to match SAM2 format
    reordered_values = [split_values[1], split_values[0], split_values[3], split_values[2]]
    # Convert the reordered list to integers and then to a numpy array
    return np.array(list(map(int, reordered_values)))
    
# Load the data
data_bbox = np.genfromtxt(
    '.../DFU/bbox_coords_test.csv', 
    delimiter=';', 
    skip_header=1,
    dtype=[('filename', 'U50'), ('values', 'O')],  # U50: max 50 char string, O: object
    converters={1: tuple_converter}  # Apply the tuple converter to the second column
)

In [None]:
# Merging all of the bounding boxes for each of the image files to match SAM2 format
import numpy as np

def merge_values_by_filename(data):
    # Initialize a dictionary to hold the merged values
    merged_dict = {}
    # Iterate over the data to merge arrays by filename
    for item in data:
        filename = item['filename']
        values = item['values']      
        if filename in merged_dict:
            merged_dict[filename].append(values)
        else:
            merged_dict[filename] = [values]
    # Convert the dictionary back to a structured numpy array
    merged_data = np.array([(filename, np.array(values)) for filename, values in merged_dict.items()],
                           dtype=[('filename', 'U50'), ('values', 'O')])
    return merged_data

bbox_merged = merge_values_by_filename(data_bbox)

## Selecting objects with SAM 2

First, load the SAM 2 model and predictor. Change the path below to point to the SAM 2 checkpoint. Running on CUDA and using the default model are recommended for best results.

In [9]:
sam2_checkpoint = "../segment-anything-2/checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"

sam2_model = build_sam2(model_cfg, sam2_checkpoint, device="cuda")

predictor = SAM2ImagePredictor(sam2_model)

## Loop for masks generation

In [25]:
output_folder = "test-set"
dateset_folder_name = '../DFU/DFUC2024_test_release/'

if not os.path.exists(output_folder):
    os.makedirs(output_folder)

for entry in bbox_merged:
    # Open the image
    with Image.open(dateset_folder_name+entry['filename']) as img:
        image = np.array(img.convert("RGB"))
    
    # Process the bounding boxes
    input_boxes = np.array(entry['values'])
    predictor.set_image(image)
    masks, scores, _ = predictor.predict(
        point_coords=None,
        point_labels=None,
        box=input_boxes,
        multimask_output=False,
    )
    
    if len(masks.shape) == 4:
        merged_mask = np.any(masks.squeeze(1), axis=0)
    else:
        merged_mask = masks.squeeze(0)
    merged_mask_image = (merged_mask * 255).astype(np.uint8)
    merged_mask_pil = Image.fromarray(merged_mask_image)
    merged_mask_pil.save(output_folder+"/"+entry['filename'].split('.')[0]+".png")