In [1]:
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
import random
import shutil

from algorithms.feature_extraction_loading import FeatureDataset, extract_diffusion_features, feature_collate_fn, concatenate_video_features
from evaluation.visualization import safe_heatmap_as_gif, place_marker_in_frames

from evaluation.evaluation_datasets import compute_tapvid_metrics

from algorithms.heatmap_generator import HeatmapGenerator
from algorithms.zero_shot_tracker import ZeroShotTracker

heatmap_generator = HeatmapGenerator()
zero_shot_tracker = ZeroShotTracker()

  from .autonotebook import tqdm as notebook_tqdm
2024-05-27 21:03:31.221124: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-05-27 21:03:31.326748: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
extract_diffusion_features(input_dataset_paths={'davis': '../tapvid_davis/tapvid_davis.pkl'}, diffusion_model_path='../text-to-video-ms-1.7b/', restrict_frame_size=True, max_frame_size=2**18)

Loading pipeline components...:   0%|          | 0/5 [00:00<?, ?it/s]

Loading pipeline components...: 100%|██████████| 5/5 [00:00<00:00,  5.01it/s]
  return F.conv3d(


(90, 1280, 8, 8)
12.8
(90, 1280, 16, 16)
3.2
(90, 640, 32, 32)
1.6
(90, 320, 32, 32)
3.2
(90, 320, 16, 16)
12.8
(90, 640, 8, 8)
25.6
(90, 1280, 4, 4)
51.2
(90, 1280, 4, 4)
51.2
(90, 1280, 4, 4)
51.2
(90, 512, 64, 64)
0.5
(368640, 512)
(368640, 256)
torch.Size([90, 256, 64, 64])
(90, 512, 128, 128)
0.125
(1474560, 512)
(1474560, 64)
torch.Size([90, 64, 128, 128])
(90, 256, 256, 256)
0.0625
(5898240, 256)
(5898240, 16)
torch.Size([90, 16, 256, 256])
(90, 128, 256, 256)
0.125
(5898240, 128)
(5898240, 16)
torch.Size([90, 16, 256, 256])


  return F.conv3d(
  return F.conv2d(input, weight, bias, self.stride,


(75, 1280, 8, 8)
12.8
(75, 1280, 16, 16)
3.2
(75, 640, 32, 32)
1.6
(75, 320, 32, 32)
3.2
(75, 320, 16, 16)
12.8
(75, 640, 8, 8)
25.6
(75, 1280, 4, 4)
51.2
(75, 1280, 4, 4)
51.2
(75, 1280, 4, 4)
51.2
(75, 512, 64, 64)
0.5
(307200, 512)
(307200, 256)
torch.Size([75, 256, 64, 64])
(75, 512, 128, 128)
0.125
(1228800, 512)
(1228800, 64)
torch.Size([75, 64, 128, 128])
(75, 256, 256, 256)
0.0625
(4915200, 256)
(4915200, 16)
torch.Size([75, 16, 256, 256])
(75, 128, 256, 256)
0.125
(4915200, 128)
(4915200, 16)
torch.Size([75, 16, 256, 256])


  return F.conv3d(
  return F.conv2d(input, weight, bias, self.stride,


(40, 1280, 8, 8)
12.8
(40, 1280, 16, 16)
3.2
(40, 640, 32, 32)
1.6
(40, 320, 32, 32)
3.2
(40, 320, 16, 16)
12.8
(40, 640, 8, 8)
25.6
(40, 1280, 4, 4)
51.2
(40, 1280, 4, 4)
51.2
(40, 1280, 4, 4)
51.2
(40, 512, 64, 64)
0.5
(163840, 512)
(163840, 256)
torch.Size([40, 256, 64, 64])
(40, 512, 128, 128)
0.125
(655360, 512)
(655360, 64)
torch.Size([40, 64, 128, 128])
(40, 256, 256, 256)
0.0625
(2621440, 256)
(2621440, 16)
torch.Size([40, 16, 256, 256])
(40, 128, 256, 256)
0.125
(2621440, 128)
(2621440, 16)
torch.Size([40, 16, 256, 256])


  return F.conv3d(
  return F.conv2d(input, weight, bias, self.stride,


(84, 1280, 8, 8)
12.8
(84, 1280, 16, 16)
3.2
(84, 640, 32, 32)
1.6
(84, 320, 32, 32)
3.2
(84, 320, 16, 16)
12.8
(84, 640, 8, 8)
25.6
(84, 1280, 4, 4)
51.2
(84, 1280, 4, 4)
51.2
(84, 1280, 4, 4)
51.2
(84, 512, 64, 64)
0.5
(344064, 512)
(344064, 256)
torch.Size([84, 256, 64, 64])
(84, 512, 128, 128)
0.125
(1376256, 512)
(1376256, 64)
torch.Size([84, 64, 128, 128])
(84, 256, 256, 256)
0.0625
(5505024, 256)


: 

In [None]:
feature_dataset = FeatureDataset()
feature_loader = DataLoader(feature_dataset, batch_size=2, collate_fn=feature_collate_fn)

video_idx = 0

for batch in feature_loader:
    query_points = []
    gt_occluded = []
    gt_tracks = []
    pred_tracks = []

    for sample in batch:
        video_features = concatenate_video_features({'up_block': sample['features']['up_block'][0:3]})

        idx = random.randint(0, len(sample['query_points'][0]) - 1)
        query_point = sample['query_points'][0][idx]
        query_points.append(query_point[None, :])

        occluded = sample['occluded'][0, idx]
        gt_occluded.append(occluded[None])

        gt_track = sample['target_points'][0, idx]
        gt_tracks.append(gt_track[None])

        folder_path = os.path.join('output', 'video_' + str(video_idx))
        if os.path.exists(folder_path):
            shutil.rmtree(folder_path)
        os.makedirs(folder_path)
        query_point_file_name = os.path.join(folder_path, 'query_point.txt')
        with open(query_point_file_name, 'w') as query_point_file:
            query_point_file.write(str(query_point))

        video_features = video_features.permute(0, 2, 3, 1).float()
        heatmaps = heatmap_generator.generate(video_features, (query_point[1], query_point[2], int(query_point[0])))
        pred_track = zero_shot_tracker.track(heatmaps)

        pred_tracks.append(pred_track.numpy()[None])

        gt_track_switched = np.zeros_like(gt_track)
        gt_track_switched[:, 1] = gt_track[:, 0]
        gt_track_switched[:, 0] = gt_track[:, 1]

        place_marker_in_frames(sample['video'].squeeze(), pred_track, ground_truth_tracks=gt_track_switched, folder_path=folder_path)
        safe_heatmap_as_gif(heatmaps, True, sample['video'].squeeze(), folder_path=folder_path)

        video_idx += 1

    print(gt_occluded)

    metrics = compute_tapvid_metrics(query_points=np.array(query_points), gt_occluded=np.array(gt_occluded), gt_tracks=np.array(gt_tracks), pred_occluded=np.array(gt_occluded), pred_tracks=np.array(pred_tracks), query_mode='strided')

    print(metrics)

[array([[False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False,
        False, False, False, False, False, False, False, False, False]])]
{'occlusion_accuracy': array([1.]), 'pts_within_1': array([0.]), 'jaccard_1': array([0.]), 'pts_within_2': array([0.01123596]), 'jaccard_2': array([0.00564972]), 'pts_within_4': array([0.02247191]), 'jaccard_4': array([0.01136364]), 'pts_within_8': array([0.08988764]), 'jaccard_8': arr

KeyboardInterrupt: 