# Custom DataLoader

In [1]:
import pickle
from pathlib import Path
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from typing import List, Tuple, Dict, Union, Optional
from tqdm import tqdm
import torch
from torch.utils.data import Dataset
import pytorch_lightning as pl

data_dir = Path("/cats/datastore/data")

In [2]:
ds = pickle.load((data_dir / "kenya.pkl").open("rb"))
ds

In [3]:
def compute_scaler(df: pd.DataFrame) -> Dict[str, pd.Series]:
    scaler = {}
    scaler["attribute_means"] = df.mean()
    scaler["attribute_stds"] = df.std()
    return scaler

In [4]:
train_period_start = "2001-01-01"
train_period_end = "2015-12-01"
test_period_start = "2016-01-01"
test_period_end = "2016-01-01"
# df.loc[test_period_start: test_period_end]

In [5]:
# # pd.infer_freq([pd.to_datetime(d) for d in df.index])
# df.index = [pd.Timestamp(dt) for dt in df.index]
# dt = pd.Timestamp(train_period_start)
# # df.index[df.index.get_loc(dt, method='nearest')]
# # df.index.get_loc(dt, method='nearest')
# # df.loc[dt]
# null_pixels = _get_null_pixels(df, "boku_VCI")
# df = df.loc[~np.isin(df["pixel"], null_pixels)]

In [6]:
# # df.index[df.index.get_loc(dt, method='nearest')]
# dt = pd.to_datetime("2001-12-31")
# # df.index.get_loc(dt, method='nearest')
# df.loc[dt]

In [7]:
from torch.utils.data import random_split, DataLoader

In [8]:
def _get_null_pixels(df: pd.DataFrame, target_var: str) -> pd.Series:
    return df["pixel"].loc[df[target_var].isnull()].unique()


In [9]:
def _stack_xarray(ds:xr.Dataset, spatial_coords: List[str]) -> xr.Dataset:
    # stack values
    stacked = ds.stack(sample=spatial_coords)
    samples = stacked.sample
    pixel_strs = [f"{ll[0]}_{ll[-1]}" for ll in samples.values]
    stacked["sample"] = pixel_strs

    samples = samples.to_dataset(name="pixel")
    samples = xr.DataArray(pixel_strs, dims=["sample"], coords={"sample": samples.sample})
    return stacked, samples

In [12]:
stacked_ds, samples = _stack_xarray(ds, spatial_coords=["lat", "lon"])
stacked_ds

In [13]:
samples

In [15]:
target_var = "boku_VCI"
isnull = stacked_ds[target_var].isnull().mean(dim=["time"]) == 1
isnull = isnull.to_dataframe().reset_index()
pixel_var = [c for c in stacked_ds.coords if c!="time"][0]
null_pixels = isnull[pixel_var].loc[isnull[target_var].astype(bool)]


# stacked_ds
stacked_ds = stacked_ds.sel(sample = ~np.isin(stacked_ds.sample, null_pixels.values))

In [16]:
stacked_ds

In [None]:
## VALIDATE SAMPLES
for i in prange(n_samples):  # iterate through lowest-frequency samples
    # find the last sample in this frequency that belongs to the lowest-frequency step j
    last_sample_of_freq = frequency_maps[i][j]
    if last_sample_of_freq < seq_length[i] - 1:
        flag[j] = 0  # too early for this frequency's seq_length (not enough history)
        continue

    # any NaN in the dynamic inputs makes the sample invalid
    if x_d is not None:
        _x_d = x_d[i][last_sample_of_freq - seq_length[i] + 1:last_sample_of_freq + 1]
        if np.any(np.isnan(_x_d)):
            flag[j] = 0
            continue

    # all-NaN in the targets makes the sample invalid
    if y is not None:
        _y = y[i][last_sample_of_freq - predict_last_n[i] + 1:last_sample_of_freq + 1]
        if np.prod(np.array(_y.shape)) > 0 and np.all(np.isnan(_y)):
            flag[j] = 0
            continue

    # any NaN in the static features makes the sample invalid
    if x_s is not None:
        _x_s = x_s[i][last_sample_of_freq]
        if np.any(np.isnan(_x_s)):
            flag[j] = 0


In [33]:
sample_coordinates = stacked_ds["sample"].values.tolist()
dynamic_inputs = ["precip"]
lagged_variables = ["boku_VCI"]

