In [None]:
# Import necessary packages
import os
os.environ['XLA_FLAGS'] = '--xla_gpu_deterministic_ops'
os.environ['TF_CUDNN_DETERMINISTIC'] = '1'

import json
import matplotlib.pyplot as plt
import pandas as pd
import optax

from datetime import datetime
from flax import serialization
from flax.training import checkpoints
from jax import random

from cellori.applications.spots import data
from cellori.applications.spots import training

In [None]:
# Define training parameters
model_id = 'spots' + '059'
model_description = 'Train with 08232022 dataset with new weighted binary cross entropy loss.'
models_path = 'models/spots'
checkpoint_prefix = model_id + '_checkpoint'
dataset_path = 'datasets/rajlab/08232022'
random_seed = 0
batch_size = 16
loss_weights = {
    'sl1l': 1,
    'bcel': 1,
    'invf1': 3
}
learning_config = {
    'schedule': 'exponential_decay',
    'init_value': 0.0005,
    'transition_steps': 1,
    'decay_rate': 0.999
}
metadata = {
    'model_id': model_id,
    'model_description': model_description,
    'dataset_path': dataset_path,
    'date': datetime.now().strftime("%m/%d/%Y, %H:%M:%S"),
    'random_seed': str(random_seed),
    'batch_size': str(batch_size),
    'loss_weights': loss_weights,
    'learning_config': learning_config,
}

In [None]:
# Generate paths
toc_path = os.path.join(models_path, 'toc.csv')
model_path = os.path.join(models_path, model_id)
metadata_path = os.path.join(model_path, model_id + '_metadata')
batch_metrics_log_path = os.path.join(model_path, model_id + '_batch_metrics_log')
epoch_metrics_log_path = os.path.join(model_path, model_id + '_epoch_metrics_log')

In [None]:
# Load training and testing datasets
ds = data.load_datasets(dataset_path)

In [None]:
# Create train state
rng = random.PRNGKey(random_seed)
learning_rate = optax.exponential_decay(learning_config['init_value'], 
                                        learning_config['transition_steps'],
                                        learning_config['decay_rate'])
state = training.create_train_state(rng, learning_rate)

# Check and load in previous train state
if os.path.isdir(model_path):
    
    state = checkpoints.restore_checkpoint(model_path, state, prefix=checkpoint_prefix)
    with open(batch_metrics_log_path, 'r') as f_batch_metrics_log:
        batch_metrics_log = json.load(f_batch_metrics_log)
    with open(epoch_metrics_log_path, 'r') as f_epoch_metrics_log:
        epoch_metrics_log = json.load(f_epoch_metrics_log)
        
else:
    
    toc_entry = pd.json_normalize(metadata, sep='_')
    if os.path.isfile(toc_path):
        toc = pd.read_csv(toc_path, index_col=0)
        toc = pd.concat((toc, toc_entry), ignore_index=True)
    else:
        toc = toc_entry
    toc.to_csv(toc_path)
    
    os.makedirs(model_path)
    with open(metadata_path, 'w') as f_metadata:
        json.dump(metadata, f_metadata, indent=4)
        
    batch_metrics_log = []
    epoch_metrics_log = []
    with open(batch_metrics_log_path, 'w') as f_batch_metrics_log:
        json.dump(batch_metrics_log, f_batch_metrics_log, indent=4)
    with open(epoch_metrics_log_path, 'w') as f_epoch_metrics_log:
        json.dump(epoch_metrics_log, f_epoch_metrics_log, indent=4)

In [None]:
# Training loop
num_epochs = 50
epoch_count = len(epoch_metrics_log)
for epoch in range(epoch_count + 1, epoch_count + num_epochs + 1):

    # Run an optimization step over a training batch
    state, batch_metrics, epoch_metrics = training.train_epoch(epoch, state, ds['train'], ds['valid'], batch_size, loss_weights, learning_rate)

    batch_metrics_log += batch_metrics
    epoch_metrics_log += [epoch_metrics]
    
    checkpoints.save_checkpoint(model_path, state, epoch, prefix=checkpoint_prefix, keep_every_n_steps=10)
    with open(batch_metrics_log_path, 'w') as f_batch_metrics_log:
        json.dump(batch_metrics_log, f_batch_metrics_log, indent=4)
    with open(epoch_metrics_log_path, 'w') as f_epoch_metrics_log:
        json.dump(epoch_metrics_log, f_epoch_metrics_log, indent=4)

In [None]:
# Plot training and validation losses
losses = [(epoch_metrics['loss'], epoch_metrics['val_loss']) for epoch_metrics in epoch_metrics_log]

fig, ax = plt.subplots(dpi=300)
ax.plot(losses, label=('Training', 'Validation'))
ax.legend()

In [None]:
# Save model
variables = {'params': state.params, 'batch_stats': state.batch_stats}
bytes_output = serialization.to_bytes(variables)

with open(os.path.join(model_path, model_id + '_model'), 'wb') as f_model:
    f_model.write(bytes_output)