In [1]:
import datetime
import os

import cartopy
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import animation
import numpy as np
import shapely.geometry as sgeom

In [2]:
from dgmr import DGMR
import os
from torch import load
import urllib

def get_pretrained():
    """Download the weights of the DGMR model, or load them if already previously
    downloaded, and return the model with weights loaded.
    """
    
    MODELFOLDER = "../model_weights"
    os.makedirs(MODELFOLDER, exist_ok=True)
    MODELFILE = os.path.join(MODELFOLDER, "pytorch_model.bin")
    
    # If pytorch_model.bin file hasn't been download yet...
    if not os.path.isfile(MODELFILE):
        # Download from link found on https://huggingface.co/openclimatefix/dgmr/tree/main
        MODEL_URL = "https://huggingface.co/openclimatefix/dgmr/resolve/main/pytorch_model.bin?download=true"
        urllib.request.urlretrieve(MODEL_URL, MODELFILE)
        
    state_dict = load(MODELFILE)
    model = DGMR()
    model.load_state_dict(state_dict)
    
    return model

In [3]:
model = get_pretrained()

In [4]:
import numpy as np
import os
import pandas as pd
import pyproj
from wradlib.io import read_opera_hdf5
import xarray as xr


def get_data_as_xarray(data_folder):
    datasets = []
    
    fns = os.listdir(data_folder)
    for i,filename in enumerate(fns):
        fns[i] = f"{data_folder}/{filename}"

    for file_name in fns:
        # Read the content
        file_content = read_opera_hdf5(file_name)

        # Extract time information
        time_str = os.path.splitext(os.path.basename(file_name))[0].split('.', 1)[0]
        time = pd.to_datetime(time_str, format='%Y%m%d%H%M%S')

        # Extract quantity information
        try:
            quantity = file_content['dataset1/data1/what']['quantity'].decode()
        except:
            quantity = file_content['dataset1/data1/what']['quantity']

        # Set variable properties based on quantity
        if quantity == 'RATE':
            short_name = 'precip_intensity'
            long_name = 'instantaneous precipitation rate'
            units = 'mm h-1'
        else:
            raise Exception(f"Quantity {quantity} not yet implemented.")

        # Create the grid
        projection = file_content.get("where", {}).get("projdef", "")
        if type(projection) is not str:
            projection = projection.decode("UTF-8")

        gridspec = file_content.get("dataset1/where", {})

        x = np.linspace(gridspec.get('UL_x', 0),
                        gridspec.get('UL_x', 0) + gridspec.get('xsize', 0) * gridspec.get('xscale', 0),
                        num=gridspec.get('xsize', 0), endpoint=False)
        x += gridspec.get('xscale', 0)
        y = np.linspace(gridspec.get('UL_y', 0),
                        gridspec.get('UL_y', 0) - gridspec.get('ysize', 0) * gridspec.get('yscale', 0),
                        num=gridspec.get('ysize', 0), endpoint=False)
        y -= gridspec.get('yscale', 0) / 2

        x_2d, y_2d = np.meshgrid(x, y)

        pr = pyproj.Proj(projection)

        lon, lat = pr(x_2d.flatten(), y_2d.flatten(), inverse=True)
        lon = lon.reshape(gridspec.get('ysize', 0), gridspec.get('xsize', 0))
        lat = lat.reshape(gridspec.get('ysize', 0), gridspec.get('xsize', 0))

        # Build the xarray dataset
        ds = xr.Dataset(
            data_vars={
                short_name: (['x', 'y'], file_content.get("dataset1/data1/data", np.nan),
                             {'long_name': long_name, 'units': units})
            },
            coords={
                'x': (['x'], x, {'axis': 'X', 'standard_name': 'projection_x_coordinate',
                                 'long_name': 'x-coordinate in Cartesian system', 'units': 'm'}),
                'y': (['y'], y, {'axis': 'Y', 'standard_name': 'projection_y_coordinate',
                                 'long_name': 'y-coordinate in Cartesian system', 'units': 'm'}),
                'lon': (['y', 'x'], lon, {'standard_name': 'longitude', 'long_name': 'longitude coordinate',
                                          'units': 'degrees_east'}),
                'lat': (['y', 'x'], lat, {'standard_name': 'latitude', 'long_name': 'latitude coordinate',
                                          'units': 'degrees_north'})
            }
        )
        ds['time'] = time

        # Append the dataset to the list
        datasets.append(ds)

    # Concatenate datasets along the time dimension
    final_dataset = xr.concat(datasets, dim='time')
    
    final_dataset = final_dataset.sortby('time')
    
    return final_dataset



