First load some utility modules

In [1]:
from stemutils.io import *
import json

Next we define the bulk of the script that we are going to modify for each training set

In [6]:
def get_script_text(dp, model_params, gen_model_weights, gen_model_params, epochs= 2000):
    return f'''

#load some packages in
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import matplotlib.pyplot as plt
import random as python_random
from numba import njit
from tensorboard.plugins.hparams import api as hp
from stemutils.io import Path
import hyperspy.api as hs
import concurrent.futures
from skimage.transform import resize
from sklearn.utils import shuffle
from sklearn.model_selection import train_test_split
from functools import lru_cache
from stemseg.processing_funcs import *

#set some variables
print('Using TensorFlow v%s' % tf.__version__)
plt.style.use('default')


#define some functions

###################################################
########### Data Preprocessing ####################
###################################################

def batch_resize(d, bs=512):
    if len(d.shape) == 4:
        flat_d = flatten_nav(d)
    else:
        flat_d = d
    n_batches = int(np.ceil(flat_d.shape[0]//bs))
    batches = [flat_d[i*bs:(i+1)*bs] for i in range(n_batches+1)]
    if len(batches[-1])==0:
        batches.pop(-1)
    print(len(batches[-1]))
    with concurrent.futures.ProcessPoolExecutor(max_workers=4) as exe:
        res = [exe.submit(resize, batch, (batch.shape[0],128,128)) for batch in batches]
    r_batches = [f.result() for f in res]
    return np.concatenate(r_batches, axis = 0).reshape((d.shape[0],128,128))

def data_manip(d, bs = 512):
    if type(d) != np.ndarray:
        print('dask to numpy')
        d = d.compute()
        print('dask to numpy done')
    print('started data manipulations')
    #d = resize(d,(d.shape[0],128,128))
    print('resized')
    d = d.astype('float32')
    for i in range(d.shape[0]):
        d_max = np.max(d[i])
        d[i] = d[i]/d_max
    d = batch_resize(d, bs)
    scaler = np.log(1001)
    return np.log((d*1000)+1)/scaler 


def data_manip_lowq(d, central_box = 128):
    pxc, pyc = d.shape[1]//2, d.shape[2]//2 
    pxl, pxu = pxc - central_box//2, pxc + central_box//2 
    pyl, pyu = pyc - central_box//2, pyc + central_box//2 

    d = d[:, pxl:pxu, pyl:pyu]
    if type(d) != np.ndarray:
        print('dask to numpy')
        d = d.compute()
        print('dask to numpy done')
    print('started data manipulations')
    #d = resize(d,(d.shape[0],128,128))
    print('resized')
    d = d.astype('float32')
    for i in range(d.shape[0]):
        d_max = np.max(d[i])
        d[i] = d[i]/d_max

    scaler = np.log(1001)
    return np.log((d*1000)+1)/scaler 


###################################################
###################################################
###################################################

def flatten_nav(sig):
    shape = [sig.shape[0]*sig.shape[1]]
    for i in sig.shape[2:]:
        shape.append(i)
    return sig.reshape(shape)


class My_Custom_Generator(keras.utils.Sequence) :
    def __init__(self, image_filenames,  batch_size) :
        self.image_filenames = image_filenames
        self.batch_size = batch_size

    def __len__(self) :
        return (np.ceil(len(self.image_filenames) / float(self.batch_size))).astype(np.int)


    @lru_cache(None)
    def __getitem__(self, idx) :
        batch_x = self.image_filenames[idx * self.batch_size : (idx+1) * self.batch_size]
        out_img = np.asarray([np.load(file_name)[:,:,None] for file_name in batch_x])
        return out_img, out_img
        #return batch_x, batch_y


class Array_Generator(keras.utils.Sequence) :
    def __init__(self, images,  batch_size, target = 'same') :
        self.images = images
        self.batch_size = batch_size
        if target == 'same':
            self.target = images
        else:
            self.target = target

    def __len__(self) :
        return (np.ceil(len(self.images) / float(self.batch_size))).astype(np.int)


    @lru_cache(None)
    def __getitem__(self, idx) :
        out_img = self.images[idx * self.batch_size : (idx+1) * self.batch_size, :,:,None]
        out_targ = self.target[idx * self.batch_size : (idx+1) * self.batch_size, :,:,None]
        return out_img, out_targ
        #return batch_x, batch_y

class Sampling(layers.Layer):
    def call(self, inputs):
        z_mean, z_log_var = inputs
        epsilon = tf.keras.backend.random_normal(shape=tf.shape(z_mean))
        return z_mean + tf.exp(0.5 * z_log_var) * epsilon

def create_vae_model(hparams):

    n_img = 128
    latent_dim = hparams['LAT']
    beta = hparams['B']

    image_input = keras.Input(shape=(n_img, n_img,1), name = 'enc_input')
    x = layers.Conv2D(hparams['KN1'],5, strides = 2, activation='relu',padding='same', input_shape=image_input.shape, name = 'enc_conv1')(image_input)
    x = layers.Conv2D(hparams['KN2'],5, strides = 2, activation='relu',padding='same', name = 'enc_conv2')(x)
    x = layers.Conv2D(hparams['KN3'],5, strides = 2, activation='relu',padding='same', name = 'enc_conv3')(x)
    x = layers.Conv2D(hparams['KN4'],5, strides = 2, activation='relu',padding='same', name = 'enc_conv4')(x)
    x = layers.Conv2D(hparams['KN5'],5, strides = 2, activation='relu',padding='same', name = 'enc_conv5')(x)
    x = layers.Flatten()(x)
    x = layers.Dense(hparams['D1'], activation='relu', name = 'enc_d1')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d2_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d3_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d4_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d5_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d6_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d7_t')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'enc_d8_t')(x)
    z_mean = layers.Dense(latent_dim, name="z_mean_t")(x)
    z_log_var = layers.Dense(latent_dim, name="z_log_var_t")(x)
    z_output = Sampling()([z_mean, z_log_var])
    encoder_VAE = keras.Model(image_input, [z_mean, z_log_var, z_output])

    z_input = keras.Input(shape=(latent_dim,), name = 'dec_input_t')
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d1_t')(z_input)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d2')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d3')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d4')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d5')(x)
    x = layers.Dense(hparams['D2'], activation="relu", name = 'dec_d6')(x)
    x = layers.Dense(hparams['D1'], activation="relu", name = 'dec_d7')(x)
    x = layers.Dense(4*4*hparams['KN5'], activation="relu", name = 'dec_d8')(x)
    x = layers.Reshape((4, 4,hparams['KN5']))(x)
    x = layers.Conv2DTranspose(hparams['KN4'],5, strides = 2, activation='relu',padding='same', name = 'dec_conv1')(x)
    x = layers.Conv2DTranspose(hparams['KN3'],5, strides = 2, activation='relu',padding='same', name = 'dec_conv2')(x)
    x = layers.Conv2DTranspose(hparams['KN2'],5, strides = 2, activation='relu',padding='same', name = 'dec_conv3')(x)
    x = layers.Conv2DTranspose(hparams['KN1'],5, strides = 2, activation='relu',padding='same', name = 'dec_conv4')(x)
    image_output = layers.Conv2DTranspose(1,5, strides = 2, activation='sigmoid',padding='same', name = 'dec_conv5')(x)
    #image_output = layers.Conv2DTranspose(16,3, strides = 2, activation='sigmoid',padding='same')
    #image_output = layers.Reshape((n_img, n_img,1))(x)
    decoder_VAE = keras.Model(z_input, image_output)

    # VAE class
    class VAE(keras.Model):
        # constructor
        def __init__(self, encoder, decoder, **kwargs):
            super(VAE, self).__init__(**kwargs)
            self.encoder = encoder
            self.decoder = decoder

        # customise train_step() to implement the loss 
        def train_step(self, x):
            if isinstance(x, tuple):
                x = x[0]
            with tf.GradientTape() as tape:
                # encoding
                z_mean, z_log_var, z = self.encoder(x)
                # decoding
                x_prime = self.decoder(z)
                # reconstruction error by binary crossentropy loss
                reconstruction_loss = tf.reduce_mean(keras.losses.binary_crossentropy(x, x_prime)) * n_img * n_img
                # KL divergence
                kl_loss = -0.5 * tf.reduce_mean(1 + z_log_var - tf.square(z_mean) - tf.exp(z_log_var))
                # loss = reconstruction error + KL divergence
                loss = reconstruction_loss + beta* kl_loss
            # apply gradient
            grads = tape.gradient(loss, self.trainable_weights)
            self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
            # return loss for metrics log
            return {'{"loss": loss}'}


        def call(self, x):
            if isinstance(x, tuple):
                x = x[0]
            # encoding
            z_mean, z_log_var, z = self.encoder(x)
            # decoding
            x_prime = self.decoder(z)
            return x_prime
    # build the VAE
    vae_model = VAE(encoder_VAE, decoder_VAE)

    # compile the VAE
    vae_model.compile(optimizer=keras.optimizers.Adam(learning_rate=hparams['LR']),loss=custom_loss)
    vae_model.build((1,128,128,1))

    return vae_model



def custom_loss(x,y):
    n_img = 128
    return tf.reduce_mean(keras.losses.binary_crossentropy(x, y)) * n_img * n_img


print(tf.config.list_physical_devices())

general_model_path = Path("{gen_model_weights}")
general_model_params = {gen_model_params}

gen_model = create_vae_model(general_model_params)
gen_model.load_weights(general_model_path)

def transfer_layer_weights(old_model, new_model, enc_layer_list, dec_layer_list):
    for l in enc_layer_list:
        w = old_model.get_layer(index = 0).get_layer(index=l).get_weights()
        new_model.get_layer(index = 0).get_layer(index=l).set_weights(w)

    for l in dec_layer_list:
        w = old_model.get_layer(index = 1).get_layer(index=l).get_weights()
        new_model.get_layer(index = 1).get_layer(index=l).set_weights(w)
    return new_model

enc_layer_list = list(range(12))

dec_layer_list = list(range(4,14))


dp = Path("{dp}")
mp = dp.redirect('FullModel')
if not mp.exists():
    mp.mkdir()
print(mp)

### Load in the Data

sample = ProcessedSample(dp, 'Test')

### Create a dictionary to hold some useful information

info = {'{}'}

nds = flatten_nav(sample.raw_data.data)

nds.shape

ds = hs.load(dp)

sx, sy = ds.data.shape[2:]

from skimage.measure import shannon_entropy

entropy = np.zeros((sample.raw_data.data.shape[0:2]))

for i in range(ds.data.shape[0]):
    print(i)
    for j in range(ds.data.shape[1]):
        entropy[i,j] = shannon_entropy(ds.data[i,j,sx//2-128:sx//2+128, sy//2-128:sy//2+128])

data_patterns = flatten_nav(ds.data)

entr = flatten_nav(entropy**10)

sl_pdf = entr/ entr.sum()


nldps = []
for i in range(10):
    sdps = hs.signals.Signal2D(data_patterns[np.random.choice(np.arange(data_patterns.shape[0]), 5000, True, sl_pdf)])
    nldps.append(sdps.data)
nlds = flatten_nav(np.asarray(nldps))




def data_manip_lowq_resized(d, central_box = 256, bs = 256):
    pxc, pyc = d.shape[1]//2, d.shape[2]//2 
    pxl, pxu = pxc - central_box//2, pxc + central_box//2 
    pyl, pyu = pyc - central_box//2, pyc + central_box//2 

    d = d[:, pxl:pxu, pyl:pyu]
    if type(d) != np.ndarray:
        print('dask to numpy')
        d = d.compute()
        print('dask to numpy done')
    print('started data manipulations')
    #d = resize(d,(d.shape[0],128,128))
    print('resized')
    d = d.astype('float32')
    for i in range(d.shape[0]):
        d_max = np.max(d[i])
        d[i] = d[i]/d_max
    d = batch_resize(d, bs)
    scaler = np.log(1001)
    return np.log((d*1000)+1)/scaler 



input_data = data_manip_lowq_resized(nlds)
#np.random.shuffle(input_data)
#input_targets = data_manip_lowq(nlds)
val_data = data_manip_lowq_resized(nlds[::10])
np.random.shuffle(val_data)



##### Otherwise just skip

print(input_data.shape)

batch_size = 2048

train_gen = Array_Generator(input_data, batch_size)#, target = input_targets)
valid_gen = Array_Generator(val_data, batch_size)

batch_shape = train_gen[0][0].shape
input_shape = (batch_shape[1],batch_shape[2],batch_shape[3])
out_dims = int(train_gen[0][1].shape[1])

### Check the training data

rand_patt = np.random.randint(0, len(input_data))

### Set the checkpointing

chkpoint_filepath = str(mp)+'/chk-{'{epoch:02d}'}-{'{val_loss:.5e}'}.hdf5'
chkpoint_model = tf.keras.callbacks.ModelCheckpoint(
    filepath = chkpoint_filepath,
    monitor="val_loss",
    verbose=1,
    save_best_only=True,
    save_weights_only=True,
    mode="min",
    save_freq="epoch",
    options=None)

### Redefine the model parameters if you want

rebin_factor = 1

sample_name = 'full'
hparams= {model_params}

vae_model = create_vae_model(hparams)
info[sample_name] = {"{'rebin': rebin_factor, 'hparams':hparams}"}

vae_model = transfer_layer_weights(gen_model, vae_model, enc_layer_list, dec_layer_list)

### Train the Model

history = vae_model.fit(train_gen, validation_data=valid_gen, epochs={epochs}, callbacks= [chkpoint_model])

cps = mp.walk('.hdf5')
best_model_ind = np.asarray([float(str(i).split('-')[-1].split('.hd')[0]) for i in cps]).argmin()
best_model = cps[best_model_ind]

for x, mod in enumerate(cps):
    if x != best_model_ind:
        mod.unlink()
print('cleared')

'''

