# SAM Video Segmentation Finetuning

In [9]:
from pathlib import Path
import os
import matplotlib.pyplot as plt
import numpy as np
import torch
from sam2.build_sam import build_sam2_video_predictor
import concurrent.futures
from tqdm import tqdm
from PIL import Image
import torch.optim as optim

In [10]:
# - Global Variables

CHECKPOINT = "./sam2_hiera_large.pt"
CONFIG = "sam2_hiera_l.yaml"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # Use GPU if available

sam2_model = build_sam2_video_predictor(CONFIG, CHECKPOINT, device=DEVICE)


In [20]:
# - Utils Function
import logging
import os
import random
import sys
import time
from datetime import datetime

import dateutil.tz
import numpy as np
import torch
from torch.autograd import Function

def set_log_dir(root_dir, exp_name):
    path_dict = {}
    os.makedirs(root_dir, exist_ok=True)

    # set log path
    exp_path = os.path.join(root_dir, exp_name)
    now = datetime.now(dateutil.tz.tzlocal())
    timestamp = now.strftime('%Y_%m_%d_%H_%M_%S')
    prefix = exp_path + '_' + timestamp
    os.makedirs(prefix)
    path_dict['prefix'] = prefix

    # set checkpoint path
    ckpt_path = os.path.join(prefix, 'Model')
    os.makedirs(ckpt_path)
    path_dict['ckpt_path'] = ckpt_path

    log_path = os.path.join(prefix, 'Log')
    os.makedirs(log_path)
    path_dict['log_path'] = log_path

    # set sample image path for fid calculation
    sample_path = os.path.join(prefix, 'Samples')
    os.makedirs(sample_path)
    path_dict['sample_path'] = sample_path

    return path_dict



def create_logger(log_dir, phase='train'):
    time_str = time.strftime('%Y-%m-%d-%H-%M')
    log_file = '{}_{}.log'.format(time_str, phase)
    final_log_file = os.path.join(log_dir, log_file)
    head = '%(asctime)-15s %(message)s'
    logging.basicConfig(filename=str(final_log_file),
                        format=head)
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)
    console = logging.StreamHandler()
    logging.getLogger('').addHandler(console)

    return logger

In [11]:
# - Print Out Sam2 Model Architecutre
sam2_model

SAM2VideoPredictor(
  (image_encoder): ImageEncoder(
    (trunk): Hiera(
      (patch_embed): PatchEmbed(
        (proj): Conv2d(3, 144, kernel_size=(7, 7), stride=(4, 4), padding=(3, 3))
      )
      (blocks): ModuleList(
        (0-1): 2 x MultiScaleBlock(
          (norm1): LayerNorm((144,), eps=1e-06, elementwise_affine=True)
          (attn): MultiScaleAttention(
            (qkv): Linear(in_features=144, out_features=432, bias=True)
            (proj): Linear(in_features=144, out_features=144, bias=True)
          )
          (drop_path): Identity()
          (norm2): LayerNorm((144,), eps=1e-06, elementwise_affine=True)
          (mlp): MLP(
            (layers): ModuleList(
              (0): Linear(in_features=144, out_features=576, bias=True)
              (1): Linear(in_features=576, out_features=144, bias=True)
            )
            (act): GELU(approximate='none')
          )
        )
        (2): MultiScaleBlock(
          (norm1): LayerNorm((144,), eps=1e-06, elemen

In [12]:
# - Load Layer Parameters
sam_layers = ([]
            # + list(sam2_model.image_encoder.parameters())
            + list(sam2_model.sam_prompt_encoder.parameters())
            + list(sam2_model.sam_mask_decoder.parameters())
            )
mem_layers = ([]
            + list(sam2_model.obj_ptr_proj.parameters())
            + list(sam2_model.memory_encoder.parameters())
            + list(sam2_model.memory_attention.parameters())
            + list(sam2_model.mask_downsample.parameters())
            )

In [13]:
# - Define Optimizers
if len(sam_layers) == 0:
    optimizer1 = None
else:
    optimizer1 = optim.Adam(sam_layers, lr=1e-4, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)
if len(mem_layers) == 0:
    optimizer2 = None
else:
    optimizer2 = optim.Adam(mem_layers, lr=1e-8, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

In [18]:
# Float Quantization
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()
if torch.cuda.get_device_properties(0).major >= 8:
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True