In [1]:
import torch
import os

from torch.utils.data import Dataset
import cv2

from torch.nn import ConvTranspose2d
from torch.nn import Conv2d
from torch.nn import MaxPool2d
from torch.nn import Module
from torch.nn import ModuleList
from torch.nn import ReLU
from torchvision.transforms import CenterCrop
from torch.nn import functional as F
import torch

from torch.nn import BCEWithLogitsLoss
from torch.optim import Adam
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from torchvision import transforms
from imutils import paths
from tqdm import tqdm
import matplotlib.pyplot as plt
import time

In [2]:
DATASET_PATH = os.path.join("./../data")
IMAGE_DATASET_PATH = os.path.join(DATASET_PATH, "images")
MASK_DATASET_PATH = os.path.join(DATASET_PATH, "masks")


TEST_SPLIT = 0.15
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print("running on", DEVICE)

PIN_MEMORY = True if DEVICE == "cuda" else False

running on cuda


In [3]:
# define the number of channels in the input, number of classes,
# and number of levels in the U-Net model
NUM_CHANNELS = 1
NUM_CLASSES = 1
NUM_LEVELS = 3

# initialize learning rate, number of epochs to train for, and the
# batch size
INIT_LR = 0.001
NUM_EPOCHS = 40
BATCH_SIZE = 64

# define the input image dimensions
INPUT_IMAGE_WIDTH = 128
INPUT_IMAGE_HEIGHT = 128

# define threshold to filter weak predictions
THRESHOLD = 0.5

# define the path to the base output directory
BASE_OUTPUT = "output"

# define the path to the output serialized model, model training
# plot, and testing image paths
MODEL_PATH = os.path.join(BASE_OUTPUT, "unet_tgs_salt.pth")
PLOT_PATH = os.path.sep.join([BASE_OUTPUT, "plot.png"])
TEST_PATHS = os.path.sep.join([BASE_OUTPUT, "test_paths.txt"])

In [4]:

class SegmentationDataset(Dataset):
    def __init__(self, imagePaths, maskPaths, transforms):
        # store the image and mask filepaths, and augmentation
        # transforms
        self.imagePaths = imagePaths
        self.maskPaths = maskPaths
        self.transforms = transforms
    def __len__(self):
        # return the number of total samples contained in the dataset
        return len(self.imagePaths)
    def __getitem__(self, idx):
        # grab the image path from the current index
        imagePath = self.imagePaths[idx]
        # load the image from disk, swap its channels from BGR to RGB,
        # and read the associated mask from disk in grayscale mode
        image = cv2.imread(imagePath)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        mask = cv2.imread(self.maskPaths[idx], 0)
        # check to see if we are applying any transformations
        if self.transforms is not None:
            # apply the transformations to both image and its mask
            image = self.transforms(image)
            mask = self.transforms(mask)
        # return a tuple of the image and its mask
        return (image, mask)

In [5]:
class Block(Module):
    def __init__(self, inChannels, outChannels):
        super().__init__()
        # store the convolution and RELU layers
        self.conv1 = Conv2d(inChannels, outChannels, 3)
        self.relu = ReLU()
        self.conv2 = Conv2d(outChannels, outChannels, 3)
    def forward(self, x):
        # apply CONV => RELU => CONV block to the inputs and return it
        return self.conv2(self.relu(self.conv1(x)))

In [6]:
class Encoder(Module):
    def __init__(self, channels=(3, 16, 32, 64)):
        super().__init__()
        # store the encoder blocks and maxpooling layer
        self.encBlocks = ModuleList(
            [Block(channels[i], channels[i + 1])
                 for i in range(len(channels) - 1)])
        self.pool = MaxPool2d(2)
    def forward(self, x):
        # initialize an empty list to store the intermediate outputs
        blockOutputs = []
        # loop through the encoder blocks
        for block in self.encBlocks:
            # pass the inputs through the current encoder block, store
            # the outputs, and then apply maxpooling on the output
            x = block(x)
            blockOutputs.append(x)
            x = self.pool(x)
        # return the list containing the intermediate outputs
        return blockOutputs

