In [1]:
import os
import pickle
from datetime import datetime, timedelta
from pathlib import Path

import numpy as np
import pandas as pd
from sklearn.preprocessing import (LabelEncoder, MinMaxScaler,
                                   QuantileTransformer, StandardScaler)
from tqdm.auto import tqdm


# helper functions

In [2]:
def sort_loc_time(data_path: str, output_path: str) -> None:
    """Sort data by location and time and save to a parquet file."""
    df = pd.read_parquet(data_path)
    df = df.sort_values(['location', 'time'])
    Path(output_path).parent.mkdir(parents=True, exist_ok=True)
    df.to_parquet(output_path)


# processing functions

In [9]:
def tft_process(path: str, hist_len: int, fut_len: int, output_path: str) -> None:
    """Create TFT training data from a sorted parquet file."""
    output_filename = Path(path).with_suffix('.pkl').name

    df = pd.read_parquet(path)
    df = df[['time', 'location', 'latitude', 'longitude', 'tmin', 'tmax',
             'precipitation', 'radiation', 'photoperiod', 'swvl1',
             'sif_clear_inst']]
    df = df.dropna()
    df['time'] = pd.to_datetime(df['time'])

    meta_attrs = ['time', 'location']
    known_attrs = ['tmin', 'tmax', 'radiation', 'precipitation', 'swvl1', 'photoperiod']
    static_attrs = ['latitude', 'longitude']
    categorical_attrs = []

    all_cols = list(df.columns)
    feature_cols = [c for c in all_cols if c not in meta_attrs]

    feature_map = {
        'static_feats_numeric': [c for c in feature_cols if c in static_attrs and c not in categorical_attrs],
        'static_feats_categorical': [c for c in feature_cols if c in static_attrs and c in categorical_attrs],
        'historical_ts_numeric': [c for c in feature_cols if c not in static_attrs and c not in categorical_attrs],
        'historical_ts_categorical': [c for c in feature_cols if c not in static_attrs and c in categorical_attrs],
        'future_ts_numeric': [],
        'future_ts_categorical': [],
    }

    scalers = {'numeric': {}, 'categorical': {}}
    categorical_cardinalities = {}

    for col in tqdm(feature_cols, desc="fit_scalers"):
        if col in categorical_attrs:
            enc = LabelEncoder().fit(df[col].values)
            scalers['categorical'][col] = enc
            categorical_cardinalities[col] = df[col].nunique()
        else:
            if col == 'sif_clear_inst':
                scaler = StandardScaler()
            elif col == 'day_of_year':
                scaler = MinMaxScaler()
            else:
                scaler = QuantileTransformer(n_quantiles=256)
            scalers['numeric'][col] = scaler.fit(df[col].astype(float).values.reshape(-1, 1))

    for col in tqdm(feature_cols, desc="transform"): 
        if col in categorical_attrs:
            df[col] = df[col].astype(np.int32)
        else:
            df[col] = scalers['numeric'][col].transform(df[col].values.reshape(-1, 1)).squeeze().astype(np.float32)

    train_subset = df[(df['time'] >= datetime(1982, 1, 1)) &
                      (df['time'] < datetime(2012, 1, 1))]
    
    val_subset   = df[(df['time'] >= datetime(2012, 1, 1)) &
                      (df['time'] < datetime(2017, 1, 1))]
    
    test_subset  = df[(df['time'] >= datetime(2017, 1, 1)) &
                      (df['time'] < datetime(2022, 1, 1))]
    
    subsets = {'train': train_subset,
               'validation': val_subset,
               'test': test_subset}

    data_sets = {k: {} for k in ['train', 'validation', 'test']}
    for subset in subsets.values():
        subset['id'] = subset['location'].astype(str) + '_' + subset['time'].astype(str)

    for subset_key, subset_data in subsets.items():
        samp_interval = hist_len + fut_len
        for i in range(0, len(subset_data), samp_interval):
            slc = subset_data.iloc[i:i + samp_interval]
            if len(slc) < samp_interval or slc.iloc[0]['location'] != slc.iloc[-1]['location']:
                continue
            data_sets[subset_key].setdefault('time_index', []).append(slc.iloc[hist_len - 1]['location'])
            data_sets[subset_key].setdefault('static_feats_numeric', []).append(
                slc.iloc[0][feature_map['static_feats_numeric']].values.astype(np.float32))
            data_sets[subset_key].setdefault('static_feats_categorical', []).append(
                slc.iloc[0][feature_map['static_feats_categorical']].values.astype(np.int32))
            data_sets[subset_key].setdefault('historical_ts_numeric', []).append(
                slc.iloc[:hist_len][feature_map['historical_ts_numeric']].values.astype(np.float32))
            data_sets[subset_key].setdefault('historical_ts_categorical', []).append(
                slc.iloc[:hist_len][feature_map['historical_ts_categorical']].values.astype(np.int32))
            data_sets[subset_key].setdefault('future_ts_numeric', []).append(
                slc.iloc[hist_len:][feature_map['future_ts_numeric']].values.astype(np.float32))
            data_sets[subset_key].setdefault('future_ts_categorical', []).append(
                slc.iloc[hist_len:][feature_map['future_ts_categorical']].values.astype(np.int32))
            data_sets[subset_key].setdefault('target', []).append(
                slc.iloc[hist_len:]['sif_clear_inst'].values.astype(np.float32))
            data_sets[subset_key].setdefault('id', []).append(
                slc.iloc[hist_len:]['id'].values.astype(str))

    for set_key, comps in data_sets.items():
        for arr_key, arr in comps.items():
            data_sets[set_key][arr_key] = np.array(arr)

    output_dir = Path(output_path)
    output_dir.mkdir(parents=True, exist_ok=True)
    with open(output_dir / output_filename, 'wb') as f:
        pickle.dump({'data_sets': data_sets,
                     'feature_map': feature_map,
                     'scalers': scalers,
                     'categorical_cardinalities': categorical_cardinalities}, f, pickle.HIGHEST_PROTOCOL)


In [None]:
path= '/burg/glab/users/al4385/data/CSIFMETEO/sorted_merged_BDT_1982_2021.parquet'
hist_len = 60
fut_len = 10
output_path = '/burg/glab/users/al4385/data/CLM_ml_phenology/'
tft_process(path, hist_len, fut_len, output_path)