In [1]:
"""
Un des enjeu sur le DeepLearning et en particulier le Hs NN du WV est d etre capable
de refaire le meme model heteroskedastic_2017.h5 que celui fourni par Sadowski en feb 2021)
Je vais utilier les mêmes libs et le meme training dataset
A Grouazel
April 2021
based on the ntebook https://github.com/hawaii-ai/SAR-Wave-Height/blob/master/notebooks/train_model_heteroskedastic.ipynb
"""

'\nUn des enjeu sur le DeepLearning et en particulier le Hs NN du WV est d etre capable\nde refaire le meme model heteroskedastic_2017.h5 que celui fourni par Sadowski en feb 2021)\nJe vais utilier les mêmes libs et le meme training dataset\nA Grouazel\nApril 2021\n'

In [1]:
# Train neural network to predict significant wave height from SAR spectra.
# Train with heteroskedastic regression uncertainty estimates.
# Author: Peter Sadowski, Dec 2020
import os, sys
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true' # Needed to avoid cudnn bug.
import numpy as np
import h5py

import tensorflow as tf
from tensorflow.keras.utils import Sequence, plot_model
from tensorflow.keras.callbacks import *
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model

#sys.path = ['../'] + sys.path
sys.path.append('/home1/datahome/agrouaze/git/sar_hs_nn/')
from sarhspredictor.lib.sarhs.generator import SARGenerator
from sarhspredictor.lib.sarhs.heteroskedastic import Gaussian_NLL, Gaussian_MSE

# model definition

