In [None]:
import os
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from netCDF4 import Dataset
import numpy as np
from datetime import datetime

from torch.utils.tensorboard import SummaryWriter
import torchvision

from config import Config
from dgmr import DGMR

In [None]:
NUM_INPUT_FRAMES = 4
NUM_TARGET_FRAMES = 4
TOTAL_FRAMES = NUM_INPUT_FRAMES + NUM_TARGET_FRAMES

In [None]:
# Load real test data
class NetCDFDataset(torch.utils.data.dataset.Dataset):
    """
    Typically, dataset returns an individual item from the dataset in __getitem__ method.
    Also, it should return the number of items in the dataset in __len__ method.

    Here, dataset returns a batch of frames in __getitem__ method, therefore __len__ method returns the number of batches.
    Also, DGMRDataModule should have batch_size=1, since we are returning a batch of frames in __getitem__ method.
    """

    def __init__(
        self,
        split,
        num_epochs,
        batches_per_epoch=5,
        batch_offset=0,
        start_day="20160301",
    ):
        super().__init__()
        self.split = split
        self.local_folder_path = os.path.join(Config.DATA_DIR, split)
        self.num_epochs = num_epochs
        self.batches_per_epoch = batches_per_epoch
        self.batch_offset = batch_offset
        self.all_files = self._get_all_files(start_day)

        # Adjust total_batches based on the split type
        if self.split == "train":
            # For training, calculate total_batches dynamically based on the number of epochs and desired batches per epoch
            self.total_batches = len(self.all_files) // TOTAL_FRAMES

        else:
            # For validation (and potentially test), limit to first n batches if split is not training
            self.total_batches = min(len(self.all_files) // TOTAL_FRAMES, self.batches_per_epoch)

        print(f"For split: {split}: total_batches: {self.total_batches} | number of files: {len(self.all_files)}")

    def __len__(self):
        """
        Return size of the data set for DataLoader, but if Dataset gives complete batches
        and not individual items from the dataset, then it should return total number of batches
        """
        if self.split == "train":
            # Only limit the length for training to process batches_per_epoch batches each epoch
            return self.batches_per_epoch
        return self.total_batches

    def _get_all_files(self, start_day="20160301"):
        limit = float("inf")
        if self.split == "test" or self.split == "validation":
            limit = self.batches_per_epoch * TOTAL_FRAMES

        all_files = []

        day_folders = sorted(os.listdir(self.local_folder_path))
        # Remove the days before the start_day
        day_folders = [d for d in day_folders if d >= start_day]

        for day_folder in day_folders:
            day_folder_path = os.path.join(self.local_folder_path, day_folder)
            if os.path.isdir(day_folder_path):
                files = sorted(os.listdir(day_folder_path))
                all_files.extend([os.path.join(day_folder_path, f) for f in files])
                if len(all_files) >= limit:
                    break
        return all_files

    def _check_batch(self, frames, batch_idx, frame_type):
        # Check for abnormal values
        if (frames >= 65535).any():
            frames[frames >= 65535] = 0  # 65534 What should be the value?

        # Check for NaN values
        if np.isnan(frames).any():
            raise ValueError(f"NaN values found in {frame_type} frames of batch {batch_idx}")

        # Check for negative values
        if (frames < 0).any():
            raise ValueError(f"Negative values found in {frame_type} frames of batch {batch_idx}")

        # Check if all frames are not a type of masked array
        if all([isinstance(f, np.ma.MaskedArray) for f in frames]):
            raise ValueError(f"Frame(s) of batch {batch_idx} have masked array")

    def _load_frame(self, file_path):
        print(f"Loading file: {file_path}")
        with Dataset(file_path, "r") as nc_data:
            return np.ma.filled(nc_data.variables["RRdata"][:], 0)

    def __getitem__(self, idx):
        """Returns one sequence(batch) of frames"""

        # Adjust index based on the batch offset
        actual_idx = (self.batch_offset + idx) % self.total_batches
        start_idx = actual_idx * TOTAL_FRAMES
        end_idx = min(start_idx + TOTAL_FRAMES, len(self.all_files))
        frame_paths = self.all_files[start_idx:end_idx]

        frames = [self._load_frame(fp) for fp in frame_paths]
        input_frames = np.stack(frames[:NUM_INPUT_FRAMES])
        target_frames = np.stack(frames[NUM_INPUT_FRAMES:])

        # Checks for the entire batch
        self._check_batch(input_frames, idx, "input")
        self._check_batch(target_frames, idx, "target")

        return input_frames, target_frames

In [None]:
# Check device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# Initialize and load model
model = DGMR(
    forecast_steps=4,
    input_channels=1,
    output_shape=256,
    latent_channels=768,
    context_channels=384,
    num_samples=3,
    visualize=True,  # Ensure this is True
).to(device)

# Evaluate the model on real test data
model.eval()
print("Model set to evaluation mode")

In [None]:
# Function to plot data
def plot_data_and_save(data, title, directory):
    batch_size, time_steps, channels, height, width = data.shape
    vmin = 0  # Minimum value of your data range (e.g., 0 mm/hr)

    # Calculate vmax across the entire dataset to ensure consistent color mapping
    # vmax = torch.max(data) if torch.max(data) > 128 else 128
    vmax = torch.max(data) if torch.max(data) > 10 else 10

    fig, axs = plt.figure(figsize=(time_steps * 4, 4)), []
    for i in range(time_steps):
        ax = fig.add_subplot(1, time_steps, i + 1)
        frame_data = data[0, i, 0].cpu().detach().numpy()  # Assuming data is a torch Tensor
        im = ax.imshow(frame_data, cmap="viridis", vmin=vmin, vmax=vmax)
        axs.append(ax)

        ax.set_title(f"{title} - Time {i + 1}")
        ax.axis("off")

    # Add one color bar for the entire figure
    fig.colorbar(im, ax=axs, orientation="horizontal", fraction=0.1, pad=0.04, label="RRdata (mm/hr)")

    # Optionally, add labels for max and min values
    # This can be done by using text annotations if needed or included in the colorbar label
    # For simplicity, it's included in the colorbar label above. Modify as needed for different requirements.

    save_path = os.path.join(directory, f"{title.replace(' ', '_')}.png")
    plt.savefig(save_path)
    plt.close()
    print(f"Plot saved as {save_path}")


def visualize_step(
    tensorboard_writer,
    x: torch.Tensor,
    y: torch.Tensor,
    y_hat: torch.Tensor,
    batch_idx: int,
    step: str,
    input_channels: int,
) -> None:
    images = x[0].cpu().detach()
    future_images = y[0].cpu().detach()
    generated_images = y_hat[0].cpu().detach()
    for i, t in enumerate(images):
        t = [torch.unsqueeze(img, dim=0) for img in t]
        image_grid = torchvision.utils.make_grid(t, nrow=input_channels)
        tensorboard_writer.add_image(f"{step}/Input_Image_Stack_Frame_{i}", image_grid, global_step=batch_idx)
    for i, t in enumerate(future_images):
        t = [torch.unsqueeze(img, dim=0) for img in t]
        image_grid = torchvision.utils.make_grid(t, nrow=input_channels)
        tensorboard_writer.add_image(f"{step}/Target_Image_Stack_Frame_{i}", image_grid, global_step=batch_idx)
    for i, t in enumerate(generated_images):
        t = [torch.unsqueeze(img, dim=0) for img in t]
        image_grid = torchvision.utils.make_grid(t, nrow=input_channels)
        tensorboard_writer.add_image(f"{step}/Predicted_Image_Stack_Frame_{i}", image_grid, global_step=batch_idx)

    print("Visualized a batch of data")

In [None]:
model_name = "DGMR-V1_20240426_0127"
timestamp = datetime.now().strftime("%Y%m%d_%H%M")

checkpoint_path = os.path.join(Config.ROOT_DIR, f"output/models/{model_name}.ckpt")

plot_output_path = os.path.join(Config.ROOT_DIR, f"output/visualize_predicted_results/{model_name}/{timestamp}")
tensorboard_output_path = os.path.join(plot_output_path, "TensorBoard")

os.makedirs(tensorboard_output_path, exist_ok=True)
os.makedirs(plot_output_path, exist_ok=True)

checkpoint = torch.load(checkpoint_path, map_location=device)
model.load_state_dict(checkpoint["state_dict"])
print("Model loaded successfully")

In [None]:
# Create DataLoader for test data
test_data_loader = DataLoader(
    NetCDFDataset(split="train", num_epochs=1, start_day="20180908"),
    batch_size=1,
)
print("Test data loaded successfully")

tensorboard_writer = SummaryWriter(tensorboard_output_path)

In [None]:
with torch.no_grad():
    for batch_idx, (test_inputs, test_targets) in enumerate(test_data_loader):
        print("Processing a batch of test data")
        print(f"Input shape: {test_inputs.shape}")

        # Move data to device
        test_inputs = test_inputs.to(device)
        test_targets = test_targets.to(device)

        # Forward pass
        test_outputs = model(test_inputs)

        # Visualization
        plot_data_and_save(test_inputs, "Real Test Input Data", directory=plot_output_path)
        plot_data_and_save(test_outputs, "Generated Data", directory=plot_output_path)
        plot_data_and_save(test_targets, "Real Test Target Data", directory=plot_output_path)

        # Visualization using modified visualize_step
        # visualize_step(
        #     tensorboard_writer,
        #     test_inputs,
        #     test_targets,
        #     test_outputs,
        #     batch_idx,
        #     "test",
        #     model.input_channels,  # Assuming this is defined in your model
        # )
        break  # Remove this line to process the entire test set

# print(
#     f"tensorboard --logdir=/work/pi_mzink_umass_edu/SPRITE/skillful_nowcasting/{tensorboard_output_path[2:]} & ./ngrok http 6006"
# )