In [7]:
class Decoder(Module):
    def __init__(self, channels=(64, 32, 16)):
        super().__init__()
        # initialize the number of channels, upsampler blocks, and
        # decoder blocks
        self.channels = channels
        self.upconvs = ModuleList(
            [ConvTranspose2d(channels[i], channels[i + 1], 2, 2)
                 for i in range(len(channels) - 1)])
        self.dec_blocks = ModuleList(
            [Block(channels[i], channels[i + 1])
                 for i in range(len(channels) - 1)])
    def forward(self, x, encFeatures):
        # loop through the number of channels
        for i in range(len(self.channels) - 1):
            # pass the inputs through the upsampler blocks
            x = self.upconvs[i](x)
            # crop the current features from the encoder blocks,
            # concatenate them with the current upsampled features,
            # and pass the concatenated output through the current
            # decoder block
            encFeat = self.crop(encFeatures[i], x)
            x = torch.cat([x, encFeat], dim=1)
            x = self.dec_blocks[i](x)
        # return the final decoder output
        return x
    def crop(self, encFeatures, x):
        # grab the dimensions of the inputs, and crop the encoder
        # features to match the dimensions
        (_, _, H, W) = x.shape
        encFeatures = CenterCrop([H, W])(encFeatures)
        # return the cropped features
        return encFeatures

In [8]:
class UNet(Module):
    def __init__(self, encChannels=(3, 16, 32, 64),
         decChannels=(64, 32, 16),
         nbClasses=1, retainDim=True,
         outSize=(INPUT_IMAGE_HEIGHT,  INPUT_IMAGE_WIDTH)):
        super().__init__()
        # initialize the encoder and decoder
        self.encoder = Encoder(encChannels)
        self.decoder = Decoder(decChannels)
        # initialize the regression head and store the class variables
        self.head = Conv2d(decChannels[-1], nbClasses, 1)
        self.retainDim = retainDim
        self.outSize = outSize
        
    def forward(self, x):
        # grab the features from the encoder
        encFeatures = self.encoder(x)
        # pass the encoder features through decoder making sure that
        # their dimensions are suited for concatenation
        decFeatures = self.decoder(encFeatures[::-1][0],
            encFeatures[::-1][1:])
        # pass the decoder features through the regression head to
        # obtain the segmentation mask
        map = self.head(decFeatures)
        # check to see if we are retaining the original output
        # dimensions and if so, then resize the output to match them
        if self.retainDim:
            map = F.interpolate(map, self.outSize)
        # return the segmentation map
        return map

In [9]:
imagePaths = sorted(list(paths.list_images(IMAGE_DATASET_PATH)))
maskPaths = sorted(list(paths.list_images(MASK_DATASET_PATH)))
# partition the data into training and testing splits using 85% of
# the data for training and the remaining 15% for testing
split = train_test_split(imagePaths, maskPaths,
    test_size=TEST_SPLIT, random_state=42)
# unpack the data split
(trainImages, testImages) = split[:2]
(trainMasks, testMasks) = split[2:]
# write the testing image paths to disk so that we can use then
# when evaluating/testing our model
#print("[INFO] saving testing image paths...")
#f = open(TEST_PATHS, "w")
#f.write("\n".join(testImages))
#f.close()

# define transformations
transforms = transforms.Compose([transforms.ToPILImage(),
    transforms.Resize((INPUT_IMAGE_HEIGHT,
        INPUT_IMAGE_WIDTH)),
    transforms.ToTensor()])
# create the train and test datasets
trainDS = SegmentationDataset(imagePaths=trainImages, maskPaths=trainMasks,
    transforms=transforms)
testDS = SegmentationDataset(imagePaths=testImages, maskPaths=testMasks,
    transforms=transforms)
print(f"[INFO] found {len(trainDS)} examples in the training set...")
print(f"[INFO] found {len(testDS)} examples in the test set...")
# create the training and test data loaders
trainLoader = DataLoader(trainDS, shuffle=True,
    batch_size=BATCH_SIZE, pin_memory=PIN_MEMORY,
    num_workers=os.cpu_count())
testLoader = DataLoader(testDS, shuffle=False,
    batch_size=BATCH_SIZE, pin_memory=PIN_MEMORY,
    num_workers=os.cpu_count())


# initialize our UNet model
unet = UNet().to(DEVICE)
# initialize loss function and optimizer
lossFunc = BCEWithLogitsLoss()
opt = Adam(unet.parameters(), lr=INIT_LR)
# calculate steps per epoch for training and test set
trainSteps = len(trainDS) // BATCH_SIZE
testSteps = len(testDS) // BATCH_SIZE
# initialize a dictionary to store training history
H = {"train_loss": [], "test_loss": []}



