SAMTrack inference notebook by [Alex Spirin](https://github.com/Sxela)

The main difference from the [official repo](https://github.com/z-x-yang/Segment-and-Track-Anything) is that it exports masks separately for easier compositing.

License: [AGPL](https://github.com/Sxela/Segment-and-Track-Anything-CLI/blob/main/LICENSE.txt)
## Local Installation prerequsites
Pre-built groundingDino and spatial-correlation-sampler binaries for local Win11/cuda11.8/python3.10 are downloaded automatically.

If the binaries didn't work for you:

- Get [MSVC Build tools](https://aka.ms/vs/17/release/vs_BuildTools.exe) and install the local c++ dev kit
-  Get [latest nVidia CUDA toolkit](https://developer.nvidia.com/cuda-downloads?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_local) or at least [11.8+](https://developer.nvidia.com/cuda-11-8-0-download-archive?target_os=Windows&target_arch=x86_64&target_version=11&target_type=exe_local) and install it. Don't forget to remove older versions.

In [None]:
#@title Install SAMTrack-CLI
#@markdown originally from https://github.com/z-x-yang/Segment-and-Track-Anything \
#@markdown Restart the notebook after install.
#https://stackoverflow.com/questions/64261546/how-to-solve-error-microsoft-visual-c-14-0-or-greater-is-required-when-inst
import os, platform
try:
  #cd to root if root dir defined
  os.chdir(root_dir)
except:
  root_dir = os.getcwd()

!git clone https://github.com/Sxela/Segment-and-Track-Anything-CLI
os.chdir(os.path.join(root_dir,'Segment-and-Track-Anything-CLI'))

!python -m pip install -e ./sam
if platform.system() == 'Linux':
  !python -m pip install -e git+https://github.com/IDEA-Research/GroundingDINO.git@main#egg=GroundingDINO
else:
  os.makedirs('./src', exist_ok=True)
  !git clone https://github.com/IDEA-Research/GroundingDINO "{os.path.join(root_dir,'Segment-and-Track-Anything-CLI')}/src/GroundingDINO"
  !python -m pip install -r "{os.path.join(root_dir,'Segment-and-Track-Anything-CLI')}/src/GroundingDINO/requirements.txt"
!python -m pip install numpy opencv-python pycocotools matplotlib Pillow scikit-image
!python -m pip install gdown
!python -m pip install wget
!git clone https://github.com/ClementPinard/Pytorch-Correlation-extension.git
if platform.system() == 'Linux':
  !python -m pip install -e ./Pytorch-Correlation-extension
else:
  !python -m pip install -r ./Pytorch-Correlation-extension/requirements.txt

os.chdir(os.path.join(root_dir,'Segment-and-Track-Anything-CLI'))
os.makedirs(os.path.join(root_dir,'Segment-and-Track-Anything-CLI', 'ckpt'), exist_ok=True)

import gdown
# download aot-ckpt
if not os.path.exists('./ckpt/R50_DeAOTL_PRE_YTB_DAV.pth'):
  gdown.download(id='1QoChMkTVxdYZ_eBlZhK2acq9KMQZccPJ', output='./ckpt/R50_DeAOTL_PRE_YTB_DAV.pth')

import wget
# download sam-ckpt
if not os.path.exists('./ckpt/sam_vit_b_01ec64.pth'):
  wget.download("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
              "ckpt/")

if not os.path.exists('./ckpt/groundingdino_swint_ogc.pth'):
# download grounding-dino ckpt
  wget.download("https://huggingface.co/ShilongLiu/GroundingDINO/resolve/main/groundingdino_swint_ogc.pth",
              "ckpt/")

import wget
import zipfile

if platform.system() != 'Linux':
  #download prebuilt binaries for cuda 11.8, torch 2, python 3.10, win11
  if not os.path.exists('./site-packages.zip'):
    wget.download("https://raw.githubusercontent.com/Sxela/Segment-and-Track-Anything-CLI/main/site-packages.zip",
                "./site-packages.zip")

  with zipfile.ZipFile("site-packages.zip", 'r') as zip_ref:
          zip_ref.extractall(f'{root_dir}/env/Lib/')

In [None]:
#@title Detection setup
#@markdown Use this cell to tweak detection settings, that will be later used on the whole video.
#@markdown Run this cell to get detection preview.\
#@markdown Code mostly taken from https://github.com/z-x-yang/Segment-and-Track-Anything/blob/main/demo_instseg.ipynb
import os, pathlib, shutil, sys, subprocess
from glob import glob
try:
  #cd to root if root dir defined
  os.chdir(root_dir)
except:
  root_dir = os.getcwd()

os.chdir(os.path.join(root_dir,'Segment-and-Track-Anything-CLI'))

#(c) Alex Spirin 2023

import hashlib
# We use input file hashes to automate video extraction
#
def generate_file_hash(input_file):
    # Get file name and metadata
    file_name = os.path.basename(input_file)
    file_size = os.path.getsize(input_file)
    creation_time = os.path.getctime(input_file)

    # Generate hash
    hasher = hashlib.sha256()
    hasher.update(file_name.encode('utf-8'))
    hasher.update(str(file_size).encode('utf-8'))
    hasher.update(str(creation_time).encode('utf-8'))
    file_hash = hasher.hexdigest()

    return file_hash

def createPath(filepath):
    os.makedirs(filepath, exist_ok=True)


def extractFrames(video_path, output_path, nth_frame, start_frame, end_frame):
  createPath(output_path)
  print(f"Exporting Video Frames (1 every {nth_frame})...")
  try:
    for f in [o.replace('\\','/') for o in glob(output_path+'/*.jpg')]:
    # for f in pathlib.Path(f'{output_path}').glob('*.jpg'):
      pathlib.Path(f).unlink()
  except:
    print('error deleting frame ', f)
  # vf = f'select=not(mod(n\\,{nth_frame}))'
  vf = f'select=between(n\\,{start_frame}\\,{end_frame}) , select=not(mod(n\\,{nth_frame}))'
  if os.path.exists(video_path):
    try:
        # subprocess.run(['ffmpeg', '-i', f'{video_path}', '-vf', f'{vf}', '-vsync', 'vfr', '-q:v', '2', '-loglevel', 'error', '-stats', f'{output_path}/%06d.jpg'], stdout=subprocess.PIPE).stdout.decode('utf-8')

        subprocess.run(['ffmpeg', '-i', f'{video_path}', '-vf', f'{vf}', '-vsync', 'vfr', '-q:v', '2', '-loglevel', 'error', '-stats', f'{output_path}/%06d.jpg'], stdout=subprocess.PIPE).stdout.decode('utf-8')
    except:
        subprocess.run([f'{root_dir}/ffmpeg.exe', '-i', f'{video_path}', '-vf', f'{vf}', '-vsync', 'vfr', '-q:v', '2', '-loglevel', 'error', '-stats', f'{output_path}/%06d.jpg'], stdout=subprocess.PIPE).stdout.decode('utf-8')

  else:
    sys.exit(f'\nERROR!\n\nVideo not found: {video_path}.\nPlease check your video path.\n')


class FrameDataset():
  def __init__(self, source_path, outdir_prefix, videoframes_root):
    self.frame_paths = None
    image_extenstions = ['jpeg', 'jpg', 'png', 'tiff', 'bmp', 'webp']

    if not os.path.exists(source_path):
      if len(glob(source_path))>0:
        self.frame_paths = sorted(glob(source_path))
      else:
        raise Exception(f'Frame source for {outdir_prefix} not found at {source_path}\nPlease specify an existing source path.')
    if os.path.exists(source_path):
      if os.path.isfile(source_path):
        if os.path.splitext(source_path)[1][1:].lower() in image_extenstions:
          self.frame_paths = [source_path]
        hash = generate_file_hash(source_path)[:10]
        out_path = os.path.join(videoframes_root, outdir_prefix+'_'+hash)

        extractFrames(source_path, out_path,
                        nth_frame=1, start_frame=0, end_frame=999999999)
        self.frame_paths = glob(os.path.join(out_path, '*.*'))
        if len(self.frame_paths)<1:
            raise Exception(f'Couldn`t extract frames from {source_path}\nPlease specify an existing source path.')
      elif os.path.isdir(source_path):
        self.frame_paths = glob(os.path.join(source_path, '*.*'))
        if len(self.frame_paths)<1:
          raise Exception(f'Found 0 frames in {source_path}\nPlease specify an existing source path.')
    extensions = []
    if self.frame_paths is not None:
      for f in self.frame_paths:
            ext = os.path.splitext(f)[1][1:]
            if ext not in image_extenstions:
              raise Exception(f'Found non-image file extension: {ext} in {source_path}. Please provide a folder with image files of the same extension, or specify a glob pattern.')
            if not ext in extensions:
              extensions+=[ext]
            if len(extensions)>1:
              raise Exception(f'Found multiple file extensions: {extensions} in {source_path}. Please provide a folder with image files of the same extension, or specify a glob pattern.')

      self.frame_paths = sorted(self.frame_paths)

    else: raise Exception(f'Frame source for {outdir_prefix} not found at {source_path}\nPlease specify an existing source path.')
    print(f'Found {len(self.frame_paths)} frames at {source_path}')

  def __getitem__(self, idx):
    idx = min(idx, len(self.frame_paths)-1)
    return self.frame_paths[idx]

  def __len__(self):
    return len(self.frame_paths)

# mostly taken from https://github.com/z-x-yang/Segment-and-Track-Anything/blob/main/demo_instseg.ipynb

import os
import cv2
from SegTracker import SegTracker
from model_args import aot_args,sam_args,segtracker_args
from PIL import Image
from aot_tracker import _palette
import numpy as np
import torch
import imageio
import matplotlib.pyplot as plt
from scipy.ndimage import binary_dilation

import gc
def save_prediction(pred_mask,output_dir,file_name):
    save_mask = Image.fromarray(pred_mask.astype(np.uint8))
    save_mask = save_mask.convert(mode='P')
    save_mask.putpalette(_palette)
    save_mask.save(os.path.join(output_dir,file_name))
def colorize_mask(pred_mask):
    save_mask = Image.fromarray(pred_mask.astype(np.uint8))
    save_mask = save_mask.convert(mode='P')
    save_mask.putpalette(_palette)
    save_mask = save_mask.convert(mode='RGB')
    return np.array(save_mask)
def draw_mask(img, mask, alpha=0.7, id_countour=False):
    img_mask = np.zeros_like(img)
    img_mask = img
    if id_countour:
        # very slow ~ 1s per image
        obj_ids = np.unique(mask)
        obj_ids = obj_ids[obj_ids!=0]

        for id in obj_ids:
            # Overlay color on  binary mask
            if id <= 255:
                color = _palette[id*3:id*3+3]
            else:
                color = [0,0,0]
            foreground = img * (1-alpha) + np.ones_like(img) * alpha * np.array(color)
            binary_mask = (mask == id)

            # Compose image
            img_mask[binary_mask] = foreground[binary_mask]

            countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask
            img_mask[countours, :] = 0
    else:
        binary_mask = (mask!=0)
        countours = binary_dilation(binary_mask,iterations=1) ^ binary_mask
        foreground = img*(1-alpha)+colorize_mask(mask)*alpha
        img_mask[binary_mask] = foreground[binary_mask]
        img_mask[countours,:] = 0

    return img_mask.astype(img.dtype)

video_path = '/content/chess girl.mov' #@param {'type':'string'}
video_name = video_path.replace('\\','/').split('/')[-1]
io_args = {
    'input_video': video_path,
    'output_mask_dir': f'./assets/{video_name}_masks', # save pred masks
    'output_video': f'./assets/{video_name}_seg.mp4', # mask+frame vizualization, mp4 or avi, else the same as input video
    'output_gif': f'./assets/{video_name}_seg.gif', # mask visualization
}
prefix = ''
try:
  videoframes_root = f'{batchFolder}/videoFrames'
except:
  videoframes_root = f'{root_dir}/videoFrames'

frames = FrameDataset(video_path, outdir_prefix=prefix, videoframes_root=videoframes_root)

# choose good parameters in sam_args based on the first frame segmentation result
# other arguments can be modified in model_args.py
# note the object number limit is 255 by default, which requires < 10GB GPU memory with amp
sam_args['generator_args'] = {
        'points_per_side': 60,
        'pred_iou_thresh': 0.8,
        'stability_score_thresh': 0.9,
        'crop_n_layers': 1,
        'crop_n_points_downscale_factor': 2,
        'min_mask_region_area': 200,
    }

# Set Text args
'''
parameter:
    grounding_caption: Text prompt to detect objects in key-frames
    box_threshold: threshold for box
    text_threshold: threshold for label(text)
    box_size_threshold: If the size ratio between the box and the frame is larger than the box_size_threshold, the box will be ignored. This is used to filter out large boxes.
    reset_image: reset the image embeddings for SAM
'''
frame_number = 0  #@param {'type':'number'}
frame_number = int(frame_number)
#@markdown Text prompt to detect objects in key-frames
grounding_caption = "person" #@param {'type':'string'}
#@markdown Box detection confidence threshold
box_threshold = 0.3 #@param {'type':'number'}
#@markdown Text confidence threshold
text_threshold = 0.3 #@param {'type':'number'}
#@markdown Box to Image ratio threshold (with box_size_threshold = 0.8 detections over 80% of the image will be ignored)
box_size_threshold = 1 #@param {'type':'number'}

reset_image = True

frame_idx = 0
segtracker = SegTracker(segtracker_args,sam_args,aot_args)
segtracker.restart_tracker()

with torch.cuda.amp.autocast():
    frame = cv2.imread(frames[frame_number])
    frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
    pred_mask, annotated_frame = segtracker.detect_and_seg(frame, grounding_caption, box_threshold, text_threshold,
                                                           box_size_threshold, reset_image=reset_image)
    torch.cuda.empty_cache()
    obj_ids = np.unique(pred_mask)
    obj_ids = obj_ids[obj_ids!=0]
    print("processed frame {}, obj_num {}".format(frame_idx,len(obj_ids)),end='\n')
    init_res = draw_mask(annotated_frame, pred_mask,id_countour=False)
    plt.figure(figsize=(10,10))
    plt.axis('off')
    plt.imshow(init_res)
    plt.show()
    plt.figure(figsize=(10,10))
    plt.axis('off')
    plt.imshow(colorize_mask(pred_mask))
    plt.show()

    del segtracker
    torch.cuda.empty_cache()
    gc.collect()

In [2]:
#@title Mask whole video.
use_cli = False #@param {'type':'boolean'}
import subprocess
start_frame = 0 #@param {'type':'number'}
end_frame = 50 #@param {'type':'number'}
#@markdown The interval to run SAM to segment new objects
sam_gap = 50 #@param {'type':'number'}
#@markdown minimal mask area to add a new mask as a new object
min_area = 200  #@param {'type':'number'}
#@markdown maximal object number to track in a video
max_obj_num = 255 #@param {'type':'number'}
#@markdown the area of a new object in the background should > 80%
min_new_obj_iou = 0.8 #@param {'type':'number'}
save_separate_masks = True
save_joint_mask = False #@param {'type':'boolean'}
save_mask = save_joint_mask
save_video = False #@param {'type':'boolean'}
save_gif = False #@param {'type':'boolean'}
# grounding_caption
# box_threshold
# text_threshold
# box_size_threshold
# video_path
output_multimask_dir = os.path.join(videoframes_root, f'{generate_file_hash(video_path)[:10]}_masks')
if use_cli:
  def run_command(cmd, cwd='./'):
      with subprocess.Popen(cmd, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True) as p:
          while True:
              line = p.stdout.readline()
              if not line:
                  break
              print(line)
          exit_code = p.poll()
      return exit_code

  # !python /content/Segment-and-Track-Anything/run.py\
  #  --video_path /content/SaveInsta.App_-_3067564057762969265_1317509610.mp4\
  #  --save_separate_masks --outdir /content/out/


  cmd = ['python', 'run.py','--video_path', video_path, '--save_separate_masks', '--outdir', output_multimask_dir,
        '--caption', grounding_caption, '--box_threshold', box_threshold, '--text_threshold', text_threshold, '--box_size_threshold', box_size_threshold,
        '--sam_gap', sam_gap, '--min_area', min_area, '--max_obj_num', max_obj_num, '--min_new_obj_iou',min_new_obj_iou]
  cmd = [str(o) for o in cmd]
  returncode = run_command(cmd, cwd=os.path.join(root_dir,'Segment-and-Track-Anything-CLI'))
  if process.returncode != 0:
    raise RuntimeError(returncode)
  else:
    print(f"The video is ready and saved to {output_multimask_dir}")
else:
  os.makedirs('./debug/seg_result', exist_ok=True)
  os.makedirs('./debug/aot_result', exist_ok=True)
  segtracker_args = {
    'sam_gap': sam_gap,
    'min_area': min_area,
    'max_obj_num': max_obj_num,
    'min_new_obj_iou': min_new_obj_iou
  }

  if save_mask:
    output_dir = io_args['output_mask_dir']
    os.makedirs(output_dir, exist_ok=True)
  pred_list = []
  masked_pred_list = []

  segtracker = SegTracker(segtracker_args, sam_args, aot_args)
  segtracker.restart_tracker()
  from tqdm.notebook import tqdm, trange
  if start_frame == 0 and end_frame == 0:
    frame_range = trange(len(frames))
  else:
    frame_range = trange(start_frame, end_frame+1)
  for frame_idx in frame_range:
    frame = cv2.imread(frames[frame_idx])
    frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
    if frame_idx == start_frame:
      pred_mask, _ = segtracker.detect_and_seg(frame, grounding_caption, box_threshold, text_threshold, box_size_threshold, reset_image)
      torch.cuda.empty_cache()
      gc.collect()
      segtracker.add_reference(frame, pred_mask)
    elif ((frame_idx-start_frame) % sam_gap) == 0:
      seg_mask, _ = segtracker.detect_and_seg(frame, grounding_caption, box_threshold, text_threshold,
                                                    box_size_threshold, reset_image)
      # save_prediction(seg_mask, './debug/seg_result', str(frame_idx)+'.png')
      torch.cuda.empty_cache()
      gc.collect()
      track_mask = segtracker.track(frame)
      # save_prediction(track_mask, './debug/aot_result', str(frame_idx)+'.png')

      # find new objects, and update tracker with new objects
      new_obj_mask = segtracker.find_new_objs(track_mask, seg_mask)
      if np.sum(new_obj_mask > 0) >  frame.shape[0] * frame.shape[1] * 0.4:
        new_obj_mask = np.zeros_like(new_obj_mask)
      if save_mask: save_prediction(new_obj_mask,output_dir,str(frame_idx)+'_new.png')
      pred_mask = track_mask + new_obj_mask
      segtracker.add_reference(frame, pred_mask)
    else:
      pred_mask = segtracker.track(frame,update_memory=True)
    torch.cuda.empty_cache()
    gc.collect()

    if save_mask: save_prediction(pred_mask,output_dir,str(frame_idx)+'.png')

    pred_list.append(pred_mask)

    print("processed frame {}, obj_num {}".format(frame_idx,segtracker.get_obj_num()),end='\r')


  if  save_video:
  # draw pred mask on frame and save as a video
    cap = cv2.VideoCapture(io_args['input_video'])
    fps = cap.get(cv2.CAP_PROP_FPS)
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    num_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))

    if io_args['input_video'][-3:]=='mp4':
        fourcc =  cv2.VideoWriter_fourcc(*"mp4v")
    elif io_args['input_video'][-3:] == 'avi':
        fourcc =  cv2.VideoWriter_fourcc(*"MJPG")
        # fourcc = cv2.VideoWriter_fourcc(*"XVID")
    else:
        fourcc = int(cap.get(cv2.CAP_PROP_FOURCC))
    out = cv2.VideoWriter(io_args['output_video'], fourcc, fps, (width, height))

    frame_idx = 0

    progress_bar = tqdm(total=num_frames)
    progress_bar.set_description("Processing frames...")

    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
        try:
          pred_mask = pred_list[frame_idx]
        except: break
        masked_frame = draw_mask(frame,pred_mask)
        # masked_frame = masked_pred_list[frame_idx]
        masked_frame = cv2.cvtColor(masked_frame,cv2.COLOR_RGB2BGR)
        out.write(masked_frame)
        print('frame {} writed'.format(frame_idx),end='\r')
        frame_idx += 1
        progress_bar.update(1)
    out.release()
    cap.release()
    print("\n{} saved".format(io_args['output_video']))
    print('\nfinished')

  if  save_gif:
    # save colorized masks as a gif
    imageio.mimsave(io_args['output_gif'],pred_list,fps=fps)
    print("{} saved".format(io_args['output_gif']))

  from multiprocessing.pool import ThreadPool as Pool
  from functools import partial
  import PIL

  threads = 12

  def write_masks_frame(frame_num,  predicted_masks, output_folder, max_ids=255):
    predicted_masks_frame = predicted_masks[frame_num]
    for i in range(max_ids+1):
      img_out = PIL.Image.fromarray(((predicted_masks_frame==i)*255).astype('uint8'))
      img_out.save(os.path.join(output_folder, f'mask{i:03}', f'alpha_{frame_num:06}.jpg'))

  def write_masks_frame_multi(predicted_masks, output_folder, max_ids):
    for i in range(max_ids+1):
      os.makedirs(os.path.join(output_folder, f'mask{i:03}'), exist_ok=True)

    with Pool(threads) as p:
      fn = partial(write_masks_frame, predicted_masks=predicted_masks, output_folder=output_folder, max_ids=max_ids)
      result = list(tqdm(p.imap(fn, range(len(predicted_masks))), total=len(predicted_masks)))

  if save_separate_masks:
    print('Saving Separate masks')
    write_masks_frame_multi(predicted_masks=pred_list, output_folder=output_multimask_dir, max_ids=segtracker.get_obj_num())
    print(f'Saved masks to {output_multimask_dir}')



final text_encoder_type: bert-base-uncased
Model loaded from ./ckpt/groundingdino_swint_ogc.pth 
 => _IncompatibleKeys(missing_keys=[], unexpected_keys=['label_enc.weight', 'bert.embeddings.position_ids'])
SegTracker has been initialized


  0%|          | 0/51 [00:00<?, ?it/s]

Saving Separate masks


  0%|          | 0/51 [00:00<?, ?it/s]

Saved masks to /content/videoFrames/4a2baa1a78_masks
