# Train solar models

In [None]:
# import packages
import json
import logging
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
%matplotlib inline
import imp
import numpy as np
import os
import random
import rasterio
import shapely
import tensorflow as tf

import descarteslabs as dl

In [None]:
# Import local modules
import train
import generator
import transforms

In [None]:
# Define parameters
# Note, setting epochs, steps to 2 for demonstration

# For full training, use:
# params = train.params

# For testing, define the parameters here
params = {
    'seed': 21,  # for train/val data split                                                                                                                                                                                

    # Training data specifications                                                                                                                                                                                         
    # DATASET METADATA #                                                                                                                                                                                                   
    'data_metadata': {
        'products': ['airbus:oneatlas:spot:v2'],
        'bands': ['red', 'green', 'blue', 'nir'],
        'resolution': 1.5,
        'start_datetime': '2016-01-01',
        'end_datetime': '2018-12-31',
        'tilesize': 512,
        'pad': 0,
    },

    # GLOBAL METADATA #                                                                                                                                                                                                    
    'global_metadata': {
        'local_ground': 'ground/', # directory containing image-target pairs                                                                                                                                               
        'local_model': 'model/', # directory to write this model                                                                                                                                                               
    },

    # MODEL METADATA                                                                                                                                                                                                       
    'model_name': 'solar_pv_airbus_spot_rgbn_v5',

    # TRAINING METADATA #                                                                                                                                                                                                  
    # Metadata to define the training stage                                                                                                                                                                                
    'training_kwargs': {
        'datalist': 'train_keys.txt',
        'batchsize': 16,
        'val_datalist': 'val_keys.txt',
        'val_batchsize': 16,
        'epochs': 1, #150,
        'steps_per_epoch': 2,
        'image_dim': (512, 512, 4) # This is the size of the training images                                                                                                                                               
    },
    'transforms': [
        transforms.CastTransform(feature_type='float32', target_type='bool'),
        transforms.SquareImageTransform(),
        transforms.AdditiveNoiseTransform(additive_noise=30.),
        transforms.MultiplicativeNoiseTransform(multiplicative_noise=0.3),
        transforms.NormalizeFeatureTransform(mean=128., std=1.),
        transforms.FlipFeatureTargetTransform(),
    ],
}


In [None]:
print(params['training_kwargs'])

In [None]:
# Train the model
train.train_from_document(params=params)

In [None]:
!cat 'model/train_solar_pv_airbus_spot_rgbn_v5.log'

## Load the model and predict on one training image

In [None]:
model = tf.keras.models.load_model('model/solar_pv_airbus_spot_rgbn_v5.hdf5')

In [None]:
trf = [
    transforms.CastTransform(feature_type='float32', target_type='bool'),
    transforms.SquareImageTransform(),
    transforms.NormalizeFeatureTransform(mean=128., std=1.),
]

In [None]:
kw_train = params['training_kwargs']
data_list = os.path.join(params['global_metadata']['local_ground'], kw_train['datalist'])

trn_generator = generator.DataGenerator(data_list, batch_size=2, dim=(512,512, 4),
                              shuffle=False, augment=True,
                              transforms=trf,
                             )

In [None]:
img, trg = trn_generator.__getitem__(0)

In [None]:
def img_plt(img):
    return np.clip((img+128).astype('uint8'), 0, 255)

ii=0
fig, ax = plt.subplots(1,2, figsize=(10,8))
ax[0].imshow(img_plt(img[ii,:,:,:3]))
ax[1].imshow(img_plt(trg[ii,:,:,:].squeeze()))

In [None]:
proba = model.predict(img)

In [None]:
proba.shape

In [None]:
plt.imshow(proba[0,...,0].squeeze())