In [None]:
import os
import time
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as Tr
import torchvision.transforms.functional as TF


from glob import glob
from PIL import Image
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import cv2
import imageio.v2 as imageio

In [None]:
# === Download checkpoint ===
checkpoint_path = "./checkpoints/track_on_checkpoint.pt"

# !wget -O ./checkpoints/track_on_checkpoint.pt "https://huggingface.co/gaydemir/track_on/resolve/main/track_on_checkpoint.pt"

In [None]:
# === Helper functions ===
def read_video(video_path):
    reader = imageio.get_reader(video_path)
    frames = []
    for i, im in enumerate(reader):
        frames.append(np.array(im))
    video = np.stack(frames)
    video = torch.from_numpy(video).permute(0, 3, 1, 2).float()  # (T, 3, 720, 1920)
    
    print(f"{video.shape[0]} frames in video")
    
    plt.imshow(video[0].permute(1, 2, 0).long())
    
    return video
    
def write_gif(png_dir, out_dir):
    images = []
    
    sorted_files = sorted([f for f in os.listdir(png_dir) if f.endswith('.png')], key=lambda x: int(x.split('.')[0]))
    
    for z, file_name in enumerate(sorted_files):    
        file_path = os.path.join(png_dir, file_name)
        images.append(imageio.imread(file_path))

    imageio.mimsave(out_dir, images, fps=30)

In [None]:
# === Read video ====
video_path = "media/messi.mp4"
video = read_video(video_path)  # (T, 3, H, W)
# === === ===

In [None]:
# === Set queries manually ====
queries = []
for x in range(1140, 1200, 20):
    for y in range(300, 350, 50):
        queries.append([x, y])

N = len(queries)

distinct_colors = plt.cm.tab20(np.linspace(0, 1, N))
hex_colors = ['#%02x%02x%02x' % (int(r*255), int(g*255), int(b*255)) for r, g, b, _ in distinct_colors]

queries = torch.tensor(queries)

plt.imshow(video[0].permute(1, 2, 0).long())
for i, q in enumerate(queries):
    plt.scatter(q[0], q[1], s=20, c=hex_colors[i])

# === === ===

In [None]:
# === Set Model Arguments ===
from utils.train_utils import restart_from_checkpoint_not_dist

class Args:
    def __init__(self):
        self.input_size = [384, 512]

        self.N = 384
        self.T = 18
        self.stride = 4
        self.transformer_embedding_dim = 256
        self.cnn_corr = False
        self.linear_visibility = False
        
        self.num_layers = 3
        self.num_layers_offset_head = 3
        
        self.num_layers_rerank = 3
        self.num_layers_rerank_fusion = 1
        self.top_k_regions = 16

        self.num_layers_spatial_writer = 3
        self.num_layers_spatial_self = 1
        self.num_layers_spatial_cross = 1
        
        self.memory_size = 12
        self.val_memory_size = 96
        self.val_vis_delta = 0.9
        self.random_memory_mask_drop = 0

        self.lambda_point = 5.0
        self.lambda_vis = 1.0
        self.lambda_offset = 1.0
        self.lambda_uncertainty = 1.0
        self.lambda_top_k = 1.0
        
        self.epoch_num = 4
        self.lr = 1e-3
        self.wd = 1e-4
        self.bs = 1
        self.gradient_acc_steps = 1

        self.validation = False
        self.checkpoint_path = checkpoint_path
        self.seed = 1234
        self.loss_after_query = True

        self.gpus = torch.cuda.device_count()

args = Args()

# Option 1: Frame inputs

In [None]:
from model.track_on_ff import TrackOnFF    # Frame Inputs

model = TrackOnFF(args)
restart_from_checkpoint_not_dist(args, run_variables={}, model=model)

model.cuda().eval()
model.set_memory_size(args.val_memory_size, args.val_memory_size)
model.visibility_treshold = args.val_vis_delta


In [None]:
T = video.shape[0]
N = queries.shape[0]