# loop over epochs
print("[INFO] training the network...")
startTime = time.time()
for e in tqdm(range(NUM_EPOCHS)):
    # set the model in training mode
    unet.train()
    # initialize the total training and validation loss
    totalTrainLoss = 0
    totalTestLoss = 0
    # loop over the training set
    for (i, (x, y)) in enumerate(trainLoader):
        # send the input to the device
        (x, y) = (x.to(DEVICE), y.to(DEVICE))
        # perform a forward pass and calculate the training loss
        pred = unet(x)
        loss = lossFunc(pred, y)
        # first, zero out any previously accumulated gradients, then
        # perform backpropagation, and then update model parameters
        opt.zero_grad()
        loss.backward()
        opt.step()
        # add the loss to the total training loss so far
        totalTrainLoss += loss
    # switch off autograd
    with torch.no_grad():
        # set the model in evaluation mode
        unet.eval()
        # loop over the validation set
        for (x, y) in testLoader:
            # send the input to the device
            (x, y) = (x.to(DEVICE), y.to(DEVICE))
            # make the predictions and calculate the validation loss
            pred = unet(x)
            totalTestLoss += lossFunc(pred, y)
    # calculate the average training and validation loss
    avgTrainLoss = totalTrainLoss / trainSteps
    avgTestLoss = totalTestLoss / testSteps
    # update our training history
    H["train_loss"].append(avgTrainLoss.cpu().detach().numpy())
    H["test_loss"].append(avgTestLoss.cpu().detach().numpy())
    # print the model training and validation information
    print("[INFO] EPOCH: {}/{}".format(e + 1, NUM_EPOCHS))
    print("Train loss: {:.6f}, Test loss: {:.4f}".format(
        avgTrainLoss, avgTestLoss))
# display the total time needed to perform the training
endTime = time.time()
print("[INFO] total time taken to train the model: {:.2f}s".format(
    endTime - startTime))


# plot the training loss
plt.style.use("ggplot")
plt.figure()
plt.plot(H["train_loss"], label="train_loss")
plt.plot(H["test_loss"], label="test_loss")
plt.title("Training Loss on Dataset")
plt.xlabel("Epoch #")
plt.ylabel("Loss")
plt.legend(loc="lower left")
plt.savefig(PLOT_PATH)
# serialize the model to disk
torch.save(unet, MODEL_PATH)

[INFO] found 3400 examples in the training set...
[INFO] found 600 examples in the test set...
[INFO] training the network...


  2%|█                                           | 1/40 [00:05<03:39,  5.63s/it]

[INFO] EPOCH: 1/40
Train loss: 0.592669, Test loss: 0.6123


  5%|██▏                                         | 2/40 [00:09<02:56,  4.65s/it]

[INFO] EPOCH: 2/40
Train loss: 0.569056, Test loss: 0.5981


  8%|███▎                                        | 3/40 [00:13<02:39,  4.31s/it]

[INFO] EPOCH: 3/40
Train loss: 0.536217, Test loss: 0.6283


 10%|████▍                                       | 4/40 [00:17<02:29,  4.14s/it]

[INFO] EPOCH: 4/40
Train loss: 0.495325, Test loss: 0.4510


 12%|█████▌                                      | 5/40 [00:21<02:20,  4.02s/it]

[INFO] EPOCH: 5/40
Train loss: 0.446368, Test loss: 0.4577


 15%|██████▌                                     | 6/40 [00:25<02:15,  3.97s/it]

[INFO] EPOCH: 6/40
Train loss: 0.415988, Test loss: 0.4167


 18%|███████▋                                    | 7/40 [00:29<02:11,  3.99s/it]

[INFO] EPOCH: 7/40
Train loss: 0.413708, Test loss: 0.4874


 20%|████████▊                                   | 8/40 [00:33<02:07,  3.99s/it]

[INFO] EPOCH: 8/40
Train loss: 0.411863, Test loss: 0.4069


 22%|█████████▉                                  | 9/40 [00:37<02:04,  4.01s/it]

[INFO] EPOCH: 9/40
Train loss: 0.386746, Test loss: 0.4216


 25%|██████████▊                                | 10/40 [00:41<02:00,  4.01s/it]

[INFO] EPOCH: 10/40
Train loss: 0.385673, Test loss: 0.4070


 28%|███████████▊                               | 11/40 [00:45<01:56,  4.00s/it]

[INFO] EPOCH: 11/40
Train loss: 0.393252, Test loss: 0.4244


 30%|████████████▉                              | 12/40 [00:49<01:52,  4.01s/it]

[INFO] EPOCH: 12/40
Train loss: 0.382038, Test loss: 0.3885


 32%|█████████████▉                             | 13/40 [00:53<01:48,  4.01s/it]

[INFO] EPOCH: 13/40
Train loss: 0.388870, Test loss: 0.4099


 35%|███████████████                            | 14/40 [00:57<01:43,  3.99s/it]

