### Imports/set-up

In [1]:
# Load the "autoreload" extension so that code can change
%load_ext autoreload
# Always reload modules so that as you change code in src, it gets loaded
%autoreload 2

In [2]:
# You must import either the torch or tensorflow extensions of deepsensor before other deepsensor modules.
# This ensures deepsensor has access to the deep learning library backend.
import deepsensor.torch as ds
# import deepsensor.tensorflow as ds

2023-05-24 10:42:18.879906: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-05-24 10:42:19.005224: I tensorflow/core/util/port.cc:104] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.


In [4]:
from deepsensor.model.models import ConvNP
from deepsensor.data.loader import TaskLoader
from deepsensor.model.nps import compute_encoding_tensor
from deepsensor.plot.utils import plot_context_encoding

In [5]:
model = ConvNP(points_per_unit=300, unet_channels=(32, 32, 32, 32), encoder_scales=(0.002, 0.005))

2023-05-24 10:42:55.848704: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1613] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13163 MB memory:  -> device: 0, name: NVIDIA A2, pci bus id: 0000:98:00.0, compute capability: 8.6


In [6]:
import numpy as np
import pandas as pd
import xarray as xr
import dask

import os

import matplotlib.pyplot as plt

In [7]:
# Check GPU visible to tf
# print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

### Load data

In [8]:
# era5_ds = xr.open_mfdataset('../deepsensor_old/data/antarctica/gridded/processed/*/*.nc')
era5_ds = xr.open_mfdataset('../deepsensor_old/data/antarctica/gridded/processed/tas_anom/*.nc')
era5_ds

Unnamed: 0,Array,Chunk
Bytes,202.60 kiB,2.86 kiB
Shape,"(25933,)","(366,)"
Dask graph,71 chunks in 143 graph layers,71 chunks in 143 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray
"Array Chunk Bytes 202.60 kiB 2.86 kiB Shape (25933,) (366,) Dask graph 71 chunks in 143 graph layers Data type int64 numpy.ndarray",25933  1,

Unnamed: 0,Array,Chunk
Bytes,202.60 kiB,2.86 kiB
Shape,"(25933,)","(366,)"
Dask graph,71 chunks in 143 graph layers,71 chunks in 143 graph layers
Data type,int64 numpy.ndarray,int64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,7.57 GiB,109.46 MiB
Shape,"(25933, 280, 280)","(366, 280, 280)"
Dask graph,71 chunks in 143 graph layers,71 chunks in 143 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 7.57 GiB 109.46 MiB Shape (25933, 280, 280) (366, 280, 280) Dask graph 71 chunks in 143 graph layers Data type float32 numpy.ndarray",280  280  25933,

Unnamed: 0,Array,Chunk
Bytes,7.57 GiB,109.46 MiB
Shape,"(25933, 280, 280)","(366, 280, 280)"
Dask graph,71 chunks in 143 graph layers,71 chunks in 143 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [9]:
from deepsensor.data.loader import construct_x1x2_ds, construct_circ_time_ds

aux_ds = xr.open_mfdataset('../deepsensor_old/data/antarctica/auxiliary/processed/*25000m/*.nc')

x1x2_ds = construct_x1x2_ds(aux_ds)
aux_ds['x1_arr'] = x1x2_ds['x1_arr']
aux_ds['x2_arr'] = x1x2_ds['x2_arr']

dates = pd.date_range(era5_ds.time.values.min(), era5_ds.time.values.max(), freq="D")
doy_ds = construct_circ_time_ds(dates, freq="D")
aux_ds["cos_D"] = doy_ds["cos_D"]
aux_ds["sin_D"] = doy_ds["sin_D"]

aux_ds

Unnamed: 0,Array,Chunk
Bytes,306.25 kiB,306.25 kiB
Shape,"(280, 280)","(280, 280)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 306.25 kiB 306.25 kiB Shape (280, 280) (280, 280) Dask graph 1 chunks in 5 graph layers Data type float32 numpy.ndarray",280  280,

Unnamed: 0,Array,Chunk
Bytes,306.25 kiB,306.25 kiB
Shape,"(280, 280)","(280, 280)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,306.25 kiB,306.25 kiB
Shape,"(280, 280)","(280, 280)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 306.25 kiB 306.25 kiB Shape (280, 280) (280, 280) Dask graph 1 chunks in 5 graph layers Data type float32 numpy.ndarray",280  280,