In [5]:
rmi_radar_fp = "/home/armoraux/Pysteps/pysteps_data/radar/rmi/radqpe/20210704"

x = get_data_as_xarray(rmi_radar_fp)



In [6]:
x = x['precip_intensity']
x.shape, x.dtype

((39, 700, 700), dtype('float32'))

In [7]:
def prep(field):
    '''
    - Crop xarray data to required dimensions (700x700 to 256x256)
    - Reshape it to:
        [B, T, C, H, W] - Batch, Time, Channel, Heigh, Width
    - Turn it into a torch.tensor
    args:
        - field: xarray.DataArray
            The precipitation data variable from the xarray
    '''
    # Crop the center of the field and get a 256x256 image
    # Intervals of +/- 256/2 around the center (which is 700/2)
    low = (700//2) - (256//2)
    high = (700//2) + (256//2)
    cropped = field[:, low:high, low:high]
    
    return cropped

In [8]:
x_context = x[:4]
x_observed = x[4:]

In [9]:
x_context_cropped = prep(x_context)
x_observed_cropped = prep(x_observed)

In [10]:
import torch

In [15]:
x_context_cropped_tensor = torch.tensor(x_context_cropped.data)

In [19]:
x_context_cropped_tensor = x_context_cropped_tensor.reshape(1,4,1,256,256)

In [20]:
pred = model(x_context_cropped_tensor)

In [21]:
pred = pred[0,:,0,:,:]

In [22]:
pred.shape

torch.Size([18, 256, 256])

In [None]:
import matplotlib
from matplotlib import animation
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import numpy as np

def plot_animation(field, figsize=None,
                   cmap="jet", **imshow_args):

    matplotlib.rc('animation', html='jshtml')
    
    fig = plt.figure(figsize=figsize)
    ax = plt.axes()
    ax.set_axis_off()
    plt.close() # Prevents extra axes being plotted below animation
    vmax = 100#np.max(field)
    vmin = 0#np.min(field)
    img = ax.imshow(field[0, :,:], norm=LogNorm(vmin=0.1, vmax=100), cmap=cmap, **imshow_args)
    cb = fig.colorbar(img, ax=ax)
    tx = ax.set_title('Frame 0')

    def animate(frame):
        img.set_data(field[frame])
        # vmax     = np.max(field[frame])
        # vmin     = np.min(field[frame])
        # img.set_clim(vmin, vmax)
        tx.set_text(f'Frame {frame}')
        return (img,)
        
    return animation.FuncAnimation(
      fig, animate, frames=field.shape[0], interval=4, blit=False)
  
def plot_subplot(input, output, figsize=None,
                  vmin=0, vmax=10, cmap="jet", **imshow_args):
  fig, axes = plt.subplots(2, 4, figsize=figsize)
  if str(type(output)) == "<class 'torch.Tensor'>":

    output = output.detach().numpy()
  for i in range(4):
    im1 = axes[0, i].imshow(input[0, i, 0], cmap=cmap, vmin=vmin, vmax= vmax, **imshow_args)
    plt.colorbar(im1, ax=axes[0, i])
    
    im2 = axes[1, i].imshow(output[0, i, 0], cmap=cmap, vmin=vmin, vmax= vmax, **imshow_args)
    plt.colorbar(im2, ax=axes[1, i])
  
  return None

In [29]:
pred_numpy = pred.detach().numpy()
pred_numpy[pred_numpy < 0.1] = 0

In [30]:
plot_animation(pred_numpy)

In [31]:
x_observed_cropped = x_observed_cropped[:18]

In [32]:
plot_animation(x_observed_cropped)

In [88]:
x_context.shape

(4, 700, 700)

In [89]:
x_context_input = tf.reshape(x_context, [4, 700, 700, 1])

In [92]:
pred_full = predict(module, x_context_input)