[INFO] EPOCH: 14/40
Train loss: 0.396601, Test loss: 0.4429


 38%|████████████████▏                          | 15/40 [01:01<01:39,  3.99s/it]

[INFO] EPOCH: 15/40
Train loss: 0.386988, Test loss: 0.3841


 40%|█████████████████▏                         | 16/40 [01:05<01:35,  3.98s/it]

[INFO] EPOCH: 16/40
Train loss: 0.385390, Test loss: 0.3937


 42%|██████████████████▎                        | 17/40 [01:09<01:31,  4.00s/it]

[INFO] EPOCH: 17/40
Train loss: 0.376810, Test loss: 0.3816


 45%|███████████████████▎                       | 18/40 [01:13<01:27,  3.97s/it]

[INFO] EPOCH: 18/40
Train loss: 0.385474, Test loss: 0.3942


 48%|████████████████████▍                      | 19/40 [01:17<01:23,  3.98s/it]

[INFO] EPOCH: 19/40
Train loss: 0.373144, Test loss: 0.3933


 50%|█████████████████████▌                     | 20/40 [01:20<01:19,  3.97s/it]

[INFO] EPOCH: 20/40
Train loss: 0.378647, Test loss: 0.3891


 52%|██████████████████████▌                    | 21/40 [01:24<01:15,  3.99s/it]

[INFO] EPOCH: 21/40
Train loss: 0.372301, Test loss: 0.3932


 55%|███████████████████████▋                   | 22/40 [01:28<01:11,  3.98s/it]

[INFO] EPOCH: 22/40
Train loss: 0.368385, Test loss: 0.3929


 57%|████████████████████████▋                  | 23/40 [01:32<01:07,  3.97s/it]

[INFO] EPOCH: 23/40
Train loss: 0.353082, Test loss: 0.3766


 60%|█████████████████████████▊                 | 24/40 [01:36<01:03,  4.00s/it]

[INFO] EPOCH: 24/40
Train loss: 0.372754, Test loss: 0.3793


 62%|██████████████████████████▉                | 25/40 [01:40<00:59,  3.98s/it]

[INFO] EPOCH: 25/40
Train loss: 0.361836, Test loss: 0.3714


 65%|███████████████████████████▉               | 26/40 [01:44<00:55,  4.00s/it]

[INFO] EPOCH: 26/40
Train loss: 0.351548, Test loss: 0.3749


 68%|█████████████████████████████              | 27/40 [01:48<00:51,  3.99s/it]

[INFO] EPOCH: 27/40
Train loss: 0.352792, Test loss: 0.3587


 70%|██████████████████████████████             | 28/40 [01:52<00:48,  4.00s/it]

[INFO] EPOCH: 28/40
Train loss: 0.346817, Test loss: 0.3732


 72%|███████████████████████████████▏           | 29/40 [01:56<00:43,  3.99s/it]

[INFO] EPOCH: 29/40
Train loss: 0.354234, Test loss: 0.3760


 75%|████████████████████████████████▎          | 30/40 [02:00<00:39,  3.99s/it]

[INFO] EPOCH: 30/40
Train loss: 0.363867, Test loss: 0.3742


 78%|█████████████████████████████████▎         | 31/40 [02:04<00:35,  3.99s/it]

[INFO] EPOCH: 31/40
Train loss: 1185778.375000, Test loss: 0.3871


 80%|██████████████████████████████████▍        | 32/40 [02:08<00:31,  3.99s/it]

[INFO] EPOCH: 32/40
Train loss: 0.340187, Test loss: 0.3394


 82%|███████████████████████████████████▍       | 33/40 [02:12<00:27,  3.97s/it]

[INFO] EPOCH: 33/40
Train loss: 0.333442, Test loss: 0.3341


 85%|████████████████████████████████████▌      | 34/40 [02:16<00:23,  4.00s/it]

[INFO] EPOCH: 34/40
Train loss: 0.336080, Test loss: 0.3596


 88%|█████████████████████████████████████▋     | 35/40 [02:20<00:20,  4.00s/it]

[INFO] EPOCH: 35/40
Train loss: 0.327370, Test loss: 0.3362


 90%|██████████████████████████████████████▋    | 36/40 [02:24<00:15,  3.99s/it]

[INFO] EPOCH: 36/40
Train loss: 0.331734, Test loss: 0.3613


ERROR: Unexpected segmentation fault encountered in worker.
 90%|██████████████████████████████████████▋    | 36/40 [04:43<00:31,  7.88s/it]


RuntimeError: DataLoader worker (pid(s) 33805) exited unexpectedly