In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
CUDA_VISIBLE_DEVICES=0
import tensorflow as tf
print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
print("Num GPUs Available: ", tf.config.experimental.list_physical_devices('GPU'))

import io
import os
import pickle
import h5py
import argparse
from datetime import datetime
from functools import partial
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv3D, ConvLSTM2D, BatchNormalization
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
from solarsat.utils import SolarSatSequence


2024-04-29 16:03:20.111201: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcuda.so.1
2024-04-29 16:03:20.124203: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1555] Found device 0 with properties: 
pciBusID: 0000:2f:00.0 name: Tesla V100-PCIE-16GB computeCapability: 7.0
coreClock: 1.38GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2024-04-29 16:03:20.125168: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcudart.so.10.1
2024-04-29 16:03:20.212929: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcublas.so.10
2024-04-29 16:03:20.267524: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcufft.so.10
2024-04-29 16:03:20.309089: I tensorflow/stream_executor/platform/default/dso_loader.cc:44] Successfully opened dynamic library libcurand.so

Num GPUs Available:  1
Num GPUs Available:  [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [2]:
# from solarsat.display import get_cmap
HOME_PATH = '/gpfs/data1/lianggp/lir/solar_data/process_results'
# Set up callbacks
datetag=datetime.now().strftime("%Y%m%d_%H%M%S")
logdir = f'{HOME_PATH}/solarsat_cnnlstm/logs/' + datetag
print(logdir)
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=logdir)
file_writer_tstimg = tf.summary.create_file_writer(logdir+'/imgs')

netsave_dir = f'{HOME_PATH}/solarsat_cnnlstm/trained_networks/'+datetag
Path(netsave_dir).mkdir(parents=True,exist_ok=True)

def CNNLSTM():
    seq = Sequential()
    seq.add(BatchNormalization())
    seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
                    input_shape=(None, 60,60, 1),
                    padding='same', return_sequences=True))
    seq.add(BatchNormalization())

    seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
                    padding='same', return_sequences=True))
    seq.add(BatchNormalization())

    seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
                    padding='same', return_sequences=True))
    seq.add(BatchNormalization())

    seq.add(ConvLSTM2D(filters=40, kernel_size=(3, 3),
                    padding='same', return_sequences=True))
    seq.add(BatchNormalization())

    seq.add(Conv3D(filters=1, kernel_size=(3, 3, 3),
                activation='sigmoid',
                padding='same', data_format='channels_last'))
    seq.compile(loss='mse', optimizer='adadelta')
    return seq
    

def main():
    # args = parse_args()
    x_trn,x_tst = load_datasets(n_batches_train=50,n_batches_test=10)  
    print(len(x_trn))
    print(len(x_trn[0]))
    print(x_trn[0][0].shape)
    
    # if args.version=='1':
    #     from sevir.models.vaes.classes import VAE_v1 as VAE
    # elif args.version=='2':
    #     from sevir.models.vaes.classes import VAE_v2 as VAE
    cnnlstm = CNNLSTM()

    # Callbacks
    def predict(x):
        yhat=cnnlstm.predict(x)
        return yhat
        # return SEVIRSequence.unnormalize(yhat,(1/255,0))
    testimg_cb = tf.keras.callbacks.LambdaCallback(
        on_epoch_end=partial(plot_test_images,x_test=x_tst[0],predict=predict) )
    checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(netsave_dir+'/weights.{epoch:02d}-{val_loss:.2f}.tf', 
                              monitor='val_loss', verbose=0, 
                              save_best_only=True, save_weights_only=False, 
                              mode='auto', period=1)
    

    x_trn_input = np.stack(x_trn[0], axis=-1)  # The new axis is the channel axis
    x_tst_input = np.stack(x_tst[0], axis=-1)  # The new axis is the channel axis
    print(x_trn_input.shape)

    cnnlstm.fit(x_trn_input,x_trn_input,
            epochs=20,
            batch_size=1,
            max_queue_size=10,
            validation_data=(x_tst_input,x_tst_input),
            workers=5,
            verbose=1,
            use_multiprocessing=True,
            # callbacks=[tensorboard_callback,
            #           testimg_cb,
            #           checkpoint_cb]
            )


