# Train the Models for some/all your datasets

### First just run the cell below, it should hopefully complete without error (expect some Warnings from TF)

In [1]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2


#load some packages in
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import random as python_random
from numba import njit
from tensorboard.plugins.hparams import api as hp
from stemutils.io import Path
import hyperspy.api as hs
import concurrent.futures
from skimage.transform import resize
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from functools import lru_cache

#set some variables
print('Using TensorFlow v%s' % tf.__version__)
plt.style.use('default')
python_random.seed(0)
np.random.seed(0)
tf.random.set_seed(0)


#define some functions

###################################################
########### Data Preprocessing ####################
###################################################

def batch_resize(d, bs=512):
    if len(d.shape) == 4:
        flat_d = flatten_nav(d)
    else:
        flat_d = d
    n_batches = int(np.ceil(flat_d.shape[0]//bs))
    batches = [flat_d[i*bs:(i+1)*bs] for i in range(n_batches+1)]
    if len(batches[-1])==0:
        batches.pop(-1)
    print(len(batches[-1]))
    with concurrent.futures.ProcessPoolExecutor() as exe:
        res = [exe.submit(resize, batch, (batch.shape[0],128,128)) for batch in batches]
    r_batches = [f.result() for f in res]
    return np.concatenate(r_batches, axis = 0).reshape((d.shape[0],128,128))

def data_manip(d, bs = 512):
    if type(d) != np.ndarray:
        print('dask to numpy')
        d = d.compute()
        print('dask to numpy done')
    print('started data manipulations')
    #d = resize(d,(d.shape[0],128,128))
    d = batch_resize(d, bs)
    print('resized')
    d_maxes = np.max(d,(1,2))[:,None,None]
    d /= d_maxes
    scaler = np.log(1001)
    return np.log((d*1000)+1)/scaler 


###################################################
###################################################
###################################################

def flatten_nav(sig):
    shape = [sig.shape[0]*sig.shape[1]]
    for i in sig.shape[2:]:
        shape.append(i)
    return sig.reshape(shape)


class My_Custom_Generator(keras.utils.Sequence) :
    def __init__(self, image_filenames,  batch_size) :
        self.image_filenames = image_filenames
        self.batch_size = batch_size

    def __len__(self) :
        return (np.ceil(len(self.image_filenames) / float(self.batch_size))).astype(np.int)
    
    
    @lru_cache(None)
    def __getitem__(self, idx) :
        batch_x = self.image_filenames[idx * self.batch_size : (idx+1) * self.batch_size]
        out_img = np.asarray([np.load(file_name)[:,:,None] for file_name in batch_x])
        return out_img, out_img
        #return batch_x, batch_y
        
        
class Array_Generator(keras.utils.Sequence) :
    def __init__(self, images,  batch_size) :
        self.images = images
        self.batch_size = batch_size

    def __len__(self) :
        return (np.ceil(len(self.images) / float(self.batch_size))).astype(np.int)
    
    
    @lru_cache(None)
    def __getitem__(self, idx) :
        out_img = self.images[idx * self.batch_size : (idx+1) * self.batch_size, :,:,None]
        return out_img, out_img
        #return batch_x, batch_y

class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        epsilon = tf.keras.backend.random_normal(shape=tf.shape(z_mean))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon
    
def create_vae_model(hparams):
    
    n_img = 128
    latent_dim = hparams['LAT']
    beta = hparams['B']

    image_input = keras.Input(shape=(n_img, n_img,1), name = 'enc_input')
    x = layers.Conv2D(hparams['KN1'],5, strides = 2, activation='relu',padding='same', input_shape=image_input.shape, name = 'enc_conv1')(image_input)
    x = layers.Conv2D(hparams['KN2'],5, strides = 2, activation='relu',padding='same', name = 'enc_conv2')(x)
    x = layers.Conv2D(hparams['KN3'],5, strides = 2, activation='relu',padding='same', name = 'enc_conv3')(x)
    x = layers.Conv2D(hparams['KN4'],5, strides = 2, activation='relu',padding='same', name = 'enc_conv4')(x)
    x = layers.Conv2D(hparams['KN5'],5, strides = 2, activation='relu',padding='same', name = 'enc_conv5')(x)
    x = layers.Flatten()(x)
    x = layers.Dense(hparams['D1'], activation='relu', name = 'enc_d1')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d2_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d3_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d4_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d5_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d6_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d7_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d8_t')(x)
    z_mean = layers.Dense(latent_dim, name="z_mean_t")(x)
    z_log_var = layers.Dense(latent_dim, name="z_log_var_t")(x)
    z_output = Sampling()([z_mean, z_log_var])
    encoder_VAE = keras.Model(image_input, [z_mean, z_log_var, z_output])

    z_input = keras.Input(shape=(latent_dim,), name = 'dec_input_t')
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d1_t')(z_input)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d2')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d3')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d4')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d5')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d6')(x)
    x = layers.Dense(hparams['D1'], activation="relu", name = 'dec_d7')(x)
    x = layers.Dense(4*4*hparams['KN5'], activation="relu", name = 'dec_d8')(x)
    x = layers.Reshape((4, 4,hparams['KN5']))(x)
    x = layers.Conv2DTranspose(hparams['KN4'],5, strides = 2, activation='relu',padding='same', name = 'dec_conv1')(x)
    x = layers.Conv2DTranspose(hparams['KN3'],5, strides = 2, activation='relu',padding='same', name = 'dec_conv2')(x)
    x = layers.Conv2DTranspose(hparams['KN2'],5, strides = 2, activation='relu',padding='same', name = 'dec_conv3')(x)
    x = layers.Conv2DTranspose(hparams['KN1'],5, strides = 2, activation='relu',padding='same', name = 'dec_conv4')(x)
    image_output = layers.Conv2DTranspose(1,5, strides = 2, activation='sigmoid',padding='same', name = 'dec_conv5')(x)
    #image_output = layers.Conv2DTranspose(16,3, strides = 2, activation='sigmoid',padding='same')
    #image_output = layers.Reshape((n_img, n_img,1))(x)
    decoder_VAE = keras.Model(z_input, image_output)

    # VAE class
    class VAE(keras.Model):
        # constructor
        def __init__(self, encoder, decoder, **kwargs):
            super(VAE, self).__init__(**kwargs)
            self.encoder = encoder
            self.decoder = decoder

        # customise train_step() to implement the loss 
        def train_step(self, x):
            if isinstance(x, tuple):
                x = x[0]
            with tf.GradientTape() as tape:
                # encoding
                z_mean, z_log_var, z = self.encoder(x)
                # decoding
                x_prime = self.decoder(z)
                # reconstruction error by binary crossentropy loss
                reconstruction_loss = tf.reduce_mean(keras.losses.binary_crossentropy(x, x_prime)) * n_img * n_img
                # KL divergence
                kl_loss = -0.5 * tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
                # loss = reconstruction error + KL divergence
                loss = reconstruction_loss + beta* kl_loss
            # apply gradient
            grads = tape.gradient(loss, self.trainable_weights)
            self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
            # return loss for metrics log
            return {"loss": loss}


        def call(self, x):
            if isinstance(x, tuple):
                x = x[0]
            # encoding
            z_mean, z_log_var, z = self.encoder(x)
            # decoding
            x_prime = self.decoder(z)
            return x_prime
    # build the VAE
    vae_model = VAE(encoder_VAE, decoder_VAE)

    # compile the VAE
    vae_model.compile(optimizer=keras.optimizers.Adam(learning_rate=hparams['LR']),loss=custom_loss)
    vae_model.build((1,128,128,1))
    
    return vae_model



