In [24]:
import datetime
import os

import cartopy
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import shapely.geometry as sgeom
import tensorflow as tf
import tensorflow_hub

2024-07-18 16:29:15.257334: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-18 16:29:15.264983: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-18 16:29:15.267345: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-18 16:29:15.275071: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [25]:
# Folder path to the model weights downloaded from 
# https://console.cloud.google.com/storage/browser/dm-nowcasting-example-data?pageState=(%22StorageObjectListTable%22:(%22f%22:%22%255B%255D%22))&project=friendly-retina-382415
TFHUB_BASE_PATH = "/home/armoraux/Pysteps/tfhub_snapshots"

def load_module(input_height, input_width):
  """Load a TF-Hub snapshot of the 'Generative Method' model."""
  hub_module = tensorflow_hub.load(
      os.path.join(TFHUB_BASE_PATH, f"{input_height}x{input_width}"))
  # Note this has loaded a legacy TF1 model for running under TF2 eager mode.
  # This means we need to access the module via the "signatures" attribute. See
  # https://github.com/tensorflow/hub/blob/master/docs/migration_tf2.md#using-lower-level-apis
  # for more information.
  return hub_module.signatures['default']

In [26]:
module = load_module(256, 256)

In [27]:
def predict(module, input_frames, num_samples=1,
            include_input_frames_in_result=False):
  """Make predictions from a TF-Hub snapshot of the 'Generative Method' model.

  Args:
    module: One of the raw TF-Hub modules returned by load_module above.
    input_frames: Shape (T_in,H,W,C), where T_in = 4. Input frames to condition
      the predictions on.
    num_samples: The number of different samples to draw.
    include_input_frames_in_result: If True, will return a total of 22 frames
      along the time axis, the 4 input frames followed by 18 predicted frames.
      Otherwise will only return the 18 predicted frames.

  Returns:
    A tensor of shape (num_samples,T_out,H,W,C), where T_out is either 18 or 22
    as described above.
  """
  input_frames = tf.math.maximum(input_frames, 0.)
  # Add a batch dimension and tile along it to create a copy of the input for
  # each sample:
  input_frames = tf.expand_dims(input_frames, 0)
  input_frames = tf.tile(input_frames, multiples=[num_samples, 1, 1, 1, 1])

  # Sample the latent vector z for each sample:
  _, input_signature = module.structured_input_signature
  z_size = input_signature['z'].shape[1]
  z_samples = tf.random.normal(shape=(num_samples, z_size))

  inputs = {
      "z": z_samples,
      "labels$onehot" : tf.ones(shape=(num_samples, 1)),
      "labels$cond_frames" : input_frames
  }
  samples = module(**inputs)['default']
  if not include_input_frames_in_result:
    # The module returns the input frames alongside its sampled predictions, we
    # slice out just the predictions:
    samples = samples[:, NUM_INPUT_FRAMES:, ...]

  # Take positive values of rainfall only.
  samples = tf.math.maximum(samples, 0.)
  return samples


# Fixed values supported by the snapshotted model.
NUM_INPUT_FRAMES = 4
NUM_TARGET_FRAMES = 5 #18

def extract_input_and_target_frames(radar_frames):
  """Extract input and target frames from a dataset row's radar_frames."""
  # We align our targets to the end of the window, and inputs precede targets.
  input_frames = radar_frames[-NUM_TARGET_FRAMES-NUM_INPUT_FRAMES : -NUM_TARGET_FRAMES]
  target_frames = radar_frames[-NUM_TARGET_FRAMES : ]
  return input_frames, target_frames


def reshape_sample(samples):
    t, h, w = samples.shape
    c = 1
    return tf.reshape(samples, [t, h, w, c])

In [30]:
import numpy as np
import os
import pandas as pd
import pyproj
from wradlib.io import read_opera_hdf5
import xarray as xr


