In [1]:
import numpy as np
import xarray as xr
import pandas as pd
# Basic imports
import xarray as xr
import matplotlib
import matplotlib.pyplot as plt
import numpy as np

# Plotting and animations
import cartopy 
import cartopy.crs as ccrs
import math
import ipywidgets as widgets
import datetime
from IPython.display import HTML
from matplotlib import animation
from typing import Optional

# Data preperation
import dataclasses

# Define dimensions


In [2]:
import numpy as np
import pandas as pd
import xarray as xr

# Define the time range (last month)
start_date = '2024-04-01'
end_date = '2024-04-30'
time_index = pd.date_range(start=start_date, end=end_date, freq='h')

# Define the latitude and longitude grid covering South Africa
# Adjust the grid resolution as needed
latitudes = np.linspace(-35, -23, 16)
longitudes = np.linspace(16, 33, 48)

# Create a function for synthetic wind speed components (u10 and v10) with noise
def generate_wind_components(t, lat, lon):
    # Parameters for u10 and v10 sinusoidal patterns
    amplitude_u = 7.0  # m/s
    amplitude_v = 6.0  # m/s (different amplitude)
    wavelength_u = 20  # Degrees of latitude
    wavelength_v = 15  # Degrees of latitude (slightly different wavelength)
    phase_shift_u = 2 * np.pi * t.hour / 24  # Phase shift for u10
    phase_shift_v = 2 * np.pi * t.hour / 30 + np.pi/4  # Phase shift for v10 (slightly different phase)

    # Calculate wind speed components using sinusoidal wave equations with noise
    u10 = amplitude_u * np.sin(2 * np.pi * (lat - latitudes.min()) / wavelength_u + phase_shift_u) + np.random.normal(scale=0.5)
    v10 = amplitude_v * np.cos(2 * np.pi * (lon - longitudes.min()) / wavelength_v + phase_shift_v) + np.random.normal(scale=0.5)
    
    return u10, v10

# Generate synthetic wind speed components data based on the function
u10_data = np.empty((len(time_index), len(latitudes), len(longitudes)))
v10_data = np.empty((len(time_index), len(latitudes), len(longitudes)))
for i, t in enumerate(time_index):
    for j, lat in enumerate(latitudes):
        for k, lon in enumerate(longitudes):
            u10_data[i, j, k], v10_data[i, j, k] = generate_wind_components(t, lat, lon)

# Create xarray DataArray for u10 and v10 wind speed components
u10_xr = xr.DataArray(u10_data,
                      dims=('time', 'latitude', 'longitude'),
                      coords={'time': time_index,
                              'latitude': latitudes,
                              'longitude': longitudes},
                      name='u10',
                      attrs={'units': 'm/s'})

v10_xr = xr.DataArray(v10_data,
                      dims=('time', 'latitude', 'longitude'),
                      coords={'time': time_index,
                              'latitude': latitudes,
                              'longitude': longitudes},
                      name='v10',
                      attrs={'units': 'm/s'})

# Create xarray Dataset
ds = xr.Dataset({'u10': u10_xr, 'v10': v10_xr})

# Save the Dataset to a NetCDF file
ds.to_netcdf('WindComponents_16_48.nc')


In [3]:
def select(
    data: xr.Dataset,
    variable: str,
    level: Optional[int] = None,
    max_steps: Optional[int] = None
    ) -> xr.Dataset:
  data = data[variable]
  # if "batch" in data.dims:
  #   data = data.isel(batch=0)
  if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
    data = data.isel(time=range(0, max_steps))
  # if level is not None and "level" in data.coords:
  #   data = data.sel(level=level)
  return data

def scale(
    data: xr.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
    ) -> tuple[xr.Dataset, matplotlib.colors.Normalize, str]:
  vmin = np.nanpercentile(data, (2 if robust else 0))
  vmax = np.nanpercentile(data, (98 if robust else 100))
  if center is not None:
    diff = max(vmax - center, center - vmin)
    vmin = center - diff
    vmax = center + diff
  return (data, matplotlib.colors.Normalize(vmin, vmax),
          ("RdBu_r" if center is not None else "viridis"))

def plot_data(
    data: dict[str, xr.Dataset],
    fig_title: str,
    plot_size: float = 5,
    robust: bool = False,
    cols: int = 4
    ) -> tuple[xr.Dataset, matplotlib.colors.Normalize, str]:

  first_data = next(iter(data.values()))[0]
  max_steps = first_data.sizes.get("time", 1)
  assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())

  cols = min(cols, len(data))
  rows = math.ceil(len(data) / cols)

  figure = plt.figure(figsize=(plot_size * 2 * cols,
                               plot_size * rows))
  figure.suptitle(fig_title, fontsize=16)
  figure.subplots_adjust(wspace=0, hspace=0)
  figure.tight_layout()
  
  images = []
  for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
    ax = figure.add_subplot(rows, cols, i+1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)
    im = ax.imshow(
        plot_data.isel(time=0, missing_dims="ignore"), norm=norm,
        origin="lower", cmap=cmap)
    plt.colorbar(
        mappable=im,
        ax=ax,
        orientation="vertical",
        pad=0.02,
        aspect=16,
        shrink=0.75,
        cmap=cmap,
        extend=("both" if robust else "neither"))
    images.append(im)

  def update(frame):
      reference_date = datetime.datetime(1970, 1, 1)
      if "time" in first_data.dims:
        td = datetime.timedelta(microseconds=first_data["time"][frame].item() / 1000)
        td = reference_date + td
        figure.suptitle(f"{fig_title}, {td.strftime('%Y-%m-%d %H:%M:%S')}", fontsize=16)
      else:
        figure.suptitle(fig_title, fontsize=16)
      for im, (plot_data, norm, cmap) in zip(images, data.values()):
        im.set_data(plot_data.isel(time=frame, missing_dims="ignore"))

  ani = animation.FuncAnimation(
      fig=figure, func=update, frames=max_steps, interval=250)
  plt.close(figure.number)
  return HTML(ani.to_jshtml())

