
This is the minimal code for training a VGG-11 U-Net to segment ring boundaries on images of macroscopic colonies formed by motile microbes (e.g. the swarming bacterium, *Proteus mirabilis*).


**References/Acknowledgments:**
Our work in bacterial ring boundary segmentation greatly benefited from [Segmentation Models: Python library with Neural Networks for Image Segmentation based on PyTorch](https://github.com/qubvel/segmentation_models.pytorch#examples) (SMP for short). The model (its architecture and pretrained encoder) and many utility functions used below came from SMP. Overall, this script closely follows [SMP's car segmentation example](https://github.com/qubvel/segmentation_models.pytorch/blob/master/examples/cars%20segmentation%20(camvid).ipynb), particularly the functions needed for data loading, augmentation, and model training, with slight modifications. 

# Imports

In [None]:
# Install PyTorch segmentation models 
!pip install git+https://github.com/qubvel/segmentation_models.pytorch

In [None]:
import numpy as np
import pandas as pd
import cv2
import csv
import copy
import time
from tqdm import tqdm
import os
import matplotlib.pyplot as plt

import torch
import torchvision
from torchvision import transforms
from torch import nn
from torch.nn import functional as F
from torchvision import models
from torch.utils.data import Dataset, DataLoader

import segmentation_models_pytorch as smp
from segmentation_models_pytorch import losses
from segmentation_models_pytorch.encoders import get_preprocessing_fn
import albumentations as albu

# Dataset

In [None]:
# set the paths to image and ring boundary mask folders
img_dir = './images'
mask_dir = '../masks'

In [None]:
# lists of filenames of train, val, and test images
# (these lists could be read in from an Excel file, for example)
# note: filenames of corresponding images and masks should be the same
train_IDs =
val_IDs = 
test_IDs = 

In [None]:
# Dataset class
class BacteriaDataset(Dataset):
    
    CLASSES = ['boundaries']
    
    def __init__(self, img_IDs, img_dir, mask_dir, 
                 classes=None,augmentation=None, preprocessing=None):
        self.img_IDs = img_IDs
        self.img_dir = img_dir
        self.mask_dir = mask_dir  
        self.augmentation = augmentation         # for augmentations
        self.preprocessing = preprocessing       # preprocessing to normalize images
        
         # convert str names to class values on masks
        self.class_values = [self.CLASSES.index(cls.lower()) for cls in classes]
        
    def __len__(self):
        return len(self.img_IDs)

    def __getitem__(self, i):
        
        # read data
        img_path = os.path.join(self.img_dir, self.img_IDs[i])
        mask_path = os.path.join(self.mask_dir, self.img_IDs[i])
        
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        mask = cv2.imread(mask_path, cv2.IMREAD_UNCHANGED)
        mask = (mask >= 1).astype('float32')
        mask = np.expand_dims(mask, axis=2) 
        
        # apply augmentations
        if self.augmentation:
            sample = self.augmentation(image=img, mask=mask)
            img, mask = sample['image'], sample['mask']
        
        # apply preprocessing
        if self.preprocessing:
            sample = self.preprocessing(image=img, mask=mask)
            img, mask = sample['image'], sample['mask']
            
        return img, mask

In [None]:
# Transformations definitions 
def get_training_augmentation():
    train_transform = [albu.PadIfNeeded(min_height=1024, min_width=1024, always_apply=True, border_mode=cv2.BORDER_REFLECT_101),
                       albu.Rotate(limit=(-10,10), border_mode=cv2.BORDER_REFLECT_101, p=0.5),
                       albu.HorizontalFlip(p=0.5),
                       albu.VerticalFlip(p=0.5),
                       albu.ShiftScaleRotate(shift_limit=0.05, scale_limit=0, rotate_limit=0,
                                          border_mode=cv2.BORDER_REFLECT_101, p=0.5), # translate
                       albu.ShiftScaleRotate(shift_limit=0, scale_limit=0.5, rotate_limit=0,
                                          border_mode=cv2.BORDER_REFLECT_101, p=0.5), # zoom
                      ]
    return albu.Compose(train_transform)

def get_val_test_augmentation():
    val_test_transform = [
                       albu.PadIfNeeded(min_height=1024, min_width=1024, always_apply=True, border_mode=cv2.BORDER_REFLECT_101),
                      ]
    return albu.Compose(val_test_transform)

def to_tensor(x, **kwargs):
    return x.transpose(2, 0, 1).astype('float32')

def get_preprocessing(preprocessing_fn):
    _transform = [
        albu.Lambda(image=preprocessing_fn),
        albu.Lambda(image=to_tensor, mask=to_tensor),
    ]
    return albu.Compose(_transform)

# Model architecture and hyperparameters

In [None]:
# Set some variables 
model_name  = 'vgg11unet' # for saving the model
Encoder = 'vgg11'
Attention = None # None / 'scse'
Weights = 'imagenet' # if initializing model with pretrained weights 
ACTIVATION = 'sigmoid'
CLASSES = ['boundaries']
preprocess_input = get_preprocessing_fn(Encoder, Weights)
patience = 2 # for early stopping
to_augment = False # False / True (whether or not to augment training data)
train_batch_size = 3
val_batch_size = 1
test_batch_size = 1

In [None]:
# Create segmentation model with pretrained encoder
# https://github.com/qubvel/segmentation_models.pytorch
model = smp.Unet(
    encoder_name=Encoder, 
    encoder_weights=Weights, 
    decoder_attention_type=Attention,
    in_channels=3, 
    classes=len(CLASSES), 
    activation=ACTIVATION,
)

In [None]:
# initialize loss, metrics, % optimizer:
loss = smp.utils.losses.DiceLoss()

metrics = [
    smp.utils.metrics.IoU(threshold=0.5),
    smp.utils.metrics.Fscore(),
    smp.utils.metrics.Accuracy(),
    smp.utils.metrics.Recall(),
    smp.utils.metrics.Precision()
]

optimizer = torch.optim.Adam([ 
    dict(params=model.parameters(), lr=0.0001),
])

# Training & Validation

In [None]:
# Whether or not we're augmenting training data
if to_augment is True:
    training_aug = get_training_augmentation()
else: 
    training_aug = get_val_test_augmentation() # if not augmenting training data, just use the function for validation & test data

In [None]:
# Create transformed & preprocessed datasets
train_dataset = BacteriaDataset(train_IDs, img_dir, mask_dir, classes=['boundaries'],
                                augmentation=training_aug,
                                preprocessing=get_preprocessing(preprocess_input),
                               )

val_dataset = BacteriaDataset(val_IDs, img_dir, mask_dir,classes=['boundaries'],
                              augmentation=get_val_test_augmentation(),
                              preprocessing=get_preprocessing(preprocess_input),
                             )

test_dataset = BacteriaDataset(test_IDs, img_dir, mask_dir,classes=['boundaries'],
                              augmentation=get_val_test_augmentation(),
                              preprocessing=get_preprocessing(preprocess_input),
                              )

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=train_batch_size, shuffle=True, num_workers=12)
val_loader = DataLoader(val_dataset, batch_size=val_batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False, num_workers=4)

