In [None]:
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
import random
import shutil

from algorithms.feature_extraction_loading import FeatureDataset, extract_diffusion_features, feature_collate_fn, concatenate_video_features
from evaluation.visualization import safe_heatmap_as_gif, place_marker_in_frames

from evaluation.evaluation_datasets import compute_tapvid_metrics

from algorithms.heatmap_generator import HeatmapGenerator
from algorithms.zero_shot_tracker import ZeroShotTracker

heatmap_generator = HeatmapGenerator()
zero_shot_tracker = ZeroShotTracker()

In [None]:
#extract_diffusion_features(input_dataset_paths={'davis': '../tapvid_davis/tapvid_davis.pkl'}, diffusion_model_path='../text-to-video-ms-1.7b/', restrict_frame_size=False, max_frame_size=2**18)

In [None]:
feature_dataset = FeatureDataset(feature_dataset_path='output/features/davis')
feature_loader = DataLoader(feature_dataset, batch_size=1, collate_fn=feature_collate_fn)

video_idx = 0

for batch in feature_loader:
    query_points = []
    gt_occluded = []
    gt_tracks = []
    pred_tracks = []

    for sample in batch:
        video_features = concatenate_video_features(
            {
                #'up_block': sample['features']['up_block'][0:3], 
                'decoder_block': [sample['features']['decoder_block'][0]],
            },
            perform_pca = True,
            n_components = 10
        )

        print(video_features.shape)

        idx = random.randint(0, len(sample['query_points'][0]) - 1)
        query_point = sample['query_points'][0][idx]
        query_points.append(query_point[None, :])

        occluded = sample['occluded'][0, idx]
        gt_occluded.append(occluded[None])

        gt_track = sample['target_points'][0, idx]
        gt_tracks.append(gt_track[None])

        folder_path = os.path.join('output', 'video_' + str(video_idx))
        if os.path.exists(folder_path):
            shutil.rmtree(folder_path)
        os.makedirs(folder_path)
        query_point_file_name = os.path.join(folder_path, 'query_point.txt')
        with open(query_point_file_name, 'w') as query_point_file:
            query_point_file.write(str(query_point))

        video_features = video_features.permute(0, 2, 3, 1).float()
        heatmaps = heatmap_generator.generate(video_features, (query_point[1], query_point[2], int(query_point[0])))
        pred_track = zero_shot_tracker.track(heatmaps)

        pred_tracks.append(pred_track.numpy()[None])

        gt_track_switched = np.zeros_like(gt_track)
        gt_track_switched[:, 1] = gt_track[:, 0]
        gt_track_switched[:, 0] = gt_track[:, 1]

        place_marker_in_frames(sample['video'].squeeze(), pred_track, ground_truth_tracks=gt_track_switched, folder_path=folder_path)
        safe_heatmap_as_gif(heatmaps, True, sample['video'].squeeze(), folder_path=folder_path)

        video_idx += 1

    metrics = compute_tapvid_metrics(query_points=np.array(query_points), gt_occluded=np.array(gt_occluded), gt_tracks=np.array(gt_tracks), pred_occluded=np.array(gt_occluded), pred_tracks=np.array(pred_tracks), query_mode='strided')

    print(metrics)