def custom_loss(x,y):
    n_img = 128
    return tf.reduce_mean(keras.losses.binary_crossentropy(x, y)) * n_img * n_img


2021-11-12 10:28:55.979830: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1


Using TensorFlow v2.3.0


### Now check that you can find the GPU

In [2]:
tf.config.list_physical_devices()

2021-11-12 10:29:16.534909: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcuda.so.1


[PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'),
 PhysicalDevice(name='/physical_device:XLA_CPU:0', device_type='XLA_CPU'),
 PhysicalDevice(name='/physical_device:XLA_GPU:0', device_type='XLA_GPU'),
 PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]

2021-11-12 10:29:16.545158: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1716] Found device 0 with properties: 
pciBusID: 0000:3b:00.0 name: Tesla V100-PCIE-32GB computeCapability: 7.0
coreClock: 1.38GHz coreCount: 80 deviceMemorySize: 31.75GiB deviceMemoryBandwidth: 836.37GiB/s
2021-11-12 10:29:16.545183: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcudart.so.10.1
2021-11-12 10:29:16.549311: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcublas.so.10
2021-11-12 10:29:16.552300: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcufft.so.10
2021-11-12 10:29:16.553691: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcurand.so.10
2021-11-12 10:29:16.556847: I tensorflow/stream_executor/platform/default/dso_loader.cc:48] Successfully opened dynamic library libcusolv

In [3]:
import time

#set the Path to be the folder containing all your timestamped data
fdp = Path('/dls/e02/data/2021/mg28749-1/processing/Calibrated')

dl = fdp.ls()
[i for i in enumerate(dl)]

