In [70]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import dask.array as da
import torch 
import os 
from glob import glob
from utils import get_train_test_data_without_scales_batched, get_train_test_data_without_scales_batched_monthly

# Dataset handle

In [71]:
data = xr.open_mfdataset('data/geopotential_500_5.625deg/*.nc', combine='by_coords')
data = data.resample(time="6H").nearest(
    tolerance="1H")  # Setting data to be 6-hour cycles

  self.index_grouper = pd.Grouper(
  flat_indexer = index.get_indexer(flat_labels, method=method, tolerance=tolerance)


## get_train_test_data_without_scales_batched

In [72]:
train_time_scale = slice('2006', '2016')
val_time_scale = slice('2016', '2016')
test_time_scale = slice('2017', '2018')

In [73]:
data_train = data.sel(time=train_time_scale).load()

# 2016
data_val = data.sel(time=val_time_scale).load()

# 2017 - 2018
data_test = data.sel(time=test_time_scale).load()

In [74]:
data_global = data.sel(time=slice('2006', '2018')).load()

In [75]:
max_val = data_global.max()["z"].values.tolist()

min_val = data_global.min()["z"].values.tolist()

In [76]:
data_train_final = (data_train - min_val) / (max_val - min_val)
data_val_final = (data_val - min_val) / (max_val - min_val)
data_test_final = (data_test - min_val) / (max_val - min_val)

In [77]:
time_vals = data_test_final.time.values
train_times = [i for i in range(2006, 2016)]
test_times = [2017, 2018]
val_times = [2016]

In [78]:
print(train_times)

[2006, 2007, 2008, 2009, 2010, 2011, 2012, 2013, 2014, 2015]


### get_batched

In [79]:
def get_batched(train_times, data_train_final, lev):
    for idx, year in enumerate(train_times):
        data_per_year = data_train_final.sel(
            time=slice(str(year), str(year))).load()
        data_values = data_per_year[lev].values
        print(f"data_values shape: {data_values.shape} || at year: {year}")

        if idx == 0:
            # has shape (time_values, 1, 1, 32, 64)  -> (time_values, year, channel, H, W) 
            train_data = torch.from_numpy(
                data_values).reshape(-1, 1, 1, data_values.shape[-2], data_values.shape[-1])
            print(f"train_data shape: {train_data.shape}")

            if year % 4 == 0:
                # skipping 29 feb in leap year
                train_data = torch.cat((train_data[:236], train_data[240:]))
        else:
            mid_data = torch.from_numpy(
                data_values).reshape(-1, 1, 1, data_values.shape[-2], data_values.shape[-1])
            print(f"train_data shape: {mid_data.shape}")
            if year % 4 == 0:
                # skipping 29 feb in leap year
                print(f"Leap year: {year}")
                mid_data = torch.cat((mid_data[:236], mid_data[240:]))
            train_data = torch.cat([train_data, mid_data], dim=1)

    return train_data



In [81]:
train_data_batched = get_batched(train_times=train_times, data_train_final=data_train_final, lev="z")

data_values shape: (1460, 32, 64) || at year: 2006
train_data shape: torch.Size([1460, 1, 1, 32, 64])
data_values shape: (1460, 32, 64) || at year: 2007
train_data shape: torch.Size([1460, 1, 1, 32, 64])
data_values shape: (1464, 32, 64) || at year: 2008
train_data shape: torch.Size([1464, 1, 1, 32, 64])
Leap year: 2008
data_values shape: (1460, 32, 64) || at year: 2009
train_data shape: torch.Size([1460, 1, 1, 32, 64])
data_values shape: (1460, 32, 64) || at year: 2010
train_data shape: torch.Size([1460, 1, 1, 32, 64])
data_values shape: (1460, 32, 64) || at year: 2011
train_data shape: torch.Size([1460, 1, 1, 32, 64])
data_values shape: (1464, 32, 64) || at year: 2012
train_data shape: torch.Size([1464, 1, 1, 32, 64])
Leap year: 2012
data_values shape: (1460, 32, 64) || at year: 2013
train_data shape: torch.Size([1460, 1, 1, 32, 64])
data_values shape: (1460, 32, 64) || at year: 2014
train_data shape: torch.Size([1460, 1, 1, 32, 64])
data_values shape: (1460, 32, 64) || at year: 2015

In [82]:
print(f"train_data_batched shape: {train_data_batched.shape}")

train_data_batched shape: torch.Size([1460, 10, 1, 32, 64])
