# SAM2 Finetuning

In [1]:
# - Packages
import numpy as np
import torch
import cv2
import os

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import sys
import os
import numpy as np
import torch
import matplotlib.pyplot as plt

import torchvision.transforms.functional as F
from torchvision.ops import masks_to_boxes
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor


## SAM2 Image Segmentation

In [2]:
# - Global Variables

data_dir= Path("./snemi/" )
raw_image_dir = data_dir / 'image_pngs'
seg_image_dir = data_dir / 'seg_pngs'
raw_image_slice_dir = data_dir / 'image_slice_pngs'
seg_image_slice_dir = data_dir / 'seg_slice_pngs'
test_iamge_slice_dir = data_dir / "image_slice_test_pngs"

sam2_checkpoint = "./sam2_hiera_large.pt"
model_cfg = "./sam2_hiera_l.yaml"
itrs = 10000
log_dir="./logs"
val_num = 10

checkpoint_dir = './checkpoints/all'
if not os.path.exists(checkpoint_dir):
    os.makedirs('./checkpoints/all')


### Data Preparation and Data Reading

In [3]:
# - Prepare dataset
data = []
for ff, name in enumerate(os.listdir(raw_image_dir)):
    data.append({'image': raw_image_dir / f'image{ff:04d}.png', 'annotation': seg_image_dir / f'seg{ff:04d}.png'})
# - split train dataset and validation dataset
valid_data = data[80:]
data = data[:80]

In [4]:
# - slice image and segmentation for florence sequence length limit
def create_slices(image_path, slice_image_dir):
    img = Image.open(image_path)
     # Get image dimensions
    width, height = img.size
    
    # Calculate the midpoint
    mid_x, mid_y = width // 2, height // 2
    
    # Define the four slices (left, upper, right, lower)
    slices = {
        'top_left': (0, 0, mid_x, mid_y),
        'top_right': (mid_x, 0, width, mid_y),
        'bottom_left': (0, mid_y, mid_x, height),
        'bottom_right': (mid_x, mid_y, width, height)
    }
    
    # Loop through the slices, crop, and save them
    all_slices = []
    for key, coords in slices.items():
        slice_img = img.crop(coords)
        # Format the name: base name + coordinates
        slice_filename = f"{image_path.stem}_{coords[0]}_{coords[1]}_{coords[2]}_{coords[3]}.png"
        slice_img.save( slice_image_dir / slice_filename)
        all_slices.append( slice_image_dir / slice_filename)

    return all_slices

def slice_all_image_seg(data, raw_image_slice_dir, seg_image_slice_dir):
    new_data = []
    for element in data:
        image_path = element['image']
        seg_path = element['annotation']
        image_lst = create_slices(image_path, raw_image_slice_dir)
        seg_lst = create_slices(seg_path, seg_image_slice_dir)

        for i in range(len(image_lst)):
            new_data.append({'image': image_lst[i], 'annotation': seg_lst[i]})
        
    return new_data

data = slice_all_image_seg(data, raw_image_slice_dir, seg_image_slice_dir)
valid_data = slice_all_image_seg(valid_data, raw_image_slice_dir, seg_image_slice_dir)

In [5]:
# - Read Data

def read_batch(data):
     #  select image
     ent  = data[np.random.randint(len(data))] # choose random entry
     Img = cv2.imread(str(ent["image"])) # read image
     ann_map_grayscale = np.array(Image.open(ent['annotation']))
     ann_map = np.stack((ann_map_grayscale, ) * 3, axis = -1)
     
     if Img.shape[0] > 1024 or Img.shape[1] > 1024:
          # Calculate scaling factor
          r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]])  # Scaling factor to fit within 1024x1024
          # Resize the image
          Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r)))
          # Resize the annotation map (with nearest neighbor interpolation)
          ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)), interpolation=cv2.INTER_NEAREST)

     # - get bounding box
     inds = np.unique(ann_map_grayscale)[1:] # load all indices

     masks = [] 
     for ind in inds:
        masks.append(ann_map_grayscale == ind)
     masks = np.array(masks)
     masks_tensor = torch.from_numpy(masks)

     boxes = masks_to_boxes(masks_tensor)
     input_boxes = boxes.numpy()



     # Get binary masks and points
     mat_map = ann_map
     inds = np.unique(mat_map)[1:] # load all indices
     points= []
     masks = [] 
     for ind in inds:
          mask=(mat_map == ind).astype(np.uint8) # make binary mask
          masks.append(mask)
          coords = np.argwhere(mask > 0) # get all coordinates in mask
          yx = np.array(coords[np.random.randint(len(coords))]) # choose random point/coordinate
          points.append([[yx[1], yx[0]]])
     return Img,np.array(masks),np.array(points), input_boxes, np.ones([len(masks),1])
