In [32]:
import xarray as xr
import numpy as np

year = '2023'

ds = xr.open_dataset(f'data/{year}_850_SA.nc')
ds.load()

data = ds.to_array().values
times = ds.time.values

print(data.shape)

np.save(f'datasets/{year}_850_SA.npy', data)
np.save(f'datasets/{year}_850_SA_times.npy', times)

(5, 8760, 34, 71)


In [1]:
import numpy as np

import pandas as pd

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

from IPython.display import HTML

import cartopy.crs as ccrs

import torch
from torch.utils.data import Dataset, DataLoader

from typing import Tuple

# Add settings to choose variables and spatial extent

class WeatherData(Dataset):
    def __init__(self, 
                 window_size: int = 24, 
                 step_size: int = 12, 
                 set:str = 'train', 
                 area: Tuple[int, int] = (-31.667, 18.239), 
                 spaces: int = 0, 
                 intervals: int = 1,
                 verbose: bool = False):
        '''

        Data format:
            0 - humidity
            1 - temperature
            2 - u-wind
            3 - v-wind
            4 - w-wind

        '''
        print(1)

        # Extract correct dataset
        if set == 'train':
            years = ['2018', '2019', '2020', '2021']
        elif set == 'val':
            years = ['2022']
        elif set == 'test':
            years = ['2023']

        self.data = np.concatenate([np.load(f'datasets/{year}_850_SA.npy') for year in years], axis=1)
        self.data = self.data.transpose(1, 2, 3, 0)

        self.times = np.concatenate([np.load(f'datasets/{year}_850_SA_times.npy') for year in years])

        print(2)

        # Get lat and long

        self.lon = np.load('datasets\SA_lon.npy')
        self.lat = np.load('datasets\SA_lat.npy')

        self.spaces = spaces

        self.get_area(area)

        print(3)

        # Normalize data and sort into variables

        if spaces != 0:
            q = self.data[:, :, :, 0]
            t = self.data[:, :, :, 1]
            u = self.data[:, :, :, 2]
            v = self.data[:, :, :, 3]
            w = self.data[:, :, :, 4]
        else:
            q = self.data[:,0]
            t = self.data[:,1]
            u = self.data[:,2]
            v = self.data[:,3]
            w = self.data[:,4]

        q, t, u, v, w = self.normalize(q, t, u, v, w)

        print(4)

        # Calculate wind speed and direction
        self.calculate_wind(u, v)

        print(5)

        # Serup the dataloader
        self.features = torch.tensor(np.stack([q, t, u, v, w, self.wspd], axis=-1), dtype=torch.float32)
        self.targets = torch.tensor(self.wspd, dtype=torch.float32)
        self.window_size = window_size
        self.step_size = step_size

        print(6)

        self.intervals = intervals

        if verbose:
            print(f'Features shape: {self.features.shape}')
            print(f'Targets shape: {self.targets.shape}')

            print(f'Longitudes: {self.lon}')
            print(f'Latitudes: {self.lat}')

    def __len__(self):
        return self.data.shape[0] - self.window_size - self.step_size + 1

    def __getitem__(self, idx):
        return self.features[idx : idx + self.window_size], self.wspd[idx + self.window_size : idx + self.window_size + self.step_size]
    
    def normalize(self, q, t, u, v, w, method = 'std'): 
        if method == 'std':
            q = (q - q.mean()) / q.std()
            t = (t - t.mean()) / t.std()
            u = (u - u.mean()) / u.std()
            v = (v - v.mean()) / v.std()
            w = (w - w.mean()) / w.std()

        return q, t, u, v, w        
    
    def calculate_wind(self, u, v):

        self.wspd = np.sqrt(u**2 + v**2)
        self.wdir = np.arctan2(u, v)

    def get_area(self, area: Tuple[int, int]):
        lon = np.argmin(np.abs(self.lon - area[1]))

        lat = np.argmin(np.abs(self.lat - area[0]))

        if self.spaces != 0:

            self.lon = self.lon[lon - self.spaces:lon + self.spaces]
            self.lat = self.lat[lat - self.spaces: lat + self.spaces]

            self.data = self.data[:, lat - self.spaces: lat + self.spaces, lon - self.spaces:lon + self.spaces, :]
        else:

            self.lon = self.lon[lon]
            self.lat = self.lat[lat]

            self.data = self.data[:, lat, lon, :]

    def plot_area(self):
        if self.spaces != 0:
            fig, ax = plt.subplots(subplot_kw={'projection': ccrs.PlateCarree()})

            ax.coastlines()

            ax.set_extent([self.lon.min(), self.lon.max(), self.lat.min(), self.lat.max()])

            lon, lat = self.lon, self.lat
            contour = ax.contourf(lon, lat, self.targets[0].detach().numpy(), transform=ccrs.PlateCarree())

            fig.colorbar(contour, ax=ax, orientation='vertical', label='Wind Speed (m/s)')

            plt.show()
        else:
            print('Cannot plot area with only one point')

    def plot_animation(self, seed: int = 0, frame_rate: int = 16, levels: int = 10) -> HTML:
        """
        Plots features and targets from the windowed arrays for visualization.

        Args:
            seed (int): Seed for reproducibility in selecting samples. Default is 0.
            frame_rate (int): The frame rate for the animation. Default is 16.
            levels (int): Number of contour levels for the plot. Default is 10.

        Returns:
            HTML: An HTML object representing the animation.
        """
        if self.spaces != 0:
            bounds = [self.lon.min(), self.lon.max(), self.lat.min(), self.lat.max()]

            features = self.features[seed:seed + self.window_size * self.intervals:self.intervals]
            targets = self.targets[seed + self.window_size * self.intervals:seed + (self.window_size + self.step_size) * self.intervals: self.intervals]
            
            time_features = self.times[seed:seed + self.window_size * self.intervals:self.intervals]
            time_targets = self.times[seed + self.window_size * self.intervals:seed + (self.window_size + self.step_size) * self.intervals: self.intervals]

            time_features = pd.to_datetime(time_features)
            time_targets = pd.to_datetime(time_targets)

            fig, axs = plt.subplots(1, 2, figsize=(21, 7), subplot_kw={'projection': ccrs.PlateCarree()})

            vmin = min(features.min().item(), targets.min().item())
            vmax = max(features.max().item(), targets.max().item())

            fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.9, wspace=0.2)

            for ax in axs:
                ax.set_extent(bounds, crs=ccrs.PlateCarree())
                ax.coastlines()

            feat = axs[0].contourf(self.lon, self.lat, features[0, :, :, 2], levels=levels, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())
            tar = axs[1].contourf(self.lon, self.lat, targets[0], levels=levels, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())
            axs[1].set_title('Target')

            fig.colorbar(feat, ax=axs[0], orientation='vertical', label='Wind Speed (m/s)')
            fig.colorbar(tar, ax=axs[1], orientation='vertical', label='Wind Speed (m/s)')

            def animate(i):
                axs[0].clear()
                axs[0].coastlines()

                axs[0].contourf(self.lon, self.lat, features[i, :, :, 2], levels=levels, vmin=vmin, vmax = vmax)

                axs[0].set_title(f'Window {i} - {time_features[i].strftime("%Y-%m-%d %H:%M:%S")}')
                if self.step_size > 1:
                    axs[1].contourf(self.lon, self.lat, targets[i % self.step_size], levels=levels, vmin=vmin, vmax = vmax)
                    axs[1].set_title(f'Target - {time_targets[i % self.step_size].strftime("%Y-%m-%d %H:%M:%S")}')
                # return pcm

                
            frames = features.shape[0]

            interval = 1000 / frame_rate

            ani = FuncAnimation(fig, animate, frames=frames, interval=interval)

            plt.close(fig)

            return HTML(ani.to_jshtml())
        else:
            print('Cannot plot area with only one point')

    def plot_point(self, seed: int = 0):
        if self.spaces == 0:
            plt.figure(figsize=(10, 5))

            plt.plot(self.times[seed:seed + self.window_size], self.features[seed:seed + self.window_size, 0], label='Humidity')
            plt.plot(self.times[seed:seed + self.window_size], self.features[seed:seed + self.window_size, 1], label='Temperature')
            plt.plot(self.times[seed:seed + self.window_size], self.features[seed:seed + self.window_size, 2], label='U-Wind')
            plt.plot(self.times[seed:seed + self.window_size], self.features[seed:seed + self.window_size, 3], label='V-Wind')
            plt.plot(self.times[seed:seed + self.window_size], self.features[seed:seed + self.window_size, 4], label='W-Wind')

            plt.plot(self.times[seed + self.window_size:seed + self.window_size + self.step_size], self.targets[seed + self.window_size:seed + self.window_size + self.step_size], label='Wind Speed', linestyle='--')
            
            plt.xticks(rotation=45)
            plt.legend()
            plt.show()
        else:
            print('Cannot plot point with multiple points')           