Unnamed: 0,Array,Chunk
Bytes,306.25 kiB,306.25 kiB
Shape,"(280, 280)","(280, 280)"
Dask graph,1 chunks in 5 graph layers,1 chunks in 5 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,612.50 kiB,612.50 kiB
Shape,"(280, 280)","(280, 280)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray
"Array Chunk Bytes 612.50 kiB 612.50 kiB Shape (280, 280) (280, 280) Dask graph 1 chunks in 2 graph layers Data type float64 numpy.ndarray",280  280,

Unnamed: 0,Array,Chunk
Bytes,612.50 kiB,612.50 kiB
Shape,"(280, 280)","(280, 280)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float64 numpy.ndarray,float64 numpy.ndarray

Unnamed: 0,Array,Chunk
Bytes,306.25 kiB,306.25 kiB
Shape,"(280, 280)","(280, 280)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray
"Array Chunk Bytes 306.25 kiB 306.25 kiB Shape (280, 280) (280, 280) Dask graph 1 chunks in 2 graph layers Data type float32 numpy.ndarray",280  280,

Unnamed: 0,Array,Chunk
Bytes,306.25 kiB,306.25 kiB
Shape,"(280, 280)","(280, 280)"
Dask graph,1 chunks in 2 graph layers,1 chunks in 2 graph layers
Data type,float32 numpy.ndarray,float32 numpy.ndarray


In [10]:
station_df = pd.read_csv('../deepsensor_old/data/antarctica/station/processed/XY_station.csv')
station_df = station_df.rename(columns={'date': 'time'})
station_df['time'] = pd.to_datetime(station_df['time'])
station_df = station_df.set_index(['time', 'x1', 'x2']).sort_index()
station_df = station_df[['tas']]
print(station_df)

                                     tas
time       x1        x2                 
1948-04-01 -1.039598  0.400453  0.998104
1948-04-02 -1.039598  0.400453  1.043380
1948-04-03 -1.039598  0.400453  0.934391
1948-04-04 -1.039598  0.400453  0.955902
1948-04-05 -1.039598  0.400453  0.892086
...                                  ...
2022-06-14  0.413597 -0.197276 -0.634638
2022-06-15  0.153729 -0.459005 -0.864647
            0.230653 -0.386157 -0.709768
            0.264879 -0.659692  0.092832
            0.281365 -0.490629 -0.467813

[941850 rows x 1 columns]


### TODO: Normalise data with DataProcessor

In [11]:
# TODO

### Set up a TaskLoader object to generate forecasting tasks
Note the flexibility of the `TaskLoader` init arguments below.
With slight changes to the arguments, we can instantiate a `TaskLoader` that generates tasks for forecasting or interpolation.

In [12]:
task_loader = TaskLoader(context=[era5_ds['t2m'], aux_ds], target=era5_ds['t2m'])
print(task_loader)

TaskLoader(2 context sets, 1 target sets)
Context variable IDs: (('t2m_t0',), ('mask_t0', 'surface_t0', 'x1_arr_t0', 'x2_arr_t0', 'cos_D_t0', 'sin_D_t0'))
Target variable IDs: (('t2m_t0',),)


In [13]:
# task_loader.load_dask()  # Load any dask arrays into memory for faster training

### The TaskLoader outputs Tasks, which are dict-like objects containing context and target data

Calling a `TaskLoader` with a `date` generates a `Task` for that date by slicing the context and target variables.
`Task`s inherit from `dict` and provide `__str__` and `__repr__` methods for debugging.
`TaskLoader` offers several data sampling methods for generating `Task`s.
```
"grid": 2D grid of data
N, int: uniform random sampling of N grid cells
```

In [14]:
task = task_loader("2000-01-01", "grid", 10)
print(type(task), "\n")

print("Concise task summary:")
print(task)

print("Verbose task summary:")
print(repr(task))

<class 'deepsensor.data.task.Task'> 

Concise task summary:
time: 2000-01-01 00:00:00
modify: None
X_c: [((280,), (280,)), ((280,), (280,))]
Y_c: [(1, 280, 280), (6, 280, 280)]
X_t: [(2, 10)]
Y_t: [(10,)]

