# SAM Video Segmentation Finetuning

In [1]:
from pathlib import Path
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from sam2.build_sam import build_sam2_video_predictor
import concurrent.futures
from tqdm import tqdm
from PIL import Image
import torch.optim as optim
from torchvision.ops import masks_to_boxes
import cv2

In [2]:
# - Global Variables

CHECKPOINT = "./sam2_hiera_large.pt"
CONFIG = "sam2_hiera_l.yaml"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available

sam2_model = build_sam2_video_predictor(CONFIG, CHECKPOINT, device=DEVICE)

log_dir = 'logs'

exp_name = 'sam_video_finetune'

data_dir = Path('./snemi')
raw_image_dir = data_dir / 'image_jpgs'
seg_dir = data_dir / 'seg_jpgs'




In [3]:
# - Utils Function
import logging
import os
import random
import sys
import time
from datetime import datetime

import dateutil.tz
import numpy as np
import torch
from torch.autograd import Function

def set_log_dir(root_dir, exp_name):
    path_dict = {}
    os.makedirs(root_dir, exist_ok=True)

    # set log path
    exp_path = os.path.join(root_dir, exp_name)
    now = datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    prefix = exp_path + '_' + timestamp
    os.makedirs(prefix)
    path_dict['prefix'] = prefix

    # set checkpoint path
    ckpt_path = os.path.join(prefix, 'Model')
    os.makedirs(ckpt_path)
    path_dict['ckpt_path'] = ckpt_path

    log_path = os.path.join(prefix, 'Log')
    os.makedirs(log_path)
    path_dict['log_path'] = log_path

    # set sample image path for fid calculation
    sample_path = os.path.join(prefix, 'Samples')
    os.makedirs(sample_path)
    path_dict['sample_path'] = sample_path

    return path_dict



def create_logger(log_dir, phase='train'):
    time_str = time.strftime('%Y-%m-%d-%H-%M')
    log_file = '{}_{}.log'.format(time_str, phase)
    final_log_file = os.path.join(log_dir, log_file)
    head = '%(asctime)-15s %(message)s'
    logging.basicConfig(filename=str(final_log_file),
                        format=head)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    console = logging.StreamHandler()
    logging.getLogger('').addHandler(console)

    return logger
path_helper = set_log_dir(log_dir, exp_name)
logger = create_logger(path_helper['log_path'])


In [4]:
# - Print Out Sam2 Model Architecutre
sam2_model

