In [1]:
import os
import torch
import numpy as np
from torch.utils.data import DataLoader
import random
import shutil

from algorithms.feature_extraction_loading import FeatureDataset, extract_diffusion_features, feature_collate_fn, concatenate_video_features
from evaluation.visualization import safe_heatmap_as_gif, place_marker_in_frames

from evaluation.evaluation_datasets import compute_tapvid_metrics

from algorithms.heatmap_generator import HeatmapGenerator
from algorithms.zero_shot_tracker import ZeroShotTracker

import evaluation.visualization as v

heatmap_generator = HeatmapGenerator()
zero_shot_tracker = ZeroShotTracker()

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
#extract_diffusion_features(input_dataset_paths={'davis': '../tapvid_davis/tapvid_davis.pkl'}, diffusion_model_path='../text-to-video-ms-1.7b/', restrict_frame_size=False, max_frame_size=2**18)

In [4]:
mode = "pca"

feature_dataset = FeatureDataset(feature_dataset_path='output/features/davis')
feature_loader = DataLoader(feature_dataset, batch_size=1, collate_fn=feature_collate_fn)

video_idx = 0

for batch in feature_loader:
    query_points = []
    gt_occluded = []
    gt_tracks = []
    pred_tracks = []

    for sample in batch:
        
        if mode == "pca":
            concat_downblock = concatenate_video_features(
                {
                    'down_block': sample['features']['down_block'][:]
                },
                perform_pca = True,
                n_components = 20
            )
            concat_midblock = concatenate_video_features(
                {
                    'mid_block': sample['features']['mid_block'][:]
                },
                perform_pca = True,
                n_components = 20
            )
            concat_upblock = concatenate_video_features(
                {
                    'up_block': sample['features']['up_block'][:]
                },
                perform_pca = True,
                n_components = 20
            )
            concat_decoderblock = concatenate_video_features(
                {
                    'decoder_block': sample['features']['decoder_block'][:]
                },
                perform_pca = True,
                n_components = 10
            )
        
        elif mode == "pooling":
            concat_downblock = concatenate_video_features(
                {
                    'down_block': sample['features']['down_block'][:]
                },
                perform_pooling = True
            )
            concat_midblock = concatenate_video_features(
                {
                    'mid_block': sample['features']['mid_block'][:]
                },
                perform_pooling = True
            )
            concat_upblock = concatenate_video_features(
                {
                    'up_block': sample['features']['up_block'][:]
                },
                perform_pooling = True
            )
            concat_decoderblock = concatenate_video_features(
                {
                    'decoder_block': sample['features']['decoder_block'][:]
                },
                perform_pooling = True
            )

        # WORSE PERFORMANCE!
        # # Scale each tensor to [-1, 1]
        # down_min = concat_downblock.min()
        # down_max = concat_downblock.max()
        # concat_downblock = (concat_downblock - down_min) / (down_max - down_min)
        # #concat_downblock = concat_downblock * 2 - 1

        # mid_min = concat_midblock.min()
        # mid_max = concat_midblock.max()
        # concat_midblock = (concat_midblock - mid_min) / (mid_max - mid_min)
        # #concat_midblock = concat_midblock * 2 - 1

        # up_min = concat_upblock.min()
        # up_max = concat_upblock.max()
        # concat_upblock = (concat_upblock - up_min) / (up_max - up_min)
        # #concat_upblock = concat_upblock * 2 - 1
    
        # decoder_min = concat_decoderblock.min()
        # decoder_max = concat_decoderblock.max()
        # concat_decoderblock = (concat_decoderblock - decoder_min) / (decoder_max - decoder_min)
        # #concat_decoderblock = concat_decoderblock * 2 - 1

        # print("AFTER PCA:")
        # print("Feature map means:")
        # print(torch.mean(concat_downblock))
        # print(torch.mean(concat_midblock))
        # print(torch.mean(concat_upblock))
        # print(torch.mean(concat_decoderblock))

        # print("Feature map maxs:")
        # print(torch.max(concat_downblock))
        # print(torch.max(concat_midblock))
        # print(torch.max(concat_upblock))
        # print(torch.max(concat_decoderblock))

        video_features = concatenate_video_features(
            {
                'down_block': [concat_downblock],
                'mid_block': [concat_midblock],
                'up_block': [concat_upblock],
                'decoder_block': [concat_decoderblock],
            }
        )

        idx = random.randint(0, len(sample['query_points'][0]) - 1)
        query_point = sample['query_points'][0][idx]
        query_points.append(query_point[None, :])

        occluded = sample['occluded'][0, idx]
        gt_occluded.append(occluded[None])

        gt_track = sample['target_points'][0, idx]
        gt_tracks.append(gt_track[None])

        folder_path = os.path.join('output', 'video_' + str(video_idx))
        if os.path.exists(folder_path):
            shutil.rmtree(folder_path)
        os.makedirs(folder_path)
        query_point_file_name = os.path.join(folder_path, 'query_point.txt')
        with open(query_point_file_name, 'w') as query_point_file:
            query_point_file.write(str(query_point))

        target = torch.tensor([[int(query_point[0]), query_point[1], query_point[2]]]) # Targets are now tensor in shape Nx3
        heatmaps = heatmap_generator.generate(video_features, target, device="cpu")

        pred_track = zero_shot_tracker.track(heatmaps)

        pred_tracks.append(pred_track.numpy()[None])

        gt_track_switched = np.zeros_like(gt_track)
        gt_track_switched[:, 1] = gt_track[:, 0]
        gt_track_switched[:, 0] = gt_track[:, 1]

        v.place_marker_in_frames(sample['video'].squeeze(), pred_track, ground_truth_tracks=gt_track_switched, folder_path=folder_path)
        v.safe_heatmap_as_gif(heatmaps, True, sample['video'].squeeze(), folder_path=folder_path)

        video_idx += 1

    #metrics = compute_tapvid_metrics(query_points=np.array(query_points), gt_occluded=np.array(gt_occluded), gt_tracks=np.array(gt_tracks), pred_occluded=np.array(gt_occluded), pred_tracks=np.array(pred_tracks), query_mode='strided')

    #print(metrics)

