In [2]:
import xarray as xr

test = xr.open_dataset('era5_testing.nc')
test.load()

Cannot find the ecCodes library


In [28]:
import xarray as xr
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

import cartopy.crs as ccrs

from torch.utils.data import Dataset

from IPython.display import HTML

import torch

from typing import Tuple

class WeatherData(Dataset):
    def __init__(self, dataset: xr.Dataset, window_size: int = 24, steps: int = 3, 
                        auto: bool = True, coarsen: int = 1, use_forcings: bool = True,
                        lightning: bool = False, variable: str = 'wdpd') -> None:
        
        self.dataset = dataset
        self.window_size = window_size
        self.steps = steps
        self.variable = variable

        if self.variable == 'wspd':
            self.min_value = self.dataset.wind_speed.min().item()
            self.max_value = self.dataset.wind_speed.max().item()

            self.mean_value = self.dataset.wind_speed.mean().item()
            self.std_value = self.dataset.wind_speed.std().item()

        elif self.variable == 'temp':
            self.min_value = self.dataset.t.min().item()
            self.max_value = self.dataset.t.max().item()

            self.mean_value = self.dataset.t.mean().item()
            self.std_value = self.dataset.t.std().item()

        elif self.variable == 'all':
            self.min_value = self.dataset.min().item()
            self.max_value = self.dataset.max().item()

            self.mean_value = self.dataset.mean().item()
            self.std_value = self.dataset.std().item()

        if lightning:
            self.land_sea_mask = np.load('/teamspace/studios/this_studio/WeatherForecasting/data/land_sea_mask.npy')
        else:
            self.land_sea_mask = np.load('land_sea_mask.npy')
        

        self.use_forcings = use_forcings

        self.input_size = self.window_size * self.dataset.latitude.size * self.dataset.longitude.size
        self.forcing_size = 2  
        self.output_size = 1 * self.dataset.latitude.size * self.dataset.longitude.size 

        if auto:
            self.subset_data(coarsen)
                
            self.split_data()
            self.normalize_data()


    def __len__(self) -> int:
        return len(self.dataset.time) - self.window_size - self.steps + 1


    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
        M = torch.tensor(self.land_sea_mask).float()
        x = self.X_test_t[idx:idx + self.window_size]
        F = self.F_test_t[idx + self.window_size]
        y = self.X_test_t[idx + self.window_size:idx + self.window_size + self.steps] 
        return x, F, M, y


    def subset_data(self, coarsen: int = 1) -> None:
        lat_slice = slice(1, 49, coarsen)
        lon_slice = slice(2, 66, coarsen)
        self.land_sea_mask = self.land_sea_mask[1::coarsen, 2:66:coarsen]

        self.dataset = self.dataset.isel(latitude=lat_slice, longitude=lon_slice)


    def split_data(self, test_size: float = 0.1, val_size: float = 0.2, random_state: int = 42) -> None:
        
        if self.variable == 'wspd':
            data = self.dataset.wind_speed.values.squeeze()

        elif self.variable == 'temp':
            data = self.dataset.t.values.squeeze()
        else:
            base = self.dataset.to_array(dim="variable")
            base = base.transpose("time", "latitude", "longitude", "variable", "pressure_level")
            data = base.values.squeeze()
        forcings = np.stack([self.dataset.time.dt.hour.values, self.dataset.time.dt.month.values], axis=-1)
        time_values = self.dataset.time.values


        self.X_test, self.F_test, self.T_test = data, forcings, time_values


    def normalize_data(self, method: str = 'min_max') -> None:
        self.X_test_t = self.normalize(self.X_test, method)

        self.X_test_t = torch.tensor(self.X_test_t).float()

        self.F_test_t = torch.tensor(self.F_test).float()


    def normalize(self, data: np.ndarray, method: str = 'avg_std') -> np.ndarray:

        if method == 'min_max':
            return (data - self.min_value) / (self.max_value - self.min_value)
        else:
            return (data - self.mean_value) / self.std_value


    def unnormalize(self, data: np.ndarray, method: str = 'avg_std') -> np.ndarray:
            
            if method == 'min_max':
                return data * (self.max_value - self.min_value) + self.min_value
            else:
                return data * self.std_value + self.mean_value
            

    def plot_from_data(self, seed: int = 0, frame_rate: int = 16, levels: int = 10) -> HTML:
        
        bounds = [self.dataset.longitude.min().item(), self.dataset.longitude.max().item(), self.dataset.latitude.min().item(), self.dataset.latitude.max().item()]

        features = self.X_test[seed:seed + self.window_size]
        targets = self.X_test[seed + self.window_size:seed + self.window_size + self.steps]
        
        time_features = self.T_test[seed:seed + self.window_size]
        time_targets = self.T_test[seed + self.window_size:seed + self.window_size + self.steps]

        time_features = pd.to_datetime(time_features)
        time_targets = pd.to_datetime(time_targets)

        fig, axs = plt.subplots(1, 2, figsize=(21, 7), subplot_kw={'projection': ccrs.PlateCarree()})

        vmin = min(features.min().item(), targets.min().item())
        vmax = max(features.max().item(), targets.max().item())

        fig.subplots_adjust(left=0.05, right=0.95, bottom=0.1, top=0.9, wspace=0.2)

        print('Features:', features.shape, '\nTargets:', targets.shape)

        for ax in axs:
            ax.set_extent(bounds, crs=ccrs.PlateCarree())
            ax.coastlines()


        feat = axs[0].contourf(self.dataset.longitude, self.dataset.latitude, features[0], levels=levels, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())
        tar = axs[1].contourf(self.dataset.longitude, self.dataset.latitude, targets[0], levels=levels, vmin=vmin, vmax = vmax, transform=ccrs.PlateCarree())
        axs[1].set_title('Target')

        fig.colorbar(feat, ax=axs[0], orientation='vertical', label='Wind Speed (m/s)')
        fig.colorbar(tar, ax=axs[1], orientation='vertical', label='Wind Speed (m/s)')

        def animate(i):
            axs[0].clear()
            axs[0].coastlines()

            axs[0].contourf(self.dataset.longitude, self.dataset.latitude, features[i], levels=levels, vmin=vmin, vmax = vmax)

            axs[0].set_title(f'Window {i} - {time_features[i].strftime("%Y-%m-%d %H:%M:%S")}')
            if self.steps > 1:
                axs[1].contourf(self.dataset.longitude, self.dataset.latitude, targets[i % self.steps], levels=levels, vmin=vmin, vmax = vmax)
                axs[1].set_title(f'Target - {time_targets[i % self.steps].strftime("%Y-%m-%d %H:%M:%S")}')
            # return pcm

            
        frames = features.shape[0]

        interval = 1000 / frame_rate

        ani = FuncAnimation(fig, animate, frames=frames, interval=interval)

        plt.close(fig)

        return HTML(ani.to_jshtml())


    def test_class(self) -> None:
        print('self.X_test:', self.X_test.shape)
        print('self.F_test:', self.F_test.shape)

        print('self.X_test_t:', self.X_test_t.shape)
        print('self.F_test_t:', self.F_test_t.shape)

        print('self.input_size:', self.input_size, 'self.forcing_size:', self.forcing_size, 'self.output_size:', self.output_size)

   

