# CNN Sequence Model Dataset

This notebook shows how to use the `CNNSeqDataset` class to generate training datasets for the CNN Sequence image models. It requires that the image data has been download to the local machine first; see the notebook __Download Training Data from S3__ for instructions on how to do that.

In [1]:
# Required imports:

import os
import sys

sys.path.append(os.path.join("..", "code"))
from cnn_model import *


DIR = os.getcwd()
DATA_DIR = os.path.join(os.path.dirname(DIR), "data")
TRAIN_DIR = os.path.join(DATA_DIR, "training_data")
DEV_DIR = os.path.join(DATA_DIR, "img_dir")

### Check the docstring

In [2]:
CNNSeqDataset?

[0;31mInit signature:[0m
[0mCNNSeqDataset[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mprecip_dirs[0m[0;34m:[0m [0mlist[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mtemp_dirs[0m[0;34m:[0m [0mlist[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0met_dirs[0m[0;34m:[0m [0mlist[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mswe_dirs[0m[0;34m:[0m [0mlist[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0my_fp[0m[0;34m:[0m [0mstr[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0my_col[0m[0;34m:[0m [0mstr[0m [0;34m=[0m [0;34m'm3'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mn_d_precip[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;36m7[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mn_d_temp[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;36m7[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mn_d_et[0m[0;34m:[0m [0mint[0m [0;34m=[0m [0;36m8[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mswe_d_rel[0m[0;34m:[0m [0mlist[0m [0;34m=[0m [0mrange[0m[0;34m([0m[0;36m7[0m[

### Create an instance

In [3]:
DIR = os.getcwd()
DATA_DIR = os.path.join(os.path.dirname(DIR), "data")
TRAIN_DIR = os.path.join(DATA_DIR, "training_data")
Y_FP = os.path.join(DATA_DIR, "streamgage-full.csv")

cnn_data = CNNSeqDataset(
    
    precip_dirs = [TRAIN_DIR],
    temp_dirs = [TRAIN_DIR],
    et_dirs = [TRAIN_DIR],
    swe_dirs = [TRAIN_DIR],
    
    y_fp = Y_FP,
    y_col = 'm3',

    n_d_precip = 7,
    n_d_temp = 7,
    n_d_et = 8,
    swe_d_rel = range(7, 85, 7),
    n_d_y = 14,
    
    min_date = '2010_01_01',
    max_date = '2010_06_30',
    val_start = '2015_01_01',
    test_start = '2016_01_01',
    
    random_seed = 42,
    shuffle_train = True,
)

### Create the filepath generator

In [4]:
fp_gen = cnn_data.train_filepath_generator()

In [5]:
# Example usage:
next(fp_gen)

{'y': 2010-05-01    0.311485
 2010-05-02    0.311485
 2010-05-03    0.311485
 2010-05-04    0.283168
 2010-05-05    0.283168
 2010-05-06    0.283168
 2010-05-07    0.283168
 2010-05-08    0.311485
 2010-05-09    0.311485
 2010-05-10    0.311485
 2010-05-11    0.311485
 2010-05-12    0.311485
 2010-05-13    0.311485
 2010-05-14    0.311485
 Name: m3, dtype: float64,
 'temp': ['/Users/tp/projects/discharge-estimation/data/training_data/11208000__EPSG_4326__11131_95__ECMWF_ERA5_LAND_HOURLY__temperature_2m__2010_04_24.tif',
  '/Users/tp/projects/discharge-estimation/data/training_data/11208000__EPSG_4326__11131_95__ECMWF_ERA5_LAND_HOURLY__temperature_2m__2010_04_25.tif',
  '/Users/tp/projects/discharge-estimation/data/training_data/11208000__EPSG_4326__11131_95__ECMWF_ERA5_LAND_HOURLY__temperature_2m__2010_04_26.tif',
  '/Users/tp/projects/discharge-estimation/data/training_data/11208000__EPSG_4326__11131_95__ECMWF_ERA5_LAND_HOURLY__temperature_2m__2010_04_27.tif',
  '/Users/tp/projects/di

### Create the data generator

In [6]:
data_gen = cnn_data.train_data_generator()

In [7]:
# Example usage:
next(data_gen)

{'y': array([0.31148531, 0.31148531, 0.31148531, 0.28316847, 0.28316847,
        0.28316847, 0.28316847, 0.31148531, 0.31148531, 0.31148531,
        0.31148531, 0.31148531, 0.31148531, 0.31148531]),
 'temp': [array([[ 0.2824267 , -0.17220944, -0.39573893],
         [ 0.76749545,  0.26693508, -0.22122857]], dtype=float32),
  array([[ 0.6811154 ,  0.10682356, -0.10941901],
         [ 1.1166869 ,  0.5666891 ,  0.05795383]], dtype=float32),
  array([[0.7745476 , 0.20319271, 0.0075217 ],
         [1.2825272 , 0.7209651 , 0.20746155]], dtype=float32),
  array([[0.74667203, 0.25263447, 0.21487653],
         [1.323935  , 0.790432  , 0.35529163]], dtype=float32),
  array([[ 0.12749757, -0.30015096, -0.48252878],
         [ 0.64526564,  0.15519382, -0.29205725]], dtype=float32),
  array([[-0.69660705, -1.1343173 , -1.3260139 ],
         [-0.1503188 , -0.6516305 , -1.1185268 ]], dtype=float32),
  array([[-5.4362875e-01, -1.0025892e+00, -1.2399712e+00],
         [ 6.1044469e-04, -5.3219676e-01, -1