[(0,
  Path('/dls/e02/data/2021/mg28749-1/processing/Calibrated/20210925_152115')),
 (1,
  Path('/dls/e02/data/2021/mg28749-1/processing/Calibrated/20210925_150802')),
 (2,
  Path('/dls/e02/data/2021/mg28749-1/processing/Calibrated/20210925_135457')),
 (3,
  Path('/dls/e02/data/2021/mg28749-1/processing/Calibrated/20210925_141948')),
 (4,
  Path('/dls/e02/data/2021/mg28749-1/processing/Calibrated/20210925_151250')),
 (5,
  Path('/dls/e02/data/2021/mg28749-1/processing/Calibrated/20210925_150242')),
 (6,
  Path('/dls/e02/data/2021/mg28749-1/processing/Calibrated/20210925_110741')),
 (7,
  Path('/dls/e02/data/2021/mg28749-1/processing/Calibrated/20210925_141615')),
 (8,
  Path('/dls/e02/data/2021/mg28749-1/processing/Calibrated/20210925_123525')),
 (9,
  Path('/dls/e02/data/2021/mg28749-1/processing/Calibrated/20210925_115723')),
 (10,
  Path('/dls/e02/data/2021/mg28749-1/processing/Calibrated/20210925_142234')),
 (11,
  Path('/dls/e02/data/2021/mg28749-1/processing/Calibrated/20210925_1

# Training a Single Dataset

### Either enter the path directly to the dp variable or use the index from the list above

In [None]:
#Either
dp = Path('something/something/file.hdf5')

In [4]:
#Or
#set select as the index of the dataset you want from the list above
select = 32
#with a cognisant selection of include and exclude phrases should be able to pull out the data
dp = dl[select].walk('.hdf5', 'Model')[0]
print(dp)

/dls/e02/data/2021/mg28749-1/processing/Calibrated/20210925_141303/20210925_141303_data.hdf5


### Create a directory to save our intermediate model checkpoints 

In [6]:
mp = dp.redirect('CropModel')
if not mp.exists():
    mp.mkdir()
print(mp)

/dls/e02/data/2021/mg28749-1/processing/Calibrated/20210925_141303/CropModel


### Load in the Data

In [7]:
ds = hs.load(dp)
s = ds.data.shape



### Create a dictionary to hold some useful information

In [8]:
info = {}

### Recommended to first train the model rebinned 3x3 for improved signal to noise and reduced training time

In [9]:
rebin_factor = 3
ds = ds.rebin(None, [rebin_factor,rebin_factor,1,1])
ds.data = ds.data/(rebin_factor*rebin_factor)

### Now we want to preprocess the data and factor into batches for training

In [11]:
#Can modify the data preprocessing by changing the Data Preprocessing functions defined in the first cell

input_data = data_manip(flatten_nav(ds.data))
data_shuffled = [x for x in shuffle(input_data,input_data)]
data_shuffled_numpy = np.array(data_shuffled[0])

X_train, X_val = data_shuffled_numpy, data_shuffled_numpy
print(X_train.shape) # (3800,)   
print(X_val.shape)   # (950,)

batch_size = 1024

train_gen = Array_Generator(X_train, batch_size)
valid_gen = Array_Generator(X_val, batch_size)

batch_shape = train_gen[0][0].shape
input_shape = (batch_shape[1],batch_shape[2],batch_shape[3])
out_dims = int(train_gen[0][1].shape[1])

started data manipulations
57
resized
(7225, 128, 128)
(7225, 128, 128)


### Set the checkpointing

In [12]:
chkpoint_filepath = str(mp)+'/chk-{epoch:02d}-{val_loss:.5e}.hdf5'
chkpoint_model = tf.keras.callbacks.ModelCheckpoint(
    filepath = chkpoint_filepath,
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode="min",
    save_freq="epoch",
    options=None)

### Redefine the model parameters if you want

In [None]:
sample_name = 'full_3bin'
hparams= {'KN1':32,'KN2':64,'KN3':128, 'KN4':128, 'KN5':256,'D1':128,'D2':512,'LAT':2,'LR':0.00001, 'B':1}

vae_model = create_vae_model(hparams)
info[sample_name] = {'rebin': rebin_factor, 'hparams':hparams}

### Will be helpful to start with a trained model so set one here (otherwise will just take longer)

In [13]:
old_model = '/dls/e02/data/2021/mg28749-1/processing/Calibrated/20210925_123232/Final_Models/full_3bin_best_model.hdf5'

In [None]:
vae_model.load_weights(old_model)

### Train the Model

In [None]:
history = vae_model.fit(train_gen, validation_data=valid_gen, epochs=2000, callbacks= [chkpoint_model])

### If the training has exitted correctly you can delete the intermediate checkpoints

In [None]:
cps = mp.walk('.hdf5')
best_model_ind = np.asarray([float(str(i).split('-')[-1].split('.hd')[0]) for i in cps]).argmin()
best_model = cps[best_model_ind]
for x, mod in enumerate(cps):
    if x != best_model_ind:
        mod.unlink()
print('cleared')

### You can then resave this model in a different folder along with the model architecture information

In [None]:
vae_model.load_weights(best_model)

final_path = Path('/'.join(str(mp).split('/')[:-1]) + '/Final_Models')

if not final_path.exists():
    final_path.mkdir()

new_name = str(final_path) + f'/{sample_name}_best_model.hdf5'

best_model.rename(new_name)

import json

with open(f'{final_path}/{sample_name}_info.json', 'w') as f:
    json.dump(info, f)

### Done

# If you want to train them all sequentially this for loop should work the same as above

In [None]:
import time

dp = Path('/dls/e02/data/2021/mg28749-1/processing/Calibrated')

dl = dp.ls()
dl

#set your starting pretrained model here - each subsequent model will retrain off the previous
new_name = '/dls/e02/data/2021/mg28749-1/processing/Calibrated/20210925_123232/Final_Models/full_3bin_best_model.hdf5'

for select in range(31, len(dl)):
    dp = dl[select].walk('.hdf5', 'Model')[0]

    mp = Path('/'.join(str(dp).split('/')[:-1])+'/CropModel')
    if not mp.exists():
        mp.mkdir()

    ds = hs.load(dp)
    s = ds.data.shape

    info = {}

    rebin_factor = 3

    ds = ds.rebin(None, [rebin_factor,rebin_factor,1,1])

    ds.data = ds.data/(rebin_factor*rebin_factor)


    input_data = data_manip(flatten_nav(ds.data))
    filenames_shuffled = [x for x in shuffle(input_data,input_data)]
    filenames_shuffled_numpy = np.array(filenames_shuffled[0])
    labels_shuffled_numpy = np.array(filenames_shuffled[1])

    X_train, X_val, Y_train, Y_val = filenames_shuffled_numpy, labels_shuffled_numpy, filenames_shuffled_numpy, labels_shuffled_numpy
    print(X_train.shape) # (3800,)
    print(Y_train.shape)           # (3800, 12)

    print(X_val.shape)   # (950,)
    print(Y_val.shape)             # (950, 12)

    batch_size = 1024

    train_gen = Array_Generator(X_train, batch_size)
    valid_gen = Array_Generator(X_val, batch_size)

    batch_shape = train_gen[0][0].shape
    input_shape = (batch_shape[1],batch_shape[2],batch_shape[3])
    out_dims = int(train_gen[0][1].shape[1])

    chkpoint_filepath = str(mp)+'/chk-{epoch:02d}-{val_loss:.5e}.hdf5'
    chkpoint_model = tf.keras.callbacks.ModelCheckpoint(
        filepath = chkpoint_filepath,
        monitor="val_loss",
        verbose=1,
        save_best_only=True,
        save_weights_only=True,
        mode="min",
        save_freq="epoch",
        options=None)


    hparams= {'KN1':32,'KN2':64,'KN3':128, 'KN4':128, 'KN5':256,'D1':128,'D2':512,'LAT':2,'LR':0.00001, 'B':1}

    vae_model = create_vae_model(hparams)

    sample_name = 'full_3bin'
    info[sample_name] = {'rebin': rebin_factor, 'hparams':hparams}
    #info[sample_name] = {'cxl':cxl, 'cxu':cxu,'cyl':cyl, 'cyu': cyu, 'rebin': rebin_factor, 'hparams':hparams}

    vae_model.load_weights(new_name)

    history = vae_model.fit(train_gen, validation_data=valid_gen, epochs=2000, callbacks= [chkpoint_model])
    print('get cps')
    cps = mp.walk('.hdf5')
    best_model_ind = np.asarray([float(str(i).split('-')[-1].split('.hd')[0]) for i in cps]).argmin()
    best_model = cps[best_model_ind]
    for x, mod in enumerate(cps):
        if x != best_model_ind:
            mod.unlink()
    print('cleared')

    vae_model.load_weights(best_model)

    final_path = Path('/'.join(str(mp).split('/')[:-1]) + '/Final_Models')

    if not final_path.exists():
        final_path.mkdir()

    new_name = str(final_path) + f'/{sample_name}_best_model.hdf5'

    best_model.rename(new_name)

    import json

    with open(f'{final_path}/{sample_name}_info.json', 'w') as f:
        json.dump(info, f)
        
    input_data = None
    filenames_shuffled = None
    filenames_shuffled_numpy = None
    labels_shuffled_numpy = None

    X_train, X_val, Y_train, Y_val = None, None, None, None
    input_data = None
    train_gen = None
    valid_gen = None
    vae_model = None
    history = None
    
    time.sleep(10)
    print('next')



Using TensorFlow v2.3.0
