In [3]:
# config file
import torch
import os

# Base path of dataset
dataset_path = os.path.join('seg', 'train')

# define path to images and masks dataset
image_dataset_path = os.path.join(dataset_path, 'images')
mask_dataset_path = os.path.join(dataset_path, 'masks')

# define the test split
test_split = 0.15

# determine the device to be used for training and evaluation
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# determine if we will be pinning memory during data loading
pin_memory = True if device == 'cuda' else False

# determine the number of channels in input, number of classes and number of levels in u-net model
num_channels = 1
num_classes = 1
num_levels = 3

# initialize the learning rate, number of epochs to train for and batch size
init_lr = 0.001
num_epochs = 40
batch_size = 64

# define input image dimensions
input_image_width = 128
input_image_height = 128

# define threshold to filter weak predictions
threshold = 0.5

#define path to base output directory
base_output = 'output'

# define path to output serialized model, model training plot and testing image paths
model_path = os.path.join(base_output, 'unet_tgs_salt.pth')
plot_path = os.path.sep.join([base_output, 'unet_tgs_salt.png'])
test_path = os.path.sep.join([base_output, 'test_paths.txt'])

In [4]:
# dataset
from torch.utils.data import Dataset
import cv2 as cv

class SegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        
    def __len__(self):
        # return the number of total samples in dataset
        return len(self.image_paths)
    
    def __get_item__(self, idx):
        # grab image path from current index
        image_path = self.image_paths[idx]
        
        # load image from disk, swap its channels from BGR to RGB and read the associated mask from disk in grayscale mode
        image = cv.imread(image_path)
        image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
        mask = cv.imread(self.mask_paths[idx], 0)
        
        # check to see if we are applying any transforms
        if self.transforms is not None:
            # apply transforms to both image and mask
            image = self.transforms(image)
            mask = self.transforms(mask)
            
        # return image and mask
        return (image, mask)

In [5]:
# model
from torch.nn import ConvTranspose2d
from torch.nn import Conv2d
from torch.nn import MaxPool2d
from torch.nn import Module
from torch.nn import ModuleList
from torch.nn import ReLU
from torchvision.transforms import CenterCrop
from torch.nn import functional as F
import torch

