In [1]:
%load_ext line_profiler
%load_ext autoreload
%autoreload 2

In [2]:
import os
import pickle
from typing import List, Dict
from tempfile import TemporaryDirectory

import imageio
import torch
import numpy as np
from matplotlib import pyplot as plt
from tqdm import trange

from mtt.data import StackedImageData, vector_to_image, simulation_window, stack_images, collate_fn
from mtt.visualize import plot_mtt
from mtt.models import load_model

rng = np.random.default_rng()

In [3]:
simulation_idx = rng.integers(100)
scales = [1,2]
print(f"Loading simulation {simulation_idx} for scales {scales} km.")

# load data for each scale
data: Dict[int,List[StackedImageData]] = {}
for scale in scales:
    with open(f"../data/test/{scale}km/simulations.pkl", "rb") as f:
        simulation = pickle.load(f)[simulation_idx]
        images = stack_images([vector_to_image(data, img_size=128*scale) for data in simulation])
        data[scale] = simulation_window(images)

Loading simulation 0 for scales [1, 2] km.


In [4]:
# Make CNN Predictions
model, name = load_model("wandb://damowerko/mtt/4uc51x21")
model = model.cuda()

output_images = {}
with torch.no_grad():
    for scale in scales:
        output_images[scale] = []
        for d in data[scale]:
            output_images[scale].append(model.forward(d.sensor_images.cuda()).detach().cpu().numpy())

[34m[1mwandb[0m:   1 of 1 files downloaded.  


In [5]:
idx = 0
out_dir = f"../out/video/"
os.makedirs(out_dir, exist_ok=True)
for scale in scales:
    # generate stills
    stills = []
    for idx in range(len(data[scale])):
        fig = plot_mtt(data[scale][idx].sensor_images.cpu().numpy(), output_images[scale][idx], data[scale][idx].info)
        # save fig to numpy array
        fig.canvas.draw()
        image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
        image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
        stills.append(image)
        plt.close()
    # make video using imageio ffmpeg
    imageio.mimsave(f"{out_dir}/{scale}km.mp4", stills, fps=10)

