In [1]:
# config file
import torch
import os

# Base path of dataset
dataset_path = os.path.join('seg', 'train')

# define path to images and masks dataset
image_dataset_path = os.path.join(dataset_path, 'images')
mask_dataset_path = os.path.join(dataset_path, 'masks')

# define the test split
test_split = 0.15

# determine the device to be used for training and evaluation
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# determine if we will be pinning memory during data loading
pin_memory = True if device == 'cuda' else False

# determine the number of channels in input, number of classes and number of levels in u-net model
num_channels = 1
num_classes = 1
num_levels = 3

# initialize the learning rate, number of epochs to train for and batch size
init_lr = 0.001
num_epochs = 40
batch_size = 64

# define input image dimensions
input_image_width = 128
input_image_height = 128

# define threshold to filter weak predictions
threshold = 0.5

#define path to base output directory
base_output = 'output'

# define path to output serialized model, model training plot and testing image paths
model_path = os.path.join(base_output, 'unet_tgs_salt.pth')
plot_path = os.path.sep.join([base_output, 'unet_tgs_salt.png'])
test_path = os.path.sep.join([base_output, 'test_paths.txt'])

In [2]:
# dataset
from torch.utils.data import Dataset
import cv2 as cv

class SegmentationDataset(Dataset):
    def __init__(self, image_paths, mask_paths, transforms):
        self.image_paths = image_paths
        self.mask_paths = mask_paths
        self.transforms = transforms
        
    def __len__(self):
        # return the number of total samples in dataset
        return len(self.image_paths)
    
    def __get_item__(self, idx):
        # grab image path from current index
        image_path = self.image_paths[idx]
        
        # load image from disk, swap its channels from BGR to RGB and read the associated mask from disk in grayscale mode
        image = cv.imread(image_path)
        image = cv.cvtColor(image, cv.COLOR_BGR2RGB)
        mask = cv.imread(self.mask_paths[idx], 0)
        
        # check to see if we are applying any transforms
        if self.transforms is not None:
            # apply transforms to both image and mask
            image = self.transforms(image)
            mask = self.transforms(mask)
            
        # return image and mask
        return (image, mask)

ModuleNotFoundError: No module named 'cv2'