Can use animate 
Can use plot_geography or plot_map_cartopy for map under the hood
Observed mm/dd/yyyy THR:MM UTC ---> Forecast +5 min ---> Forecast +10 min 

https://pysteps.readthedocs.io/en/latest/pysteps_reference/visualization.html


1. Import packages
2. Load data
3. Plot data
A. Observed data
B. DGMR 273
C. DGMR 378
D. STEPS
E. LINDA
4. 30 minute steps


Resolve problem with weird background
Create with 30 minute steps and 5 minute steps

Check if cartopy working properly

In [None]:
import os
import re
import time
import torch
from torch.utils.data import DataLoader
from netCDF4 import Dataset
import numpy as np
from datetime import datetime
from sprite_core.config import Config
from dgmr import DGMR
import random
import pysteps
import matplotlib.pyplot as plt
from matplotlib import cm, colors
from pysteps.visualization import plot_precip_field

from CASADataset import CASADataset
from NIMRODDataset import NIMRODDataset

from contextlib import contextmanager

In [2]:
@contextmanager
def timer(model_name):
    start_time = time.time()
    yield
    end_time = time.time()
    print(f"{model_name} | Elapsed time: {end_time - start_time:.4f} seconds")

In [None]:
# Setup
NUM_INPUT_FRAMES = 4
NUM_TARGET_FRAMES = 18
TOTAL_FRAMES = NUM_INPUT_FRAMES + NUM_TARGET_FRAMES

RANDOM_SAMPLE = False
USE_CASA = True
# If false, can only execute using the sbatch as it allows to request RAM for the data without OOM.
metadata = None

# Check if GPU is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

if USE_CASA:
    test_data_loader = DataLoader(
        CASADataset(split="test", include_datetimes=True, data_dir=Config.DATA_DIR), batch_size=1
    )
else:
    test_data_loader = DataLoader(NIMRODDataset(split="validation"), batch_size=1)

# Randomly sample 5 indices from test data
if RANDOM_SAMPLE:
    sampled_ids = random.sample(range(len(test_data_loader.dataset)), 5)
    print("Sampled ids:", sampled_ids)

else:
    sampled_ids = [116, 198, 273, 145, 148]
    print("Predefined ids:", sampled_ids)

In [4]:
model_names = [
    # ["/work/pi_mzink_umass_edu/SPRITE/outputs/GridCellLoss_500/J26840027-DGMR–CASA-N1-G4-P32_GS6_BS32-PW1.0_E500_GLR5e-5_DLR2e-4", "last"],
    # ["/work/pi_mzink_umass_edu/SPRITE/outputs/GridCellLoss_500/J26840028-DGMR–CASA-N1-G4-P32_GS6_BS32-PW24.0_E500_GLR5e-5_DLR2e-4", "last"],
    [
        "/work/pi_mzink_umass_edu/SPRITE/outputs/GridCellLoss_500/J26840029-DGMR–CASA-N1-G4-P32_GS6_BS32-PW64.0_E500_GLR5e-5_DLR2e-4",
        "last",
    ],
    [
        "/work/pi_mzink_umass_edu/SPRITE/outputs/GridCellLoss_500/J26982810-DGMR–CASA-128-N2-G4-P32_GS6_BS32-PW24_E500_GLR5e-5_DLR2e-4",
        "last",
    ],
    # ["/work/pi_mzink_umass_edu/SPRITE/outputs/GridCellLoss_500/J26840030-DGMR–CASA-N1-G4-P32_GS6_BS32-PW128.0_E500_GLR5e-5_DLR2e-4", "last"],
    # ["/work/pi_mzink_umass_edu/SPRITE/outputs/GridCellLoss_500/J26840031-DGMR–CASA-N1-G4-P32_GS6_BS32-PW203.2_E500_GLR5e-5_DLR2e-4", "last"]
]

In [5]:
dataset = "CASA" if USE_CASA else "NIMROD"
plot_output_path = os.path.join(
    Config.OUTPUTS_DIR,
    f"GridCellLoss_500/visualize_predicted_results/{datetime.now().strftime('%Y%m%d_%H%M')}-{dataset}",
)
os.makedirs(plot_output_path, exist_ok=True)

In [None]:
# Load models

