In [None]:
%cd /content
!git clone https://github.com/henry123-boy/SpaTracker
%cd /content/SpaTracker
!pip install timm==0.6.7 flow_vis
!pip install -r requirements.txt

!mkdir /content/checkpoints
!gdown -O /content/checkpoints/spaT_final.pth 18YlG_rgrHcJ7lIYQWfRz_K669z6FdmUX
!gdown -O /content/SpaTracker/assets/butterfly.mp4 1BDtvfrvbzEFY84XJPp62Dq1PujIpbOK_
!gdown -O /content/SpaTracker/assets/butterfly.png 1hlAGFony7LzpLcEAoGLiNaY3zxfiN_bW
!gdown -O /content/SpaTracker/assets/sintel_bandage.mp4 1iL5Qs5ea8r9nFwgVC6fusFyBfFQDICyo
!gdown -O /content/SpaTracker/assets/sintel_bandage.png 1_cL3m_1bW6aFwhxRPr_vGkdizwJuObbH

!mkdir -p /content/models/monoD/zoeDepth/ckpts
!wget -O /content/models/monoD/zoeDepth/ckpts/dpt_beit_large_384.pt https://github.com/isl-org/MiDaS/releases/download/v3_1/dpt_beit_large_384.pt
!wget -O /content/models/monoD/zoeDepth/ckpts/ZoeD_M12_K.pt https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_K.pt
!wget -O /content/models/monoD/zoeDepth/ckpts/ZoeD_M12_NK.pt https://github.com/isl-org/ZoeDepth/releases/download/v1.0/ZoeD_M12_NK.pt

%cd /content

In [None]:
%cd /content
!git clone https://github.com/DepthAnything/Depth-Anything-V2
%cd /content/Depth-Anything-V2
!pip install -r requirements.txt

!mkdir /content/checkpoints
!wget -O /content/checkpoints/depth_anything_v2_vitl.pth https://huggingface.co/depth-anything/Depth-Anything-V2-Large/resolve/main/depth_anything_v2_vitl.pth?download=true

%cd /content

In [1]:
import os, sys
from easydict import EasyDict as edict
from base64 import b64encode
import importlib
from PIL import Image
import numpy as np
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
from moviepy.editor import ImageSequenceClip


%cd /content/SpaTracker/

from models.spatracker.predictor import SpaTrackerPredictor
from models.spatracker.utils.visualizer import Visualizer
from models.monoD.zoeDepth.models.builder import build_model
from models.monoD.zoeDepth.utils.config import get_config

%cd /content/Depth-Anything-V2

from depth_anything_v2.dpt import DepthAnythingV2

%cd /content

/content/SpaTracker




/content/Depth-Anything-V2
/content


In [None]:
def read_video(path):
    cap = cv2.VideoCapture(path)
    if not cap.isOpened():
        raise IOError("Error opening video file")

    frames = []
    while True:
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))

    cap.release()
    return np.stack(frames)

def load_video(vid_path, downsample, fps):
    video = read_video(vid_path)
    video = torch.from_numpy(video).permute(0, 3, 1, 2).unsqueeze(0).float()
    video = F.interpolate(video[0], scale_factor=downsample, mode='bilinear', align_corners=True).unsqueeze(0)
    idx = torch.arange(0, video.shape[1], fps).long()
    return video[:, idx]

def load_segmentation_mask(seg_path, H, W):
    if os.path.exists(seg_path):
        segm_mask = np.array(Image.open(seg_path))
    else:
        segm_mask = np.ones((H, W), dtype=np.uint8)
        print("No segmentation mask provided. Computing tracks in whole image.")

    if segm_mask.ndim == 3:
        segm_mask = (segm_mask[..., :3].mean(axis=-1) > 0).astype(np.uint8)

    return cv2.resize(segm_mask, (W, H), interpolation=cv2.INTER_NEAREST)

In [None]:
def visualize_results(video, pred_tracks, pred_visibility, outdir, vid_name, fps_vis, len_track, point_size):
    vis = Visualizer(save_dir=outdir, grayscale=True,
                     fps=fps_vis, pad_value=0, linewidth=point_size,
                     tracks_leave_trace=len_track)
    return vis.visualize(
        video=video,
        tracks=pred_tracks[..., :2],
        visibility=pred_visibility,
        filename=f"{vid_name}_spatracker"
    )