: 

In [20]:
# Check mean and max of featuremaps before pca

featuremapssss = sample["features"]
print(featuremapssss.keys())

for (key, value) in sample["features"].items():
    concat_block = concatenate_video_features(
        {
            'x': value[:]
        }
    )
    print(key)
    print("avg:")
    print(torch.mean(concat_block))
    print("max:")
    print(torch.max(concat_block))
    
        

dict_keys(['up_block', 'down_block', 'mid_block', 'decoder_block'])
up_block
avg:
tensor(-0.0637, dtype=torch.float16)
max:
tensor(284.7500, dtype=torch.float16)
down_block
avg:
tensor(-0.3726, dtype=torch.float16)
max:
tensor(98.2500, dtype=torch.float16)
mid_block
avg:
tensor(-0.5400, dtype=torch.float16)
max:
tensor(77.6875, dtype=torch.float16)
decoder_block
avg:
tensor(0.5352, dtype=torch.float16)
max:
tensor(6688., dtype=torch.float16)


In [None]:
## Use decoder for upsampling

feature_dataset = FeatureDataset(feature_dataset_path='output/features/davis')
feature_loader = DataLoader(feature_dataset, batch_size=1, collate_fn=feature_collate_fn)

video_idx = 0

for batch in feature_loader:
    query_points = []
    gt_occluded = []
    gt_tracks = []
    pred_tracks = []

    for sample in batch:
          
        concat_upblock = concatenate_video_features(
            {
                'up_block': sample['features']['up_block'][:]
            },
            perform_pca = True,
            n_components = 4
        )

        ###
        HIER DECODER REINLADEN UND EIN FORWARDPASS
        ###

        #idx = random.randint(0, len(sample['query_points'][0]) - 1)
        query_point = sample['query_points'][0][idx]
        query_points.append(query_point[None, :])

        occluded = sample['occluded'][0, idx]
        gt_occluded.append(occluded[None])

        gt_track = sample['target_points'][0, idx]
        gt_tracks.append(gt_track[None])

        folder_path = os.path.join('output', 'video_' + str(video_idx))
        if os.path.exists(folder_path):
            shutil.rmtree(folder_path)
        os.makedirs(folder_path)
        query_point_file_name = os.path.join(folder_path, 'query_point.txt')
        with open(query_point_file_name, 'w') as query_point_file:
            query_point_file.write(str(query_point))

        target = torch.tensor([[int(query_point[0]), query_point[1], query_point[2]]]) # Targets are now tensor in shape Nx3
        heatmaps = heatmap_generator.generate(video_features, target, device="cpu")

        pred_track = zero_shot_tracker.track(heatmaps)

        pred_tracks.append(pred_track.numpy()[None])

        gt_track_switched = np.zeros_like(gt_track)
        gt_track_switched[:, 1] = gt_track[:, 0]
        gt_track_switched[:, 0] = gt_track[:, 1]

        v.place_marker_in_frames(sample['video'].squeeze(), pred_track, ground_truth_tracks=gt_track_switched, folder_path=folder_path)
        v.safe_heatmap_as_gif(heatmaps, True, sample['video'].squeeze(), folder_path=folder_path)

        video_idx += 1

    #metrics = compute_tapvid_metrics(query_points=np.array(query_points), gt_occluded=np.array(gt_occluded), gt_tracks=np.array(gt_tracks), pred_occluded=np.array(gt_occluded), pred_tracks=np.array(pred_tracks), query_mode='strided')

    #print(metrics)