models = {}
for path_to_dir, model_name in model_names:
    print(f"Loading model {model_name}")
    model = DGMR(
        forecast_steps=NUM_TARGET_FRAMES,
        input_channels=1,
        output_shape=256,
        latent_channels=768,
        context_channels=384,
        num_samples=3,
        visualize=True,
    ).to(device)

    checkpoint_path = os.path.join(Config.ROOT_DIR, os.path.join(path_to_dir, f"{model_name}.ckpt"))
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint["state_dict"])
    model.eval()

    # Find the PWC value from the path
    pwc_match = re.search(r"PW(\d+\.?\d*)", path_to_dir)
    pwc_value = pwc_match.group(1) if pwc_match else "Unknown"

    # Create shorter model name with GCR (Grid Cell Regularization) prefix
    model_name = f"GCR_{pwc_value}"
    models[model_name] = model

    print(f"Model {model_name} loaded")

In [7]:
def clean_up_axes(ax):
    ax.set_xticks([])
    ax.set_yticks([])
    # Remove the spines
    ax.spines["top"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    ax.spines["left"].set_visible(False)
    ax.spines["right"].set_visible(False)


def calculate_pixel_percentage(data_slice, levels):
    """
    Calculate and print the percentage of pixels in each range of levels.

    Args:
        data_slice (2D numpy array): The input data slice to analyze.
        levels (list): The precipitation intensity levels.

    Returns:
        percentages (list): The percentage of pixels in each range.
    """
    # Flatten the data slice to a 1D array for easier processing
    flat_data = data_slice.flatten()

    # Initialize an empty list to store percentages
    percentages = []

    # Calculate percentages for each range defined by levels
    for i in range(len(levels) - 1):
        lower_bound = levels[i]
        upper_bound = levels[i + 1]
        # Count the number of pixels within the current range
        count = np.sum((flat_data >= lower_bound) & (flat_data < upper_bound))
        # Calculate the percentage of pixels in the current range
        percentage = (count / flat_data.size) * 100
        percentages.append(percentage)
        # Print the range and rounded percentage
        print(f"{lower_bound:.1f}-{upper_bound:.1f}: {percentage:.1f}%")

    # For the upper bound (last bin)
    count_above = np.sum(flat_data >= levels[-1])
    percentage_above = (count_above / flat_data.size) * 100
    percentages.append(percentage_above)
    # Print the upper bound range and rounded percentage
    print(f">= {levels[-1]:.1f}: {percentage_above:.1f}%")

    return percentages


def create_metadata(x0_resized, y0_resized):
    """
    projection  PROJ.4-compatible projection definition
    x1          x-coordinate of the lower-left corner of the data raster
    y1          y-coordinate of the lower-left corner of the data raster
    x2          x-coordinate of the upper-right corner of the data raster
    y2          y-coordinate of the upper-right corner of the data raster
    yorigin     a string specifying the location of the first element in the data raster w.r.t. y-axis: ‘upper’ = upper border, ‘lower’ = lower border
    """

    metadata = {
        "projection": "+proj=longlat +datum=WGS84 +no_defs",  # Example projection, update as needed
        "x1": x0_resized[0],  # Lower-left x-coordinate
        "y1": y0_resized[-1],  # Lower-left y-coordinate (last element due to 'upper' origin)
        "x2": x0_resized[-1],  # Upper-right x-coordinate
        "y2": y0_resized[0],  # Upper-right y-coordinate (first element due to 'upper' origin)
        "yorigin": "upper",  # Assuming the y-axis starts from the top
    }
    return metadata


# Load the coordinates from the original 366x350 netCDF file
nc_file = Dataset("/work/pi_mzink_umass_edu/SPRITE/CASAData/test/20221221/20221221_062335.nc", "r")
x0_full, y0_full = nc_file.variables["x0"][:], nc_file.variables["y0"][:]
nc_file.close()

# Crop the coordinates to the 350x350 grid
x0_cropped = x0_full[8:-8]
y0_cropped = y0_full[8:-8]

# Resize to 256x256
x0_resized = np.linspace(x0_cropped.min(), x0_cropped.max(), 256)
y0_resized = np.linspace(y0_cropped.min(), y0_cropped.max(), 256)

metadata = create_metadata(x0_resized, y0_resized)

In [8]:
# Create custom colormap

from matplotlib import cm, colors


class ColormapConfig:
    def __init__(self):
        self.cmap = None
        self.norm = None
        self.clevs = None
        self.bounds = None

        self.build_colormap()

    def build_colormap(self):
        # Define the colormap boundaries and colors

        # Eric's custom colormap
        # color_list = [
        #     "#4cecec",
        #     "#44c6f0",
        #     "#429afb",
        #     "#3431fd",
        #     "#40f600",
        #     "#3ada0b",
        #     "#2eb612",
        #     "#2a8a0f",
        #     "#f8f915",
        #     "#e9d11c",
        #     "#dcb11f",
        #     "#bd751f",
        #     "#f39a9c",
        #     "#f23a43",
        #     "#da1622",
        #     "#a90c1b",
        #     "#fa31ff",
        #     "#d32ada",
        #     "#9f1fa3",
        #     "#751678",
        #     "#ffffff",
        #     "#c1bdff",
        #     "#c5ffff",
        #     "#fcfec0",
        #     "#fcfec0"
        # ]

        # self.clevs = [1, 6.35, 12.7, 19.05, 25.4, 31.75, 38.1, 44.45, 50.8, 57.15, 63.5, 69.85, 76.2, 82.55, 88.9, 95.25, 101.6, 114.3, 127.0, 139.7, 152.4, 165.1, 177.8, 190.5, 203.2]
        # self.bounds = [1, 12.7, 25.4, 38.1, 50.8, 63.5, 76.2, 88.9, 101.6, 127.0, 152.4, 177.8, 203.2]

        # Shortened version of Eric's custom colormap
        color_list = [
            "#44c6f0",
            "#429afb",
            "#3431fd",
            "#f8f915",
            "#dcb11f",
            "#bd751f",
            "#f23a43",
            "#da1622",
            "#a90c1b",
        ]

        self.clevs = [1, 6.35, 12.7, 19.05, 25.4, 31.75, 38.1, 44.45, 50.8]  # 9
        self.bounds = [1, 12.7, 25.4, 38.1, 50.8]

        self.cmap = colors.ListedColormap(color_list)
        # self.cmap.set_over("#fcfec0") # from Eric's custom colormap
        self.cmap.set_over("darkmagenta")
        self.cmap.set_under("none")
        self.cmap.set_bad("gray", alpha=0.5)
        self.norm = colors.BoundaryNorm(self.clevs, self.cmap.N)
        self.cmap.name = "Custom Colormap"


cmap_config = ColormapConfig()

Verify whether the plots are related to the colorbar
Metadata geography
Graphing of 4 before, and 16 after.

In [9]:
def plot_data_and_save(
    inputs, targets, outputs, metadata, directory, idx, moment_datetime="Unspecified", title=None, model_subset_idx=0
):
    # Select the appropriate 2D slice
    input_slices = inputs[0, -4:, 0, :, :].cpu().numpy()
    target_slices = [targets[0, i, 0, :, :].cpu().numpy() for i in range(min(4, targets.shape[1]))]

    # Adjusted figsize for the plot
    fig = plt.figure(figsize=(51, 36))  # wxh

    # Get subset of models (4 at a time)
    model_items = list(outputs.items())
    start_idx = model_subset_idx * 4
    end_idx = min(start_idx + 4, len(model_items))
    current_outputs = dict(model_items[start_idx:end_idx])

    # Keep consistent 5-row layout
    gs1 = fig.add_gridspec(1, 4, bottom=0.775, top=0.95, wspace=0.00, hspace=0.00, left=0.00, right=0.5)
    gc2 = fig.add_gridspec(5, 4, bottom=0.04, top=0.95, wspace=0.00, hspace=0.05, left=0.5, right=1.0)

    # Plot observations (always the same for all plots)
    for i in range(len(input_slices)):
        ax = fig.add_subplot(gs1[0, i])
        input_slice = input_slices[i]
        plot_precip_field(input_slice, geodata=metadata, colorbar=False, axis="off", colormap_config=cmap_config)

        if i == 0:
            ax.text(-0.05, 0.5, "Observation", fontsize=37, ha="center", va="center", rotation=90)
        if i == 3:
            ax.text(0.5, 1.05, f"{moment_datetime}", fontsize=37, ha="center", va="center", transform=ax.transAxes)
        else:
            ax.text(0.5, 1.05, f"{-15 + (i * 5)} min", fontsize=37, ha="center", va="center", transform=ax.transAxes)
        clean_up_axes(ax)

    for i in range(len(target_slices)):
        ax = fig.add_subplot(gc2[0, i])
        plot_precip_field(
            target_slices[i], geodata=metadata, bbox=None, colorbar=False, axis="off", colormap_config=cmap_config
        )
        ax.text(0.5, 1.05, f"+{(i + 1) * 5} min", fontsize=37, ha="center", va="center")
        clean_up_axes(ax)

    # Plot the models from current subset
    for x, (model_name, output_slices) in enumerate(current_outputs.items(), start=1):
        print(f"Model: {model_name}")
        print(f"Output slices shape: {output_slices.shape}")
        num_slices = min(len(output_slices), gc2.ncols)
        for i in range(num_slices):
            ax = fig.add_subplot(gc2[x, i])
            plot_precip_field(
                np.squeeze(output_slices[i]), geodata=metadata, colorbar=False, axis="off", colormap_config=cmap_config
            )
            if i == 0:
                ax.text(-0.05, 0.5, model_name, fontsize=37, ha="center", va="center", rotation=90)
            clean_up_axes(ax)

    # Add colorbar
    cmap = cmap_config.cmap
    norm = cmap_config.norm
    bounds = cmap_config.bounds
    cbaxes = fig.add_axes([0.3, 0.01, 0.4, 0.02])

    cbar = fig.colorbar(
        cm.ScalarMappable(norm=norm, cmap=cmap),
        cax=cbaxes,
        orientation="horizontal",
        fraction=5,
        shrink=0,
        anchor=(0.5, 1.0),
        panchor=(0.5, 0.0),
        ticks=bounds,
        extend="both",
        extendfrac="auto",
        spacing="uniform",
    )

    cbar.ax.set_xlabel(r"Rainfall intensity (mm h$^{-1}$)", fontsize=37)
    cbar.ax.tick_params(labelsize=37)

    # Find existing files with similar pattern for this specific idx
    existing_files = [f for f in os.listdir(directory) if f.startswith(f"Forecast_{idx}_")]
    next_num = len(existing_files) + 1

    outfile = os.path.join(directory, f"Forecast_{idx}_{next_num}.png")
    plt.savefig(outfile, bbox_inches="tight")
    plt.close()

    print(f"Saved plot to {outfile}")

In [10]:
from pysteps import motion
from pysteps.nowcasts import steps, linda
from pysteps.motion.lucaskanade import dense_lucaskanade
from pysteps.utils import transformation
import torch


class ForecastModel:
    def __init__(self, rainrate_field, prediction_step, seed=None):
        if isinstance(rainrate_field, torch.Tensor):
            # check if input is of shape (4,1,256,256), then squeeze the second dimension
            if rainrate_field.shape[1] == 1:
                rainrate_field = rainrate_field.squeeze(1)  # shape transfer to (4, 256, 256)
            self.rainrate_field = rainrate_field.cpu().numpy()  # transfer to np.ndarray

        self.prediction_step = prediction_step
        try:
            self.advection = dense_lucaskanade(self.rainrate_field, verbose=True)
        except Exception as e:
            print(f"Error: {e}")
            self.advection = None

        self.seed = seed
        self.velocity = self.calculate_velocity_field()

    def calculate_velocity_field(self):
        oflow_method = motion.get_method("LK")
        return oflow_method(self.rainrate_field)

    def get_steps_forecast(self):
        # Process input data：Remove the second dimension and convert to np.ndarray
        rainrate_field_db = self.rainrate_field

        # using dB transfering
        rainrate_field_db, _ = transformation.dB_transform(rainrate_field_db, threshold=0.1, zerovalue=-15.0)

        # Use inline if-else to control slicing
        ar_order = 2  # Default autoregressive order for STEPS
        required_frames = ar_order + 1  # STEPS requires at least 3 frames

        precip_input = (
            rainrate_field_db[: -self.prediction_step]
            if rainrate_field_db.shape[0] > self.prediction_step
            and rainrate_field_db.shape[0] - self.prediction_step >= required_frames
            else rainrate_field_db
        )

        # use steps.forecast to predict
        forecast_steps = steps.forecast(
            precip_input,
            self.velocity,
            self.prediction_step,
            20,
            n_cascade_levels=6,
            precip_thr=-10.0,
            kmperpixel=2.0,
            timestep=5,
            noise_method="nonparametric",
            vel_pert_method="bps",
            mask_method="incremental",
            seed=self.seed,
        )

        # reversed dB transfering
        forecast_steps = transformation.dB_transform(forecast_steps, threshold=-10.0, inverse=True)[0]

        forecast_steps = np.mean(forecast_steps, axis=0)

        # Process the output data: add second dimension and convert to torch.Tensor
        forecast_steps = np.expand_dims(forecast_steps, axis=1)  # shape to (18, 1, 256, 256)
        forecast_steps = torch.from_numpy(forecast_steps)  # transfer to torch.Tensor

        return torch.where(torch.isnan(forecast_steps), torch.tensor(float(0)), forecast_steps)

    def get_linda_forecast(self):
        # Process the input data: remove the second dimension and convert to np.ndarray
        rainrate_field_db = self.rainrate_field

        # using dB transfering
        rainrate_field_db, _ = transformation.dB_transform(rainrate_field_db, threshold=0.1, zerovalue=-15.0)

        # Use inline if-else to control slicing
        ari_order = 2  # Default autoregressive order for LINDA
        required_frames = ari_order + 2  # LINDA requires at least 4 frames

        precip_input = (
            rainrate_field_db[: -self.prediction_step]
            if rainrate_field_db.shape[0] > self.prediction_step
            and rainrate_field_db.shape[0] - self.prediction_step >= required_frames
            else rainrate_field_db
        )

        # use linda.forecast to predict
        forecast_linda = linda.forecast(
            precip_input,
            self.advection if self.advection is not None else self.velocity,
            self.prediction_step,
            max_num_features=15,
            add_perturbations=False,
            num_workers=8,
            measure_time=True,
        )[0]

        # Process the output data: add second dimension and convert to torch.Tensor
        forecast_linda = np.expand_dims(forecast_linda, axis=1)  # shape to (18, 1, 256, 256)
        forecast_linda = torch.from_numpy(forecast_linda)  # transfer to torch.Tensor

        return torch.where(torch.isnan(forecast_linda), torch.tensor(float(0)), forecast_linda)

    def get_sprog_forecast(self):
        rainrate_field_db, _ = transformation.dB_transform(
            self.rainrate_field, self.metadata, threshold=0.1, zerovalue=-15.0
        )
        rainrate_thr, _ = transformation.dB_transform(np.array([0.5]), self.metadata, threshold=0.1, zerovalue=-15.0)
        rainrate_field_db[~np.isfinite(rainrate_field_db)] = -15.0
        rainrate_field_db = np.nan_to_num(rainrate_field_db, nan=-15.0).astype(np.float64)
        forecast_sprog = sprog.forecast(
            rainrate_field_db[: -self.prediction_step],
            self.velocity,
            self.prediction_step,
            n_cascade_levels=6,
            R_thr=rainrate_thr[0],
        )
        forecast_sprog, _ = transformation.dB_transform(forecast_sprog, threshold=-10.0, inverse=True)
        forecast_sprog[forecast_sprog < 0.5] = 0.0
        return forecast_sprog


# model = ForecastModel(rainrate_field, 18, 42)

# forecast_steps = model.get_steps_forecast()
# forecast_linda = model.get_linda_forecast()

# print(forecast_steps.shape)  # should be (18, 1, 256, 256)
# print(forecast_linda.shape)  # should be (18, 1, 256, 256)

In [None]:
# Initialize dictionary to hold timing results
timing = {}

with torch.no_grad():
    for idx in sampled_ids:
        print("Plotting Sample", idx)
        title = f"Plotting Sample {idx}"
        if USE_CASA:
            test_inputs, test_targets, frame_datetimes = test_data_loader.dataset[idx]
            last_input_datetime = frame_datetimes[-1]
            moment_datetime = last_input_datetime.strftime("%d %B, %Y - %H:%M")
        else:  # Nimrod
            test_inputs, test_targets = test_data_loader.dataset[idx]
            moment_datetime = "Nimrod unspecified"
            metadata = None

        # Convert NumPy arrays to PyTorch tensors and add batch dimension
        test_inputs = torch.tensor(test_inputs).unsqueeze(0).to(device)
        test_targets = torch.tensor(test_targets).unsqueeze(0).to(device)

        print("Using ML models to predict")
        # Initialize dictionary to hold outputs with model names
        test_outputs = {}
        timing = {}
        for model_name, model in models.items():
            with timer(model_name):
                model_output = model(test_inputs)
                # remove the batch dimension
                test_outputs[model_name] = model_output.squeeze(0).cpu().numpy()

        # print("Using PySteps to predict")

        pysteps_model = ForecastModel(test_inputs.squeeze(0).squeeze(1), NUM_TARGET_FRAMES, 42)

        with timer("STEPS"):
            test_outputs["STEPS"] = pysteps_model.get_steps_forecast()

        with timer("LINDA"):
            test_outputs["LINDA"] = pysteps_model.get_linda_forecast()

        # Proceed with plotting

        print("Plotting")
        # Calculate how many plots we need (ceil division)
        num_plots = (len(test_outputs) + 3) // 4  # +3 to round up

        # Create multiple plots for each subset of models
        for plot_idx in range(num_plots):
            plot_data_and_save(
                test_inputs,
                test_targets,
                test_outputs,
                metadata,
                directory=plot_output_path,
                idx=idx,
                moment_datetime=moment_datetime,
                title=title,
                model_subset_idx=plot_idx,
            )  # break