In [1]:
from Model.FastSCNN import *
from Dataset.dataset import *

In [2]:
import os
from yaml import load, dump
try:
    from yaml import CLoader as Loader, CDumper as Dumper
except ImportError:
    from yaml import Loader, Dumper

# Parameters

The parameters are read from a [yaml](https://en.wikipedia.org/wiki/YAML) file.

In [3]:
with open("params.yml") as file:
    params = load(file, Loader=Loader)
    
params

{'dataset_path': '/mnt/sda/datasets/Audi/roadline',
 'crop_height': 960,
 'crop_width': 1920}

In [4]:
# required
dataset_path = params["dataset_path"]   # Path at which the dataset is located
crop_height  = params["crop_height"]    # Height of cropped/resized input image
crop_width   = params["crop_width"]     # Width of cropped/resized input image

# optional
num_epochs            = params.get("num_epochs", 100)                   # Number of epochs to train for
epoch_start           = params.get("epoch_start", 0)                    # Start counting epochs from this number
batch_size            = params.get("batch_size", 2)                     # Number of images in each batch
checkpoint_step       = params.get("checkpoint_step", 2)                # How often to save checkpoints (epochs)
validation_step       = params.get("validation_step", 2)                # How often to perform validation (epochs)
learning_rate         = params.get("learning_rate", 0.01)               # learning rate used for training
cuda                  = params.get("cuda", "0,1")                       # GPU ids used for training  
use_gpu               = params.get("use_gpu", True)                     # whether to user gpu for training
pretrained_model_path = params.get("pretrained_model_path", None)       # path to pretrained model
save_model_path       = params.get("save_model_path", "./checkpoints")  # path to save model

In [5]:
if crop_height * 2 != crop_width:
    raise AssertionError("Crop width must be exactly twice the size of crop height")

# Dataset

In [6]:
# Check to see if all required paths are present
if not os.path.join(dataset_path, "class_dict.csv"):
    raise AssertionError(os.path.join(dataset_path, "class_dict.csv") + " does not exist")

for directory in ("train", "train_labels", "test", "test_labels", "val", "val_labels"):
    if not os.path.isdir(os.path.join(dataset_path, directory)):
        raise AssertionError(os.path.join(dataset_path, directory) + " does not exist")

In [7]:
dataset = Dataset(dataset_path, crop_height, crop_width, mode="train")

# Train

In [8]:
model = FastSCNN(image_height   = crop_height,
                 image_width    = crop_width,
                 image_channels = 3,
                 num_classes    = 5)

In [9]:
num_parameters = sum(p.numel() for p in model.parameters())
num_parameters

1136245