In [1]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

import pandas as pd

from IPython.display import HTML
import cartopy.crs as ccrs
ds = xr.open_zarr('gs://weatherbench2/datasets/keisler/2020-64x32_equiangular_conservative.zarr')

Cannot find the ecCodes library


In [27]:
class WeatherData:
    def __init__(self, ds: xr.Dataset) -> None:
        self.ds = ds
    
    def select_data(self, horizon: int = 36, offset: int = 0) -> None:
        self.horizon = int(horizon * 1e9 * 3600 + 1)

        self.latitude = np.arange(-35 - offset, -22.75 + offset, 0.25) 
        self.longitude = np.arange(16 - offset, 33.25 + offset, 0.25) 

        self.level = 850 # hPa

        # Print selection criteria
        print(f"Latitude: {self.latitude.min()} to {self.latitude.max()}")
        print(f"Longitude: {self.longitude.min()} to {self.longitude.max()}")
        print(f"Level: {self.level} hPa")
        print(f"Horizon: {horizon} days")

        self.ds = self.ds.sel(level=self.level)
        self.ds = self.ds.sel(latitude=slice(self.latitude.min(), self.latitude.max()), longitude=slice(self.longitude.min(), self.longitude.max()))
        self.ds = self.ds.sel(prediction_timedelta=slice(np.timedelta64(0, 'ns'), np.timedelta64(self.horizon, 'ns')))

    def load_data(self) -> None:
        self.ds.load()

    def describe_data(self) -> None:
        print(self.ds.coords)

    def plot_from_ds(self, seed: int = 0, frame_rate: int = 8, levels: int = 10, variables: list = ['geopotential', 'wind_speed', 'temperature', 'specific_humidity']) -> HTML:

        sample = self.ds[variables]

        sample = sample.isel(time=seed)

        seed_time = sample.time.values

        timedeltas = pd.to_timedelta(sample.prediction_timedelta.values, unit='ns')

        times_at_seed_plus_delta = pd.to_datetime(seed_time) + timedeltas

        formatted_times = times_at_seed_plus_delta.strftime('%Y-%m-%dT%H:%M:%S.') + times_at_seed_plus_delta.nanosecond.astype(str).str.zfill(9)

        bounds = [sample.longitude.min().item(), sample.longitude.max().item(),
                sample.latitude.min().item(), sample.latitude.max().item()]
        
        print(bounds)

        latitudes = sample.latitude.values
        longitudes = sample.longitude.values

        lon_grid, lat_grid = np.meshgrid(latitudes, longitudes)

        n_columns = 2
        n_rows = int(np.ceil(len(variables) / n_columns))

        # Set up the figure and axes with dynamic rows and 2 columns
        fig, axs = plt.subplots(n_rows, n_columns, figsize=(14, 7 * n_rows), 
                                subplot_kw={'projection': ccrs.PlateCarree()})
        
        fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.9, wspace=0.2, hspace=0.2)

        # Flatten the axes array for easy indexing
        axs = axs.flatten()

        for i, variable in enumerate(variables):
            print(f"Variable: {variable}")

            ax = axs[i]
            
            ax.coastlines()
            ax.set_extent(bounds, crs=ccrs.PlateCarree())
            vmin = sample[variable].values.min()
            vmax = sample[variable].values.max()
            feat = ax.contourf(lat_grid, lon_grid, sample[variable].values[0], levels=levels, vmin=vmin, vmax=vmax, transform=ccrs.PlateCarree())
            fig.colorbar(feat, ax=ax, orientation="vertical", label=f"{variable} levels", shrink=0.7, pad=0.02)

            ax.set_title(f"{variable} - {formatted_times[0]}")


        def animate(i):
            print(f"Frame: {i}/{len(sample.prediction_timedelta)}", end='\r')
            for j, variable in enumerate(variables):
                ax = axs[j]
                ax.clear() 
                
                ax.coastlines()
                ax.set_extent(bounds, crs=ccrs.PlateCarree())
                
                vmin = sample[variable].values.min()
                vmax = sample[variable].values.max()
                
                feat = ax.contourf(lat_grid, lon_grid, sample[variable].values[i], levels=levels, vmin=vmin, vmax=vmax,  transform=ccrs.PlateCarree())
                ax.set_title(f"{variable} - {formatted_times[i]}")
            return feat,

        frames = len(sample.prediction_timedelta)

        interval = 1000 / frame_rate

        ani = FuncAnimation(fig, animate, frames=frames, interval=interval)

        plt.close(fig)

        return HTML(ani.to_jshtml())


In [28]:
weather_data = WeatherData(ds)

In [29]:
weather_data.select_data(offset=5)

Latitude: -40.0 to -18.0
Longitude: 11.0 to 38.0
Level: 850 hPa
Horizon: 36 days


