# Training notebook 

In this notebook, we will create the graph, build the STGNN model, perform training, and evaluation.

# Create SoilWaterDataset object 

SoilWaterDataset is a fundamental object created based on TabularDataset from the TSL library. Using this object, we will be able to enabling efficient loading, preprocessing, and spatiotemporal structuring of data.



In [25]:
from typing import Optional, Union, List
import numpy as np 

from tsl.datasets.prototypes import TabularDataset

class SoilWaterDataset(TabularDataset):

    similarity_options = {'distance', 'correlation'}

    def __init__(self,
                 root: str = None
                 ):

        self.root = root

        # Load data
        target, mask, u, dist, metadata = self.load()

        covariates = {
            'u': (u),
            'metadata' : (metadata),
            'distances': (dist)
        }

        super().__init__(target=target,
                         mask=mask,
                         covariates=covariates,
                         similarity_score='distance',
                         temporal_aggregation='mean',
                         spatial_aggregation='mean',
                         name='DroughtDataset')

    def load(self):
        """
        Load data from files.

        Returns:
            tuple: Containing target, mask, covariates, distances, and metadata.
        """
        target_path = f"{self.root}target.npy"
        mask_path = f"{self.root}mask.npy"
        dist_path = f"{self.root}distances.npy"
        covariates_path = f"{self.root}covariates.npy"
        metadata_path = f"{self.root}metadata.npy"

        # Load main data
        target = np.load(target_path)
        mask = np.load(mask_path)
        u = np.load(covariates_path)
        dist = np.load(dist_path)
        metadata = np.load(metadata_path)

        return target, mask, u, dist, metadata


    def compute_similarity(self, method: str, **kwargs):
        """
        Compute similarity matrix based on the specified method.

        Args:
            method (str): The similarity computation method ('distance' or 'correlation').
            **kwargs: Additional keyword arguments for similarity computation.

        Returns:
            numpy.ndarray: Computed similarity matrix.

        Raises:
            ValueError: If an unknown similarity method is provided.
        """
        if method == "distance":
            # Calculate a Gaussian kernel similarity from the distance matrix, using a default or provided 'theta'
            theta = kwargs.get('theta', np.std(self.distances))
            return self.gaussian_kernel(self.distances, theta=theta)
        elif method == "correlation":
            # Compute the average correlation between nodes over the target features
            # Reshape target data to have nodes as columns
            target_values = self.target.values.reshape(len(self.target), -1, len(self.target_node_feature))
            # Average over the target features
            target_mean = target_values.mean(axis=2)
            # Compute correlation between nodes
            corr = np.corrcoef(target_mean, rowvar=False)
            return (corr + 1) / 2  # Normalize to [0, 1]
        else:
            raise ValueError(f"Unknown similarity method: {method}")

    @staticmethod
    def gaussian_kernel(distances, theta):
        """
        Compute Gaussian kernel similarity from distances.

        Args:
            distances (numpy.ndarray): Distance matrix.
            theta (float): Kernel bandwidth parameter.

        Returns:
            numpy.ndarray: Gaussian kernel similarity matrix.
        """
        return np.exp(-(distances ** 2) / (2 * (theta ** 2)))

In [26]:
dataset = SoilWaterDataset(root='ml-drought-forecasting/soil-water-forecasting/data/05_model_input/')

In [27]:
dataset.target