def job_submission_script(fpath):
    return f'''
#!/bin/bash
#$ -l h_rt=119:00:00
#$ -cwd
#$ -P e02
#$ -q all.q
#$ -l m_mem_free=128G
#$ -l gpu=4


module load python/epsic3.7
module load cuda/10.1
python {fpath}

'''

Now we need to create a list of file paths for the datasets you want to train on.
You can do this however you like - I like to use this walk function within a list comprehension, using the "if" to exclude things

In [7]:
fdp = Path('/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/mg28034-1/processing/Merlin/Calibrated/O3_pure')

dps = [dp for dp in fdp.walk('.hdf5','binned', max_depth =1) if len(dp.parent.walk('FullModel', max_depth =3)) ==0]

Also need to point to a model to use for the Conv Layer Transfer. Will need the model weights and the params for the architecture

In [8]:
gen_model_weights = Path('/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/mg28034-1/processing/Merlin/Calibrated/GeneralModels/O3_pure/general_model_weights.hdf5')
gen_model_params_path = Path('/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/mg28034-1/processing/Merlin/Calibrated/GeneralModels/O3_pure/general_model_info.json')

In [9]:
with open(gen_model_params_path, 'r') as f:
    info = json.load(f)
print(info)

{'full': {'rebin': 1, 'hparams': {'KN1': 32, 'KN2': 64, 'KN3': 128, 'KN4': 128, 'KN5': 256, 'D1': 128, 'D2': 512, 'LAT': 128, 'LR': 0.0001, 'B': 1}}}