def get_data_as_xarray(data_folder):
    datasets = []
    
    fns = os.listdir(data_folder)
    for i,filename in enumerate(fns):
        fns[i] = f"{data_folder}/{filename}"

    for file_name in fns:
        # Read the content
        file_content = read_opera_hdf5(file_name)

        # Extract time information
        time_str = os.path.splitext(os.path.basename(file_name))[0].split('.', 1)[0]
        time = pd.to_datetime(time_str, format='%Y%m%d%H%M%S')

        # Extract quantity information
        try:
            quantity = file_content['dataset1/data1/what']['quantity'].decode()
        except:
            quantity = file_content['dataset1/data1/what']['quantity']

        # Set variable properties based on quantity
        if quantity == 'RATE':
            short_name = 'precip_intensity'
            long_name = 'instantaneous precipitation rate'
            units = 'mm h-1'
        else:
            raise Exception(f"Quantity {quantity} not yet implemented.")

        # Create the grid
        projection = file_content.get("where", {}).get("projdef", "")
        if type(projection) is not str:
            projection = projection.decode("UTF-8")

        gridspec = file_content.get("dataset1/where", {})

        x = np.linspace(gridspec.get('UL_x', 0),
                        gridspec.get('UL_x', 0) + gridspec.get('xsize', 0) * gridspec.get('xscale', 0),
                        num=gridspec.get('xsize', 0), endpoint=False)
        x += gridspec.get('xscale', 0)
        y = np.linspace(gridspec.get('UL_y', 0),
                        gridspec.get('UL_y', 0) - gridspec.get('ysize', 0) * gridspec.get('yscale', 0),
                        num=gridspec.get('ysize', 0), endpoint=False)
        y -= gridspec.get('yscale', 0) / 2

        x_2d, y_2d = np.meshgrid(x, y)

        pr = pyproj.Proj(projection)

        lon, lat = pr(x_2d.flatten(), y_2d.flatten(), inverse=True)
        lon = lon.reshape(gridspec.get('ysize', 0), gridspec.get('xsize', 0))
        lat = lat.reshape(gridspec.get('ysize', 0), gridspec.get('xsize', 0))

        # Build the xarray dataset
        ds = xr.Dataset(
            data_vars={
                short_name: (['x', 'y'], file_content.get("dataset1/data1/data", np.nan),
                             {'long_name': long_name, 'units': units})
            },
            coords={
                'x': (['x'], x, {'axis': 'X', 'standard_name': 'projection_x_coordinate',
                                 'long_name': 'x-coordinate in Cartesian system', 'units': 'm'}),
                'y': (['y'], y, {'axis': 'Y', 'standard_name': 'projection_y_coordinate',
                                 'long_name': 'y-coordinate in Cartesian system', 'units': 'm'}),
                'lon': (['y', 'x'], lon, {'standard_name': 'longitude', 'long_name': 'longitude coordinate',
                                          'units': 'degrees_east'}),
                'lat': (['y', 'x'], lat, {'standard_name': 'latitude', 'long_name': 'latitude coordinate',
                                          'units': 'degrees_north'})
            }
        )
        ds['time'] = time

        # Append the dataset to the list
        datasets.append(ds)

    # Concatenate datasets along the time dimension
    final_dataset = xr.concat(datasets, dim='time')
    return final_dataset

In [31]:
rmi_radar_fp = "/home/armoraux/Pysteps/pysteps_data/radar/rmi/radqpe/20210704"

x = get_data_as_xarray(rmi_radar_fp)



In [32]:
x

In [87]:
x.time

In [33]:
x = x['precip_intensity']
x.shape, x.dtype

((39, 700, 700), dtype('float32'))

In [41]:
def prep(field):
    '''
    - Crop xarray data to required dimensions (700x700 to 256x256)
    - Reshape it to:
        [B, T, C, H, W] - Batch, Time, Channel, Heigh, Width
    - Turn it into a torch.tensor
    args:
        - field: xarray.DataArray
            The precipitation data variable from the xarray
    '''
    # Crop the center of the field and get a 256x256 image
    # Intervals of +/- 256/2 around the center (which is 700/2)
    low = (700//2) - (256//2)
    high = (700//2) + (256//2)
    cropped = field[:, low:high, low:high]
    
    return cropped

In [42]:
x_context = x[:4]
x_observed = x[4:]

In [44]:
x_context_cropped = prep(x_context)
x_observed_cropped = prep(x_observed)

In [53]:
x_context_cropped = tf.reshape(x_context_cropped, [4, 256, 256, 1])

In [61]:
pred = predict(module, x_context_cropped)

In [62]:
pred = pred[0,:,:,:,0]

In [75]:
pred.shape

TensorShape([18, 256, 256])

In [78]:
import matplotlib
from matplotlib import animation
import matplotlib.pyplot as plt
import numpy as np

def plot_animation(field, figsize=None,
                   cmap="jet", **imshow_args):
  
  matplotlib.rc('animation', html='jshtml')
  
  fig = plt.figure(figsize=figsize)
  ax = plt.axes()
  ax.set_axis_off()
  plt.close() # Prevents extra axes being plotted below animation
  vmax = np.max(field)
  vmin = np.min(field)
  img = ax.imshow(field[0, :,:], vmin=vmin, vmax=vmax, cmap=cmap, **imshow_args)
  cb = fig.colorbar(img, ax=ax)
  tx = ax.set_title('Frame 0')

  def animate(frame):
    img.set_data(field[frame])
    vmax     = np.max(field[frame])
    vmin     = np.min(field[frame])
    img.set_clim(vmin, vmax)
    tx.set_text(f'Frame {frame}')
    return (img,)

  return animation.FuncAnimation(
      fig, animate, frames=field.shape[0], interval=4, blit=False)
  
def plot_subplot(input, output, figsize=None,
                  vmin=0, vmax=10, cmap="jet", **imshow_args):
  fig, axes = plt.subplots(2, 4, figsize=figsize)
  if str(type(output)) == "<class 'torch.Tensor'>":
    output = output.detach().numpy()
  for i in range(4):
    im1 = axes[0, i].imshow(input[0, i, 0], cmap=cmap, vmin=vmin, vmax= vmax, **imshow_args)
    plt.colorbar(im1, ax=axes[0, i])
    
    im2 = axes[1, i].imshow(output[0, i, 0], cmap=cmap, vmin=vmin, vmax= vmax, **imshow_args)
    plt.colorbar(im2, ax=axes[1, i])
  
  return None

In [79]:
plot_animation(pred)

In [84]:
x_observed_cropped = x_observed_cropped[:18]

In [85]:
plot_animation(x_observed_cropped)