def load_generators():
    # # date ranges for train/test
    # train_dates = (datetime(2018,1,1),datetime(2019,6,1))
    # test_dates  = (datetime(2019,6,1),datetime(2020,1,1))

    # Generate single images of weather radar echos
    trn_gen_file='/tmp/trn_gen.pkl'
    if not os.path.exists(trn_gen_file):
        data_gen_trn = SolarSatSequence(x_img_types=['ssr'],y_img_types=['ssr'],
                                     batch_size=2,
                                     n_batch_per_epoch=20,
                                     unwrap_time=True, # don't generate sequences
                                     shuffle=True,
                                    #  start_date=train_dates[0],
                                    #  end_date=train_dates[1],
                                    #  normalize_x=[(1/255,0)]
                                    )
        # data_gen_trn.save(trn_gen_file)
    else:
        data_gen_trn = SolarSatSequence.load(trn_gen_file)

    tst_gen_file='/tmp/tst_gen.pkl'
    if not os.path.exists(tst_gen_file):
        data_gen_tst = SolarSatSequence(x_img_types=['ssr'],y_img_types=['ssr'],
                                     batch_size=1,
                                     n_batch_per_epoch=1,
                                     unwrap_time=True, # don't generate sequences
                                     shuffle=True,
                                    #  start_date=train_dates[0],
                                    #  end_date=train_dates[1],
                                    #  normalize_x=[(1/255,0)]
                                    )
        # data_gen_tst.save(tst_gen_file)
    else:
        data_gen_tst = SolarSatSequence.load(tst_gen_file)

    return data_gen_trn,data_gen_tst


def load_datasets(n_batches_train=100,n_batches_test=50):
    
    trn_data_name = f'{HOME_PATH}/data'
    if not os.path.exists(trn_data_name):
        os.mkdir(trn_data_name)
    trn_data_name+='/solarsat_cnnlstm_train.h5'
    # if not os.path.exists(trn_data_name):
    #     # make it
    #     data_gen_trn,data_gen_tst = load_generators()
    #     print('Loading training data')
    #     x_train = data_gen_trn.load_batches(n_batches=n_batches_train,progress_bar=True)
    #     # print(x_train.shape)
    #     print('Loading test data')
    #     x_test = data_gen_tst.load_batches(n_batches=n_batches_test,progress_bar=True)
    #     with h5py.File(trn_data_name,'w') as hf:
    #         hf.create_dataset("TRAIN",data=x_train)
    #         hf.create_dataset("TEST",data=x_test)
    # else:
    #     with h5py.File(trn_data_name,'r') as hf:
    #         x_train = hf['TRAIN'][:]
    #         x_test = hf['TEST'][:]


    # make it
    data_gen_trn,data_gen_tst = load_generators()
    print('Loading training data')
    x_train = data_gen_trn.load_batches(n_batches=n_batches_train,progress_bar=True)
    # print(x_train.shape)
    print('Loading test data')
    x_test = data_gen_tst.load_batches(n_batches=n_batches_test,progress_bar=True)
    # print(len(x_train))

    return x_train,x_test


def plot_to_image(figure):
    """Converts the matplotlib plot specified by 'figure' to a PNG image and
    returns it. The supplied figure is closed and inaccessible after this call."""
    # Save the plot to a PNG in memory.
    buf = io.BytesIO()
    plt.savefig(buf, format='png')
    # Closing the figure prevents it from being displayed directly inside
    # the notebook.
    plt.close(figure)
    buf.seek(0)
    # Convert PNG buffer to TF image
    image = tf.image.decode_png(buf.getvalue(), channels=4)
    # Add the batch dimension
    image = tf.expand_dims(image, 0)
    return image



def plot_test_images(epoch,logs,x_test,predict):
    # cmap,norm,vmin,vmax = get_cmap('vil')
    fig,axs = plt.subplots(3,6,figsize=(15,10))
    np.random.seed(seed=2)
    idx = np.random.choice( x_test[0].shape[0],9)
    ii=0
    yhat=predict(x_test[0][idx])
    for i in range(3):
        for j in range(0,6,2):
            # xx=SEVIRSequence.unnormalize(x_test[0][idx[ii],:,:,0],(1/255,0))
            xx=x_test[0][idx[ii],:,:,0]
            # axs[i][j].imshow(xx,cmap=cmap,norm=norm,vmin=vmin,vmax=vmax)
            axs[i][j].imshow(xx)
            axs[i][j].set_xticks([], []), axs[i][j].set_yticks([], [])
            axs[i][j].set_xlabel('Original Image')
            # axs[i][j+1].imshow(yhat[ii,:,:,0],cmap=cmap,norm=norm,vmin=vmin,vmax=vmax)
            axs[i][j+1].imshow(yhat[ii,:,:,0])
            axs[i][j+1].set_xticks([], []), axs[i][j+1].set_yticks([], [])
            axs[i][j+1].set_xlabel('Decoded Image')
            ii+=1
    tst_images = plot_to_image(fig)
    # Log the confusion matrix as an image summary.
    with file_writer_tstimg.as_default():
        tf.summary.image("CNNLSTM Test Images", tst_images, step=epoch)








