In [None]:
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
import random

from algorithms.feature_extraction_loading import FeatureDataset, extract_diffusion_features, feature_collate_fn, concatenate_video_features
from evaluation.visualization import safe_heatmap_as_gif, place_marker_in_frames

from algorithms.heatmap_generator import HeatmapGenerator
from algorithms.zero_shot_tracker import ZeroShotTracker

heatmap_generator = HeatmapGenerator()
zero_shot_tracker = ZeroShotTracker()

In [None]:
#extract_diffusion_features(input_dataset_paths={'davis': '../tapvid_davis/tapvid_davis.pkl'}, diffusion_model_path='../text-to-video-ms-1.7b/')

In [None]:
feature_dataset = FeatureDataset()
feature_loader = DataLoader(feature_dataset, batch_size=1, collate_fn=feature_collate_fn)

video_idx = 0

for batch in feature_loader:
    for sample in batch:
        video_features = concatenate_video_features({'up_block': sample['features']['up_block'][0:3]})

        idx = random.randint(0, len(sample['query_points'][0]))
        query_point = sample['query_points'][0][idx]

        print(query_point)

        folder_path = os.path.join('output', 'video_' + str(video_idx))
        os.makedirs(folder_path)
        query_point_file_name = os.path.join(folder_path, 'query_point.txt')
        with open(query_point_file_name, 'w') as query_point_file:
            query_point_file.write(str(query_point))

        video_features = video_features.permute(0, 2, 3, 1).float()
        heatmaps = heatmap_generator.generate(video_features, (query_point[1], query_point[2], int(query_point[0])))
        tracks = zero_shot_tracker.track(heatmaps)

        gt_tracks = np.zeros_like(sample['target_points'][0][idx])#np.array(sample['target_points'][0][idx][:,1], sample['target_points'][0][idx][:,0])
        gt_tracks[:, 1] = sample['target_points'][0][idx][:, 0]
        gt_tracks[:, 0] = sample['target_points'][0][idx][:, 1]

        place_marker_in_frames(sample['video'].squeeze(), tracks, ground_truth_tracks=gt_tracks, folder_path=folder_path)
        safe_heatmap_as_gif(heatmaps, True, sample['video'].squeeze(), folder_path=folder_path)

        video_idx += 1