# SAM2 Finetuning

In [8]:
# - Packages
import numpy as np
import torch
import cv2
import os

from PIL import Image
from pathlib import Path
import matplotlib.pyplot as plt
import sys
import os
import numpy as np
import torch
import matplotlib.pyplot as plt

import torchvision.transforms.functional as F
from torchvision.ops import masks_to_boxes
import os

os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
import numpy as np
import torch
import matplotlib.pyplot as plt
from PIL import Image

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor


## SAM2 Image Segmentation

In [9]:
# - Global Variables

data_dir= Path("./snemi/" )
raw_image_dir = data_dir / 'image_pngs'
seg_image_dir = data_dir / 'seg_pngs'

sam2_checkpoint = "./sam2_hiera_large.pt"
model_cfg = "./sam2_hiera_l.yaml"
itrs = 10000

checkpoint_dir = './checkpoints/all'
if not os.path.exists(checkpoint_dir):
    os.makedirs('./checkpoints/all')


### Data Preparation and Data Reading

In [10]:
# - Prepare dataset
data = []
for ff, name in enumerate(os.listdir(raw_image_dir)):
    data.append({'image': raw_image_dir / f'image{ff:04d}.png', 'annotation': seg_image_dir / f'seg{ff:04d}.png'})
# - split train dataset and validation dataset
valid_data = data[80:]
data = data[:80]

In [None]:
# - Read Data

def read_batch(data):
     #  select image
     ent  = data[np.random.randint(len(data))] # choose random entry
     Img = cv2.imread(str(ent["image"])) # read image
     ann_map_grayscale = np.array(Image.open(ent['annotation']))
     ann_map = np.stack((ann_map_grayscale, ) * 3, axis = -1)
     print(f'image shape: {np.array(Img).shape}')
     print(f'annotation shape: {np.array(ann_map).shape}')
     
     # resize image
     r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]]) # scalling factor
     Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r)))
     ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)),interpolation=cv2.INTER_NEAREST)

     # - get bounding box
     inds = np.unique(ann_map_grayscale)[1:] # load all indices

     masks = [] 
     for ind in inds:
        masks.append(ann_map_grayscale == ind)
     masks = np.array(masks)
     masks_tensor = torch.from_numpy(masks)

     boxes = masks_to_boxes(masks_tensor)
     input_boxes = boxes.numpy()



     # Get binary masks and points
     mat_map = ann_map
     inds = np.unique(mat_map)[1:] # load all indices
     points= []
     masks = [] 
     for ind in inds:
          mask=(mat_map == ind).astype(np.uint8) # make binary mask
          masks.append(mask)
          coords = np.argwhere(mask > 0) # get all coordinates in mask
          yx = np.array(coords[np.random.randint(len(coords))]) # choose random point/coordinate
          points.append([[yx[1], yx[0]]])
     return Img,np.array(masks),np.array(points), input_boxes, np.ones([len(masks),1])
img, mask_arr,  point_arr, input_boxes, one_arr= read_batch(data)
print(f'mask_arr shape: {mask_arr.shape}')

In [29]:
def resize_masks_opencv(mask, output_size=(256, 256)):
    """
    Reshapes the input NumPy array by selecting the first of the 3 redundant channels,
    then resizes each mask to the given output size using nearest-neighbor interpolation.

    Args:
    - mask (np.ndarray): Array of shape (223, 1024, 1024, 3).
    - output_size (tuple): Desired output size (H, W) for the mask. Default is (256, 256).

    Returns:
    - resized_masks: Resized masks of shape (223, 1, 256, 256).
    """
    # Select the first channel (shape becomes (223, 1024, 1024))
    masks_single_channel = mask[..., 0]
    
    # Reshape to (223, 1, 1024, 1024) to add the single channel back using np.expand_dims
    reshaped_masks = np.expand_dims(masks_single_channel, axis=1)

    # Initialize the array to store resized masks (223, 1, 256, 256)
    resized_masks = np.zeros((reshaped_masks.shape[0], 1, output_size[0], output_size[1]), dtype=reshaped_masks.dtype)
    
    # Loop over each mask and resize using cv2.resize with nearest-neighbor interpolation
    for i in range(reshaped_masks.shape[0]):
        resized_masks[i, 0, :, :] = cv2.resize(reshaped_masks[i, 0, :, :], output_size, interpolation=cv2.INTER_NEAREST)
    
    return resized_masks

