In [None]:
import tensorflow as tf
import utils
import tensorflow_addons as tfa
from data_parameters import data_param
import models.identification_models as models

print(tf.config.experimental.get_visible_devices())

per_sys = data_param['per_sys']
N_elem = len(per_sys)
spectrum_length = data_param['spectrum_length']
max_buffer = data_param['max_buffer'] # loads all data in to RAM to shuffle; requires a lot of RAM! (more than data which is about 20G; could pre shuffle data and then use smaller buffer)
batch_size = 32


train_dataset = utils.get_dataset(tf.data.TFRecordDataset.list_files('Core-loss EELS TFRecord/trainingset/TRAIN*.tfrecords', shuffle=True))
train_dataset = train_dataset.shuffle(buffer_size= max_buffer , reshuffle_each_iteration=True).batch(batch_size).prefetch(tf.data.AUTOTUNE) 

val_dataset = utils.get_dataset(tf.data.TFRecordDataset.list_files('Core-loss EELS TFRecord/validationset/VALIDATION*.tfrecords', shuffle=True))
val_dataset = val_dataset.shuffle(buffer_size= max_buffer , reshuffle_each_iteration=True).batch(batch_size).prefetch(tf.data.AUTOTUNE)  

In [None]:
!rm -rf 'logs/fit'

In [None]:
model_type = "UNet_1x1conv"

#recommended
if model_type == "UNet_1x1conv":
    model = models.UNet(spectrum_length = spectrum_length, N_elem = N_elem,reduction_method = "1x1conv")
    LR = 0.001
elif model_type == "ViT_1x1conv":
    model = models.ViT(spectrum_length = spectrum_length, N_elem = N_elem,reduction_method = "1x1conv")
    LR = 0.001
    
#less recommended
'''
elif model_type == "MLP":
    model = models.MLP(spectrum_length = spectrum_length, N_elem = N_elem)
elif model_type == "CNN":
    model = models.CNN(spectrum_length = spectrum_length, N_elem = N_elem)
elif model_type == "ResNet_GAP":
    model = models.ResNet(spectrum_length = spectrum_length, N_elem = N_elem,reduction_method = "GAP")
elif model_type == "ResNet_1x1conv":
    model = models.ResNet(spectrum_length = spectrum_length, N_elem = N_elem,reduction_method = "1x1conv")
elif model_type == "ResNet_flatten":
    model = models.ResNet(spectrum_length = spectrum_length, N_elem = N_elem,reduction_method = "flatten")
elif model_type == "UNet_GAP":
    model = models.UNet(spectrum_length = spectrum_length, N_elem = N_elem,reduction_method = "GAP")
elif model_type == "CCT_1x1conv":
    model = models.CCT(spectrum_length=spectrum_length,N_elem=N_elem,reduction_method = "1x1conv")
elif model_type == "CCT_seq_pool":
    model = models.CCT(spectrum_length=spectrum_length,N_elem=N_elem,reduction_method = "seq_pool")
elif model_type == "ViT_GAP":
    model = models.ViT(spectrum_length = spectrum_length, N_elem = N_elem,reduction_method = "GAP")
elif model_type == "ViT_flatten":
    model = models.ViT(spectrum_length = spectrum_length, N_elem = N_elem,reduction_method = "flatten")
elif model_type == "ViT_token":
    model = models.ViT(spectrum_length = spectrum_length, N_elem = N_elem,reduction_method = "token")
'''


log_dir = f"logs/fit/{model_type}"
tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir)

model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=LR),   
            loss=utils.custom_loss,
            metrics= [tfa.metrics.F1Score(N_elem, 'weighted', 0.8, name = 'f1')],
            ) 
model.summary()

### TRAIN A SINGLE MODEL

In [None]:
reduce_LR = tf.keras.callbacks.ReduceLROnPlateau(monitor='f1',factor=0.5,patience=3,mode='max',min_delta=0.025)

class LR_limit_stop(tf.keras.callbacks.Callback):
	def on_epoch_end(self, epoch, logs={}):
		if(logs.get('lr') < 5e-6):
			self.model.stop_training = True
lr_stop = LR_limit_stop()

model.fit(train_dataset.repeat(2),
	        validation_data= val_dataset, 
            epochs=35, 
            verbose= 1, 
            callbacks=[tensorboard_callback,reduce_LR, lr_stop],
            workers = 8,
            use_multiprocessing=True,
            steps_per_epoch=1000,
            )

#### SAVE A TRAINED MODEL  

In [None]:
model.save(f'newly_trained_models/trained_{model_type}')