In [6]:
plot_example_variable = widgets.Dropdown(
    options=ds.data_vars.keys(),
    value="u10",
    description="Variable")
# plot_example_level = widgets.Dropdown(
#     options=ds.coords["level"].values,
#     value=1000,
#     description="Level")
plot_example_robust = widgets.Checkbox(value=True, description="Robust")
plot_example_max_steps = widgets.IntSlider(
    min=1, max=ds.dims["time"], value=6,
    description="Max steps")

widgets.VBox([
    plot_example_variable,
    # plot_example_level,
    plot_example_robust,
    plot_example_max_steps,
    widgets.Label(value="Run the next cell to plot the data. Rerunning this cell clears your selection.")
])

  min=1, max=ds.dims["time"], value=6,


VBox(children=(Dropdown(description='Variable', options=('u10', 'v10'), value='u10'), Checkbox(value=True, des…

In [7]:
def select(
    data: xr.Dataset,
    variable: str,
    level: Optional[int] = None,
    max_steps: Optional[int] = None
    ) -> xr.Dataset:
  data = data[variable]
  # if "batch" in data.dims:
  #   data = data.isel(batch=0)
  if max_steps is not None and "time" in data.sizes and max_steps < data.sizes["time"]:
    data = data.isel(time=range(0, max_steps))
  # if level is not None and "level" in data.coords:
  #   data = data.sel(level=level)
  return data

def scale(
    data: xr.Dataset,
    center: Optional[float] = None,
    robust: bool = False,
    ) -> tuple[xr.Dataset, matplotlib.colors.Normalize, str]:
  vmin = np.nanpercentile(data, (2 if robust else 0))
  vmax = np.nanpercentile(data, (98 if robust else 100))
  if center is not None:
    diff = max(vmax - center, center - vmin)
    vmin = center - diff
    vmax = center + diff
  return (data, matplotlib.colors.Normalize(vmin, vmax),
          ("RdBu_r" if center is not None else "viridis"))

def plot_data(
    data: dict[str, xr.Dataset],
    fig_title: str,
    plot_size: float = 5,
    robust: bool = False,
    cols: int = 4
    ) -> tuple[xr.Dataset, matplotlib.colors.Normalize, str]:

  first_data = next(iter(data.values()))[0]
  max_steps = first_data.sizes.get("time", 1)
  assert all(max_steps == d.sizes.get("time", 1) for d, _, _ in data.values())

  cols = min(cols, len(data))
  rows = math.ceil(len(data) / cols)

  figure = plt.figure(figsize=(plot_size * 2 * cols,
                               plot_size * rows))
  figure.suptitle(fig_title, fontsize=16)
  figure.subplots_adjust(wspace=0, hspace=0)
  figure.tight_layout()
  
  images = []
  for i, (title, (plot_data, norm, cmap)) in enumerate(data.items()):
    ax = figure.add_subplot(rows, cols, i+1)
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(title)
    im = ax.imshow(
        plot_data.isel(time=0, missing_dims="ignore"), norm=norm,
        origin="lower", cmap=cmap)
    plt.colorbar(
        mappable=im,
        ax=ax,
        orientation="vertical",
        pad=0.02,
        aspect=16,
        shrink=0.75,
        cmap=cmap,
        extend=("both" if robust else "neither"))
    images.append(im)

  def update(frame):
      reference_date = datetime.datetime(1970, 1, 1)
      if "time" in first_data.dims:
        td = datetime.timedelta(microseconds=first_data["time"][frame].item() / 1000)
        td = reference_date + td
        figure.suptitle(f"{fig_title}, {td.strftime('%Y-%m-%d %H:%M:%S')}", fontsize=16)
      else:
        figure.suptitle(fig_title, fontsize=16)
      for im, (plot_data, norm, cmap) in zip(images, data.values()):
        im.set_data(plot_data.isel(time=frame, missing_dims="ignore"))

  ani = animation.FuncAnimation(
      fig=figure, func=update, frames=max_steps, interval=250)
  plt.close(figure.number)
  return HTML(ani.to_jshtml())

plot_size = 7
level = 1000

data = {
    " ": scale(select(ds, plot_example_variable.value, level, plot_example_max_steps.value),
              robust=plot_example_robust.value),
}

fig_title = plot_example_variable.value

plot_data(data, fig_title, plot_size, plot_example_robust.value)

In [144]:
# ds.to_netcdf('SyntheticData.nc')