SAM2VideoPredictor(
  (image_encoder): ImageEncoder(
    (trunk): Hiera(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 144, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
      )
      (blocks): ModuleList(
        (0-1): 2 x MultiScaleBlock(
          (norm1): LayerNorm((144,), eps=1e-06, elementwise_affine=True)
          (attn): MultiScaleAttention(
            (qkv): Linear(in_features=144, out_features=432, bias=True)
            (proj): Linear(in_features=144, out_features=144, bias=True)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((144,), eps=1e-06, elementwise_affine=True)
          (mlp): MLP(
            (layers): ModuleList(
              (0): Linear(in_features=144, out_features=576, bias=True)
              (1): Linear(in_features=576, out_features=144, bias=True)
            )
            (act): GELU(approximate='none')
          )
        )
        (2): MultiScaleBlock(
          (norm1): LayerNorm((144,), eps=1e-06, elemen

In [5]:
# - Load Layer Parameters
sam_layers = ([]
            # + list(sam2_model.image_encoder.parameters())
            + list(sam2_model.sam_prompt_encoder.parameters())
            + list(sam2_model.sam_mask_decoder.parameters())
            )
mem_layers = ([]
            + list(sam2_model.obj_ptr_proj.parameters())
            + list(sam2_model.memory_encoder.parameters())
            + list(sam2_model.memory_attention.parameters())
            + list(sam2_model.mask_downsample.parameters())
            )

In [6]:
# - Define Optimizers
if len(sam_layers) == 0:
    optimizer1 = None
else:
    optimizer1 = optim.Adam(sam_layers, lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
if len(mem_layers) == 0:
    optimizer2 = None
else:
    optimizer2 = optim.Adam(mem_layers, lr=1e-8, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

In [7]:
# Float Quantization
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

In [36]:
# - Load Data Mannually
img_file_path_lst = sorted([raw_image_dir / img_file for img_file in os.listdir(raw_image_dir)])
seg_file_path_lst = sorted([seg_dir / img_file for img_file in os.listdir(seg_dir)])
data = []
for i in range(len(img_file_path_lst)):
    data.append({'image': img_file_path_lst[i], 'annotation': seg_file_path_lst[i]})
# - split train dataset and validation dataset
valid_data = data[80:]
data = data[:80]
# - read batch for jpg files 
def read_frame(data, idx):
     #  select image
     ent  = data[idx] # choose random entry
     Img = cv2.imread(str(ent["image"])) # read image
     ann_map_grayscale = np.array(Image.open(ent['annotation']))
     ann_map = ann_map_grayscale
     ann_map_grayscale = ann_map[...,0]
     
     if Img.shape[0] > 1024 or Img.shape[1] > 1024:
          # Calculate scaling factor
          r = np.min([1024 / Img.shape[1], 1024 / Img.shape[0]])  # Scaling factor to fit within 1024x1024
          # Resize the image
          Img = cv2.resize(Img, (int(Img.shape[1] * r), int(Img.shape[0] * r)))
          # Resize the annotation map (with nearest neighbor interpolation)
          ann_map = cv2.resize(ann_map, (int(ann_map.shape[1] * r), int(ann_map.shape[0] * r)), interpolation=cv2.INTER_NEAREST)

     # - get bounding box
     inds = np.unique(ann_map_grayscale)[1:] # load all indices

     masks = [] 
     for ind in inds:
        masks.append(ann_map_grayscale == ind)
     masks = np.array(masks)
     masks_tensor = torch.from_numpy(masks)

     boxes = masks_to_boxes(masks_tensor)
     input_boxes = boxes.numpy()



     # Get binary masks and points
     mat_map = ann_map
     inds = np.unique(mat_map)[1:] # load all indices
     points= []
     masks = [] 
     for ind in inds:
          mask=(mat_map == ind).astype(np.uint8) # make binary mask
          masks.append(mask)
          coords = np.argwhere(mask > 0) # get all coordinates in mask
          yx = np.array(coords[np.random.randint(len(coords))]) # choose random point/coordinate
          points.append([[yx[1], yx[0]]])
     return Img,np.array(masks),np.array(points), input_boxes, np.ones([len(masks),1])
img, mask_arr,  point_arr, input_boxes, one_arr= read_frame(data, 0)
print(f'img shape: {img.shape}')
print(f'mask_arr shape: {mask_arr.shape}')
print(f'point_arr shape: {point_arr.shape}')
print(f'input_boxes shape: {input_boxes.shape}')
print(f'point_labels shape: {one_arr.shape}')


img shape: (1024, 1024, 3)
mask_arr shape: (255, 1024, 1024, 3)
point_arr shape: (255, 1, 2)
input_boxes shape: (255, 4)
point_labels shape: (255, 1)


In [None]:
 # - Load Data only for bbox, point labels; return one frame info. If there are multiple videos, we need to modify the dataloader
# from torch.utils.data import Dataset
# class SnemiDataset(Dataset):
#     def __init__(self, img_files, ann_files):
#         self.img_files = img_files
#         self.ann_files = ann_files

#     def __len__(self):
#         return len(self.img_files)
    
#     def __getitem__(self, idx):
#         img_file_name = str(img_files[idx])
#         ann_file_name = str(img_files[idx])
#         mask_dict = {}
#         point_label_dict = {}
#         pt_dict = {}
#         image_meta_dict = {}

#         # - read image and segmentation
#         img = cv2.imread(img_file_name) # read image
#         img_tensor = torch.tensor(img).permutae(2, 0, 1)
#         ann_map_grayscale = np.array(Image.open(ann_file_name))
#         ann_map = np.stack((ann_map_grayscale, ) * 3, axis = -1)

#         ids = np.unique(ann_map)[1:]

#         for id in ids:
#             mask = ann_map_grayscale == id
#             mask_dict[id] = torch.from_numpy(mask)



#         return {
#                 'image':img_tensor,
#                 'label': mask_dict,
#                 'p_label':point_label_dict,
#                 'pt':pt_dict,
#                 'image_meta_dict':image_meta_dict,
#             }


In [37]:
# - Begin Training
best_acc = 0.0
best_tol = 1e4
best_dice = 0.0
epochs = 100

In [None]:
# - Start Training

for epoch in range(epochs):
    sam2_model.train()
    time_start = time.time()
    loss, prompt_loss, non_prompt_loss = function.train_sam(args, net, optimizer1, optimizer2, nice_train_loader, epoch)
    logger.info(f'Train loss: {loss}, {prompt_loss}, {non_prompt_loss} || @ epoch {epoch}.')
    time_end = time.time()
    print('time_for_training ', time_end - time_start)