# Experiment 2: Transfer learning for ConvNP

When training the ConvNP for data assimilation, we want to be able to first train the model on a broad dataset, covering a large portion of Australasia. We then want to refine the ConvNP for predictions on a narrower dataset, specifically for New Zealand. To make a plan for what context, auxiliary and target sets we use, we need to know how the default model structure changes to different input structures.

Things which need to be considered:
- Target resolution/model internal density
- Input scale
- Number of context sets
- Number of target sets
- Multi-output vs single-output?

Note: if the encoder changes structure but the decoder stays the same, then we might be able to train seperate encoders, but share the decoder structure?

In [1]:
# setup project root for imports (requirement for all notebooks in this repo)
import sys
from pathlib import Path

# Make project root importable
ROOT = Path().resolve().parents[1]
sys.path.append(str(ROOT))

In [2]:
%load_ext autoreload

In [3]:
%autoreload 2
from src.utils.variables.var_names import *
from src.utils.variables.coord_names import *
from src.data_processing.conversions.scalar_conversions import *
from src.config.env_loader import get_env_var
import src.learning.model_diagnostics as model_diagnostics
from src.learning.model_training import batch_data_by_num_stations, compute_val_loss

from src.data_processing.station_processor import ProcessStations
from src.data_processing.topography_processor import ProcessTopography
from src.data_processing.era5_processor import ProcessERA5

from src.data_processing.auxiliary.sun_position import get_sun_culmination

In [4]:
# DeepSensor imports
# note this pulls from a fork of DeepSensor.
import deepsensor.torch
from deepsensor.train.train import train_epoch, set_gpu_default_device
from deepsensor.data.loader import TaskLoader
from deepsensor.data.processor import DataProcessor
from deepsensor.model.convnp import ConvNP
from deepsensor.data.utils import construct_x1x2_ds

import numpy as np
import matplotlib.pyplot as plt
import xarray as xr
from mpl_toolkits.basemap import Basemap
import torch
from torch import optim
import os
import lab as B
from tqdm import tqdm
import cartopy.crs as ccrs
import cartopy.feature as cf

In [None]:
# setup variables for experiment
years = [2010, 2011, 2012, 2013, 2014]

train_years = [2010, 2011, 2012, 2013]
validation_years = [2014]

# GPU settings
use_gpu = True
if use_gpu:
    cuda_device = int(get_env_var("CUDA_DEVICE"))
    set_gpu_default_device(backend="cuda", dev_id=cuda_device)

# visualisations of data
DEBUG_PLOTS = True

In [6]:
# dataset loader modules. These are from src.data_processing.
#    They use a file loader module in src.data_processing.file_loaders to load raw data files,
#    in the case of getting data in a different structure, changes will need to be made to the file loader modules.
station_processor = ProcessStations()
topography_processor = ProcessTopography()
era5_processor = ProcessERA5()

# topography and ERA5 datasets are loaded as simple xarray datasets
topography_ds = topography_processor.load_ds(standardise_var_names=True, standardise_coord_names=True)
era5_ds = era5_processor.load_ds(mode="surface", years=years, standardise_var_names=True, standardise_coord_names=True)