class Block(Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = Conv2d(in_channels, out_channels, 3)
        self.relu = ReLU()
        self.conv2 = Conv2d(out_channels, out_channels, 3)
        
    def forward(self, x):
        return self.conv2(self.relu(self.conv1(x)))
    
class Encoder(Module):
    def __init__(self, channels=(3, 16, 32, 64)):
        super().__init__()
        # store encoder block and maxpooling layer
        self.encBlocks = ModuleList(
            [Block(channels[i], channels[i+1])
                for i in range(len(channels)-1)])
        self.pool = MaxPool2d(2)
        
    def forward(self, x):
        # initialize an empty list to store intermediate outputs
        block_outputs = []
        
        # loop through the encoder block
        for block in self.encBlocks:
            # pass the inuts through current encoder block, store outputs and then apply maxpooling on output
            x = block(x)
            block_outputs.append(x)
            x = self.pool(x)
            
        return block_outputs
    
class Decoder(Module):
    def __init__(self, channels=(64, 32, 16)):
        super().__init__()
        # initialize the number of channels, upsampler blocks and decoder block
        self.channels = channels
        self.upconvs = ModuleList([
            ConvTranspose2d(channels[i], channels[i+1], 2, 2)
                for i in range(len(channels)-1)])
        self.decod_blocks = ModuleList([
            Block(channels[i], channels[i+1])
                for i in range(len(channels)-1)])
        
    def forward(self, x, enc_features):
        # loop through number of channels
        for i in range(len(self.channels)-1):
            # pass the inputs through upsampler blocks
            x = self.upconvs[i](x)
            
            # crop the current features from encoder blocks, concatenate them with current upsampled features,
            # and pass the concatenated output through current decoder block
            enc_feat = self.crop(enc_features[i], x)
            x = torch.cat([x, enc_feat], dim=1)
            x = self.decod_blocks[i](x)
            
        return x
    
    def crop(self, enc_features, x):
        # grab the dimensions of inputs, and crop the encoder features to match the dimensions
        (_, _, H, W) = x.shape
        enc_features = CenterCrop([H, W])(enc_features)
        
        return enc_features
        

class UNET(Module):
    def __init__(self, enc_channels=(3, 16, 32, 64), dec_channels=(64, 32, 16), num_classes=1, retain_dim=True,
                out_size=(input_image_height, input_image_width)):
        super().__init__()
        # initialize encoder and decoder
        self.encoder = Encoder(enc_channels)
        self.decoder = Decoder(dec_channels)
        
        # initialize regression head and store class veriables
        self.head = Conv2d(dec_channels[-1], num_classes, 1)
        self.retain_dim = retain_dim
        self.out_size = out_size
        
    def forward(self, x):
        # grab the features from encoder
        enc_features= self.encoder(x)
        
        #pass the encoder features through decoder making sure that their dimensions are suited for concatenation
        dec_features = self.decoder(enc_features[::-1][0], enc_features[::-1][1:])
        
        # pass the decoder features through the regression head to obtain segmentation mask
        map = self.head(dec_features)
        
        # check to see if we are retaining original output dimensions if so resize output to match them
        if self.retain_dim:
            map = F.interpolate(map, self.out_size)
            
        return map

In [None]:
# training
from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from torchvision import transforms
from imutils import paths
from tqdm import tqdm
import matplotlib.pyplot as plt
import torch
import time
import os

# load the image and mask paths
image_paths = sorted(list(paths.list_images(image_dataset_path)))
mask_paths = sorted(list(paths.list_images(mask_dataset_path)))

# partition the data into training and testing set use 85% for training and 15% for testing
split = train_test_split(image_paths, mask_paths, test_size = test_split, random_state=42)

# unpack data split
(train_images, test_images) = split[:2]
(train_masks, test_masks) = split[2:]

# write the test image paths to disk so we can use then during evaluation
print('[INFO] saving test image paths...')
f = open(test_paths, 'w')
f.write('\n'.join(test_images))
f.close()

# define transformations
transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((input_image_height, input_image_width)),
    transforms.ToTensor()])

# create train and test dataset
train_ds = SegmentationDataset(image_paths=train_images, masks_paths=train_masks, transforms=transforms)
test_ds = SegmentationDataset(image_paths=test_images, masks_paths=test_masks, transforms=transforms)
print(f'[INFO] found {len(train_ds)} examples in training set...')
print(f'[INFO] found {len(test_ds)} examples in test set...')

# create training and test data loaders
train_loader = DataLoader(train_ds, shuffle=True, batch_size=batch_size, pin_memory=pin_memory, 
                          num_workers=os.cpu_count())
test_loader = DataLoader(test_ds, shuffle=False, batch_size=batch_size, pin_memory=pin_memory, 
                          num_workers=os.cpu_count())

# initialize Unet model
unet = UNET().to(device)

# initialize the loss function and optimizer
loss_func = BCEWithLogitsLoss()
opt = Adam(unet.parameters(), lr=init_lr)

# calculate steps per epochs for training and test set
train_steps = len(train_ds) // batch_size
test_steps = len(test_ds) // batch_size

# initialize dictionary to store training history
H = {'training_loss': [], 'test_loss': []}

