In [None]:
!pip install git+https://github.com/facebookresearch/segment-anything.git >> None

In [None]:
!pip install opencv-python pycocotools matplotlib onnxruntime onnx >> None

In [174]:
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
import seaborn as sns
from tqdm import tqdm
from typing import List


In [4]:
import os
HOME = os.getcwd()


In [5]:
%cd {HOME}
!mkdir {HOME}/weights
%cd {HOME}/weights

!wget -q https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

CHECKPOINT_PATH = os.path.join(HOME, "weights", "sam_vit_h_4b8939.pth")

/kaggle/working
/kaggle/working/weights


In [178]:
VIDEO_PATH = # YOUR VIDEO                   "/kaggle/input/sam-da/diff/traffic_(720p).mp4"

In [7]:
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda', index=0)

In [188]:
class SamVideoMasker:
    def __init__(
        self,
        checkpoint_path: str = CHECKPOINT_PATH, 
        device: torch.device = DEVICE,
        *args, **kwargs
    ):
        """Init SamAutomaticMaskGenerator model, load checkpoint, and put it on device
        Parameters
        ----------
        checkpoint_path: str
            path to model checkpoint
        device: torch.device
            device
        args, kwargs:
            model parameters for tuning
        """

        sam = sam_model_registry["vit_h"](checkpoint=checkpoint_path)
        sam.to(device=device)
        
        self.mask_generator = SamAutomaticMaskGenerator(model=sam, *args, **kwargs)
    
    @staticmethod
    def _read_video(video_path: str):
        """Capture video and reads it by frames

        Parameters
        ----------
        video_path : str
            path to video

        Returns
        -------
        frames: List[np.ndarray]
            list of video frames
        """
        cap = cv2.VideoCapture(video_path)
        
        if cap.isOpened() is False: 
            raise Exception("Error opening video stream or file")
        
        frames = []
        while cap.isOpened():
            ret, frame = cap.read()
            if ret is True:
                frames.append(frame)
            else:
                break
                
        cap.release()     
        
        return frames
                
    def _generate_image_masks(self, image: np.ndarray):
        """Generate masks for input image

        Parameters
        ----------
        image : np.ndarray
            input image

        Returns
        -------
        image_masks: List[dict]
            List over masks, where each mask is a dictionary containing various data about the mask. 
            These keys are:
                segmentation : the mask
                area : the area of the mask in pixels
                bbox : the boundary box of the mask in XYWH format
                predicted_iou : the model's own prediction for the quality of the mask
                point_coords : the sampled input point that generated this mask
                stability_score : an additional measure of mask quality
                crop_box : the crop of the image used to generate this mask in XYWH format
        """
        image_masks = self.mask_generator.generate(image)
        return image_masks
    
    def _generate_video_frames_masks(self, frames: List[np.ndarray], show_progress: bool=True):
        """Generate masks for input image

        Parameters
        ----------
        frames : List[np.ndarray]
            list of video frames
        show_progress: bool
            show progress bar, by default = True

        Returns
        -------
        frames_masks_: List[List[dict]]
            List with frames masks
        """
        self.frames_masks_ = []
        for frame in tqdm(frames, desc='masks generation', disable=not show_progress):
            masks = self._generate_image_masks(frame)
            self.frames_masks_.append(masks)
    
    @staticmethod
    def _get_masked_image(
        image: np.ndarray, 
        image_masks: List[dict], 
        alpha: float, 
        beta: float, 
        gamma: float,
        palette_size: int,
    ):
        """Blending original image and its masks

        Returns
        -------
        masked_image: np.ndarray
            masked image
        """
        palette = tuple(sns.husl_palette(palette_size))
        pixel_count = image.shape[0] * image.shape[1]
        
        sorted_masks = sorted(image_masks, key=(lambda x: x['area']), reverse=True)
        img = np.zeros_like(image)
        
        for mask in sorted_masks:
            m = mask['segmentation']
            area = mask['area']
            color_idx = int(np.log(pixel_count / area ) * palette_size / 10) % palette_size
            color_mask = np.uint8(np.array(palette[color_idx]) * 255)
            img[m] = color_mask
            
        img[img == 0] = image[img == 0]
        masked_image = cv2.addWeighted(image, alpha, img, beta, gamma)
        
        return masked_image
    
    def _get_masked_frames(
        self, 
        frames: List[np.ndarray],
        frames_masks: List[List[dict]], 
        alpha: float,
        beta: float,
        gamma: float,
        palette_size: int, 
        show_progress: bool=True
    ):
        """Blending original image and its masks

        Returns
        -------
        masked_frames: List[np.ndarray]
            List of masked frames
        """        
        masked_frames = []
        for frame, frame_masks in tqdm(
            zip(frames, frames_masks),
            total = len(frames),
            desc='masked frames processing', 
            disable=not show_progress
        ):
            masked_image = self._get_masked_image(frame, frame_masks, alpha, beta, gamma, palette_size)
            masked_frames.append(masked_image)
        
        return masked_frames
    
    def create_video_with_mask(
        self, 
        orig_video_path: str, 
        out_filename: str="masked_video.avi", 
        fps: int=15, 
        alpha: float=0.65, 
        beta: float=0.35, 
        gamma: float=0, 
        palette_size: int=100,
        show_progress: bool=True,
    ):
        """Create masked video

        Parameters
        ----------
        orig_video_path: str
            path to original video
        out_filename: str
            name of the output video file
        fps: int
            framerate of the created video stream, by default=15
        alpha: float
            image blending coef [0.0-1.0], by default 0.55 
        beta: float
            mask blending coef equal 1 - alpha, by default 0.45
        gamma: float
            gamma
        palette_size: int
            palette size, by default 100
        show_progress: bool
            show progress bar, by default = True
        """
        frames = self._read_video(orig_video_path)
        self._generate_video_frames_masks(frames, show_progress=show_progress)
        masked_frames = self._get_masked_frames(frames, self.frames_masks_, alpha, beta, gamma, palette_size, show_progress)
        
        fourcc = cv2.VideoWriter_fourcc(*'XVID')
        frame_size = frames[0].shape[1], frames[0].shape[0]
        out = cv2.VideoWriter(out_filename, fourcc, fps, frame_size)
        
        for frame in masked_frames:
            out.write(frame)
        out.release()
        

In [189]:
%%time
sam_video_masker = SamVideoMasker()

CPU times: user 5.32 s, sys: 871 ms, total: 6.19 s
Wall time: 5.89 s


In [190]:
%%time
sam_video_masker.create_video_with_mask(VIDEO_PATH)

masks generation: 100%|██████████| 20/20 [03:25<00:00, 10.27s/it]
masked frames processing: 100%|██████████| 20/20 [00:06<00:00,  3.12it/s]

CPU times: user 3min 33s, sys: 162 ms, total: 3min 33s
Wall time: 3min 32s