Verbose task summary:
time: Timestamp/2000-01-01 00:00:00
modify: NoneType/None
X_c: [('ndarray/float32/(280,)', 'ndarray/float32/(280,)'), ('ndarray/float32/(280,)', 'ndarray/float32/(280,)')]
Y_c: ['ndarray/float32/(1, 280, 280)', 'ndarray/float32/(6, 280, 280)']
X_t: ['ndarray/float32/(2, 10)']
Y_t: ['ndarray/float32/(10,)']



### Set up ConvNP object

`ConvNP` wraps around the `neuralprocesses` library to provide a convenient interface for inference with environmental data in the form of `Task`s.

There are several ways to set up a `ConvNP` object:

In [15]:
from deepsensor.model.models import TFModel
from plum import resolve_type_hint
resolve_type_hint(TFModel)

keras.engine.training.Model

In [16]:
# Instantiate a ConvNP with `neuralprocesses.construct_ConvNP` kwargs
model = ConvNP(points_per_unit=300, unet_channels=(32, 32, 32, 32), encoder_scales=(0.002, 0.005))

In [17]:
# This works too: instantiate a ConvNP from an existing TensorFlow model
model = ConvNP(model.model)

In [18]:
# However, we will instantiate a ConvNP with a TaskLoader to infer sensible defaults (unless overridden with kwargs)
model_dim = 128
model = ConvNP(task_loader, unet_channels=(model_dim,) * 4)

dim_yc inferred from TaskLoader: (1, 6)
dim_yt inferred from TaskLoader: 1
points_per_unit inferred from TaskLoader: 167
encoder_scales inferred from TaskLoader: [0.0035714285913854837, 0.0035714285913854837]


A `ConvNP`'s `__call__` method accepts a `Task` and returns a distribution object.
The distribution object can be used to compute predictions, sample from the model, compute entropy, etc, without having to run the model again.

In [19]:
# Run model on a random task to build model, with all context data and 3 random target locations
task = task_loader("2000-01-01", "grid", 3)
dist = model(task)
print('Distribution object: ', type(dist))
x = model.predict(dist)
print('Mean: ', x.shape, x)
x = model.sample(dist, n_samples=2)
print('Sample: ', x.shape, x)
x = model.variance(dist)
print('Variance: ', x.shape, x)
x = model.stddev(dist)
print('Std dev: ', x.shape, x)
x = model.covariance(dist)
print('Covariance: ', x.shape)
x = model.logpdf(dist, task)
print('logpdf: ', x)
x = model.entropy(dist)
print('Entropy: ', x)

2023-05-24 10:43:32.129500: I tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:630] TensorFloat-32 will be used for the matrix multiplication. This will only be logged once.
2023-05-24 10:43:42.062577: I tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:428] Loaded cuDNN version 8201


Distribution object:  <class 'neuralprocesses.dist.normal.MultiOutputNormal[matrix.matrix.Dense, matrix.lowrank.LowRank, matrix.diagonal.Diagonal]'>
Mean:  (3,) [7.047846 8.363125 8.253919]
Sample:  (2, 3) [[7.619549  8.991036  8.889617 ]
 [7.1698184 8.6254835 8.54682  ]]
Variance:  (3,) [0.19775026 0.2094001  0.20240806]
Std dev:  (3,) [0.44469118 0.45760256 0.44989783]
Covariance:  (3, 3)
logpdf:  -183.18462
Entropy:  0.09392309


Note, all of the above methods can be called on a `Task` rather than a distribution object, and the model will be run internally, e.g.:

In [20]:
x = model.predict(task)
print('Mean: ', x.shape, x)

Mean:  (3,) [7.047846 8.363125 8.253919]


After model has been built, we can get some useful information:

In [21]:
print(f"Model receptive field: {model.model.receptive_field:.2f}")
print(f"Model has {ds.backend.nps.num_params(model.model):,} parameters")

Model receptive field: 0.39
Model has 4,134,341 parameters


The model has a method to reshape the task data into a format that can be passed to the model:

In [None]:
task = ConvNP.modify_task(task)
print(repr(task))

### Visualise encoded context data in the model
Inspecting the gridded encoding of the context data helps with understanding the context sampling schemes.

