## Deepsensor Key classes.

* DataProcessor: Maps xarray and pandas data from their native units to a normalised and standardised format (and vice versa).

* TaskLoader: Slices and samples normalised xarray and pandas data to generate Task objects for training and inference.

* Task: Container for context and target data. Subclass of dict with additional methods for processing and summarising the data.

* DeepSensorModel: Base class for DeepSensor models, implementing a high-level .predict method for predicting straight to xarray/pandas in original coordinates and units.

* ConvNP: Convolutional neural process (ConvNP) model class (subclass of DeepSensorModel). Uses the neuralprocesses library. This is currently the only model provided by DeepSensor.

* Trainer: Class for training on Task objects using backpropagation and the Adam optimiser.

* AcquisitionFunction: Base class for active learning acquisition functions.

* GreedyAlgorithm: Greedy search algorithm for active learning.

In addition, a deepsensor.plot module provides useful plotting functions for visualising:

* Task context and target sets

* DeepSensorModel predictions.

* ConvNP internals (encoding and feature maps)

* GreedyAlgorithm active learning outputs.

In [1]:
pip install git+https://github.com/scott-hosking/get-station-data.git

Collecting git+https://github.com/scott-hosking/get-station-data.git
  Cloning https://github.com/scott-hosking/get-station-data.git to /tmp/pip-req-build-lap9fuds
  Running command git clone --filter=blob:none --quiet https://github.com/scott-hosking/get-station-data.git /tmp/pip-req-build-lap9fuds
  Resolved https://github.com/scott-hosking/get-station-data.git to commit f7eaa50823ff75117b577b370e67d194c3d342a1
  Preparing metadata (setup.py) ... [?25ldone
Note: you may need to restart the kernel to use updated packages.


In [1]:
from deepsensor.data import DataProcessor
from deepsensor.data.sources import get_ghcnd_station_data, get_era5_reanalysis_data, get_earthenv_auxiliary_data, get_gldas_land_mask
import os

In [2]:
import logging
logging.captureWarnings(True)

import xarray as xr
import pandas as pd

# Using the same settings allows use to use pre-downloaded cached data
data_range = ("2015-06-25", "2015-06-30")
extent = "europe"
station_var_IDs = ["TAVG", "PRCP"]
era5_var_IDs = ["2m_temperature", "10m_u_component_of_wind", "10m_v_component_of_wind"]
aux_var_IDs = ["elevation", "tpi"]
cache_dir = "../../.datacache"

In [3]:
station_raw_df = get_ghcnd_station_data(station_var_IDs, extent, date_range=data_range, cache=True, cache_dir=cache_dir)
era5_raw_ds = get_era5_reanalysis_data(era5_var_IDs, extent, date_range=data_range, cache=True, cache_dir=cache_dir)
aux_raw_ds = get_earthenv_auxiliary_data(aux_var_IDs, extent, "1KM", cache=True, cache_dir=cache_dir)
land_mask_raw_ds = get_gldas_land_mask(extent, cache=True, cache_dir=cache_dir)

In [4]:
era5_raw_ds

### Initialising a DataProcessor.
* defaults - time, x1, x2.

In [5]:
data_processor = DataProcessor(x1_name="lat", x2_name="lon")
print(data_processor)

DataProcessor with normalisation params:
{'coords': {'time': {'name': 'time'},
            'x1': {'map': None, 'name': 'lat'},
            'x2': {'map': None, 'name': 'lon'}}}


### Normalising data with DataProcessor.

In [6]:
era5_ds = data_processor(era5_raw_ds)
era5_ds

In [7]:
station_raw_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,PRCP,TAVG
time,lat,lon,station,Unnamed: 4_level_1,Unnamed: 5_level_1
2015-06-25,35.017,-1.450,AGM00060531,0.0,23.0
2015-06-25,35.100,-1.850,AGE00147716,0.0,23.4
2015-06-25,35.117,36.750,SYM00040030,,25.4
2015-06-25,35.167,2.317,AGM00060514,0.0,25.9
2015-06-25,35.200,-0.617,AGM00060520,0.0,24.9
...,...,...,...,...,...
2015-06-30,45.933,7.700,ITM00016052,,5.7
2015-06-30,38.367,-0.500,SPM00008359,0.0,27.6
2015-06-30,55.383,36.700,RSM00027611,0.0,17.2
2015-06-30,59.080,17.860,SWE00138750,0.0,


In [15]:
station_df = data_processor(station_raw_df)
station_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,PRCP,TAVG
time,x1,x2,station,Unnamed: 4_level_1,Unnamed: 5_level_1
2015-06-25,0.000309,0.246364,AGM00060531,-0.278036,0.759107
2015-06-25,0.001818,0.239091,AGE00147716,-0.278036,0.836200
2015-06-25,0.002127,0.940909,SYM00040030,,1.221668
2015-06-25,0.003036,0.314855,AGM00060514,-0.278036,1.318035
2015-06-25,0.003636,0.261509,AGM00060520,-0.278036,1.125301
...,...,...,...,...,...
2015-06-30,0.198782,0.412727,ITM00016052,,-2.575188
2015-06-30,0.061218,0.263636,SPM00008359,-0.278036,1.645682
2015-06-30,0.370600,0.940000,RSM00027611,-0.278036,-0.358749
2015-06-30,0.437818,0.597455,SWE00138750,-0.278036,


### Can also process multiple variables in one Dataprocessor call.
* min_max - scale the data in a range [-1,1].

In [8]:
aux_raw_ds

In [9]:
land_mask_raw_ds

In [10]:
aux_ds, land_mask_ds = data_processor([aux_raw_ds, land_mask_raw_ds], method="min_max")
aux_ds

In [12]:
land_mask_ds

### DataProcessor configuration.
* Keeps track of the normalization parameters used to keep track of the data.

In [22]:
print(data_processor)

DataProcessor with normalisation params:
{'10m_u_component_of_wind': {'method': 'mean_std',
                             'params': {'mean': 0.6799039244651794,
                                        'std': 2.8934481143951416}},
 '10m_v_component_of_wind': {'method': 'mean_std',
                             'params': {'mean': -0.20978610217571259,
                                        'std': 3.282301187515259}},
 '2m_temperature': {'method': 'mean_std',
                    'params': {'mean': 288.9769287109375,
                               'std': 5.505463123321533}},
 'GLDAS_mask': {'method': 'min_max', 'params': {'max': 1.0, 'min': 0.0}},
 'PRCP': {'method': 'mean_std',
          'params': {'mean': 1.1599447860459278, 'std': 4.171918456167085}},
 'TAVG': {'method': 'mean_std',
          'params': {'mean': 19.0613726868119, 'std': 5.18850380936251}},
 'coords': {'time': {'name': 'time'},
            'x1': {'map': (35.0, 90.0), 'name': 'lat'},
            'x2': {'map': (-15.0, 40.0),

### Unnormalizing data.

In [13]:
era5_raw_ds_unnormalised = data_processor.unnormalise(era5_ds)
xr.testing.assert_allclose(era5_raw_ds, era5_raw_ds_unnormalised, atol=1e-6)
era5_raw_ds_unnormalised

In [16]:
station_df_unnormalised = data_processor.unnormalise(station_df)
pd.testing.assert_frame_equal(station_raw_df, station_df_unnormalised)
station_df_unnormalised

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,Unnamed: 3_level_0,PRCP,TAVG
time,lat,lon,station,Unnamed: 4_level_1,Unnamed: 5_level_1
2015-06-25,35.017,-1.450,AGM00060531,0.0,23.0
2015-06-25,35.100,-1.850,AGE00147716,0.0,23.4
2015-06-25,35.117,36.750,SYM00040030,,25.4
2015-06-25,35.167,2.317,AGM00060514,0.0,25.9
2015-06-25,35.200,-0.617,AGM00060520,0.0,24.9
...,...,...,...,...,...
2015-06-30,45.933,7.700,ITM00016052,,5.7
2015-06-30,38.367,-0.500,SPM00008359,0.0,27.6
2015-06-30,55.383,36.700,RSM00027611,0.0,17.2
2015-06-30,59.080,17.860,SWE00138750,0.0,


### Saving and loading a DataProcessor

In [16]:

data_processor.save("deepsensor_config/")
data_processor2 = DataProcessor("deepsensor_config/")
print(data_processor2)

DataProcessor with normalisation params:
{'10m_u_component_of_wind': {'method': 'mean_std',
                             'params': {'mean': 0.6799039244651794,
                                        'std': 2.8934481143951416}},
 '10m_v_component_of_wind': {'method': 'mean_std',
                             'params': {'mean': -0.20978610217571259,
                                        'std': 3.282301187515259}},
 '2m_temperature': {'method': 'mean_std',
                    'params': {'mean': 288.9769287109375,
                               'std': 5.505463123321533}},
 'GLDAS_mask': {'method': 'min_max', 'params': {'max': 1.0, 'min': 0.0}},
 'PRCP': {'method': 'mean_std',
          'params': {'mean': 1.1599447860459278, 'std': 4.171918456167085}},
 'TAVG': {'method': 'mean_std',
          'params': {'mean': 19.0613726868119, 'std': 5.18850380936251}},
 'coords': {'time': {'name': 'time'},
            'x1': {'map': (35.0, 90.0), 'name': 'lat'},
            'x2': {'map': (-15.0, 40.0),

### Normalization parameters over a subset of Data.

In [25]:
_ = data_processor(era5_raw_ds.sel(time=slice("2015-06-25", "2015-06-27")))
era5_ds = data_processor(era5_raw_ds)  # Will use the normalisation parameters computed above when called on the full dataset