# Imports and Data Open

In [1]:
import xarray as xr
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from matplotlib.animation import FuncAnimation, PillowWriter
from IPython.display import HTML

In [2]:
ds = xr.open_dataset('data_850/2022_850_SA.nc')
ds.load()

Cannot find the ecCodes library


# Class

In [37]:
class WeatherData:
    def __init__(self, dataset: xr.Dataset, window_size: int = 24, steps = 3):
        self.dataset = dataset
        self.window_size = window_size
        self.steps = steps
        self.calculate_wind_speed()
        self.dataset = self.dataset.sortby('latitude')
    
    def subset_data(self):
        lat_slice = slice(1, 33)  
        lon_slice = slice(3, 67)  

        self.dataset = self.dataset.isel(latitude=lat_slice, longitude=lon_slice)

    def calculate_wind_speed(self):
        self.dataset['wspd'] = np.sqrt(self.dataset.u**2 + self.dataset.v**2)
        self.dataset.attrs['wspd_units'] = 'm/s'
        self.dataset['wdir'] = np.arctan2(self.dataset.v, self.dataset.u) * 180 / np.pi
        self.dataset.attrs['wdir_units'] = 'degrees'

    def window_dataset(self, variable: str = 'wspd'):
        features = []
        targets = []
        forcings = []
        time_values = []

        time_dim = self.dataset.sizes['time']
        total_windows = time_dim - self.window_size - self.steps

        for i in range(total_windows):
            print(f'{i}/{total_windows}', end='\r')
            features.append(self.dataset[variable].isel(time=slice(i, i + self.window_size)))
            targets.append(self.dataset[variable].isel(time=slice(i + self.window_size, i + self.window_size +  self.steps)))        
            time_values.append(self.dataset.time.isel(time=slice(i, i + self.window_size)).values)

            # Forcings with hour and month values
            forcings.append([self.dataset.time.isel(time=i + self.window_size).dt.hour.values, self.dataset.time.isel(time=i + self.window_size).dt.month.values])

        self.features = np.stack(features)
        self.targets = np.stack(targets)
        self.forcings = np.array(forcings)
        self.time_values = time_values

        return self.features, self.targets, self.forcings, self.time_values

    def slice_dataset(self, end_time):
        start_time = pd.to_datetime(end_time) - pd.Timedelta(hours=self.window_size)
        return self.dataset.sel(time=slice(start_time, end_time))
    
    def weather_gifs(self, ds_, ds_f = None, feature='wspd', metric='m/s', levels=20, frames=0, frame_rate=16):

        if ds_f is None:
            vmin = ds_[feature].min().item()
            vmax = ds_[feature].max().item()
        
            fig, axs = plt.subplots(figsize=(10, 6), subplot_kw={'projection': ccrs.PlateCarree()})

            contour = ds_[feature].isel(time=0).plot.contourf(ax=axs, levels=levels, vmin=vmin, vmax=vmax, add_colorbar=False)
            colorbar = plt.colorbar(contour, ax=axs)
            
        else:

            vmax = max(ds_[feature].max().values, ds_f[feature].max().values)
            vmin = min(ds_[feature].min().values, ds_f[feature].min().values)
            
            fig, axs = plt.subplots(1, 3, figsize=(21, 7), subplot_kw={'projection': ccrs.PlateCarree()})

            fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.9, wspace=0.2)

            contour_f = ds_f[feature].isel(time=0).plot.contourf(ax=axs[0], levels=levels, vmin=vmin, vmax=vmax, add_colorbar=False)
            colorbar = plt.colorbar(contour_f, ax=axs[0], shrink=0.5, aspect=10)

            error = ds_[feature].values - ds_f[feature].values

            ds_error = xr.Dataset({
                feature: (('time', 'latitude', 'longitude'), error),
                'latitude': ds_.latitude,
                'longitude': ds_.longitude,
                'time': ds_.time})
            
            vmax_e = ds_error[feature].max().values
            vmin_e = ds_error[feature].min().values
            
            contour_e = ds_error[feature].isel(time=0).plot.contourf(ax=axs[2], levels=levels, vmin=vmin_e, vmax=vmax_e, add_colorbar=False, cmap='coolwarm')
            colorbar = plt.colorbar(contour_e, ax=axs[2], shrink=0.5, aspect=10)

            contour_a = ds_[feature].isel(time=0).plot.contourf(ax=axs[1], levels=levels, vmin=vmin, vmax=vmax, add_colorbar=False)
            colorbar = plt.colorbar(contour_a, ax=axs[1], shrink=0.5, aspect=10)

        def animate(i):
            if ds_f is None:
                axs.clear()  
                axs.coastlines()  
                contour = ds_[feature].isel(time=i).plot.contourf(ax=axs, levels=levels, vmin=vmin, vmax=vmax, add_colorbar=False)

                str_time = ds_.time.isel(time=i).values
                str_time = pd.to_datetime(str_time)

                axs.set_title(f'Observed {feature} {metric} at {str_time.strftime("%Y-%m-%d %H:%M:%S")} UTC')

            else:
                for ax in axs:
                    ax.clear()
                    ax.coastlines()

                contour_a = ds_[feature].isel(time=i).plot.contourf(ax=axs[1], levels=levels, vmin=vmin, vmax=vmax, add_colorbar=False)
                axs[1].set_title(f'Analysis ({feature}) {metric}')

                contour_f = ds_f[feature].isel(time=i).plot.contourf(ax=axs[0], levels=levels, vmin=vmin, vmax=vmax, add_colorbar=False)
                axs[0].set_title(f'Forecast ({feature}) {metric}')

                contour_e = ds_error[feature].isel(time=i).plot.contourf(ax=axs[2], levels=levels, vmin=vmin_e, vmax=vmax_e, add_colorbar=False, cmap='coolwarm')
                axs[2].set_title(f'Error ({feature}) {metric}')

        if frames == 0:
            frames = ds_.time.size

        interval = 1000 / frame_rate

        ani = FuncAnimation(fig, animate, frames=frames, interval=interval)

        plt.close(fig)

        return HTML(ani.to_jshtml())

    def plot_window_target(self, seed = 0, frame_rate=16, bounds = [ds.longitude.min().item(), ds.longitude.max().item(), ds.latitude.min().item(), ds.latitude.max().item()]):
        features = self.features[seed]
        targets = self.targets[seed]
        time_values = self.time_values

        fig, axs = plt.subplots(1, 2, figsize=(21, 7), subplot_kw={'projection': ccrs.PlateCarree()})

        vmin = min(features.min().item(), targets.min().item())
        vmax = max(features.max().item(), targets.max().item())

        fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.9, wspace=0.2)

        for ax in axs:
            ax.set_extent(bounds, crs=ccrs.PlateCarree())
            ax.coastlines()


        feat = axs[0].contourf(self.dataset.longitude, self.dataset.latitude, features[0], levels=20, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())
        tar = axs[1].contourf(self.dataset.longitude, self.dataset.latitude, targets[0], levels=20, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())
        axs[1].set_title('Target')

        fig.colorbar(feat, ax=axs[0], orientation='vertical', label='Wind Speed (m/s)')
        fig.colorbar(tar, ax=axs[1], orientation='vertical', label='Wind Speed (m/s)')

        def animate(i):
            axs[0].clear()
            axs[0].coastlines()

            pcm = axs[0].contourf(self.dataset.longitude, self.dataset.latitude, features[i], levels=20, vmin=vmin, vmax = vmax)
            

            start_time = time_values[i][0]
            end_time = time_values[i][-1]

            start_time = pd.to_datetime(start_time)
            end_time = pd.to_datetime(end_time)

            axs[0].set_title(f'Window {i} - {start_time.strftime("%Y-%m-%d %H:%M:%S")} to {end_time.strftime("%Y-%m-%d %H:%M:%S")}')
            if self.steps > 1:
                ptm = axs[1].contourf(self.dataset.longitude, self.dataset.latitude, targets[i % self.steps], levels=20, vmin=vmin, vmax = vmax)
                axs[1].set_title(f'Target - {end_time.strftime("%Y-%m-%d %H:%M:%S")}')
            return pcm

            
        frames = features.shape[0]

        interval = 1000 / frame_rate

        ani = FuncAnimation(fig, animate, frames=frames, interval=interval)

        plt.close(fig)

        return HTML(ani.to_jshtml())
        


# Class Use

In [38]:
import gc

gc.collect()

10664

In [39]:
weather_data = WeatherData(ds, window_size=24, steps=24)

weather_data.subset_data()
features, targets, forcings, time_values = weather_data.window_dataset()



8711/8712

In [40]:
weather_data.plot_window_target(seed=0, frame_rate=16)

In [41]:
step = 3

for i in range(0, 24):
    print(f'Window {i % step}')

Window 0
Window 1
Window 2
Window 0
Window 1
Window 2
Window 0
Window 1
Window 2
Window 0
Window 1
Window 2
Window 0
Window 1
Window 2
Window 0
Window 1
Window 2
Window 0
Window 1
Window 2
Window 0
Window 1
Window 2