array([[[-4.2025931e-06],
        [-4.2025931e-06],
        [-4.2025931e-06],
        ...,
        [-4.2025931e-06],
        [-4.2025931e-06],
        [-4.2025931e-06]],

       [[-9.8175369e-06],
        [-9.8175369e-06],
        [-9.8175369e-06],
        ...,
        [-9.8175369e-06],
        [-9.8175369e-06],
        [-9.8175369e-06]],

       [[ 4.5859342e-06],
        [ 4.5859342e-06],
        [ 4.5859342e-06],
        ...,
        [ 4.5859342e-06],
        [ 4.5859342e-06],
        [ 4.5859342e-06]],

       ...,

       [[ 1.5608966e-06],
        [ 1.5608966e-06],
        [ 1.5608966e-06],
        ...,
        [ 1.5608966e-06],
        [ 1.5608966e-06],
        [ 1.5608966e-06]],

       [[ 1.2002420e-06],
        [ 1.2002420e-06],
        [ 1.2002420e-06],
        ...,
        [ 1.2002420e-06],
        [ 1.2002420e-06],
        [ 1.2002420e-06]],

       [[-4.9313530e-06],
        [-4.9313530e-06],
        [-4.9313530e-06],
        ...,
        [-4.9313530e-06],
        [-4.931

In [28]:
print(f"Has missing values: {dataset.has_mask}")

Has missing values: True


In [29]:
dataset.mask

array([[[ True],
        [ True],
        [ True],
        ...,
        [ True],
        [ True],
        [ True]],

       [[ True],
        [ True],
        [ True],
        ...,
        [ True],
        [ True],
        [ True]],

       [[ True],
        [ True],
        [ True],
        ...,
        [ True],
        [ True],
        [ True]],

       ...,

       [[ True],
        [ True],
        [ True],
        ...,
        [ True],
        [ True],
        [ True]],

       [[ True],
        [ True],
        [ True],
        ...,
        [ True],
        [ True],
        [ True]],

       [[ True],
        [ True],
        [ True],
        ...,
        [ True],
        [ True],
        [ True]]])

In [30]:
dataset.set_mask(dataset.mask)

In [31]:
dataset.covariates

{'u': array([[[ 2.94796997e+02,  2.91070618e+02,  1.01160688e+05, ...,
          -1.19325705e-09,  5.65088913e-03,  0.00000000e+00],
         [ 2.94720825e+02,  2.90953430e+02,  1.01172688e+05, ...,
          -1.19325705e-09,  5.64719364e-03,  0.00000000e+00],
         [ 2.94691528e+02,  2.90701477e+02,  1.01184812e+05, ...,
          -1.19325705e-09,  5.64600155e-03,  0.00000000e+00],
         ...,
         [ 2.91923950e+02,  2.87041321e+02,  1.01571438e+05, ...,
          -1.19325705e-09,  5.84162399e-03,  0.00000000e+00],
         [ 2.92054810e+02,  2.87088196e+02,  1.01587562e+05, ...,
          -1.19325705e-09,  5.82457706e-03,  0.00000000e+00],
         [ 2.92066528e+02,  2.87248352e+02,  1.01605938e+05, ...,
          -1.19325705e-09,  5.80800697e-03,  0.00000000e+00]],
 
        [[ 2.95360138e+02,  2.91539581e+02,  1.01597125e+05, ...,
           1.60071068e-09,  5.51155582e-03,  0.00000000e+00],
         [ 2.95342560e+02,  2.91631378e+02,  1.01593375e+05, ...,
           1.600

# Create connectivity to our graph 

Here, we adjust the connectivity to retain only the five nearest neighbors (knn=5) per node, while excluding self-loops and normalizing along the specified axis.

In [1]:
# dataset.distances

In [33]:
# sim = dataset.compute_similarity("distance")  # or dataset.compute_similarity()

In [34]:
# sim

array([[1.        , 0.9998158 , 0.9992636 , ..., 0.4123332 , 0.41279253,
        0.41306838],
       [0.9998158 , 1.        , 0.9998158 , ..., 0.411691  , 0.4123332 ,
        0.41279253],
       [0.9992636 , 0.9998158 , 1.        , ..., 0.41086674, 0.411691  ,
        0.4123332 ],
       ...,
       [0.4123332 , 0.411691  , 0.41086674, ..., 1.        , 0.9998158 ,
        0.9992636 ],
       [0.41279253, 0.4123332 , 0.411691  , ..., 0.9998158 , 1.        ,
        0.9998158 ],
       [0.41306838, 0.41279253, 0.4123332 , ..., 0.9992636 , 0.9998158 ,
        1.        ]], dtype=float32)

In [35]:
# Adjust connectivity to reduce the number of edges
connectivity = dataset.get_connectivity(          
    knn=5,     
    include_self=False,
    normalize_axis=1,
    layout="edge_index"
)

In [36]:
edge_index, edge_weight = connectivity

In [37]:
edge_index

array([[    0,     0,     0, ..., 21959, 21959, 21959],
       [    1,   359,   360, ..., 21599, 21600, 21958]])

In [38]:
edge_weight

array([0.20002224, 0.20002224, 0.20000282, ..., 0.20000282, 0.20002224,
       0.20002224], dtype=float32)

In [39]:
# def print_matrix(matrix):
#     return pd.DataFrame(matrix)

In [40]:
# import pandas as pd
# from tsl.ops.connectivity import edge_index_to_adj
# 
# adj = edge_index_to_adj(edge_index, edge_weight)
# print(f'A {adj.shape}:')
# print_matrix(adj)

A (21960, 21960):


Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,21950,21951,21952,21953,21954,21955,21956,21957,21958,21959
0,0.000000,0.200022,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
1,0.200022,0.000000,0.200022,0.000000,0.000000,0.000000,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
2,0.000000,0.200022,0.000000,0.200022,0.000000,0.000000,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
3,0.000000,0.000000,0.200022,0.000000,0.200022,0.000000,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
4,0.000000,0.000000,0.000000,0.200022,0.000000,0.200022,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
21955,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.200022,0.000000,0.200022,0.000000,0.000000,0.000000
21956,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.000000,0.200022,0.000000,0.200022,0.000000,0.000000
21957,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.000000,0.000000,0.200022,0.000000,0.200022,0.000000
21958,0.000000,0.000000,0.000000,0.000000,0.000000,0.000000,0.0,0.0,0.0,0.0,...,0.0,0.0,0.0,0.0,0.000000,0.000000,0.000000,0.200022,0.000000,0.200022


# Create torch_dataset 

torch_dataset, created using SpatioTemporalDataset from the tsl library, structures time-series and spatial data (target, covariates, mask, and connectivity) into a form optimized for spatiotemporal model training, enabling easy handling of lookback windows and prediction horizons in forecasting tasks.


In [41]:
from tsl.data import SpatioTemporalDataset

# covariates=dict(u=dataset.covariates['u'])
covariates=dataset.covariates
mask = dataset.mask

torch_dataset = SpatioTemporalDataset(target=dataset.dataframe(),
                                      mask=mask,
                                      covariates=covariates,
                                      connectivity=connectivity,
                                      horizon=6, # Predict 7 step ahead
                                      window=12, # Use 30 timestamps to predict the next one
                                      stride=1 # Move 7 step forward each time
                                      )

In [42]:
torch_dataset.mask

tensor([[[True],
         [True],
         [True],
         ...,
         [True],
         [True],
         [True]],

        [[True],
         [True],
         [True],
         ...,
         [True],
         [True],
         [True]],

        [[True],
         [True],
         [True],
         ...,
         [True],
         [True],
         [True]],

        ...,

        [[True],
         [True],
         [True],
         ...,
         [True],
         [True],
         [True]],

        [[True],
         [True],
         [True],
         ...,
         [True],
         [True],
         [True]],

        [[True],
         [True],
         [True],
         ...,
         [True],
         [True],
         [True]]])

# Create datamodule

datamodule, created with SpatioTemporalDataModule, manages the SpatioTemporalDataset by applying scaling, splitting data into train/validation/test sets, and preparing data loaders with batch processing, enabling efficient, modular, and scalable data handling for deep learning models.

In [44]:
from tsl.data.preprocessing import StandardScaler, MinMaxScaler

scalers = {
    'target': MinMaxScaler(axis=(0, 1)),
    'u': MinMaxScaler(axis=(0, 1))
}

In [45]:
from tsl.data.datamodule import (SpatioTemporalDataModule,
                                 TemporalSplitter)
                                 
# Split data sequentially:
#   |------------ dataset -----------|
#   |--- train ---|- val -|-- test --|
splitter = TemporalSplitter(val_len=0.35, test_len=0.1)


In [46]:
# Create a SpatioTemporalDataModule
datamodule = SpatioTemporalDataModule(
    dataset=torch_dataset,
    scalers=scalers,
    mask_scaling=True,
    splitter=splitter,
    batch_size=8,
    workers=15,
    )

print(datamodule)

SpatioTemporalDataModule(train_len=None, val_len=None, test_len=None, scalers=[target, u], batch_size=8)


In [47]:
datamodule.setup()

In [48]:
datamodule

SpatioTemporalDataModule(train_len=141, val_len=69, test_len=25, scalers=[target, u], batch_size=8)

In [49]:
datamodule.splitter.indices

{'train': array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
         13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
         26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
         39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
         52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
         65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
         78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
         91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
        104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
        117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
        130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140]),
 'val': array([153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165,
        166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178,
        179, 180, 181, 182, 183, 184, 185, 

# Create STGNN Model Architecture

The **TimeAndGraphAnisoModel** is built based on code from [HD-TTS
 repository](https://github.com/marshka/hdtts/blob/main/lib/nn/models/baselines/stgnns/time_and_graph_anisotropic.py) and the research paper by Cini et al. (2023d). This model utilizes spatiotemporal architectures equipped with anisotropic message passing for effective time and space representation.

### Reference
Cini, A., Marisca, I., Zambon, D., and Alippi, C. *Taming Local Effects in Graph-Based Spatiotemporal Forecasting.* In *Advances in Neural Information Processing Systems,* volume 36, pp. 55375–55393. Curran Associates, Inc., 2023.  
[https://arxiv.org/abs/2302.04071](https://arxiv.org/abs/2302.04071)


In [None]:
from einops import rearrange
from torch import Tensor, nn
from torch_geometric.typing import Adj, OptTensor

from tsl.nn.blocks.decoders import MLPDecoder
from tsl.nn.blocks.encoders import DCRNN, ConditionalBlock
from tsl.nn.models.base_model import BaseModel


class DCRNNModel(BaseModel):
    r"""The Diffusion Convolutional Recurrent Neural Network from the paper
    `"Diffusion Convolutional Recurrent Neural Network: Data-Driven Traffic
    Forecasting" <https://arxiv.org/abs/1707.01926>`_ (Li et al., ICLR 2018).

    Differently from the original implementation, the recurrent decoder is
    substituted with a fixed-length nonlinear readout.

    Args:
        input_size (int): Number of features of the input sample.
        output_size (int): Number of output channels.
        horizon (int): Number of future time steps to forecast.
        exog_size (int): Number of features of the input covariate,
            if any. (default: :obj:`0`)
        hidden_size (int): Number of hidden units.
            (default: :obj:`32`)
        kernel_size (int): Order of the spatial diffusion process.
            (default: :obj:`2`)
        ff_size (int): Number of units in the nonlinear readout.
            (default: :obj:`256`)
        n_layers (int): Number of DCRNN cells.
            (default: :obj:`1`)
        dropout (float): Dropout probability.
            (default: :obj:`0`)
        activation (str): Activation function in the readout.
            (default: :obj:`'relu'`)
    """

    return_type = Tensor

    def __init__(self,
                 input_size: int,
                 output_size: int,
                 horizon: int,
                 exog_size: int = 0,
                 hidden_size: int = 32,
                 kernel_size: int = 2,
                 ff_size: int = 256,
                 n_layers: int = 1,
                 cache_support: bool = False,
                 dropout: float = 0.,
                 activation: str = 'relu'):
        super(DCRNNModel, self).__init__()
        if exog_size:
            self.input_encoder = ConditionalBlock(input_size=input_size,
                                                  exog_size=exog_size,
                                                  output_size=hidden_size,
                                                  activation=activation)
        else:
            self.input_encoder = nn.Linear(input_size, hidden_size)

        self.dcrnn = DCRNN(input_size=hidden_size,
                           hidden_size=hidden_size,
                           n_layers=n_layers,
                           k=kernel_size,
                           return_only_last_state=True)
        self.cache_support = cache_support

        self.readout = MLPDecoder(input_size=hidden_size,
                                  hidden_size=ff_size,
                                  output_size=output_size,
                                  horizon=horizon,
                                  activation=activation,
                                  dropout=dropout)

    def forward(self,
                x: Tensor,
                edge_index: Adj,
                edge_weight: OptTensor = None,
                u: OptTensor = None) -> Tensor:
        """"""
        if u is not None:
            if u.dim() == 3:
                u = rearrange(u, 'b s c -> b s 1 c')
            x = self.input_encoder(x, u)
        else:
            x = self.input_encoder(x)

        out = self.dcrnn(x,
                         edge_index,
                         edge_weight,
                         cache_support=self.cache_support)
        return self.readout(out)

# Setup model 

Model is configured with hidden units, feed-forward layers, multiple SpatioTemporalConvNet blocks, and utilizes temporal and spatial convolution kernels, layer normalization, and gated mechanisms; it adapts to the dataset’s input size, number of nodes, horizon, and available exogenous features.

In [53]:
hidden_size = 32          # Number of hidden units
ff_size = 64             # Number of units in the feed-forward layers
n_layers = 3              # Number of SpatioTemporalConvNet blocks
kernel_size = 3  
# spatial_kernel_size = 3   # Order of the spatial diffusion process
# norm='layer'
# gated=True

input_size = torch_dataset.n_channels
output_size = torch_dataset.n_channels
# n_nodes = torch_dataset.n_nodes
horizon = torch_dataset.horizon
exog_size = torch_dataset.input_map.u.shape[-1] if 'u' in torch_dataset else 0

In [54]:
model = DCRNNModel(
    input_size=input_size,
    output_size=output_size,
    horizon=horizon,
    exog_size=exog_size,
    hidden_size=hidden_size,
    kernel_size=kernel_size,
    ff_size=ff_size,
    n_layers=n_layers,
)

# Print the model architecture
print(model)

TimeAndGraphAnisoModel(
  (emb): None
  (encoder): Linear(in_features=20, out_features=32, bias=True)
  (decoder): MLPDecoder(
    (readout): MLP(
      (mlp): Sequential(
        (0): Dense(
          (affinity): Linear(in_features=32, out_features=32, bias=True)
          (activation): ELU(alpha=1.0)
          (dropout): Identity()
        )
      )
      (readout): Linear(in_features=32, out_features=6, bias=True)
    )
    (rearrange): Rearrange('b n (h f) -> b h n f', f=1, h=6)
  )
  (stmp_conv): GraphAnisoGRU(cell=GraphAnisoGRUCell, return_only_last_state=True)
)


In [55]:
def print_model_size(model):
    tot = sum([p.numel() for p in model.parameters() if p.requires_grad])
    out = f"Number of model ({model.__class__.__name__}) parameters:{tot:10d}"
    print("=" * len(out))
    print(out)

In [56]:
print_model_size(model)

Number of model (TimeAndGraphAnisoModel) parameters:     23913


# Setup training 

This setup initializes a Predictor for the model with a masked mean-squared error loss function and multiple evaluation metrics (MSE, MAE, MAPE, and specific MSE at selected timesteps), then configures a Trainer using PyTorch Lightning with early stopping and model checkpointing based on validation MSE, gradient clipping to prevent exploding gradients, and 16-bit precision for efficient training.

In [57]:
from tsl.metrics.torch import MaskedMSE, MaskedMAE, MaskedMAPE
from tsl.engines import Predictor

# Define the loss function
loss_fn = MaskedMSE()

# Setup metrics
metrics = {
    'mse': MaskedMSE(),
    'mae': MaskedMAE(),
    'mape': MaskedMAPE(),
    'mse_at_3': MaskedMSE(at=2),  # '2' indicates the third time step
    'mse_at_6': MaskedMSE(at=5)
}

# Setup predictor
predictor = Predictor(
    model=model,
    optim_class=torch.optim.Adam,
    optim_kwargs={'lr': 0.001},
    loss_fn=loss_fn,
    metrics=metrics
)


In [58]:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

early_stop_callback = EarlyStopping(
    monitor='val_mse',
    patience=30,
    mode='min'
)

from pathlib import Path

# Logs 
dirpath = Path('ml-drought-forecasting/soil-water-forecasting/data/06_models/DCRNN/logs')

dirpath.mkdir(parents=True, exist_ok=True)

checkpoint_callback = ModelCheckpoint(
    dirpath=dirpath,
    save_top_k=1,
    monitor='val_mse', 
    mode='min',
)

# Setup trainer
trainer = pl.Trainer(max_epochs=100,
                    #  logger=logger,
                    #  limit_train_batches=100,  # end an epoch after 200 updates
                     callbacks=[early_stop_callback, checkpoint_callback],
                     log_every_n_steps=2,
                     gradient_clip_val=1.0,    # Prevent exploding gradients
                     precision=16
                     )


/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


# Training 

In [59]:
# Set float32 matmul precision to 'medium' or 'high'
torch.set_float32_matmul_precision('medium')

In [63]:
import multiprocessing
multiprocessing.set_start_method('spawn', force=True)
trainer.fit(predictor, datamodule=datamodule)

/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:654: Checkpoint directory /teamspace/studios/this_studio/logs exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                   | Params | Mode 
-----------------------------------------------------------------
0 | loss_fn       | MaskedMSE              | 0      | train
1 | train_metrics | MetricCollection       | 0      | train
2 | val_metrics   | MetricCollection       | 0      | train
3 | test_metrics  | MetricCollection       | 0      | train
4 | model         | TimeAndGraphAnisoModel | 23.9 K | train
-----------------------------------------------------------------
23.9 K    Trainable params
0         Non-trainable params
23.9 K    Total params
0.096     Total estimated model params size (MB)
81        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:419: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.
Arguments ['metadata', 'distances'] are filtered out. Only args ['x', 'u', 'edge_index', 'edge_weight'] are forwarded to the model (TimeAndGraphAnisoModel).
/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:419: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


In [None]:
# Path to the best checkpoint
best_checkpoint_path = checkpoint_callback.best_model_path

print(f"Best checkpoint saved at: {best_checkpoint_path}")

# Load the model from the best checkpoint
from pytorch_lightning.utilities.memory import load_obj

trained_model = TimeAndGraphAnisoModel(
    input_size=input_size,
    horizon=horizon,
    n_nodes=n_nodes,
    output_size=input_size,
    exog_size=exog_size,
    hidden_size=hidden_size
)

# Load checkpoint weights
state_dict = torch.load(best_checkpoint_path)["state_dict"]
trained_model.load_state_dict(state_dict)

# Evaluation 

In [64]:
# print(checkpoint_callback.best_model_path)


/teamspace/studios/this_studio/logs/epoch=86-step=1479.ckpt


In [65]:
predictor.load_model(checkpoint_callback.best_model_path)

  storage = torch.load(filename, lambda storage, loc: storage)
Predictor with already instantiated model is loading a state_dict from /teamspace/studios/this_studio/logs/epoch=86-step=1479.ckpt. Cannot  check if model hyperparameters are the same.


In [60]:
predictor.load_model('/teamspace/studios/this_studio/logs/epoch=86-step=1479.ckpt')

  storage = torch.load(filename, lambda storage, loc: storage)
Predictor with already instantiated model is loading a state_dict from /teamspace/studios/this_studio/logs/epoch=86-step=1479.ckpt. Cannot  check if model hyperparameters are the same.


In [61]:
predictor.freeze()

In [62]:
trainer.test(predictor, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: |          | 0/? [00:00<?, ?it/s]

Arguments ['metadata', 'distances'] are filtered out. Only args ['x', 'edge_index', 'edge_weight', 'u'] are forwarded to the model (TimeAndGraphAnisoModel).


[{'test_mae': 0.007187105715274811,
  'test_mape': 2584.482177734375,
  'test_mse': 0.000349479349097237,
  'test_mse_at_3': 0.00036978835123591125,
  'test_mse_at_6': 0.00036310526775196195,
  'test_loss': 0.000349479349097237}]

In [63]:
trainer.validate(predictor, datamodule=datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[{'val_mae': 0.0071140737272799015,
  'val_mape': 636.721435546875,
  'val_mse': 0.00033514376264065504,
  'val_mse_at_3': 0.0003464221372269094,
  'val_mse_at_6': 0.0003566509112715721,
  'val_loss': 0.00033514376264065504}]

In [64]:
predictions_test = trainer.predict(predictor, datamodule.test_dataloader())

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Predicting: |          | 0/? [00:00<?, ?it/s]

In [65]:
# predictions_test

In [66]:
splitter.indices

{'train': array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
         13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
         26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
         39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
         52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
         65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
         78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
         91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
        104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
        117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
        130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140]),
 'val': array([153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165,
        166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178,
        179, 180, 181, 182, 183, 184, 185, 

In [67]:
dates = pd.to_datetime(ds['date'].values)

In [68]:
dates

DatetimeIndex(['2000-01-01', '2000-02-01', '2000-03-01', '2000-04-01',
               '2000-05-01', '2000-06-01', '2000-07-01', '2000-08-01',
               '2000-09-01', '2000-10-01',
               ...
               '2022-03-01', '2022-04-01', '2022-05-01', '2022-06-01',
               '2022-07-01', '2022-08-01', '2022-09-01', '2022-10-01',
               '2022-11-01', '2022-12-01'],
              dtype='datetime64[ns]', length=276, freq=None)

In [69]:
datamodule.splitter.indices

{'train': array([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,
         13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,
         26,  27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,
         39,  40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,
         52,  53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,
         65,  66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,
         78,  79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,
         91,  92,  93,  94,  95,  96,  97,  98,  99, 100, 101, 102, 103,
        104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116,
        117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129,
        130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140]),
 'val': array([153, 154, 155, 156, 157, 158, 159, 160, 161, 162, 163, 164, 165,
        166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178,
        179, 180, 181, 182, 183, 184, 185, 

In [70]:
import numpy as np
import pandas as pd
import torch
import xarray as xr

# 1. Load Metadata and Dates
metadata_array = np.load('ml-drought-forecasting/ml-modeling-pipeline/data/05_model_input/metadata.npy')  # Shape: (21960, 2)
latitudes = metadata_array[:, 0]
longitudes = metadata_array[:, 1]
num_nodes = len(latitudes)

# Load the original dataset
# ds = xr.open_dataset('path_to_your_dataset.nc')  # Replace with your actual path
dates = pd.to_datetime(ds['date'].values)
total_time_steps = len(dates)

# Initialize splitter
# splitter = TemporalSplitter(val_len=0.35, test_len=0.1)

# Assume you have a SpatioTemporalDataset instance
# dataset = your_spatio_temporal_dataset  # Replace with your actual dataset variable

# Fit the splitter to the dataset
# splitter.fit(dataset)

# 2. Correctly Retrieve the Split Indices
split_indices = datamodule.splitter.indices
train_indices = split_indices['train']
val_indices = split_indices['val']
test_indices = split_indices['test']

print(f"Train indices: {train_indices}")
print(f"Validation indices: {val_indices}")
print(f"Test indices: {test_indices}")

# 3. Process Predictions and Actual Values
y_list = []
y_hat_list = []
mask_list = []

for batch in predictions_test:
    y_tensor = batch['y']
    y_hat_tensor = batch['y_hat']
    mask_tensor = batch['mask']
    
    y_np = y_tensor.cpu().numpy().squeeze(-1)      # Shape: (batch_size, steps, nodes)
    y_hat_np = y_hat_tensor.cpu().numpy().squeeze(-1)
    mask_np = mask_tensor.cpu().numpy().squeeze(-1)
    
    y_list.append(y_np)
    y_hat_list.append(y_hat_np)
    mask_list.append(mask_np)

# Concatenate along the batch dimension
y = np.concatenate(y_list, axis=0)         # Shape: (num_test_batches, steps, nodes)
y_hat = np.concatenate(y_hat_list, axis=0) # Shape: (num_test_batches, steps, nodes)
mask = np.concatenate(mask_list, axis=0)   # Shape: (num_test_batches, steps, nodes)

print(f"Shape of y: {y.shape}")
print(f"Shape of y_hat: {y_hat.shape}")
print(f"Shape of mask: {mask.shape}")

# 4. Map Node IDs to Geographical Coordinates
num_steps = y.shape[1]          # Number of prediction steps (e.g., 6)
num_test_dates = y.shape[0]     # Number of test dates (e.g., 25)

# Create node_ids, lat, lon repeated for each step
node_ids = np.tile(np.arange(num_nodes), num_test_dates * num_steps)  # Shape: (3,294,000,)
lat_column = np.tile(latitudes, num_test_dates * num_steps)
lon_column = np.tile(longitudes, num_test_dates * num_steps)

print(f"Shape of node_ids: {node_ids.shape}")
print(f"Shape of lat_column: {lat_column.shape}")
print(f"Shape of lon_column: {lon_column.shape}")

# 5. Align Predictions with Dates

# Repeat each test date for the number of prediction steps
test_dates_steps = np.repeat(dates[test_indices], num_steps)  # Shape: (25 * 6,) => (150,)

# Now, repeat each date for all nodes to align with the predictions
all_dates_expanded = np.repeat(test_dates_steps, num_nodes)   # Shape: (150 * 21960,) => (3,294,000,)

print(f"Shape of test_dates_steps: {test_dates_steps.shape}")       # (150,)
print(f"Shape of all_dates_expanded: {all_dates_expanded.shape}") # (3,294,000,)

# 6. Flatten the Predictions and Mask
y_flat = y.flatten()         # Shape: (3,294,000,)
y_hat_flat = y_hat.flatten()
mask_flat = mask.flatten()

print(f"Shape of y_flat: {y_flat.shape}")
print(f"Shape of y_hat_flat: {y_hat_flat.shape}")
print(f"Shape of mask_flat: {mask_flat.shape}")

# 7. Apply Mask to Filter Valid Entries
valid_indices = mask_flat.astype(bool)  # Ensure mask is boolean

filtered_dates = all_dates_expanded[valid_indices]  # Shape: (num_valid_entries,)
filtered_lat = lat_column[valid_indices]
filtered_lon = lon_column[valid_indices]
filtered_y = y_flat[valid_indices]
filtered_y_hat = y_hat_flat[valid_indices]
filtered_node_ids = node_ids[valid_indices]

print(f"Shape of filtered_dates: {filtered_dates.shape}")
print(f"Shape of filtered_lat: {filtered_lat.shape}")
print(f"Shape of filtered_lon: {filtered_lon.shape}")
print(f"Shape of filtered_y: {filtered_y.shape}")
print(f"Shape of filtered_y_hat: {filtered_y_hat.shape}")
print(f"Shape of filtered_node_ids: {filtered_node_ids.shape}")

# 8. Create the DataFrame
df_predictions = pd.DataFrame({
    'date': filtered_dates,
    'lat': filtered_lat,
    'lon': filtered_lon,
    'node_id': filtered_node_ids,
    'y': filtered_y,
    'y_hat': filtered_y_hat
})

print(df_predictions.head())
print(df_predictions.shape)


Train indices: [  0   1   2   3   4   5   6   7   8   9  10  11  12  13  14  15  16  17
  18  19  20  21  22  23  24  25  26  27  28  29  30  31  32  33  34  35
  36  37  38  39  40  41  42  43  44  45  46  47  48  49  50  51  52  53
  54  55  56  57  58  59  60  61  62  63  64  65  66  67  68  69  70  71
  72  73  74  75  76  77  78  79  80  81  82  83  84  85  86  87  88  89
  90  91  92  93  94  95  96  97  98  99 100 101 102 103 104 105 106 107
 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125
 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140]
Validation indices: [153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170
 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206
 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221]
Test indices: [234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
 252 253 254 255 256 257

In [71]:
df = df_predictions

In [72]:
df 

Unnamed: 0,date,lat,lon,node_id,y,y_hat
0,2019-07-01,-30.0,-180.0,0,0.000010,-0.000218
1,2019-07-01,-30.0,-179.0,1,0.000010,-0.000157
2,2019-07-01,-30.0,-178.0,2,0.000010,-0.000154
3,2019-07-01,-30.0,-177.0,3,0.000010,-0.000120
4,2019-07-01,-30.0,-176.0,4,0.000010,-0.000367
...,...,...,...,...,...,...
3293995,2021-07-01,30.0,175.0,21955,-0.000005,-0.001788
3293996,2021-07-01,30.0,176.0,21956,-0.000005,-0.001751
3293997,2021-07-01,30.0,177.0,21957,-0.000005,-0.001880
3293998,2021-07-01,30.0,178.0,21958,-0.000005,-0.001606


In [73]:
df.to_parquet('ml-drought-forecasting/ml-modeling-pipeline/data/07_model_output/predictions.parquet')

# Visualization

In [None]:
from typing import Optional
import plotly.express as px
import plotly.graph_objects as go
import pandas as pd

def visualize_variable_on_map(
    dataset: pd.DataFrame,
    variable: str,
    lat_dim: str = 'latitude',
    lon_dim: str = 'longitude',
    projection: str = 'natural earth',
    color_scale: str = 'Viridis',
    title: Optional[str] = None,
    animation_frame: Optional[str] = None,
    hover_precision: int = 2,
) -> go.Figure:

    df = dataset

    # Drop NaNs
    df = df.dropna(subset=[variable])

    # Handle animation frame
    if animation_frame:
        df[animation_frame] = df[animation_frame].astype(str)
    else:
        df['Frame'] = 'Frame'
    
    # Create scatter_geo plot
    fig = px.scatter_geo(
        df,
        lat=lat_dim,
        lon=lon_dim,
        color=variable,
        animation_frame=animation_frame if animation_frame else 'Frame',
        projection=projection,
        color_continuous_scale=color_scale,
        title=title,
        labels={variable: variable.upper()},
        hover_data={variable: f':.{hover_precision}f'},
    )
    
    # Add Play/Pause buttons
    fig.update_layout(
        updatemenus=[dict(
            type='buttons',
            buttons=[
                dict(label='Play',
                     method='animate',
                     args=[None, {"frame": {"duration": 500, "redraw": True}, "fromcurrent": True}]),
                dict(label='Pause',
                     method='animate',
                     args=[[None], {"frame": {"duration": 0, "redraw": False}, "mode": "immediate"}])
            ],
            showactive=False,
            x=0.1,
            y=0
        )]
    )

    return fig

In [75]:
# Example usage:
fig = visualize_variable_on_map(
    dataset=df,
    variable='y',
    lat_dim='lat',
    lon_dim='lon',
    animation_frame='date',
    title='Volumetric soil water layer 1',
    color_scale = [
        [0.0, "darkred"],
        [0.1, "red"],
        [0.2, "orangered"],
        [0.3, "lightgreen"],
        [0.4, "limegreen"],
        [0.5, "green"],
        [0.55, "darkseagreen"],
        [0.6, "darkgreen"],
        [0.7, "lightblue"],
        [0.8, "skyblue"],
        [0.9, "deepskyblue"],
        [1.0, "blue"]
    ]
)
fig.show()

In [125]:
import plotly.graph_objects as go
import plotly.subplots as sp
import pandas as pd
import numpy as np

def visualize_predictions_plotly(predictions, metadata, idx, horizon=None, 
                                 batches_to_visualize=None, variables_to_visualize=None):
    """
    Visualize predicted vs actual values using Plotly for specified batches and variables with variable-specific normalization.
    
    Parameters:
    - predictions (dict): Contains 'y' and 'y_hat' tensors.
    - metadata (pd.DataFrame): DataFrame with 'lat' and 'lon' columns.
    - idx (int): Index of the item being visualized.
    - horizon (int, optional): Number of time steps to visualize. Defaults to the maximum horizon in data.
    - batches_to_visualize (list or range, optional): List of batch indices to visualize. Defaults to all batches.
    - variables_to_visualize (list or range, optional): List of variable indices to visualize. Defaults to all variables.
    """
    # Extract predictions and actual values
    y_hat = predictions['y_hat'].squeeze().numpy()  # Shape: [batch, horizon, spatial_dim, variables]
    y = predictions['y'].squeeze().numpy()          # Shape: [batch, horizon, spatial_dim, variables]

    # Ensure y_hat and y have the same shape
    assert y_hat.shape == y.shape, "Predictions and actual values must have the same shape"

    # Determine the horizon if not provided
    if horizon is None:
        horizon = y_hat.shape[1]

    # Determine the number of variables
    if y_hat.ndim == 3:
        # Shape: [batch, horizon, spatial_dim]
        num_variables = 1
        y_hat = y_hat[..., np.newaxis]  # Add a variables dimension
        y = y[..., np.newaxis]
    else:
        num_variables = y_hat.shape[-1]

    if variables_to_visualize is None:
        variables_to_visualize = list(range(num_variables))
    else:
        # Ensure the variable indices are within the correct range
        variables_to_visualize = [v for v in variables_to_visualize if v < num_variables]
        if not variables_to_visualize:
            raise ValueError("No valid variables to visualize. Check 'variables_to_visualize' indices.")

    # Get latitude and longitude from metadata
    lats = metadata['lat'].values
    lons = metadata['lon'].values

    # Ensure that the number of spatial points matches the number of lats and lons
    spatial_dim = y_hat.shape[2]
    if spatial_dim != len(lats):
        raise ValueError(f"The number of spatial points in predictions ({spatial_dim}) does not match the number of locations in metadata ({len(lats)}).")

    # Calculate the min and max for lats and lons for setting map extent
    lat_min, lat_max = lats.min(), lats.max()
    lon_min, lon_max = lons.min(), lons.max()

    # If batches_to_visualize is None, visualize all batches
    if batches_to_visualize is None:
        batches_to_visualize = range(y_hat.shape[0])
    else:
        # Ensure the batch indices are within the correct range
        batches_to_visualize = [b for b in batches_to_visualize if b < y_hat.shape[0]]
        if not batches_to_visualize:
            raise ValueError("No valid batches to visualize. Check 'batches_to_visualize' indices.")

    # Precompute min and max for each variable in variables_to_visualize
    var_min_max = {}
    for var in variables_to_visualize:
        var_data_y = y[..., var]  # Shape: [batch, horizon, spatial_dim]
        var_data_y_hat = y_hat[..., var]  # Shape: [batch, horizon, spatial_dim]
        min_val = min(np.min(var_data_y), np.min(var_data_y_hat))
        max_val = max(np.max(var_data_y), np.max(var_data_y_hat))
        import numpy as np

        # ...

        # Loop through each batch and variable
        for batch in batches_to_visualize:
            for var in variables_to_visualize:
                frames = []
                slider_steps = []

                # Retrieve normalization for the current variable
                min_val, max_val = var_min_max[var]

                # Get the highest and lowest values of the variable
                var_min = np.min(y_hat[batch, :, :, var])
                var_max = np.max(y_hat[batch, :, :, var])

                # Update the min and max values if necessary
                if var_min < min_val:
                    min_val = var_min
                if var_max > max_val:
                    max_val = var_max

                # Update the normalization dictionary
                var_min_max[var] = (min_val, max_val)

                # ...

                # Predicted data for current month
                pred_data = go.Scattergeo(
                    lon=lons,
                    lat=lats,
                    mode='markers',
                    marker=dict(
                        size=6,
                        color=y_hat[batch, time_step, :, var],
                        colorscale='RdYlBu_r',
                        cmin=min_val,
                        cmax=max_val,
                        colorbar=dict(title='Predicted'),
                        opacity=0.8
                    ),
                    name='Predicted',
                    showlegend=False
                )

                # Actual data for current month
                actual_data = go.Scattergeo(
                    lon=lons,
                    lat=lats,
                    mode='markers',
                    marker=dict(
                        size=6,
                        color=y[batch, time_step, :, var],
                        colorscale='RdYlBu_r',
                        cmin=min_val,
                        cmax=max_val,
                        colorbar=dict(title='Actual'),
                        opacity=0.8
                    ),
                    name='Actual',
                    showlegend=False
                )

                # ...

                # Initial data (first frame)
                pred_data_initial = go.Scattergeo(
                    lon=lons,
                    lat=lats,
                    mode='markers',
                    marker=dict(
                        size=6,
                        color=y_hat[batch, 0, :, var],
                        colorscale='RdYlBu_r',
                        cmin=min_val,
                        cmax=max_val,
                        colorbar=dict(title='Predicted'),
                        opacity=0.8
                    ),
                    name='Predicted',
                    showlegend=True
                )

                actual_data_initial = go.Scattergeo(
                    lon=lons,
                    lat=lats,
                    mode='markers',
                    marker=dict(
                        size=6,
                        color=y[batch, 0, :, var],
                        colorscale='RdYlBu_r',
                        cmin=min_val,
                        cmax=max_val,
                        colorbar=dict(title='Actual'),
                        opacity=0.8
                    ),
                    name='Actual',
                    showlegend=True
                )

                # ...
                        visible=True,
                        prefix="Date: ",
                        xanchor="right",
                        font=dict(size=14, color="#666")
                    ),
                )],
                geo=dict(
                    scope='world',
                    projection_type='natural earth',
                    showland=True,
                    landcolor='lightgray',
                    showcountries=True,
                    countrycolor='black',
                    lataxis=dict(range=[lat_min - 1, lat_max + 1]),
                    lonaxis=dict(range=[lon_min - 1, lon_max + 1]),
                ),
                geo2=dict(
                    scope='world',
                    projection_type='natural earth',
                    showland=True,
                    landcolor='lightgray',
                    showcountries=True,
                    countrycolor='black',
                    lataxis=dict(range=[lat_min - 1, lat_max + 1]),
                    lonaxis=dict(range=[lon_min - 1, lon_max + 1]),
                )
            )

            # Update frames to assign traces to correct subplots
            for frame in fig.frames:
                time_step_idx = int(frame.name.split()[1]) - 1
                frame.data = [
                    go.Scattergeo(
                        lon=lons,
                        lat=lats,
                        mode='markers',
                        marker=dict(
                            size=6,
                            color=y_hat[batch, time_step_idx, :, var],
                            colorscale='RdYlBu_r',
                            cmin=min_val,
                            cmax=max_val,
                            opacity=0.8
                        ),
                        showlegend=False
                    ),
                    go.Scattergeo(
                        lon=lons,
                        lat=lats,
                        mode='markers',
                        marker=dict(
                            size=6,
                            color=y[batch, time_step_idx, :, var],
                            colorscale='RdYlBu_r',
                            cmin=min_val,
                            cmax=max_val,
                            opacity=0.8
                        ),
                        showlegend=False
                    )
                ]

            # Update layout for better appearance
            fig.update_layout(
                height=600,
                width=1200,
                margin=dict(l=50, r=50, t=100, b=50)
            )

            # Display the figure
            fig.show()

            # Optionally, save the figure to an HTML file
            # fig.write_html(f'item_{idx+1}_batch_{batch+1}_variable_{var+1}.html')

# Example Usage

# Assuming you have the metadata DataFrame available
# metadata = pd.read_parquet('ml-drought-forecasting/ml-modeling-pipeline/data/05_model_input/metadata.parquet')

# Ensure that predictions_test is a list of dictionaries with 'y' and 'y_hat'
# For example:
# predictions_test = [{'y': tensor1, 'y_hat': tensor2}, {'y': tensor3, 'y_hat': tensor4}, ...]

# # Specify the indices of the items you want to visualize
# items_to_visualize = [1]  # Replace with desired item indices (0-based indexing)
# batches_to_visualize = [0, 1]  # Replace with desired batch indices within each item
# variables_to_visualize = [0]  # Replace with desired variable indices (e.g., if multiple EDDI metrics)

# # Call the function for specified items in predictions_test
# for idx in items_to_visualize:
#     pred = predictions_test[idx]
#     visualize_predictions_plotly(
#         pred, metadata, idx, horizon=pred['y'].shape[1],
#         batches_to_visualize=batches_to_visualize, 
#         variables_to_visualize=variables_to_visualize
#     )
#     print(f"Processed item {idx + 1}/{len(predictions_test)}")