def _check_no_missing_times_in_time_series(df):
    min_timestamp = df_native.index.min()
    max_timestamp = df_native.index.max()
    inf_freq = pd.infer_freq(df.index)
    assert list(pd.date_range(start=min_timestamp, end=max_timestamp, freq=inf_freq).difference(df.index)) == [], f"Missing data"


# store data
# create a lookup object from INDEX to PIXEL str
lookup = []
x_d, x_s, y = {}, {}, {}
for sample in tqdm(sample_coordinates):    
    # get raw dataframe
    keep_cols = [target_var] + dynamic_inputs + lagged_variables
    df_native = stacked_ds[keep_cols].sel({"sample": sample}).to_dataframe()
    # ensure that sorted by time
    df_native = df_native.sort_index()
    # ensure that every expected timestep is in the dataframe
    _check_no_missing_times_in_time_series(df_native)
    
    y[sample] = df_native[target_var].to_numpy().reshape(-1, 1)
    x_d[sample] = df_native[dynamic_inputs + lagged_variables].to_numpy()

#     valid_samples = np.argwhere(flag == 1)
#     for f in valid_samples:
#         # store pointer to pixel and the sample's index
    lookup.append((sample))
    
# 
lookup_table = {i: elem for i, elem in enumerate(lookup)}
num_samples = len(lookup_table)

100%|██████████| 1433/1433 [00:09<00:00, 156.87it/s]


In [32]:
x_d["6.0_33.75"].shape

(205, 2)

In [25]:
attributes = x_s = False

def getitem(item: int) -> Dict[str, torch.Tensor]:
    basin, indices = lookup_table[item]

    sample = {}
    for seq_len, idx in zip(seq_len, indices):
        # slice until idx + 1 because slice-end is excluding
        sample['x_d'] = x_d[basin][idx - seq_len + 1:idx + 1]
        sample['y'] = y[basin][idx - seq_len + 1:idx + 1]

        # check for static inputs
        static_inputs = []
        if attributes:
            static_inputs.append(self.attributes[basin_id])
        if x_s:
            static_inputs.append(self.x_s[basin][idx])
        if static_inputs:
            sample['x_s'] = torch.cat(static_inputs, dim=-1)

    if per_basin_target_stds:
        sample['per_basin_target_stds'] = self.per_basin_target_stds[basin]
    if one_hot is not None:
        x_one_hot = self.one_hot.zero_()
        x_one_hot[self.id_to_int[basin_id]] = 1
        sample['x_one_hot'] = x_one_hot

    return sample

In [26]:
getitem(0)

ValueError: too many values to unpack (expected 2)

In [None]:
min_timestamp = df_native.index.min()
max_timestamp = df_native.index.max()

df_native.index.get_loc(min_timestamp, method='nearest') + 3

# df.index[df.index.get_loc(min_timestamp, method='nearest')]

In [None]:
    
_check_no_missing_times_in_time_series(df_native)

In [None]:
inf_freq = pd.infer_freq(df_native.index)
if inf_freq == "M":
    freq = "D"
    value = 30 * seq_len
elif inf_freq == "Y":
    freq = "D"
    value = 366 * seq_len
elif inf_freq == "W":
    freq = "D"
    value = 7 * seq_len
else:
    freq = inf_freq
    value = seq_len
    
pd.Timedelta(value=value, unit="D")

In [None]:
# ?pd.Timedelta

In [None]:
stacked_ds[keep_cols]

In [None]:
df_native.head()


In [None]:
seq_len = 3
df_time_min = df_native.index.min()

In [None]:
df_nat

In [None]:
df_native.loc[:, [target_var] + dynamic_inputs + lagged_variables]

In [None]:
df

In [None]:
from torch.utils.data import random_split, DataLoader
import sys
from tqdm import tqdm


def _join_latlon_cols(df: pd.DataFrame, join_str: str = "_") -> pd.Series:
    points = df.astype({"lat": "str", "lon": "str"})
    return points['lat'] + join_str +  points['lon']

def _latlon_as_strings(ds, join_str: str = "_"):
    points = ds[["lat", "lon"]].to_dataframe().reset_index()
    return _join_latlon_cols(points, join_str=join_str)

def create_pixel_dataframe(ds: xr.Dataset, pixel_vars: List[str]) -> pd.DataFrame:
    df = ds.to_dataframe().reset_index(pixel_vars)
    df["pixel"] = _join_latlon_cols(df)
    return df

def _add_lagged_feature(ds: xr.Dataset, variable: str, lag: int = 1) -> Tuple[xr.Dataset, str]:
    var_name = f"{variable}_shift{lag}"
    ds[var_name] = ds[variable].shift(time=1)
    return ds, var_name

