In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import pickle
from typing import List, Dict

import imageio
import torch
import numpy as np
from matplotlib import pyplot as plt

from mtt.models.convolutional import load_model
from mtt.data.image import (
    StackedImageData,
    to_image,
    rolling_window,
    stack_images,
)
from mtt.data.sim import SimulationStep

rng = np.random.default_rng()

In [3]:
scales = [1, 2]
simulation_idx = rng.integers(100)

# load data for each scale
data: Dict[int, List[StackedImageData]] = {}
for scale in scales:
    print(f"Loading simulation {simulation_idx} for scales {scales} km.")
    with open(f"data/test/{scale}km/simulations.pkl", "rb") as f:
        simulation: List[SimulationStep] = pickle.load(f)[simulation_idx]
    images = stack_images([to_image(data, img_size=128 * scale) for data in simulation])
    data[scale] = rolling_window(images)  # type: ignore

Loading simulation 24 for scales [1, 2] km.
Loading simulation 24 for scales [1, 2] km.


In [4]:
# Make CNN Predictions
from mtt.peaks import find_peaks

model, name = load_model("models/e7ivqipk.ckpt")
model = model.cuda()

output_images = {}
output_estimates = {}
with torch.no_grad():
    for scale in scales:
        output_images[scale] = []
        output_estimates[scale] = []
        for d in data[scale]:
            output_image = (
                model.forward(d.sensor_images.cuda())[-1].detach().cpu().numpy()
            )
            output_images[scale].append(output_image)

            output_estimate = find_peaks(output_image, d.info[-1]["window"]).means
            output_estimates[scale].append(output_estimate)

Lightning automatically upgraded your loaded checkpoint from v1.9.4 to v2.2.1. To apply the upgrade to your files permanently, run `python -m pytorch_lightning.utilities.upgrade_checkpoint models/e7ivqipk.ckpt`


In [5]:
from mtt.visualize import plot_mtt

idx = 0
out_dir = f"data/out/video/"
os.makedirs(out_dir, exist_ok=True)
for scale in scales:
    # generate stills
    stills = []
    for idx in range(len(data[scale])):
        fig = plot_mtt(
            data[scale][idx].sensor_images[-1].cpu().numpy(),
            output_images[scale][idx],
            data[scale][idx].info[-1],
            estimates=output_estimates[scale][idx],
        )
        # save fig to numpy array
        fig.canvas.draw()
        image = np.frombuffer(fig.canvas.buffer_rgba(), dtype=np.uint8)  # type: ignore
        image = image.reshape(fig.canvas.get_width_height()[::-1] + (4,))
        stills.append(image)
        plt.close()
    # make video using imageio ffmpeg
    imageio.mimsave(f"{out_dir}/{scale}km.mp4", stills, fps=5)

