In [1]:
import xarray as xr
import numpy as np
import pandas as pd
from os.path import join
import os, sys
import torch
import warnings

warnings.filterwarnings("ignore")

path = join(os.pardir, os.pardir) # '../../src/'
sys.path.append(path)

from src.constants import M_COLUMNS, S_COLUMNS, G_COLUMNS, TARGET, FOLDER, TARGET_TEST


In [2]:
ROOT_DIR = os.path.join(os.pardir, os.pardir)
DATA_PATH = os.path.join(ROOT_DIR, 'data', 'processed', 'augment_100_5', 'train_enriched.nc')

test: bool = False
s_times: int = 24
m_times: int = 120

xds = xr.load_dataset(DATA_PATH, engine='scipy')
xds

In [3]:
obs, nb_aug, aug = 9, 100, 45
idx = obs * nb_aug + 45
idx_obs = idx // nb_aug
idx, idx_obs

(945, 9)

In [4]:
g_arr = xds[G_COLUMNS].to_dataframe()
g_input = torch.tensor(g_arr.values)
g_input.shape

torch.Size([557, 3])

In [5]:
for i in range(g_input.shape[0]):
    assert (g_input[i] == torch.tensor(xds.sel(ts_obs=i)[G_COLUMNS].to_array().values)).all()