This can also be an extremely useful debugging tool. For example:
* Do the length scales of the encoded data seem reasonable (i.e. avoids blurring high frequency components while not being so small to induce checkerboard artefacts)?
* Are the channel magnitudes in the encoding reasonable?
* Are there any `nan` values?

In [None]:
task = task_loader("2000-01-01", (250, "grid"), 5000)
encoding = compute_encoding_tensor(model, task)
print(f"\nEncoding is shape {encoding.shape}")
fig = plot_context_encoding(model, task, task_loader)
plt.show()

### Train model

In [None]:
import tensorflow as tf
from tqdm.notebook import tqdm

opt = tf.keras.optimizers.Adam(1e-5)

def train_step(tasks):
    if not isinstance(tasks, list):
        tasks = [tasks]
    with tf.GradientTape() as tape:
        task_losses = []
        for task in tasks:
            task_losses.append(model.loss_fn(task, normalise=True))
        mean_batch_loss = tf.reduce_mean(task_losses)
    grads = tape.gradient(mean_batch_loss, model.model.trainable_weights)
    opt.apply_gradients(zip(grads, model.model.trainable_weights))
    return mean_batch_loss

n_epochs = 10
epoch_losses = []
for epoch in tqdm(range(n_epochs), position=0):
    dates = pd.date_range('1980-01-01', '2009-12-31')[::365]
    pbar = tqdm(dates, position=1, smoothing=1)
    batch_losses = []
    for date in pbar:
        n_obs = np.random.randint(5, 500)
        n_t = 5000
        task = task_loader(date, (n_obs, "grid"), n_t)
        batch_loss = train_step(task)
        batch_losses.append(batch_loss)
        pbar.set_description('avg loss: {:.2f}'.format(np.mean(batch_losses)))
    epoch_loss = np.mean(batch_losses)
    epoch_losses.append(epoch_loss)
    print(f"Loss: {epoch_loss}")

In [None]:
plt.plot(epoch_losses)
plt.gca().set_ylabel("loss")

### Predict on heldout data

In [None]:
test_dates = pd.date_range("2000-01-01", "2000-12-31")
tasks = task_loader(test_dates, "grid", "grid")

In [None]:
pred_ds = model.predict_ongrid(tasks, reference_grid=era5_ds, n_samples=3, progress_bar=1)
pred_ds

In [None]:
# Convert time coord to pandas timestamps
pred_ds = pred_ds.assign_coords(time=pd.to_datetime(pred_ds.time.values))

In [None]:
pred_ds['time'] = [pd.Timestamp(t) for t in pred_ds.time.values]

In [None]:
pred_ds

In [None]:
# Convert init time to forecast time
pred_ds = pred_ds.assign_coords(time=pred_ds['time'] + pd.Timedelta(days=task_loader.target_delta_t[0]))
pred_ds

In [None]:
true_da = task_loader.target[0]

In [None]:
err_da = pred_ds['mean'] - true_da
err_da

In [None]:
dask.array.fabs(err_da).mean(['x1', 'x2']).plot()
plt.gca().set_ylabel("MAE (K)")

In [None]:
mae = dask.array.fabs(err_da).mean()
rmse = np.sqrt(dask.array.square(err_da).mean())
print(f"Test MAE: {mae.values:.2f} K, RMSE: {rmse.values:.2f} K")

In [None]:
init_date = pd.Timestamp("2000-06-25")
target_date = init_date + pd.DateOffset(days=task_loader.target_delta_t[0])
true_da.sel(time=[init_date, target_date]).load().plot(col="time")

In [None]:
pred_ds['mean'].isel(time=0).plot()

In [None]:
pred_ds['samples'].plot(col='sample')

In [None]:
err_da = pred_ds['mean'].isel(time=0) - true_da.sel(time=target_date)
err_da.plot()

In [None]:
pred_ds['std'].plot()

### Test model saving and loading

In [None]:
os.makedirs('model', exist_ok=True)
model.model.save_weights('model/')

In [None]:
type(task_loader)

In [None]:
model = ConvNP(task_loader)

Randomly initialised prediction

In [None]:
model.predict_ongrid(task, reference_grid=era5_ds)['std'].plot()

Load weights and predict using trained model

In [None]:
model.model.load_weights('model/')
model.predict_ongrid(task, reference_grid=era5_ds)['std'].plot()