In [None]:
import torch
import numpy as np
from PIL import Image, ImageSequence

from algorithms.diffusion_wrapper import DiffusionWrapper
from algorithms.heatmap_generator import HeatmapGenerator
from algorithms.zero_shot_tracker import ZeroShotTracker


diffusion_wrapper = DiffusionWrapper('../text-to-video-ms-1.7b')
heatmap_generator = HeatmapGenerator()
zero_shot_tracker = ZeroShotTracker()

video_features_dict = diffusion_wrapper.extract_video_features('../videos/rocket256.gif', "A rocket starting on Mars.")

In [None]:
for video_features_name, video_features in video_features_dict.items():
    print(video_features_name)
    print(len(video_features))
    for video_feature in video_features:
        print(video_feature.shape)

In [None]:
from algorithms.feature_extraction_loading import concatenate_video_features

video_features = concatenate_video_features(
    {
        'up_block': video_features_dict['up_block'][2:3], 
        'decoder_block': video_features_dict['decoder_block'][0:3]
    }
)
video_features = video_features.permute(0, 2, 3, 1).float().cpu()

In [None]:
from evaluation.visualization import safe_heatmap_as_gif, place_marker_in_frames

heatmaps = heatmap_generator.generate(video_features, (93, 137, 0))
safe_heatmap_as_gif(heatmaps)

#heatmaps = torch.permute(heatmaps, (0, 3, 1, 2))
#heatmaps = torch.nn.functional.interpolate(heatmaps, size=256, mode="bilinear", align_corners=True) * 255
#
#heatmaps = heatmaps.squeeze().numpy()
#frames_gif = [Image.fromarray(f) for f in heatmaps]
#frames_gif[0].save("output/heatmaps.gif", save_all=True, append_images=frames_gif[1:], duration=100, loop=0)

tracks = zero_shot_tracker.track(heatmaps)

def load_frames(image: Image, mode='RGB'):
    return np.array([
        np.array(frame.convert(mode))
        for frame in ImageSequence.Iterator(image)
    ])

with Image.open('../videos/rocket256.gif') as im:
    frames = load_frames(im)

place_marker_in_frames(frames, tracks)