png_folder = "./out/messi"
Path(png_folder).mkdir(parents=True, exist_ok=True)

vis_all = []
point_all = []

with torch.no_grad():
    for t in range(T):
    
        # === For the first frame, initialize the queries and memories ===
        if t == 0:
            model.init_queries_and_memory(queries.cuda(), video[t].unsqueeze(0).cuda())
        # === === ===
    
        # === Model forward, for each frame ===
        
        point, vis = model.ff_forward(video[t].unsqueeze(0).cuda())
        # === === ===
    
        # === Save the predictions frame-by-frame ===
        vis_all.append(vis)
        point_all.append(point)
        
        plt.imshow(video[t].permute(1, 2, 0).long())
        for n in range(N):
            if vis[n]:
                plt.scatter(point[n, 0].cpu(), point[n, 1].cpu(), c=hex_colors[n], s=20)
    
        plt.axis("off")
        plt.savefig(os.path.join(png_folder, f"{t}.png"), bbox_inches='tight')
        plt.clf()
        # === === ===

write_gif(png_folder, os.path.join(png_folder, "out.gif"))

# Option 2: Video inputs

In [None]:
from model.track_on import TrackOn    # Video Inputs

model = TrackOn(args)
restart_from_checkpoint_not_dist(args, run_variables={}, model=model)

model.cuda().eval()
model.set_memory_size(args.val_memory_size, args.val_memory_size)
model.visibility_treshold = args.val_vis_delta


T = video.shape[0]
N = queries.shape[0]

video_tmp = video.unsqueeze(0).cuda()                                                # (1, T, 3, H, W)
queries_tmp = torch.cat([torch.zeros(N, 1, device=queries.device), queries], dim=1)  # (N, 3), to (t, x, y) format, with all t = 0
queries_tmp = queries_tmp.unsqueeze(0).cuda()                                        # (1, N, 3)

with torch.no_grad():
    out = model.inference(video_tmp, queries_tmp)
    
vis_all = out["visibility"]   # (1, T, N)
point_all = out["points"]     # (1, T, N, 2)

In [None]:
# === Save the predictions by looping them all ===
png_folder = "./out/messi"
Path(png_folder).mkdir(parents=True, exist_ok=True)


for t in range(T):
    plt.imshow(video_tmp[0, t].permute(1, 2, 0).long().cpu())
    for n in range(N):
        if vis_all[0, t, n]:
            plt.scatter(point_all[0, t, n, 0].cpu(), point_all[0, t, n, 1].cpu(), c=hex_colors[n], s=20)

    plt.axis("off")
    plt.savefig(os.path.join(png_folder, f"{t}.png"), bbox_inches='tight')
    plt.clf()

write_gif(png_folder, os.path.join(png_folder, "out_vi.gif"))

# Test

In [None]:
!python track_camera.py --checkpoint_path checkpoints/track_on_checkpoint.pt

  from torch.distributed.optim import \
xFormers not available
xFormers not available
正在加载模型...
  checkpoint = torch.load(checkpoint_path, map_location="cpu")
=> loaded 'model' from checkpoint with msg <All keys matched successfully>
Model loaded on cuda
模型加载完成！

摄像头已打开
无显示模式：追踪结果将输出到控制台
按 Ctrl+C 退出追踪

从命令行参数解析关键点: 320,240 100,100
已解析 2 个关键点

正在初始化追踪，关键点数量: 2
追踪初始化完成！
Frame 30: 2/2 points visible
Frame 60: 2/2 points visible
Frame 90: 2/2 points visible
Frame 120: 2/2 points visible
Frame 150: 2/2 points visible
Frame 180: 2/2 points visible
Frame 210: 2/2 points visible
Frame 240: 2/2 points visible

用户中断
^C
Traceback (most recent call last):
  File "track_with_camera.py", line 354, in <module>
    main()
  File "track_with_camera.py", line 345, in main
    cap.release()
KeyboardInterrupt