In [None]:
def save_results(video, video_vis, pred_tracks, outdir, vid_name, model_name):
    img0 = video_vis[0,0].permute(1,2,0).detach().cpu().numpy()
    cv2.imwrite(os.path.join(outdir, f'{vid_name}_ref_query.png'), img0[:,:,::-1])

    tracks_vis = pred_tracks[0].detach().cpu().numpy()
    np.save(os.path.join(outdir, f'{vid_name}_{model_name}_tracks.npy'), tracks_vis)

    wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in video.unbind(1)]
    clip = ImageSequenceClip(wide_list, fps=60)
    save_path = os.path.join(outdir, f'{vid_name}_vid.mp4')
    clip.write_videofile(save_path, codec="libx264", fps=25, logger=None)
    print(f"Original Video saved to {save_path}")

def save_3d_trajectories(pred_tracks, video, outdir, vid_name):
    T, N, _ = pred_tracks[0].shape
    H, W = video[0].shape[-2:]
    xyzt = pred_tracks[0].cpu().numpy()

    intr = np.array([[W, 0.0, W//2],
                     [0.0, W, H//2],
                     [0.0, 0.0, 1.0]])

    xyztVis = xyzt.copy()
    xyztVis[..., 2] = 1.0
    xyztVis = np.linalg.inv(intr[None, ...]) @ xyztVis.reshape(-1, 3, 1)
    xyztVis = xyztVis.reshape(T, -1, 3)
    xyztVis[..., 2] *= xyzt[..., 2]

    pred_tracks2d = pred_tracks[0][:, :, :2]
    pred_tracks2dNm = pred_tracks2d.clone()
    pred_tracks2dNm[..., 0] = 2*(pred_tracks2dNm[..., 0] / W - 0.5)
    pred_tracks2dNm[..., 1] = 2*(pred_tracks2dNm[..., 1] / H - 0.5)

    color_interp = torch.nn.functional.grid_sample(video[0], pred_tracks2dNm[:,:,None,:], align_corners=True)
    color_interp = color_interp[:, :, :, 0].permute(0,2,1).cpu().numpy().astype(np.uint8)

    colored_pts = np.concatenate([xyztVis, color_interp], axis=-1)
    np.save(f'{outdir}/{vid_name}_3d.npy', colored_pts)
    print(f"3D colored tracks saved to {outdir}/{vid_name}_3d.npy")

In [None]:
def setup_environment(root, vid_name, outdir):
    os.makedirs(outdir, exist_ok=True)
    vid_path = os.path.join(root, f'{vid_name}.mp4')
    seg_path = os.path.join(root, f'{vid_name}.png')
    return vid_path, seg_path

def load_and_process_data(vid_path, seg_path, downsample, fps):
    video = load_video(vid_path, downsample, fps)
    _, _, _, H, W = video.shape
    segm_mask = load_segmentation_mask(seg_path, H, W)
    return video, segm_mask

def save_initial_images(video, segm_mask, outdir, vid_name):
    img0 = video[0, 0].permute(1, 2, 0).detach().cpu().numpy()
    cv2.imwrite(os.path.join(outdir, f'{vid_name}_ref.png'), img0[:, :, ::-1])
    cv2.imwrite(os.path.join(outdir, f'{vid_name}_seg.png'), segm_mask * 255)

def initialize_models(checkpoint_path, interp_shape, seq_length, device):
    spatracker_predictor = SpaTrackerPredictor(
        checkpoint=checkpoint_path,
        interp_shape=interp_shape,
        seq_length=seq_length
    ).to(device)
    monodepth_model = MonoDEst().model
    monodepth_model.eval()
    return spatracker_predictor, monodepth_model

def process_video(spatracker_predictor, monodepth_model, video, segm_mask, query_frame, seq_length):
    pred_tracks, pred_visibility, T_Firsts = spatracker_predictor(
        video, video_depth=None, grid_size=40, backward_tracking=False,
        depth_predictor=monodepth_model, grid_query_frame=query_frame,
        segm_mask=torch.from_numpy(segm_mask).unsqueeze(0).unsqueeze(0),
        wind_length=seq_length
    )

    msk_query = (T_Firsts == query_frame)
    pred_tracks = pred_tracks[:, :, msk_query.squeeze()]
    pred_visibility = pred_visibility[:, :, msk_query.squeeze()]
    return pred_tracks, pred_visibility

In [3]:
class DepthAnythingV2Wrapper(nn.Module):
    def __init__(self, encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024]):
        super(DepthAnythingV2Wrapper, self).__init__()
        self.model = DepthAnythingV2(encoder=encoder, features=features, out_channels=out_channels)
        self.model.load_state_dict(torch.load('/content/checkpoints/depth_anything_v2_vitl.pth', map_location='cpu'))
        self.model = self.model.to(device).eval()

    def infer(self, rgbs):
        with torch.no_grad():
            batch_size, channels, height, width = rgbs.shape
            depth_maps = []

            for i in range(batch_size):
                img = rgbs[i].permute(1, 2, 0).cpu().numpy()
                img = (img * 255).astype(np.uint8)
                depth = self.model.infer_image(img)

                depth_tensor = torch.from_numpy(depth).unsqueeze(0).unsqueeze(0)
                depth_maps.append(depth_tensor)

            return torch.cat(depth_maps, dim=0).to(rgbs.device)

class MonoDEst(nn.Module):
    def __init__(self):
        super(MonoDEst, self).__init__()
        self.model = self._build_model()
        self.metric3d = build_model(get_config("zoedepth_nk", "infer")).to(device).eval()

    def _build_model(self):
        return DepthAnythingV2Wrapper(encoder='vitl', features=256, out_channels=[256, 512, 1024, 1024])

    def infer(self, rgbs, scale=None, shift=None):
        depth_map = self.model.infer(rgbs)
        metric_dp = self.metric3d.infer(rgbs[:20])
        metric_dp_inv = 1 / metric_dp
        dp_0_rel = depth_map[:20]
        scale, shift = np.polyfit(dp_0_rel.view(-1).cpu().numpy(),
                                  metric_dp_inv.view(-1).cpu().numpy(), 1)
        depth_map = depth_map * scale + shift
        return (1 / depth_map).clamp(0.01, 65)

In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

root = '/content/SpaTracker/assets'
vid_name = 'butterfly'
outdir = './vis_results'
downsample = 1
fps = 1
len_track = 1
fps_vis = 15
query_frame = 0
point_size = 3
seq_length = 12
checkpoint_path = '/content/checkpoints/spaT_final.pth'
interp_shape = (384, 512)

vid_path, seg_path = setup_environment(root, vid_name, outdir)
video, segm_mask = load_and_process_data(vid_path, seg_path, downsample, fps)
save_initial_images(video, segm_mask, outdir, vid_name)

spatracker_predictor, monodepth_model = initialize_models(checkpoint_path, interp_shape, seq_length, device)

video = video.to(device)
pred_tracks, pred_visibility = process_video(spatracker_predictor, monodepth_model, video, segm_mask, query_frame, seq_length)

video_vis = visualize_results(video, pred_tracks, pred_visibility, outdir, vid_name, fps_vis, len_track, point_size)
save_results(video, video_vis, pred_tracks, outdir, vid_name, 'spatracker')
save_3d_trajectories(pred_tracks, video, outdir, vid_name)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]



img_size [384, 512]





Params passed to Resize transform:
	width:  512
	height:  384
	resize_target:  True
	keep_aspect_ratio:  True
	ensure_multiple_of:  32
	resize_method:  minimal
Using pretrained resource local::./models/monoD/zoeDepth/ckpts/ZoeD_M12_NK.pt
Loaded successfully





Time taken for inference:  77.01509428024292
Video saved to ./vis_results/butterfly_spatracker_pred_track.mp4
Original Video saved to ./vis_results/butterfly_vid.mp4
3D colored tracks saved to ./vis_results/butterfly_3d.npy