### Load Model and Finetuning Model

In [13]:
# - Build Model
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device='cpu')
predictor = SAM2ImagePredictor(sam2_model)
# - Set Training Parameters
predictor.model.sam_mask_decoder.train(True)
predictor.model.sam_prompt_encoder.train(True)
predictor.model.image_encoder.train(True)
# - Set Optimizer and Scaler
optimizer=torch.optim.AdamW(params=predictor.model.parameters(),lr=1e-5,weight_decay=4e-5)
scaler = torch.amp.GradScaler()
# - Force cuda operations to be synchronized
!export CUDA_LAUNCH_BLOCKING=1
# - Define Devices
device0 = torch.device('cuda:0')
device1 = torch.device('cuda:1')
device_cpu = torch.device('cpu')


In [None]:
# - Finetuning by Using Points, Boxes, Mask Logits
for itr in range(itrs):
    with torch.amp.autocast(device_type='cpu'):
        image, mask, input_point, input_boxes, input_label = read_batch(data)
        if mask.shape[0] == 0: continue
        predictor.set_image(image)

        reshaped_masks = resize_masks_opencv(mask)


        # - prompt encoding
        mask_input, unnorm_coords, labels, unnorm_box = predictor._prep_prompts(input_point, input_label, box=input_boxes, mask_logits=None, normalize_coords=True)
        sparse_embeddings, dense_embeddings = predictor.model.sam_prompt_encoder(points=(unnorm_coords, labels), boxes=unnorm_box, masks=mask_input)



        # - mask decoder
        batched_mode = unnorm_coords.shape[0] > 1
        high_res_features = [feat_level[-1].unsqueeze(0) for feat_level in predictor._features["high_res_feats"]]
        low_res_masks, prd_scores, _, _ = predictor.model.sam_mask_decoder(
            image_embeddings=predictor._features["image_embed"][-1].unsqueeze(0),
            image_pe=predictor.model.sam_prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=True,
            repeat_image=batched_mode,
            high_res_features=[feat for feat in high_res_features],
        )


        prd_masks = predictor._transforms.postprocess_masks(low_res_masks, predictor._orig_hw[-1])

        # - segmentation loss calculation on CPU
        gt_mask = torch.tensor(mask.astype(np.float32))[:, :, :, 0]
        prd_mask = torch.sigmoid(prd_masks[:, 0])
        seg_loss = (-gt_mask * torch.log(prd_mask + 0.00001) - (1 - gt_mask) * torch.log((1 - prd_mask) + 0.00001)).mean()

        # - score loss calculation (intersection over union) IOU on CPU
        inter = (gt_mask * (prd_mask > 0.5)).sum(1).sum(1)
        iou = inter / (gt_mask.sum(1).sum(1) + (prd_mask > 0.5).sum(1).sum(1) - inter)
        score_loss = torch.abs(prd_scores[:, 0] - iou).mean()

        # Move the total loss to GPU for backpropagation
        loss = (seg_loss + score_loss * 0.05)
        predictor.model.zero_grad()  # empty gradient
        scaler.scale(loss).backward()  # Backpropagate
        scaler.step(optimizer)
        scaler.update()  # Mix precision

        if (itr + 1) % 1000 == 0:
            torch.save(predictor.model.state_dict(), f"./checkpoints/all/large_model_all_{itr}.torch")
            print("save model")

        if itr == 0:
            mean_iou = 0
        mean_iou = mean_iou * 0.99 + 0.01 * np.mean(iou.cpu().detach().numpy())
        print("step)", itr, "Accuracy(IOU)=", mean_iou)