img, mask_arr,  point_arr, input_boxes, one_arr= read_batch(data)
print(f'mask_arr shape: {mask_arr.shape}')

mask_arr shape: (64, 512, 512, 3)


In [6]:
def mask_to_logits(mask, epsilon=1e-6):
    """
    Convert binary mask to mask logits.

    Args:
        mask (torch.Tensor or np.ndarray): A binary mask tensor with values in {0, 1}.
        epsilon (float): A small value to prevent division by zero.
    
    Returns:
        torch.Tensor: The corresponding logits.
    """
    # Ensure the mask is in float32 and has values in range [0, 1]
    mask = mask.astype(np.float32)
    
    # Apply the logit function: log(p / (1 - p))
    logits = np.log(mask + epsilon) - np.log(1 - mask + epsilon)
    
    return logits


def resize_masks_opencv(mask, output_size=(256, 256)):
    """
    Reshapes the input NumPy array by selecting the first of the 3 redundant channels,
    then resizes each mask to the given output size using nearest-neighbor interpolation.
    After resizing, the binary masks (0 and 1) are converted to logits.

    Args:
    - mask (np.ndarray): Array of shape (223, 1024, 1024, 3).
    - output_size (tuple): Desired output size (H, W) for the mask. Default is (256, 256).

    Returns:
    - resized_mask_logits: Resized mask logits of shape (223, 1, 256, 256).
    """
    # Select the first channel (shape becomes (223, 1024, 1024))
    masks_single_channel = mask[..., 0]
    
    # Reshape to (223, 1, 1024, 1024) to add the single channel back using np.expand_dims
    reshaped_masks = np.expand_dims(masks_single_channel, axis=1)

    # Initialize the array to store resized masks (223, 1, 256, 256)
    resized_masks = np.zeros((reshaped_masks.shape[0], 1, output_size[0], output_size[1]), dtype=reshaped_masks.dtype)
    
    # Loop over each mask and resize using cv2.resize with nearest-neighbor interpolation
    for i in range(reshaped_masks.shape[0]):
        resized_masks[i, 0, :, :] = cv2.resize(reshaped_masks[i, 0, :, :], output_size, interpolation=cv2.INTER_NEAREST)

    # Convert the resized binary masks (0 and 1) to mask logits
    resized_mask_logits = np.zeros_like(resized_masks, dtype=np.float32)
    for i in range(resized_masks.shape[0]):
        resized_mask_logits[i, 0, :, :] = mask_to_logits(resized_masks[i, 0, :, :])
    
    return resized_mask_logits

### Load Model and Finetuning Model

In [7]:
# - Build Model
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device='cpu')
predictor = SAM2ImagePredictor(sam2_model)
# - Set Training Parameters
predictor.model.sam_mask_decoder.train(True)
predictor.model.sam_prompt_encoder.train(True)
predictor.model.image_encoder.train(True)
# - Set Optimizer and Scaler
optimizer=torch.optim.AdamW(params=predictor.model.parameters(),lr=1e-5,weight_decay=4e-5)
scaler = torch.amp.GradScaler()
# - Force cuda operations to be synchronized
!export CUDA_LAUNCH_BLOCKING=1
# - Define Devices
device0 = torch.device('cuda:0')
device1 = torch.device('cuda:1')
device_cpu = torch.device('cpu')