def create_pixel_id_encoding():
    id_to_int = {b: i for i, b in enumerate(np.random.permutation(self.pixels))}

    
def _get_null_pixels(df: pd.DataFrame, target_var: str) -> pd.Series:
    return df["pixel"].loc[df[target_var].isnull()].unique()


def get_all_null_pixels_from_dataset(ds, target_var: str) -> pd.Series:
    assert len(ds.coords) == 2, f"Expect coords = [time sample], got: {[c for c in ds.coords]}"
    isnull = ds[target_var].isnull().mean(dim=["time"]) == 1
    isnull = isnull.to_dataframe().reset_index()    
    pixel_var = [c for c in ds.coords if c!="time"][0]
    null_pixels = isnull[pixel_var].loc[isnull[target_var].astype(bool)]
    
    return null_pixels


def _stack_xarray(ds: xr.Dataset, spatial_coords: List[str]) -> xr.Dataset:
    # stack values
    stacked = ds.stack({"sample": spatial_coords})
    samples = stacked["sample"]
    pixel_strs = [f"{ll[0]}_{ll[-1]}" for ll in samples.values]
    stacked["sample"] = pixel_strs

    samples = samples.to_dataset(name="pixel")
    samples = xr.DataArray(pixel_strs, dims=["sample"], coords={"sample": samples.sample})
    return stacked, samples


class SpatioTemporalDataset(Dataset):
    """
    Tasks:
    -----
    • normalise input and output data
    • split into X, y variables (per-pixel)
    • split X into: static, dynamic data
    • convert xr.Dataset -> torch.Tensor
    • keep track of the metadata (time, lat, lon) [Ask Sharan]
    • create train, test, validation
    • Remove missing data
    • Create lagged target variable as input
    
    To implement:
    ------------
    prepare_data (how to download(), tokenize, etc…)
    setup (how to split, etc…)
    train_dataloader
    val_dataloader(s)
    test_dataloader(s)
    """    
    def __init__(
        self, 
        ds: xr.Dataset, 
        target_var: str,
        seq_len: int,
        forecast_horizon: int = 1,
        static_ds: Optional[xr.Dataset] = None,
        dynamic_inputs: List[str] = [],
        static_inputs: List[str] = [],
        batch_size: int = 32, 
        pixel_vars: List[str] = ["lat", "lon"],
        use_pixel_id_encoding: bool = False,
        lagged_variables: List[str] = [],
        lag: int = 1,
    ):
        super().__init__()
        self.batch_size = batch_size
        self.pixel_vars = pixel_vars
        self.lag = lag
        self.seq_len
        self.forecast_horizon = forecast_horizon
        
        # get X, y variables
        self.target_var = target_var
        self.dynamic_inputs = dynamic_inputs
        self.lagged_variables = lagged_variables 
        
        # run checks on input arguments
        self._init_checks(ds)
        
        self.ds = ds
        self.scaler: Dict[str, pd.Series] = {}
            
    def _create_lookup_table(self, ds: xr.Dataset):
        samples_without_data = []
        sample_coordinates = ds["sample"].values.tolist()
        
        for sample in tqdm(sample_coordinates, file=sys.stdout):
            # store data
            x_d, x_s, y = {}, {}, {}
            df_native = ds.sel({"sample": sample}).to_dataframe()
            
            pass
        
        pass
    
    def prepare_data(self):
        # 1. stack data (time, lat, lon) -> (time, pixel)
        ds, samples = _stack_xarray(self.ds, pixel_vars)
        null_pixels = get_all_null_pixels_from_dataset(ds, target_var)
        ds = ds.sel(sample = ~np.isin(ds.sample, null_pixels.values))

        # 2. create the lagged_variables
        _lagged_variables = []
        for var in self.lagged_variables:
            ds, var_name = _add_lagged_feature(ds, var, lag=self.lag)
            _lagged_variables.append(var_name)
        self.lagged_variables = _lagged_variables

        # 3. get only the used data variables
        keep_cols = ["sample"] + self.dynamic_inputs + [self.target_var] + self.lagged_variables
        ds = ds[keep_cols]
                
        # 4. normalise the data based on the scaler
        # 5. create the pixel id to integer OHE
        return

    def _get_data_variables(self) -> List[str]:
        data_variables = [self.target_var] + self.lagged_variables + self.dynamic_inputs
        
    def compute_static_scaler(self, df: pd.DataFrame):
        self.scaler["static_means"] = df.mean()
        self.scaler["static_stds"] = df.std()
    
    def _apply_static_normalisation(self, df: pd.DataFrame) -> pd.DataFrame:
        df = (df - self.scaler['attribute_means']) / self.scaler["attribute_stds"]
        return df
    
    def _setup_normalization(self, ds: xr.Dataset):
        self.scaler["dynamic_stds"] = ds.std(skipna=True)
        self.scaler["dynamic_means"] = ds.mean(skipna=True)

    def normalise_data(self):
        pass

    def setup(self, stage: Optional[str] = None):
        # Assign Train/val split(s) for use in Dataloaders
        if stage == 'fit' or stage is None:
            # load in data
            # transform data (normalise with scaler)
            # transformations (log)
            self.train_data, self.val_data = random_split(mnist_full, [55000, 5000])
            self.dims = self.train_data[0][0].shape

        # Assign Test split(s) for use in Dataloaders
        if stage == 'test' or stage is None:
            self.mnist_test = MNIST(
                self.data_dir,
                train=False,
                download=True,
                transform=self.transform
            )
            self.dims = getattr(self, 'dims', self.test_data[0][0].shape)
    
    def train_dataloader(self):
        return DataLoader(self.train_data, batch_size=self.batch_size)

    def val_dataloader(self):
        return DataLoader(self.val_data, batch_size=self.batch_size)

    def test_dataloader(self):
        return DataLoader(self.test_data, batch_size=self.batch_size)

    def _init_checks(self, ds: xr.Dataset) -> None:
        # check all init arguments
        assert all(np.isin(self.pixel_vars, [l for l in ds.coords])), f"Expect to find {self.pixel_vars} in {ds.coords}"
        assert target_var in ds.data_vars, f"Expect to find {target_var} in {ds.data_vars}"
        if self.lagged_variables != []:
            assert all(np.isin(self.lagged_variables, [v for v in ds.data_vars])), f"Expect to find {self.lagged_variables} in {ds.data_vars}"
        if self.dynamic_inputs != []:
            assert all(np.isin(self.dynamic_inputs, [v for v in ds.data_vars])), f"Expect to find {self.dynamic_inputs} in {ds.data_vars}"

        if (dynamic_inputs == []) & (lagged_variables == []):
            assert False, "Need to include some input features in `lagged_variables` or `dynamic_inputs`"
        
        assert ds.time.dtype == np.dtype('<M8[ns]'), "Time should be of datetime type"
            
            