In [6]:
s_arr = xds[S_COLUMNS].to_dataframe()[S_COLUMNS]
s_arr = s_arr.to_numpy()
s_arr = s_arr.reshape(s_arr.shape[0] // 24, 24, 8)
# s_arr.shape
# s_arr = s_arr.reshape((s_arr.shape[0], np.prod(s_arr.shape[1:3]), s_arr.shape[3]))
# s_arr = s_arr.T.swapaxes(0, 1)
s_input = torch.tensor(s_arr)
s_input.shape

torch.Size([55700, 24, 8])

In [7]:
for i in range(s_input.shape[0]):
    assert (s_input[i] == torch.tensor(xds.sel(ts_obs=i // 100, ts_aug=i % 100)[S_COLUMNS].to_array().values.T)).all()
    if i > 20:
        break

In [8]:
df_time = xds[['time', 'District']].to_dataframe()
df_time = df_time.groupby(['ts_obs', 'state_dev', 'District']).first()
df_time = df_time.reset_index('state_dev').drop(columns='ts_id')
df_time = df_time[df_time['state_dev'].isin([0, 23])]
df_time = df_time.pivot(columns='state_dev').droplevel(None, axis=1)
df_time.reset_index('District', inplace=True)

list_weather = []
for index, series in df_time.iterrows():
    all_dates = pd.date_range(series[0], series[23], freq='D')
    all_dates = all_dates[-m_times:]
    m_arr = xds.sel(datetime=all_dates, name=series['District'])[M_COLUMNS].to_array().values
    list_weather.append(m_arr.T)

m_arr = np.asarray(list_weather)
m_input = torch.tensor(m_arr)
m_input.shape

torch.Size([557, 120, 17])

In [9]:
def get_minput(i, xds):
    xds_time = xds.isel(ts_obs=i, ts_aug=0)
    all_dates = pd.date_range(xds_time['time'].min().values, xds_time['time'].max().values, freq='D')
    all_dates = all_dates[-m_times:]
    m_arr = xds_time.sel(datetime=all_dates, name=xds_time['District'])[M_COLUMNS].to_array().values
    m_arr = m_arr.reshape((len(M_COLUMNS), m_times)).T
    return torch.tensor(m_arr)

get_minput(1, xds).shape

torch.Size([120, 17])

In [10]:
for i in range(m_input.shape[0]):
    assert (m_input[i] == get_minput(i, xds)).all()

In [11]:
torch.tensor(xds[TARGET].to_dataframe().reset_index().to_numpy()).shape

torch.Size([557, 2])

In [12]:
import src.models.dataloader as old
from src.models.dataloader_jupyter import JupyterDataset, transform_data, create_train_val_idx
val_rate = .2
device = 'cpu'

dataset_path = join(ROOT_DIR, 'data', 'processed', FOLDER, 'train_enriched.nc')
xdf_train = xr.open_dataset(dataset_path, engine='scipy')

train_idx, val_idx = create_train_val_idx(xdf_train, val_rate)
train_array = xdf_train.sel(ts_obs=train_idx)
items = transform_data(train_array)
items['device'] = device
train_dataset = JupyterDataset(**items)

In [13]:
def get_item(i, xds: xr.Dataset):
    item = {}
    item['g_input'] = torch.tensor(xds.isel(ts_obs=i // 100)[G_COLUMNS].to_array().values).to(dtype=torch.float32)
    item['s_input'] = torch.tensor(xds.isel(ts_obs=i // 100, ts_aug=i % 100)[S_COLUMNS].to_array().values.T).to(dtype=torch.float32)
    item['m_input'] = torch.tensor(get_minput(i // 100, xds)).to(dtype=torch.float32)
    item['observation'] = torch.tensor(xds['ts_obs'].values[i//100]).to(dtype=torch.float32)
    item['target'] = torch.tensor(xds[TARGET].values[i//100]).to(dtype=torch.float32)
    return item

In [14]:
test_array = train_array.sortby(['ts_obs', 'ts_aug'])
train_shape = test_array['ts_id'].shape
test_array['ts_id'].values = np.arange(np.prod(train_shape)).reshape(train_shape)
old_dataset = old.CustomDataset(test_array)

In [15]:
for i in range(len(train_dataset)):
    item = train_dataset[i]
    item_ref = get_item(i, test_array)
    for key in item.keys():
        assert (item[key] == item_ref[key]).all()
    if i > 20:
        break

In [16]:
len(train_dataset) == len(old_dataset)

True

In [17]:
for i in range(len(train_dataset)):
    item = train_dataset[i]
    item_ref = old_dataset[i]
    for key in item.keys():
        assert (item[key] == item_ref[key]).all()
    if i > 20:
        break

In [22]:
from torch.utils.data import DataLoader, Dataset


dataloader = DataLoader(train_dataset,
        batch_size=16,
        drop_last=True,
        shuffle=False)

dataloader_ref = DataLoader(old_dataset,
        batch_size=16,
        drop_last=True,
        shuffle=False)

In [23]:
for i, (data, old_data) in enumerate(zip(dataloader, dataloader_ref)):
    for key in data.keys():
        # print(i, key)
        if i == 4 and key == 'target':
            print(data[key])
            print(old_data[key])
            assert (data[key] == old_data[key]).all()

0 observation
0 s_input
0 m_input
0 g_input
0 target
1 observation
1 s_input
1 m_input
1 g_input
1 target
2 observation
2 s_input
2 m_input
2 g_input
2 target
3 observation
3 s_input
3 m_input
3 g_input
3 target
4 observation
4 s_input
4 m_input
4 g_input
4 target
tensor([0.1071, 0.1071, 0.1071, 0.1071, 0.1071, 0.1071, 0.1071, 0.1071, 0.1071,
        0.1071, 0.1071, 0.1071, 0.1071, 0.1071, 0.1071, 0.1071])
tensor([[0.1071],
        [0.1071],
        [0.1071],
        [0.1071],
        [0.1071],
        [0.1071],
        [0.1071],
        [0.1071],
        [0.1071],
        [0.1071],
        [0.1071],
        [0.1071],
        [0.1071],
        [0.1071],
        [0.1071],
        [0.1071]])
5 observation
5 s_input
5 m_input
5 g_input
5 target
6 observation
6 s_input
6 m_input
6 g_input
6 target
7 observation
7 s_input
7 m_input
7 g_input
7 target
8 observation
8 s_input
8 m_input
8 g_input
8 target
9 observation
9 s_input
9 m_input
9 g_input
9 target
10 observation
10 s_input
10 m_input

KeyboardInterrupt: 