In [8]:
# - Finetuning by Using Points, Boxes, Mask Logits
train_loss_file = open(os.path.join(log_dir, "sam_train_loss.txt"), "w")
val_loss_file = open(os.path.join(log_dir, "sam_val_loss.txt"), "w")
for itr in range(itrs):
    with torch.amp.autocast(device_type='cpu'):
        image, mask, input_point, input_boxes, input_label = read_batch(data)
        if mask.shape[0] == 0: continue
        predictor.set_image(image)

        reshaped_masks = resize_masks_opencv(mask)


        # - prompt encoding
        mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=input_boxes, mask_logits=reshaped_masks, normalize_coords=True)
        sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(unnorm_coords, labels), boxes=unnorm_box, masks=mask_input)



        # - mask decoder
        batched_mode = unnorm_coords.shape[0] > 1
        high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
        low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
            image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
            image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=True,
            repeat_image=batched_mode,
            high_res_features=[feat for feat in high_res_features],
        )


        prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

        # - segmentation loss calculation on CPU
        gt_mask = torch.tensor(mask.astype(np.float32))[:, :, :, 0]
        prd_mask = torch.sigmoid(prd_masks[:, 0])
        seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean()

        # - score loss calculation (intersection over union) IOU on CPU
        inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
        iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
        score_loss = torch.abs(prd_scores[:, 0] - iou).mean()

        # Move the total loss to GPU for backpropagation
        loss = (seg_loss + score_loss * 0.05)
        predictor.model.zero_grad()  # empty gradient
        scaler.scale(loss).backward()  # Backpropagate
        scaler.step(optimizer)
        scaler.update()  # Mix precision

        if (itr + 1) % 1000 == 0:
            torch.save(predictor.model.state_dict(), f"./checkpoints/all/large_model_slice_{itr + 1}.torch")
            print("save model")

        if itr == 0:
            mean_iou = 0
        mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())
        print(f"Training IOU at iteration {itr + 1}: {mean_iou}")
        # Save training loss for this epoch
        train_loss_file.write(f"{itr + 1},{mean_iou}\n")
        train_loss_file.flush()
    
    # - Evaluation Step
    # Evaluation step on validation data after each iteration

    total_iou = 0
    for i in range(val_num):
        with torch.no_grad():  # Disable gradient calculation for inference
            img, mask, input_points, input_boxes, input_labels = read_batch(valid_data)

            predictor.set_image(img)  # Set image in the predictor (Image Encoder)

            # Prompt Encoder + Mask Decoder
            masks, scores, logits = predictor.predict(
                point_coords=input_points,
                point_labels=input_labels,
                box=input_boxes,
                multimask_output=False
            )

            prd_mask = torch.sigmoid(torch.tensor(masks[:, 0], dtype=torch.float32))
            gt_mask = torch.tensor(mask.astype(np.float32))[:, :, :, 0]

            # Calculate IOU for validation data
            inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
            iou_val = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)

            total_iou += iou_val.mean().cpu().numpy()
            print(total_iou)
    
    
    avg_iou = total_iou / val_num
    print(f"Validation IOU at iteration {itr}: {avg_iou}")

    # Save validation loss for this epoch
    val_loss_file.write(f"{itr + 1},{avg_iou}\n")
    val_loss_file.flush()
# Close the log files
train_loss_file.close()
val_loss_file.close()


Training IOU at iteration 1: 0.008890880346298218
tensor([0.2999, 0.6136, 0.9041, 0.8043, 0.8168, 0.9331, 0.9613, 0.2648, 0.9659,
        0.8611, 0.9293, 0.8414, 0.8593, 0.9043, 0.9238, 0.9344, 0.7657, 0.8494,
        0.8800, 0.8904, 0.8756, 0.7221, 0.8314, 0.9111, 0.9258, 0.9310, 0.8420,
        0.8553, 0.8519, 0.8424, 0.8743, 0.8981, 0.9190, 0.7528, 0.8654, 0.8700,
        0.9080, 0.8211, 0.9126, 0.8986, 0.8508, 0.7543, 0.8369, 0.9035, 0.6735,
        0.2425, 0.8741, 0.8169, 0.7463, 0.8405, 0.8963, 0.8719, 0.8676, 0.1012,
        0.3604, 0.9195, 0.7761, 0.8594, 0.9032, 0.8063, 0.8484, 0.8423, 0.8312,
        0.9064, 0.8433, 0.8223, 0.7826, 0.6420, 0.7427, 0.9013])
tensor([0.9381, 0.8797, 0.8792, 0.5185, 0.5893, 0.8863, 0.8253, 0.7427, 0.9185,
        0.9401, 0.9062, 0.7980, 0.8295, 0.9360, 0.8190, 0.8059, 0.8022, 0.8744,
        0.8075, 0.8983, 0.9080, 0.8226, 0.7326, 0.6758, 0.8886, 0.8895, 0.0500,
        0.9188, 0.7467, 0.8704, 0.8053, 0.8687, 0.9171, 0.8401, 0.8514, 0.9018,
     