In [None]:
# @title Setup
%load_ext autoreload
%autoreload 2
import tensorflow as tf

from usl_models.flood_ml.model import FloodModel
from usl_models.flood_ml import dataset
from usl_models.flood_ml import eval
from usl_models.flood_ml import visualizer

In [4]:
# @title Load model
model_uri = "gs://climateiq-vertexai/aiplatform-custom-training-2024-07-20-12:39:37.879/model"  # @param { type: "string" }
sim_names = "Manhattan-config_v1/Rainfall_Data_13.txt"  # @param { type: "string" }
sim_names = sim_names.split(",")
rainfall_durations = "4"  # @param { type: "string" }
rainfall_durations = [int(n) for n in rainfall_durations.split(",")]
max_chunks = None  # @param { type: "number" }
batch_size = 2  # @param { type: "number" }

model = FloodModel.from_checkpoint(model_uri)

In [None]:
# @title Visualize outputs
for sim_name, rainfall_duration in zip(sim_names, rainfall_durations):
    data = dataset.load_dataset(
        sim_names=sim_names,
        dataset_split="test",
        batch_size=batch_size,
        max_chunks=max_chunks,
    )
    spatial_maes = []
    temporal_mae, temporal_rmse = [], []
    max_mae = 0.0
    max_spatial_mae = None
    highest_error_nse = None
    highest_error_pred = None
    highest_error_label = None

    for input, labels in data:
        predictions = model.call_n(input, n=rainfall_duration)
        for prediction, label in zip(tf.unstack(predictions), tf.unstack(labels)):
            max_pred = tf.reduce_max(prediction, axis=0)
            max_label = tf.reduce_max(label, axis=0)
            spatial_mae = eval.spatial_mae(max_pred, max_label)

            temporal_mae.append(eval.temporal_mae(prediction, label))
            temporal_rmse.append(eval.temporal_rmse(prediction, label))

            spatial_maes.append(spatial_mae)

            mae = tf.reduce_mean(spatial_mae)
            if mae > max_mae:
                max_mae = mae
                max_spatial_mae = spatial_mae
                max_mae_nse = eval.spatial_nse(prediction, label)
                max_mae_pred = max_pred
                max_mae_label = max_label

    num_test_examples = len(spatial_maes)
    overall_mae = tf.reduce_mean(tf.stack(spatial_maes))
    temporal_mae = tf.reduce_mean(tf.stack(temporal_mae), axis=0)
    temporal_rmse = tf.reduce_mean(tf.stack(temporal_rmse), axis=0)

    visualizer.plot_temporal_errors(
        sim_name,
        rainfall_duration,
        temporal_mae=temporal_mae,
        temporal_rmse=temporal_rmse,
    )
    visualizer.plot_maps(
        sim_name,
        spatial_mae=max_spatial_mae,
        nse=max_mae_nse,
        pred=max_mae_pred,
        label=max_mae_label,
    )