In [154]:
test_loader = DataLoader(WeatherData(set='test', spaces=4, verbose=True),
                            batch_size=32,
                            shuffle=False)


Features shape: torch.Size([8760, 8, 8, 5])
Targets shape: torch.Size([8760, 8, 8])
Longitudes: [17.3061   17.556114 17.80613  18.056143 18.306158 18.556171 18.806185
 19.0562  ]
Latitudes: [-30.637 -30.887 -31.137 -31.387 -31.637 -31.887 -32.137 -32.387]


In [155]:

for x, y in test_loader:
    print(x.shape, y.shape)
    break


torch.Size([32, 24, 8, 8, 5]) torch.Size([32, 12, 8, 8])


In [32]:
test_set = WeatherData(window_size=48, step_size=24, set='test', spaces=10, verbose=True)

test_set.plot_animation()

Features shape: torch.Size([8760, 20, 20, 5])
Targets shape: torch.Size([8760, 20, 20])
Longitudes: [15.806014 16.056028 16.306044 16.556057 16.806072 17.056086 17.3061
 17.556114 17.80613  18.056143 18.306158 18.556171 18.806185 19.0562
 19.306213 19.556229 19.806242 20.056257 20.30627  20.556286]
Latitudes: [-29.137 -29.387 -29.637 -29.887 -30.137 -30.387 -30.637 -30.887 -31.137
 -31.387 -31.637 -31.887 -32.137 -32.387 -32.637 -32.887 -33.137 -33.387
 -33.637 -33.887]
