In [None]:
# Neural network libraries
import os
os.environ["SM_FRAMEWORK"] = "tf.keras"
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
import segmentation_models as sm
sm.set_framework("tf.keras")

# Plotting
import matplotlib.pyplot as plt
%matplotlib inline

# DeepD3 
from deepd3.model import DeepD3_Model
from deepd3.training.stream import DataGeneratorStream

## Load training data

In [None]:
TRAINING_DATA_PATH = r"DeepD3_Training.d3set"
VALIDATION_DATA_PATH = r"DeepD3_Validation.d3set"

dg_training = DataGeneratorStream(TRAINING_DATA_PATH, 
                                  batch_size=32, # Data processed at once, depends on your GPU
                                  target_resolution=0.094, # fixed to 94 nm, can be None for mixed resolution training
                                  min_content=50) # images need to have at least 50 segmented px

dg_validation = DataGeneratorStream(VALIDATION_DATA_PATH, 
                                    batch_size=32, 
                                    target_resolution=0.094,
                                    min_content=50, 
                                    augment=False,
                                    shuffle=False)

## Visualize data

Glancing on the data to verify that settings are as expected.

In [None]:
X, Y = dg_training[0]
i = 0

plt.figure(figsize=(12,4))

plt.subplot(131)
plt.imshow(X[i].squeeze(), cmap='gray')
plt.colorbar()

plt.subplot(132)
plt.imshow(Y[0][i].squeeze(), cmap='gray')
plt.colorbar()

plt.subplot(133)
plt.imshow(Y[1][i].squeeze(), cmap='gray')
plt.colorbar()

plt.tight_layout()

## Creating model and set training parameters

In [None]:
# Create a naive DeepD3 model with a given base filter count (e.g. 32)
m = DeepD3_Model(filters=32)

# Set appropriate training settings
m.compile(Adam(learning_rate=0.0005), # optimizer, good default setting, can be tuned 
          [sm.losses.dice_loss, "mse"], # Dice loss for dendrite, MSE for spines
          metrics=['acc', sm.metrics.iou_score]) # Metrics for monitoring progress

m.summary()

## Fitting model

Loading some training callbacks, such as adjusting the learning rate across time, saving training progress and intermediate models

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, CSVLogger, LearningRateScheduler

In [None]:
def schedule(epoch, lr):
    if epoch < 15:
        return lr
    
    else:
        return lr * tf.math.exp(-0.1)

# Train your own DeepD3 model

In [None]:
EPOCHS = 30

# Save best model automatically during training
mc = ModelCheckpoint("DeepD3_model.h5",
                            save_best_only=True)
        
# Save metrics  
csv = CSVLogger("DeepD3_model.csv")

# Adjust learning rate during training to allow for better convergence
lrs = LearningRateScheduler(schedule)

# Actually train the network
h = m.fit(dg_training, 
        batch_size=32, 
        epochs=EPOCHS, 
        validation_data=dg_validation, 
        callbacks=[mc, csv, lrs])

## Save model for use in GUI or batch processing

This is for saving the neural network manually. The best model is automatically saved during training.

In [None]:
m.save("deepd3_custom_trained_model.h5")