In [70]:
weather = WeatherData(test, window_size=5, steps=7, auto=True, coarsen=1, use_forcings=True, lightning=False, variable='wspd')
M = torch.tensor(weather.land_sea_mask).float() / 2
M = M.expand(730, 1, 48, 64)

weather.test_class()

self.X_test: (1464, 48, 64)
self.F_test: (1464, 2)
self.X_test_t: torch.Size([1464, 48, 64])
self.F_test_t: torch.Size([1464, 2])
self.input_size: 16905 self.forcing_size: 2 self.output_size: 3381


In [67]:
import torch
import torch.nn as nn

class batchnorm_relu(nn.Module):
    def __init__(self, in_c):
        super().__init__()

        self.bn = nn.BatchNorm2d(in_c)
        self.relu = nn.ReLU()

    def forward(self, inputs):
        x = self.bn(inputs)
        x = self.relu(x)
        return x
    
class residual_block(nn.Module):
    def __init__(self, in_c, out_c, stride=1):
        super().__init__()

        """ Convolutional layer """
        self.b1 = batchnorm_relu(in_c)
        self.c1 = nn.Conv2d(in_c, out_c, kernel_size=3, padding=1, stride=stride)
        self.b2 = batchnorm_relu(out_c)
        self.c2 = nn.Conv2d(out_c, out_c, kernel_size=3, padding=1, stride=1)

        """ Shortcut Connection (Identity Mapping) """
        self.s = nn.Conv2d(in_c, out_c, kernel_size=1, padding=0, stride=stride)

    def forward(self, inputs):
        x = self.b1(inputs)
        x = self.c1(x)
        x = self.b2(x)
        x = self.c2(x)
        s = self.s(inputs)

        skip = x + s
        return skip
    