In [2]:
def define_model():
    # Low-level features.
    inputs = Input(shape=(72, 60, 2))
    x = Conv2D(64, (3, 3), activation='relu')(inputs)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = Conv2D(128, (3, 3), activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = Conv2D(256, (3, 3), activation='relu')(x)
    x = MaxPooling2D(pool_size=(2, 2))(x)

    x = GlobalMaxPooling2D()(x)
    x = Dense(256, activation='relu')(x)
    #x = Dropout(0.5)(x)
    x = Dense(256, activation='relu')(x)
    x = Dropout(0.5)(x)
    cnn = Model(inputs, x)

    # High-level features.
    inp = Input(shape=(32, ))  # 'hsSM', 'hsWW3v2', 'hsALT', 'altID', 'target' -> dropped
    x = Dense(units=256, activation='relu')(inp)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dense(units=256, activation='relu')(x)
    #x = Dropout(0.5)(x)
    x = Dense(units=256, activation='relu')(x)
    #x = Dropout(0.5)(x)
    x = Dense(units=256, activation='relu')(x)
    x = Dropout(0.5)(x)
    ann = Model(inputs=inp, outputs=x)
    
    # Combine
    combinedInput = concatenate([cnn.output, ann.output])
    x = Dense(256, activation="relu")(combinedInput)
    x = Dropout(0.5)(x)
    x = Dense(256, activation="relu", name='penultimate')(x)  
    x = Dropout(0.5)(x)
    x = Dense(2, activation="softplus", name='output')(x)
    model = Model(inputs=[cnn.input, ann.input], outputs=x)
    return model

In [3]:
momo = define_model()
print(momo.summary())

Model: "model_2"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_2 (InputLayer)            [(None, 32)]         0                                            
__________________________________________________________________________________________________
dense_2 (Dense)                 (None, 256)          8448        input_2[0][0]                    
__________________________________________________________________________________________________
input_1 (InputLayer)            [(None, 72, 60, 2)]  0                                            
__________________________________________________________________________________________________
dense_3 (Dense)                 (None, 256)          65792       dense_2[0][0]                    
____________________________________________________________________________________________

# cell added to get equivalent file sar_hs.h5
 inspired from https://github.com/hawaii-ai/SAR-Wave-Height/blob/master/scripts/create_dataset_from_nc.ipynb
 pas mal de petite modif sur le nom des variables et avec les variables deja assemblees
 
 dataflow:
 training dataset ALT_...nc -> ALT_...processed.nc -> aggregate.h5 -> split by groups .h5

In [4]:
# Reads NetCDF4 file, preprocesses data, and writes hdf5 file.
# This is much simpler than aggregating multiple files, then
# performing preprocessing.
# Author: Peter Sadowski, Dec 2020
import os, sys, h5py
import numpy as np
import glob
from netCDF4 import Dataset
import time
import traceback
from sarhspredictor.lib.sarhs import preprocess

# Source and destination filenames.
#file_src  = "/mnt/lts/nfs_fs02/sadow_lab/preserve/stopa/sar_hs/data/S1B_201905_test01S/S1B_201905_test01S.nc"  # Example file containing single observation.
#file_dest = "/mnt/lts/nfs_fs02/sadow_lab/preserve/stopa/sar_hs/data/S1B_201905_test01S/S1B_201905_test01S_processed.h5"
#file_src = '/home/datawork-cersat-public/cache/project/mpc-sentinel1/analysis/s1_data_analysis/hs_nn/quach2020/validation/input_output/final2/S1B_20190501_ifr_tmp_input_output_quach2020_pythonv2.nc'
file_src = '/home1/datawork/agrouaze/data/sentinel1/cwave/training_dataset_quach2020_python_v2/S1A_ALT_coloc201501S.nc'
out_dd = '/home1/scratch/agrouaze/training_quach_redo_model/'
for sat in ['S1A','S1B']:
    lst_training_files = glob.glob(os.path.join('/home1/datawork/agrouaze/data/sentinel1/cwave/training_dataset_quach2020_python_v2/',sat+'*.nc'))
    if sat=='S1A':
        satellite = 1 # 1=S1A, 0=S1B
    else:
        satellite = 0
    print('nb input files to train for %s : %s'%(sat,len(lst_training_files)))
    for ffii,file_src in enumerate(lst_training_files):
        file_dest = os.path.join(out_dd,os.path.basename(file_src).replace('.nc','_processed.nc'))
        print('file_dest',file_dest,ffii,'/',len(lst_training_files))
        if os.path.exists(os.path.dirname(file_dest)) is False:
            os.makedirs(os.path.dirname(file_dest))
            print('outputdir mkdir')
        # These variables are expected in the source file.
        keys = ['timeSAR', 'lonSAR',  'latSAR', 'incidenceAngle', 'cspcRe', 'cspcIm','py_S','sigma0','normalizedVariance'] # Needed for predictions.
        t0 = time.time()
        try:
            h5py.File(file_dest, 'r').close() #try to close the file if it is opened before
        except:
            print('traceback',traceback.format_exc())
            pass
            
        with Dataset(file_src) as fs, h5py.File(file_dest, 'w') as fd:
            # Check input file.
            src = fs.variables
            for k in keys:
                if k not in src.keys():
                    raise IOError(f'Variable {k} not found in input file.')
            num_examples = src[keys[0]].shape[0]
            print(f'Found {num_examples} events.')

            # Get 22 CWAVE features. Concatenate 20 parameters with sigma0 and normVar.
            #src['S'].set_auto_scale(False) # Some of the NetCDF4 files had some weird scaling.
            S = np.array(src['py_S'][:]) #* float(src['py_S'].scale_factor))
            cwave = np.hstack([S, src['sigma0'][:].reshape(-1,1), src['normalizedVariance'][:].reshape(-1,1)])
            #cwave = src['cwave'][:]
            cwave = preprocess.conv_cwave(cwave) # Remove extrema, then standardize with hardcoded mean, vars.
            fd.create_dataset('cwave', data=cwave)

            # Observation meta data.
            latSAR, lonSAR = src['latSAR'][:], src['lonSAR'][:]
            latSARcossin = preprocess.conv_position(latSAR) # Computes cos and sin used by NN.
            lonSARcossin = preprocess.conv_position(lonSAR)
            #latlonSARcossin = src['latlonSARcossin'][:]
            fd.create_dataset('latlonSAR', data=np.column_stack([latSAR, lonSAR]))
            fd.create_dataset('latlonSARcossin', data=np.hstack([latSARcossin, lonSARcossin]))
            #fd.create_dataset('latlonSARcossin', data=latlonSARcossin)

            timeSAR = src['timeSAR'][:]
            todSAR = preprocess.conv_time(timeSAR)
            #todSAR = src['todSAR'][:]
            fd.create_dataset('timeSAR', data=timeSAR, shape=(timeSAR.shape[0], 1))
            fd.create_dataset('todSAR', data=todSAR, shape=(todSAR.shape[0], 1))

            incidence = preprocess.conv_incidence(src['incidenceAngle'][:]) # Separates into 2 var.
            fd.create_dataset('incidence', data=incidence)

            satellite_indicator = np.ones((src['timeSAR'].shape[0], 1), dtype=float) * satellite
            fd.create_dataset('satellite', data=satellite_indicator, shape=(satellite_indicator.shape[0], 1))

            # Spectral data.
            re = preprocess.conv_real(src['cspcRe'][:])
            im = preprocess.conv_imaginary(src['cspcIm'][:])
            x = np.stack((re, im), axis=3)
            fd.create_dataset('spectrum', data=x)

            # Altimeter features.
            hsALT = src['hsALT'][:]
            fd.create_dataset('hsALT', data=hsALT, shape=(hsALT.shape[0], 1))
            dx = preprocess.conv_dx(src['dx'][:])
            dt = preprocess.conv_dt(src['dt'][:])
            fd.create_dataset('dxdt', data=np.column_stack([dx, dt]))
            
            timeALT = src['timeALT'][:] #added by agrouaze
            fd.create_dataset('timeALT',data=timeALT, shape=(todSAR.shape[0], 1))
            
            lonALT = src['lonALT'][:] #added by agrouaze
            fd.create_dataset('lonALT', data=lonALT)
            
            latALT = src['latALT'][:] #added by agrouaze
            fd.create_dataset('latALT', data=latALT)
              
            fd.create_dataset('hsSM', data=src['hsSM'][:]) #added by agrouaze
            fd.create_dataset('nk', data=src['nk'][:]) #added by agrouaze
            fd.create_dataset('dx', data=src['dx'][:]) #added by agrouaze
            fd.create_dataset('dt', data=src['dt'][:]) #added by agrouaze
            fd.create_dataset('sigma0', data=src['sigma0'][:]) #added by agrouaze
            fd.create_dataset('normalizedVariance', data=src['normalizedVariance'][:]) #added by agrouaze
            fd.create_dataset('incidenceAngle', data=src['incidenceAngle'][:]) #added by agrouaze
            fd.create_dataset('lonSAR', data=src['lonSAR'][:]) #added by agrouaze
            fd.create_dataset('latSAR', data=src['latSAR'][:]) #added by agrouaze
            fd.create_dataset('cspcRe', data=src['cspcRe'][:]) #added by agrouaze
            fd.create_dataset('cspcIm', data=src['cspcIm'][:]) #added by agrouaze
            fd.create_dataset('py_S', data=S) #added by agrouaze
        print('elapsed time to build %s: %1.3f seconds'%(file_dest,time.time()-t0))

nb input files to train for S1A : 42
file_dest /home1/scratch/agrouaze/training_quach_redo_model/S1A_ALT_coloc201607S_processed.nc
Found 11346 events.


  a.partition(kth, axis=axis, kind=kind, order=order)
cannot be safely cast to variable data type
cannot be safely cast to variable data type


elapsed time to build /home1/scratch/agrouaze/training_quach_redo_model/S1A_ALT_coloc201607S_processed.nc: 12.527 seconds
file_dest /home1/scratch/agrouaze/training_quach_redo_model/S1A_ALT_coloc201601S_processed.nc
Found 9760 events.


cannot be safely cast to variable data type


elapsed time to build /home1/scratch/agrouaze/training_quach_redo_model/S1A_ALT_coloc201601S_processed.nc: 10.456 seconds
file_dest /home1/scratch/agrouaze/training_quach_redo_model/S1A_ALT_coloc201705S_processed.nc
Found 139 events.


cannot be safely cast to variable data type


elapsed time to build /home1/scratch/agrouaze/training_quach_redo_model/S1A_ALT_coloc201705S_processed.nc: 0.663 seconds
file_dest /home1/scratch/agrouaze/training_quach_redo_model/S1A_ALT_coloc201602S_processed.nc
Found 12874 events.
elapsed time to build /home1/scratch/agrouaze/training_quach_redo_model/S1A_ALT_coloc201602S_processed.nc: 13.664 seconds
file_dest /home1/scratch/agrouaze/training_quach_redo_model/S1A_ALT_coloc201707S_processed.nc
Found 18632 events.
elapsed time to build /home1/scratch/agrouaze/training_quach_redo_model/S1A_ALT_coloc201707S_processed.nc: 20.422 seconds
file_dest /home1/scratch/agrouaze/training_quach_redo_model/S1A_ALT_coloc201506S_processed.nc
Found 5565 events.
elapsed time to build /home1/scratch/agrouaze/training_quach_redo_model/S1A_ALT_coloc201506S_processed.nc: 6.063 seconds
file_dest /home1/scratch/agrouaze/training_quach_redo_model/S1A_ALT_coloc201610S_processed.nc
Found 15740 events.
elapsed time to build /home1/scratch/agrouaze/training_quac

# aggregate the monthly processed files

In [6]:
from sarhspredictor.bin import aggregate_monthly_training_files
import glob
import logging
from importlib import reload
reload(logging)
logging.basicConfig(level=logging.INFO)
from importlib import reload
reload(aggregate_monthly_training_files)
files_src = sorted(glob.glob('/home1/scratch/agrouaze/training_quach_redo_model/*_processed.nc'))
print(f'Found {len(files_src)} files.')
print(files_src[0])
# file_dest =  "/mnt/lts/nfs_fs02/sadow_lab/preserve/stopa/sar_hs/data/alt/aggregated_ALT.h5"
# file_dest =  "/mnt/tmp/psadow/sar/aggregated_ALT.h5"
# file_dest = "/mnt/tmp/psadow/sar/aggregated_2019.h5"
# file_dest =  "/mnt/lts/nfs_fs02/sadow_lab/preserve/stopa/sar_hs/data/alt/aggregated_2019.h5"
file_dest = os.path.join('/home1/scratch/agrouaze/training_quach_redo_model/',"aggregated.h5")

# keys = ['timeSAR', 'timeALT', 'lonSAR', 'lonALT', 'latSAR', 'latALT', 'hsALT', 'dx', 'dt', 'nk', 'hsSM', 'incidenceAngle', 'sigma0', 'normalizedVariance', 'S']
# keys = ['timeSAR', 'lonSAR',  'latSAR', 'incidenceAngle', 'sigma0', 'normalizedVariance', 'S']
# keys += ['cspcRe', 'cspcIm']
# keys = ['timeSAR', 'lonSAR',  'latSAR', 'incidenceAngle', 'sigma0', 'normalizedVariance', 'py_S', 'cspcRe', 'cspcIm'] #'py_cspcRe', 'py_cspcIm']
keys = ['timeSAR','timeALT','lonSAR','lonALT','latSAR','latALT','hsALT','dx','dt','nk','hsSM','incidenceAngle','sigma0',
        'normalizedVariance','cspcRe','cspcIm','cwave','todSAR','py_S'] #
aggregate_monthly_training_files.aggregate(files_src,file_dest,keys=keys)
logging.info('done')

Found 65 files.
/home1/scratch/agrouaze/training_quach_redo_model/S1A_ALT_coloc201501S_processed.nc


  0%|          | 0/65 [00:00<?, ?it/s]

['S1A', 'ALT', 'coloc201501S', 'processed', 'nc']


  2%|▏         | 1/65 [00:02<02:29,  2.34s/it]

['S1A', 'ALT', 'coloc201502S', 'processed', 'nc']


  3%|▎         | 2/65 [00:03<02:06,  2.01s/it]

['S1A', 'ALT', 'coloc201503S', 'processed', 'nc']


  5%|▍         | 3/65 [00:04<01:51,  1.79s/it]

['S1A', 'ALT', 'coloc201504S', 'processed', 'nc']


  6%|▌         | 4/65 [00:06<01:37,  1.60s/it]

['S1A', 'ALT', 'coloc201505S', 'processed', 'nc']


  8%|▊         | 5/65 [00:08<01:43,  1.73s/it]

['S1A', 'ALT', 'coloc201506S', 'processed', 'nc']


  9%|▉         | 6/65 [00:13<02:44,  2.79s/it]

['S1A', 'ALT', 'coloc201507S', 'processed', 'nc']


 11%|█         | 7/65 [00:17<03:07,  3.24s/it]

['S1A', 'ALT', 'coloc201508S', 'processed', 'nc']


 12%|█▏        | 8/65 [00:24<04:04,  4.30s/it]

['S1A', 'ALT', 'coloc201509S', 'processed', 'nc']


 14%|█▍        | 9/65 [00:29<04:21,  4.66s/it]

['S1A', 'ALT', 'coloc201510S', 'processed', 'nc']


 15%|█▌        | 10/65 [00:37<05:07,  5.60s/it]

['S1A', 'ALT', 'coloc201511S', 'processed', 'nc']


 17%|█▋        | 11/65 [00:43<04:57,  5.52s/it]

['S1A', 'ALT', 'coloc201512S', 'processed', 'nc']


 18%|█▊        | 12/65 [00:51<05:36,  6.35s/it]

['S1A', 'ALT', 'coloc201601S', 'processed', 'nc']


 20%|██        | 13/65 [00:58<05:36,  6.47s/it]

['S1A', 'ALT', 'coloc201602S', 'processed', 'nc']


 22%|██▏       | 14/65 [01:07<06:13,  7.32s/it]

['S1A', 'ALT', 'coloc201603S', 'processed', 'nc']


 23%|██▎       | 15/65 [01:15<06:14,  7.48s/it]

['S1A', 'ALT', 'coloc201604S', 'processed', 'nc']


 25%|██▍       | 16/65 [01:27<07:15,  8.88s/it]

['S1A', 'ALT', 'coloc201605S', 'processed', 'nc']


 26%|██▌       | 17/65 [01:36<07:07,  8.91s/it]

['S1A', 'ALT', 'coloc201606S', 'processed', 'nc']


 28%|██▊       | 18/65 [01:42<06:26,  8.21s/it]

['S1A', 'ALT', 'coloc201607S', 'processed', 'nc']


 29%|██▉       | 19/65 [01:50<06:10,  8.05s/it]

['S1A', 'ALT', 'coloc201608S', 'processed', 'nc']


 31%|███       | 20/65 [02:01<06:34,  8.77s/it]

['S1A', 'ALT', 'coloc201609S', 'processed', 'nc']


 32%|███▏      | 21/65 [02:09<06:16,  8.55s/it]

['S1A', 'ALT', 'coloc201610S', 'processed', 'nc']


 34%|███▍      | 22/65 [02:19<06:34,  9.18s/it]

['S1A', 'ALT', 'coloc201611S', 'processed', 'nc']


 35%|███▌      | 23/65 [02:30<06:46,  9.69s/it]

['S1A', 'ALT', 'coloc201612S', 'processed', 'nc']


 37%|███▋      | 24/65 [02:42<07:08, 10.45s/it]

['S1A', 'ALT', 'coloc201701S', 'processed', 'nc']


 38%|███▊      | 25/65 [02:54<07:08, 10.70s/it]

['S1A', 'ALT', 'coloc201702S', 'processed', 'nc']


 40%|████      | 26/65 [03:03<06:44, 10.38s/it]

['S1A', 'ALT', 'coloc201703S', 'processed', 'nc']


 42%|████▏     | 27/65 [03:13<06:23, 10.09s/it]

['S1A', 'ALT', 'coloc201705S', 'processed', 'nc']


 43%|████▎     | 28/65 [03:13<04:29,  7.28s/it]

['S1A', 'ALT', 'coloc201706S', 'processed', 'nc']


 45%|████▍     | 29/65 [03:20<04:13,  7.04s/it]

['S1A', 'ALT', 'coloc201707S', 'processed', 'nc']


 46%|████▌     | 30/65 [03:33<05:05,  8.72s/it]

['S1A', 'ALT', 'coloc201708S', 'processed', 'nc']


 48%|████▊     | 31/65 [03:43<05:18,  9.36s/it]

['S1A', 'ALT', 'coloc201709S', 'processed', 'nc']


 49%|████▉     | 32/65 [03:52<05:06,  9.29s/it]

['S1A', 'ALT', 'coloc201710S', 'processed', 'nc']


 51%|█████     | 33/65 [04:01<04:47,  8.97s/it]

['S1A', 'ALT', 'coloc201711S', 'processed', 'nc']


 52%|█████▏    | 34/65 [04:11<04:52,  9.43s/it]

['S1A', 'ALT', 'coloc201712S', 'processed', 'nc']


 54%|█████▍    | 35/65 [04:20<04:33,  9.13s/it]

['S1A', 'ALT', 'coloc201801S', 'processed', 'nc']


 55%|█████▌    | 36/65 [04:29<04:25,  9.14s/it]

['S1A', 'ALT', 'coloc201802S', 'processed', 'nc']


 57%|█████▋    | 37/65 [04:38<04:13,  9.05s/it]

['S1A', 'ALT', 'coloc201803S', 'processed', 'nc']


 58%|█████▊    | 38/65 [04:52<04:49, 10.72s/it]

['S1A', 'ALT', 'coloc201804S', 'processed', 'nc']


 60%|██████    | 39/65 [05:02<04:29, 10.38s/it]

['S1A', 'ALT', 'coloc201805S', 'processed', 'nc']


 62%|██████▏   | 40/65 [05:13<04:23, 10.54s/it]

['S1A', 'ALT', 'coloc201806S', 'processed', 'nc']


 63%|██████▎   | 41/65 [05:18<03:32,  8.87s/it]

['S1A', 'ALT', 'coloc201807S', 'processed', 'nc']


 65%|██████▍   | 42/65 [05:22<02:50,  7.41s/it]

['S1B', 'ALT', 'coloc201606S', 'processed', 'nc']


 66%|██████▌   | 43/65 [05:23<01:59,  5.45s/it]

['S1B', 'ALT', 'coloc201607S', 'processed', 'nc']


 68%|██████▊   | 44/65 [05:29<01:59,  5.68s/it]

['S1B', 'ALT', 'coloc201608S', 'processed', 'nc']


 69%|██████▉   | 45/65 [05:40<02:25,  7.28s/it]

['S1B', 'ALT', 'coloc201609S', 'processed', 'nc']


 71%|███████   | 46/65 [05:49<02:27,  7.77s/it]

['S1B', 'ALT', 'coloc201610S', 'processed', 'nc']


 72%|███████▏  | 47/65 [06:00<02:40,  8.93s/it]

['S1B', 'ALT', 'coloc201611S', 'processed', 'nc']


 74%|███████▍  | 48/65 [06:12<02:47,  9.84s/it]

['S1B', 'ALT', 'coloc201612S', 'processed', 'nc']


 75%|███████▌  | 49/65 [06:26<02:54, 10.92s/it]

['S1B', 'ALT', 'coloc201701S', 'processed', 'nc']


 77%|███████▋  | 50/65 [06:37<02:44, 10.98s/it]

['S1B', 'ALT', 'coloc201702S', 'processed', 'nc']


 78%|███████▊  | 51/65 [06:47<02:31, 10.79s/it]

['S1B', 'ALT', 'coloc201703S', 'processed', 'nc']


 80%|████████  | 52/65 [06:53<01:59,  9.18s/it]

['S1B', 'ALT', 'coloc201707S', 'processed', 'nc']


 82%|████████▏ | 53/65 [07:05<02:02, 10.21s/it]

['S1B', 'ALT', 'coloc201708S', 'processed', 'nc']


 83%|████████▎ | 54/65 [07:17<01:56, 10.59s/it]

['S1B', 'ALT', 'coloc201709S', 'processed', 'nc']


 85%|████████▍ | 55/65 [07:26<01:43, 10.31s/it]

['S1B', 'ALT', 'coloc201710S', 'processed', 'nc']


 86%|████████▌ | 56/65 [07:35<01:27,  9.74s/it]

['S1B', 'ALT', 'coloc201711S', 'processed', 'nc']


 88%|████████▊ | 57/65 [07:45<01:19,  9.96s/it]

['S1B', 'ALT', 'coloc201712S', 'processed', 'nc']


 89%|████████▉ | 58/65 [07:54<01:06,  9.54s/it]

['S1B', 'ALT', 'coloc201801S', 'processed', 'nc']


 91%|█████████ | 59/65 [08:04<00:57,  9.61s/it]

['S1B', 'ALT', 'coloc201802S', 'processed', 'nc']


 92%|█████████▏| 60/65 [08:13<00:48,  9.65s/it]

['S1B', 'ALT', 'coloc201803S', 'processed', 'nc']


 94%|█████████▍| 61/65 [08:31<00:47, 11.93s/it]

['S1B', 'ALT', 'coloc201804S', 'processed', 'nc']


 95%|█████████▌| 62/65 [08:41<00:34, 11.50s/it]

['S1B', 'ALT', 'coloc201805S', 'processed', 'nc']


 97%|█████████▋| 63/65 [08:52<00:22, 11.41s/it]

['S1B', 'ALT', 'coloc201806S', 'processed', 'nc']


 98%|█████████▊| 64/65 [08:57<00:09,  9.52s/it]

['S1B', 'ALT', 'coloc201807S', 'processed', 'nc']


100%|██████████| 65/65 [09:09<00:00,  8.45s/it]
INFO:root:done


# split by groups

In [5]:
# the training dataset must be separated into sub groups
# long long task (about 30min)
import split_aggregated_into_groups
from importlib import reload
reload(split_aggregated_into_groups)
file_src2 = os.path.join('/home1/scratch/agrouaze/training_quach_redo_model/',"aggregated.h5")
print('source ',file_src2)
file_dest2 = '/home1/scratch/agrouaze/training_quach_redo_model/aggregated_grouped_final.h5'
if os.path.exists(file_dest2):
    os.remove(file_dest2)
split_aggregated_into_groups.split_aggregated_ds_v2(file_src2,file_dest2)

source  /home1/scratch/agrouaze/training_quach_redo_model/aggregated.h5
k cspcIm
k cspcRe
k cwave
k dt
k dx
k hsALT
k hsSM
k incidenceAngle
k latALT
k latSAR
k lonALT
k lonSAR
k month
k nk
k normalizedVariance
k py_S
k satellite
k sigma0
k timeALT
k timeSAR
k todSAR
k year
start creating the final .h5 file
indices 2015 (766426,)
indices 2016 (766426,)
Found 304379 events from years:  [2015, 2016]
timeSAR (766426, 1)
Done with [2015, 2016]
indices 2017 (766426,)
Found 265052 events from years:  [2017]
timeSAR (766426, 1)
Done with [2017]
indices 2018 (766426,)
Found 185151 events from years:  [2018]
timeSAR (766426, 1)
Done with [2018]
Done


# training keras

In [None]:
# Train
from sarhspredictor.config import model_IFR_replication_quach2020_sadowski_release_5feb2021
#file_model = '/home1/scratch/agrouaze/heteroskedastic_2017_agrouaze.h5'
file_model = model_IFR_replication_quach2020_sadowski_release_5feb2021 #
print(file_model)
model = define_model()
model.compile(loss=Gaussian_NLL, optimizer=Adam(lr=0.0001), metrics=[Gaussian_MSE])

# Dataset
batch_size = 128
epochs = 123
#filename = '../../data/alt/sar_hs.h5'
#filename = '/mnt/tmp/psadow/sar/sar_hs.h5'
filename = file_dest2
print(file_dest2)
train = SARGenerator(filename=filename, 
                     subgroups=['2015_2016', '2017'], 
                     batch_size=batch_size)
valid = SARGenerator(filename=filename, subgroups=['2018'], batch_size=batch_size)
# filename = '/mnt/tmp/psadow/sar/sar_hs.h5'
# epochs = 25
# train = SARGenerator(filename=filename, 
#                      subgroups=['2015_2016', '2017', '2018'], # Train on all data without early stopping.
#                      batch_size=batch_size)

# Callbacks
# This LR schedule is slower than in the paper.
reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.9, patience=1) 
check = ModelCheckpoint(file_model, monitor='val_loss', verbose=0,
                        save_best_only=True, save_weights_only=False,
                        mode='auto', save_freq='epoch')
stop = EarlyStopping(monitor='val_loss', min_delta=0, patience=10, verbose=0, 
                     mode='auto', baseline=None, restore_best_weights=False)
clbks = [reduce_lr, check, stop]

history = model.fit(train,
                    epochs=epochs,
                    validation_data=valid,
                    callbacks=clbks,
                    verbose=1)

/home1/datahome/agrouaze/sources/sentinel1/hs_total/validation_quach2020/heteroskedastic_2017_agrouaze.h5
/home1/scratch/agrouaze/training_quach_redo_model/aggregated_grouped_final.h5
Epoch 1/123
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has no attribute 'Index'
Please report this to the TensorFlow team. When filing the bug, set the verbosity to 10 (on Linux, `export AUTOGRAPH_VERBOSITY=10`) and attach the full output.
Cause: module 'gast' has n