In [None]:
# Create epoch runners, as done in https://github.com/qubvel/segmentation_models.pytorch
device = 'cuda' if torch.cuda.is_available() else 'cpu'

train_epoch = smp.utils.train.TrainEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    optimizer=optimizer,
    device=device,
    verbose=True,
)

val_epoch = smp.utils.train.ValidEpoch(
    model, 
    loss=loss, 
    metrics=metrics, 
    device=device,
    verbose=True,
)

In [None]:
# Create dataframe for storing metrics
dfTrainVal = pd.DataFrame(columns=['Epoch', 
                                   'Train Loss','Val Loss', 
                                   'Train Accuracy','Val Accuracy', 
                                   'Train Precision','Val Precision', 
                                   'Train Recall','Val Recall', 
                                   'Train IoU','Val IoU', 
                                   'Train Fscore','Val Fscore',
                                   'Train Dice','Val Dice'])

# For saving the dataframe:
trainvalCSVname = model_name + '_TrainValcsv.csv'

In [None]:
# Train & Validate Model
EPOCHS = 35 # Set max number of epochs to train/validate for
es = 0 # initiliaze early stopping counter


for epoch in range(0, EPOCHS):
    
    print('\nEpoch: {}'.format(epoch))
    train_logs = train_epoch.run(train_loader)
    val_logs = val_epoch.run(val_loader)
    
    # Determine what the previous min val loss was 
    if epoch == 0:
        min_val_loss = 1
    else:
        min_val_loss = dfTrainVal['Val Loss'].min()
    
    # Update the dataframe with scores from this epoch
    dfTrainVal.loc[epoch, ['Epoch']] = epoch
    dfTrainVal.loc[epoch, ['Train Loss']] = train_logs['dice_loss']
    dfTrainVal.loc[epoch, ['Val Loss']] = val_logs['dice_loss']
    dfTrainVal.loc[epoch, ['Train Accuracy']] = train_logs['accuracy']
    dfTrainVal.loc[epoch, ['Val Accuracy']] = val_logs['accuracy']
    dfTrainVal.loc[epoch, ['Train Precision']] = train_logs['precision']
    dfTrainVal.loc[epoch, ['Val Precision']] = val_logs['precision']
    dfTrainVal.loc[epoch, ['Train Recall']] = train_logs['recall']
    dfTrainVal.loc[epoch, ['Val Recall']] = val_logs['recall']
    dfTrainVal.loc[epoch, ['Train IoU']] = train_logs['iou_score']
    dfTrainVal.loc[epoch, ['Val IoU']] = val_logs['iou_score']
    dfTrainVal.loc[epoch, ['Train Fscore']] = train_logs['fscore']
    dfTrainVal.loc[epoch, ['Val Fscore']] = val_logs['fscore']
    dfTrainVal.loc[epoch, ['Train Dice']] = 1 - train_logs['dice_loss']
    dfTrainVal.loc[epoch, ['Val Dice']] = 1 - val_logs['dice_loss']
    
    # Save the dataframe
    dfTrainVal.to_csv(trainvalCSVname,index=False)
    
    # Save model checkpoints
    checkpoint = {'epoch': epoch,
                  'model_state_dict': model.state_dict(),
                  'optimizer_state_dict': optimizer.state_dict(),
                  'loss': loss}
    checkpoint_path = './'+model_name+'_epoch_'+str(epoch)+'.pth'
    torch.save(checkpoint, checkpoint_path)
    
    # Early stopping: check if val loss has decreased/increased from the previous min_val_loss
    val_loss = val_logs['dice_loss']
    if val_loss < min_val_loss:
        es = 0 # Early stopping not considered
    else: 
        es += 1 # Start counting
        print("EarlyStopping Counter {} of {}".format(es,patience))
        
        if es >= patience:
            print("Early stopping with min_val_loss: ", min_val_loss, "and val_loss for this epoch: ", val_loss, "...")
            break