class decoder_block(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        self.upsample = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
        self.r = residual_block(in_c+out_c, out_c)

    def forward(self, inputs, skip):
        x = self.upsample(inputs)
        x = torch.cat([x, skip], axis=1)
        x = self.r(x)
        return x
 
class build_res_unet_time(nn.Module):
    def __init__(self, in_c, out_c):
        super().__init__()

        ''' Encoder 1 '''
        self.c11 = nn.Conv2d(in_c, 64, kernel_size=3, padding=1)  # in_c + 2 to account for time inputs
        self.bn = nn.GroupNorm(64, 64)
        self.c12 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.c13 = nn.Conv2d(in_c, 64, kernel_size=1, padding=0)  # Shortcut feature

        """ Encoder 2 and 3 """
        self.r2 = residual_block(64, 128, stride=2)
        self.r3 = residual_block(128, 256, stride=2)

        """ Bridge """
        self.r4 = residual_block(256, 512, stride=2)

        """ Decoder """
        self.d1 = decoder_block(512, 256)
        self.d2 = decoder_block(256, 128)
        self.d3 = decoder_block(128, 64)

        """ Output """
        self.output = nn.Conv2d(64, out_c, kernel_size=1, padding=0)
        self.sigmoid = nn.Sigmoid()
        
    def step(self, x, M):

        M = M / torch.tensor(2) 

        combined_input = torch.cat([x, M], dim=1) #  F, , F_embedded , time_indices

        # print('combined_input shape:', combined_input.shape)

        """ Encoder 1 """
        out = self.c11(combined_input)
        out = self.bn(out)
        out = self.c12(out)
        s = self.c13(combined_input)
        skip1 = out + s
        # print('skip1 shape:', skip1.shape)

        """ Encoder 2 and 3 """
        skip2 = self.r2(skip1)
        # print('skip2 shape:', skip2.shape)
        skip3 = self.r3(skip2)
        # print('skip3 shape:', skip3.shape)

        """ Bridge """
        b = self.r4(skip3)
        # print('b shape:', b.shape)

        """ Decoder """
        d1 = self.d1(b, skip3)
        # print('d1 shape:', d1.shape)
        d2 = self.d2(d1, skip2)
        # print('d2 shape:', d2.shape)
        d3 = self.d3(d2, skip1)
        # print('d3 shape:', d3.shape)

        """ Output """
        output = self.output(d3)
        output = self.sigmoid(output)

        return output

    def rollout(self, x, M, steps):

        y_hats = torch.empty(x.size(0), steps, x.size(2), x.size(3), device=x.device)

        current_X = x
        for i in range(steps):
            y_hat = self.step(current_X, M)
            y_hats[:, i] = y_hat.squeeze(1)
            current_X = torch.cat((current_X[:, 1:], y_hat), dim=1)

        return y_hats

    def forward(self, x, M, steps):
        return self.rollout(x, M, steps)

input_size = 6
output_size = 1
    
model = build_res_unet_time(input_size,  output_size)


In [55]:
window_size = 5
inputs = []
for i in range(0, 1464, 2):
    print('Seed:', i)
    x = weather.X_test_t[i:i + window_size].unsqueeze(0)
    inputs.append(x)

Seed: 0
Seed: 2
Seed: 4
Seed: 6
Seed: 8
Seed: 10
Seed: 12
Seed: 14
Seed: 16
Seed: 18
Seed: 20
Seed: 22
Seed: 24
Seed: 26
Seed: 28
Seed: 30
Seed: 32
Seed: 34
Seed: 36
Seed: 38
Seed: 40
Seed: 42
Seed: 44
Seed: 46
Seed: 48
Seed: 50
Seed: 52
Seed: 54
Seed: 56
Seed: 58
Seed: 60
Seed: 62
Seed: 64
Seed: 66
Seed: 68
Seed: 70
Seed: 72
Seed: 74
Seed: 76
Seed: 78
Seed: 80
Seed: 82
Seed: 84
Seed: 86
Seed: 88
Seed: 90
Seed: 92
Seed: 94
Seed: 96
Seed: 98
Seed: 100
Seed: 102
Seed: 104
Seed: 106
Seed: 108
Seed: 110
Seed: 112
Seed: 114
Seed: 116
Seed: 118
Seed: 120
Seed: 122
Seed: 124
Seed: 126
Seed: 128
Seed: 130
Seed: 132
Seed: 134
Seed: 136
Seed: 138
Seed: 140
Seed: 142
Seed: 144
Seed: 146
Seed: 148
Seed: 150
Seed: 152
Seed: 154
Seed: 156
Seed: 158
Seed: 160
Seed: 162
Seed: 164
Seed: 166
Seed: 168
Seed: 170
Seed: 172
Seed: 174
Seed: 176
Seed: 178
Seed: 180
Seed: 182
Seed: 184
Seed: 186
Seed: 188
Seed: 190
Seed: 192
Seed: 194
Seed: 196
Seed: 198
Seed: 200
Seed: 202
Seed: 204
Seed: 206
Seed: 208
Seed:

In [58]:
inputs_ = torch.cat(inputs[:730], dim=0)


In [73]:
pred = model(inputs_, M, 6)

KeyboardInterrupt: 

torch.Size([730, 1, 48, 64])