if __name__=='__main__':
    main()
    print('hhS')


/gpfs/data1/lianggp/lir/solar_data/process_results/solarsat_cnnlstm/logs/20240429_160332


2024-04-29 16:03:32.412470: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 AVX512F FMA
2024-04-29 16:03:32.421363: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2400000000 Hz
2024-04-29 16:03:32.423381: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x562d015a52b0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2024-04-29 16:03:32.423403: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2024-04-29 16:03:32.424329: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1555] Found device 0 with properties: 
pciBusID: 0000:2f:00.0 name: Tesla V100-PCIE-16GB computeCapability: 7.0
coreClock: 1.38GHz coreCount: 80 deviceMemorySize: 15.78GiB deviceMemoryBandwidth: 836.37GiB/s
2024-04-29 16:03:32.424374: I tensorflow/stream_executor/platform/default/dso

        tile  year  start_index
0     h15v03  2018           60
1     h15v03  2018           64
2     h15v03  2018          156
3     h15v03  2018          160
4     h15v03  2018          252
...      ...   ...          ...
1791  h15v03  2018        34816
1792  h15v03  2018        34908
1793  h15v03  2018        34912
1794  h15v03  2018        35004
1795  h15v03  2018        35008

[1796 rows x 3 columns]
        tile  year  start_index
0     h15v03  2018           60
1     h15v03  2018           64
2     h15v03  2018          156
3     h15v03  2018          160
4     h15v03  2018          252
...      ...   ...          ...
1791  h15v03  2018        34816
1792  h15v03  2018        34908
1793  h15v03  2018        34912
1794  h15v03  2018        35004
1795  h15v03  2018        35008

[1796 rows x 3 columns]
Loading training data


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 20/20 [00:09<00:00,  2.14it/s]


Loading test data


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  6.31it/s]

2
1
(40, 8, 30, 30)





(40, 8, 30, 30, 1)
Train on 40 samples, validate on 1 samples
Epoch 1/20


2024-04-29 16:03:54.390907: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:561] layout failed: Invalid argument: MutableGraphView::SortTopologically error: detected edge(s) creating cycle(s) {'Func/sequential/conv_lst_m2d_1/while_grad/body/_1098/input/_2220' -> 'sequential/conv_lst_m2d_1/while_grad/body/_1098/gradients/AddN', 'Func/sequential/conv_lst_m2d/while_grad/body/_1289/input/_2336' -> 'sequential/conv_lst_m2d/while_grad/body/_1289/gradients/AddN', 'Func/sequential/conv_lst_m2d_2/while_grad/body/_907/input/_2104' -> 'sequential/conv_lst_m2d_2/while_grad/body/_907/gradients/AddN', 'Func/sequential/conv_lst_m2d_3/while_grad/body/_717/input/_1988' -> 'sequential/conv_lst_m2d_3/while_grad/body/_717/gradients/AddN', 'sequential/conv_lst_m2d_3/while/body/_538/mul_5' -> 'sequential/conv_lst_m2d_3/while/body/_538/Identity_4', 'sequential/conv_lst_m2d_3/while/body/_538/mul_2' -> 'sequential/conv_lst_m2d_3/while/body/_538/add_5', 'sequential/conv_lst_m2d_2/while/body/_359/mul_5' 



2024-04-29 16:04:04.134717: E tensorflow/core/grappler/optimizers/meta_optimizer.cc:561] layout failed: Invalid argument: MutableGraphView::SortTopologically error: detected edge(s) creating cycle(s) {'Func/sequential/conv_lst_m2d_3/while/body/_148/input/_330' -> 'sequential/conv_lst_m2d_3/while/body/_148/mul_2', 'sequential/conv_lst_m2d_3/while/body/_148/Identity_4' -> 'sequential/conv_lst_m2d_3/while/next_iteration/_192', 'sequential/conv_lst_m2d_2/while/body/_99/Identity_4' -> 'sequential/conv_lst_m2d_2/while/next_iteration/_143', 'Func/sequential/conv_lst_m2d_2/while/body/_99/input/_292' -> 'sequential/conv_lst_m2d_2/while/body/_99/mul_2', 'sequential/conv_lst_m2d_1/while/body/_50/Identity_4' -> 'sequential/conv_lst_m2d_1/while/next_iteration/_94', 'sequential/conv_lst_m2d_1/while/body/_50/mul_2' -> 'sequential/conv_lst_m2d_1/while/body/_50/add_5', 'sequential/conv_lst_m2d/while/body/_1/add_5' -> 'sequential/conv_lst_m2d/while/body/_1/Identity_5', 'sequential/conv_lst_m2d/while/bod

Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20
hhS