# loop over the epochs
print('[INFO] training the network...')
start_time = time.time()
for e in tqdm(range(num_epochs)):
    # set model in training mode
    unet.train()
    
    # initialize total training and validation loss
    total_train_loss = 0
    total_test_loss = 0
    
    # loop over the training set
    for (i, (x, y)) in enumerate(train_loader):
        # send input to device
        (x, y) = (x.to(device), y.to(device))
        
        # perform a forward pass and calculate training loss
        pred = unet(x)
        loss = loss_func(pred, y)
        
        # first, zero any accumulated gradients, then perform backpropagation and update model parameters
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        # add loss to total training loss
        total_training_loss += loss
        
    # switch off auto grad
    with torch.no_grad():
        # set model in eval mode
        unet.eval()
        
        # loop over the validation dataset
        for (x, y) in test_loader:
            # send input to device
            (x, y) = (x.to(device), y.to(device))
            
            #make prediction and calculate validation loss
            pred = unet(x)
            total_test_loss += loss_func(pred, y)
            
    # calculate average training and test loss
    avg_train_loss = total_train_loss / train_steps
    avg_test_loss = total_test_loss / test_steps
    
    # update training history
    H['train_loss'].append(avg_train_loss.cpu().detach().numpy())
    H['test_loss'].append(avg_test_loss.cpu().detach().numpy())
    
    # print model training and validation information
    print(f'[INFO] Epoch: {e+1}/{num_epochs}')
    print(f'Train loss: {avg_train_loss:.6f}, Test loss: {avg_test_loss:.4f}')
    
# display the total time needed to performtraining
end_time = time.time()
print(f'[INFO] Total time taken to train the model: {end_time - start_time:.2f}s')

# plot the training loss and accuracy
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["test_loss"], label="test_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.legend(loc="lower left")

# save the training plot
plotPath = os.path.sep.join([plot_path, "training.png"])
plt.savefig(plotPath)

# serialize model
torch.save(unet, model_path)

In [None]:
# Predict
import matplotlib.pyplot as plt
import numpy as np
import cv2 as cv
import torch
import os

def prepare_plot(oringinal_image, original_mask, pred_mask):
    # initialize figure
    figure, ax = plt.subplots(nrows=1, ncols=3, figsize=(10, 10))
    
    # plot original image, original mask and its prediction
    ax[0].imshow(original_image)
    ax[1].imshow(original_mask)
    ax[2].imshow(pred_mask)
    
    # set titles of subplots
    ax[0].set_title('Image')
    ax[1].set_title('Original Mask')
    ax[2].set_title('Predicted Mask')
    
    # set the layout of the figure and display it
    figure.tight_layout()
    figure.show()
    

def make_predictions(model, image_path):
    # set model to evaluation mode
    model.eval()
    
    # turn off gradient tracking
    with torch.no_grad():
        # load the image from disk, swap its color channels cast it to float data type and scale its pixel values
        image = cv.imread(image_path)
        image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
        image = image.astype('float32')/255.0
        
        # resize the image and take a copy of it for visualization
        image = cv.resize(image, (128, 128))
        orig = image.copy()
        
        # find the filename and generate path to ground truth mask
        file_name = image_path.split(os.path.sep)[-1]
        ground_truth_path = os.path.join(mask_dataset_path, file_name)
        
        # load the ground truth segmentation mask in grayscale mode and resize it
        ground_truth_mask = cv.imread(ground_truth_path, 0)
        ground_truth_mask = cv.resize(ground_truth_mask, (input_image_height, input_image_height))
        
        # make channel axis to be leading one, add a batch dimension, create a pytorch tensor,
        # and flash it to the device
        image = np.transpose(image, (2, 0, 1))
        image = np.expand_dims(image, 0)
        image = torch.from_numpy(image).to(device)
        
        # make the predictions, pass the result through sigmoid function and convert result into numpy array
        pred_mask = model(image).squeeze()
        pred_mask = torch.sigmoid(pred_mask)
        pred_mask = pred_mask.cpu().numpy()
        
        # filter out the weak predictions and convert them to integers
        pred_mask = (pred_mask > config.THRESHOLD) * 255.0
        pred_mask = pred_mask.astype(np.uint8)
        
        # prepare a plot for visualization
        prepare_plot(orig, ground_truth_mask, pred_mask)
        
# load the image paths in our testing file and randomly select 10 image paths
print('[INFO] loading up test image paths...')
image_paths=open(test_paths).read().strip().split('\n')
image_paths=np.random.choice(image_paths, size=10)

# load our model from disk and flash it to current device
print('[INFO] load model')
unet = torch.load(model_path).to(device)

# iterate over randomly selected test image paths
for path in image_paths:
    # make predictions and visualize the results
    make_predictions(unet, path)