In [10]:
gen_model_params = info['full']['hparams']
print(gen_model_params)

{'KN1': 32, 'KN2': 64, 'KN3': 128, 'KN4': 128, 'KN5': 256, 'D1': 128, 'D2': 512, 'LAT': 128, 'LR': 0.0001, 'B': 1}


Finally need to define the model parameters for the model that is being trained

In [11]:
hparams = {'KN1':32,'KN2':64,'KN3':128, 'KN4':128, 'KN5':256,'D1':128,'D2':512,'LAT':2,'LR':0.0001, 'B':1}

Finally define the directory to save these to

In [None]:
script_path = Path('/dls/science/groups/imaging/ePSIC_students/Andy_Bridger/ClusterJobs/LoopedTrainingScripts/mg28034/O3_pure/')

In [12]:
for dp in dps:
    timestamp = dp.parts[-2].replace(' ', '_')
    
    #Get our script and where we are going to save it 

    example_script = get_script_text(dp, hparams, gen_model_weights, gen_model_params, epochs = 2000)
    script_py_path = script_path.redirect(f'{timestamp}_train.py', 0)
    
    #Get our submission script and where we are going to save that 
    
    script_sub_path = script_py_path.redirect(f'sub_job_{timestamp}.sh')
    script_sub = job_submission_script(script_py_path)
    
    #Save these scripts 
    
    with open(script_py_path, 'w') as f:
        f.write(example_script)

    with open(script_sub_path, 'w') as f:
        f.write(script_sub)

Now submit the jobs to hamilton by running the following commands:
> module load hamilton

> for f in ./*.sh; do qsub $f; done