In [30]:
weather_data.describe_data()

Coordinates:
  * latitude              (latitude) float64 32B -36.56 -30.94 -25.31 -19.69
    level                 int64 8B 850
  * longitude             (longitude) float64 40B 11.25 16.88 22.5 28.12 33.75
  * prediction_timedelta  (prediction_timedelta) timedelta64[ns] 56B 00:00:00...
  * time                  (time) datetime64[ns] 6kB 2020-01-01 ... 2020-12-31...


In [31]:
weather_data.ds

Unnamed: 0,Array,Chunk
Bytes,400.31 kiB,2.19 kiB
Shape,"(732, 7, 5, 4)","(4, 7, 5, 4)"
Dask graph,183 chunks in 5 graph layers,183 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 400.31 kiB 2.19 kiB Shape (732, 7, 5, 4) (4, 7, 5, 4) Dask graph 183 chunks in 5 graph layers Data type float32 numpy.ndarray",732  1  4  5  7,

Unnamed: 0,Array,Chunk
Bytes,400.31 kiB,2.19 kiB
Shape,"(732, 7, 5, 4)","(4, 7, 5, 4)"
Dask graph,183 chunks in 5 graph layers,183 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,400.31 kiB,2.19 kiB
Shape,"(732, 7, 5, 4)","(4, 7, 5, 4)"
Dask graph,183 chunks in 5 graph layers,183 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 400.31 kiB 2.19 kiB Shape (732, 7, 5, 4) (4, 7, 5, 4) Dask graph 183 chunks in 5 graph layers Data type float32 numpy.ndarray",732  1  4  5  7,

Unnamed: 0,Array,Chunk
Bytes,400.31 kiB,2.19 kiB
Shape,"(732, 7, 5, 4)","(4, 7, 5, 4)"
Dask graph,183 chunks in 5 graph layers,183 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,400.31 kiB,2.19 kiB
Shape,"(732, 7, 5, 4)","(4, 7, 5, 4)"
Dask graph,183 chunks in 5 graph layers,183 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 400.31 kiB 2.19 kiB Shape (732, 7, 5, 4) (4, 7, 5, 4) Dask graph 183 chunks in 5 graph layers Data type float32 numpy.ndarray",732  1  4  5  7,

Unnamed: 0,Array,Chunk
Bytes,400.31 kiB,2.19 kiB
Shape,"(732, 7, 5, 4)","(4, 7, 5, 4)"
Dask graph,183 chunks in 5 graph layers,183 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,400.31 kiB,2.19 kiB
Shape,"(732, 7, 5, 4)","(4, 7, 5, 4)"
Dask graph,183 chunks in 5 graph layers,183 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 400.31 kiB 2.19 kiB Shape (732, 7, 5, 4) (4, 7, 5, 4) Dask graph 183 chunks in 5 graph layers Data type float32 numpy.ndarray",732  1  4  5  7,

Unnamed: 0,Array,Chunk
Bytes,400.31 kiB,2.19 kiB
Shape,"(732, 7, 5, 4)","(4, 7, 5, 4)"
Dask graph,183 chunks in 5 graph layers,183 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,400.31 kiB,2.19 kiB
Shape,"(732, 7, 5, 4)","(4, 7, 5, 4)"
Dask graph,183 chunks in 5 graph layers,183 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 400.31 kiB 2.19 kiB Shape (732, 7, 5, 4) (4, 7, 5, 4) Dask graph 183 chunks in 5 graph layers Data type float32 numpy.ndarray",732  1  4  5  7,

Unnamed: 0,Array,Chunk
Bytes,400.31 kiB,2.19 kiB
Shape,"(732, 7, 5, 4)","(4, 7, 5, 4)"
Dask graph,183 chunks in 5 graph layers,183 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,400.31 kiB,2.19 kiB
Shape,"(732, 7, 5, 4)","(4, 7, 5, 4)"
Dask graph,183 chunks in 5 graph layers,183 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 400.31 kiB 2.19 kiB Shape (732, 7, 5, 4) (4, 7, 5, 4) Dask graph 183 chunks in 5 graph layers Data type float32 numpy.ndarray",732  1  4  5  7,

Unnamed: 0,Array,Chunk
Bytes,400.31 kiB,2.19 kiB
Shape,"(732, 7, 5, 4)","(4, 7, 5, 4)"
Dask graph,183 chunks in 5 graph layers,183 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [32]:
weather_data.plot_from_ds(seed=0, frame_rate=8, levels=10, variables=['geopotential', 'wind_speed', 'temperature', 'specific_humidity'])



[11.25, 33.75, -36.5625, -19.687499999999996]
Variable: geopotential
Variable: wind_speed
Variable: temperature
Variable: specific_humidity
Frame: 6/7