d = SpatioTemporalDataset(_ds, seq_len=3, target_var="boku_VCI", lagged_variables=["boku_VCI"], dynamic_inputs=["precip"])
d.prepare_data()

In [None]:
target_var = "boku_VCI"
pixel_vars: List[str] = ["lat", "lon"]
in_lagged_variables = ["boku_VCI"]
dynamic_inputs = []

df = create_pixel_dataframe(ds, pixel_vars)

lagged_variables = []
for var in in_lagged_variables:
    df, var_name = _add_lagged_feature(df, var, lag=1)
    lagged_variables.append(var_name)

keep_cols = ["pixel"] + dynamic_inputs + [target_var] + lagged_variables
df = df[keep_cols]
df.head()

In [None]:
def encode_doys(
    doys: Union[List[str], List[int], List[datetime.datetime], str, int, datetime.datetime],
    start_doy: int = 1,
    end_doy: int = 366
) -> Tuple[List[float], List[float]]:
    doys_sin = []
    doys_cos = []
    for doy in doys:
        if isinstance(doy, str):
            doy = int(doy)
        elif isinstance(doy, datetime.datetime):
            doy = int(doy.strftime('%j'))

        if doy > 9999999:
            doy = int(datetime.datetime.strptime(
                str(doy), '%Y%m%d').strftime('%j'))
        else:
            if doy > 366 or doy < 1:
                raise ValueError(f'Invalid date "{doy}"')

        doys_sin.append(np.sin(2 * math.pi * (doy - start_doy) /
                               (end_doy - start_doy + 1)))
        doys_cos.append(np.cos(2 * math.pi * (doy - start_doy) /
                               (end_doy - start_doy + 1)))

    return doys_sin, doys_cos



# try xbatcher

In [None]:
import xbatcher

bgen = xbatcher.BatchGenerator(ds["boku_VCI"], input_dims={'time': 10})
for batch in bgen:
    pass
# batch["boku_VCI"]
# batch["boku_VCI"].to_dataframe()

In [None]:
# Dataset?