<h1>Training a Convolutional Neural Network for Foliage</h1>
<p>This example notebook is a more comprehensive example with a configuration for training a convolutional neural network for recognizing foliage. If you would like to learn step by step what is done, please refer to the  'example_config.ipynb' notebook.</p>

In [13]:
from inference_training import Configuration, ImageDataset
from inference_training import initCudaEnvironment, createTransforms
from inference_training import drawImageAndFeatureMasks
from inference_training import exportOnnxModel, writeONNXMeta, loadONNX
from inference_training import trainModel, saveModel, loadModel
from inference_training import createModelInstance, testInference
from inference_training import logger

In [14]:
initCudaEnvironment(numCudaDevices=1,
                    visibleCudaDevices="0",
                    clearCudaDeviceCount=False)

In [15]:
# train on the GPU or on the CPU, if a GPU is not available
config = Configuration()
print("Device: " + str(config.device))

trainDirectory = "/path/to/train_dataset/"
testDirectory = "/path/to/test_dataset/"

config.setDatasetPaths(trainPath=trainDirectory, testPath=testDirectory)
config.setFilePrefix("foliage_")
config.setModelName("foliage")
config.setInputSizes(inputWidth=250, inputHeight=250)
config.setInputCellSize(cellSizeM=0.25, minCellSizeM=0.1, maxCellSizeM=0.5)
config.setAutoLimitLabel(True)

logger.info("Version: " + str(config.version))

config.setModelInfo(channels=3, numClasses=8+1,  # (1 + background)
                    bboxOverlap=True, bboxPerImage=250, reuseModel=False)
config.setEpochs(0)

description = "Inference model to detect deciduous trees, pine trees, "\
    "heather, hedges, plants, reed, shrubbery, flowbeds. " \
    "Additionally regions of decidious trees without leaves can be detected."
config.setOnnxInfo(producer="Tygron", description=description)

config.addLegendEntry("Background", 0, "#00000000")
config.addLegendEntry("Deciduous Tree", 1, "#00ffbf")
config.addLegendEntry("Pine Tree", 2, "#12d900")
config.addLegendEntry("Heather", 3, "#f3a6b2")
config.addLegendEntry("Hedge", 4, "#8d5a99")
config.addLegendEntry("Shrubbery", 5, "#e80004")
config.addLegendEntry("Reed", 6, "#f8ff20")
config.addLegendEntry("Flowerbed", 7, "#b7484b")
config.addLegendEntry("Deciduous Tree (Leafless)", 8, "#e6994d")

config.setOnnxMetaData(scoreThreshold=0.2,
                       maskThreshold=0.3,
                       strideFraction=0.5)

config.setTensorInfo(tensorName='input_A:RGB_normalized', batchAmount=1)
trainingDataset = ImageDataset(config, True, createTransforms(True))
testDataset = ImageDataset(config, False, createTransforms(False))

logger.info("Train Image count: "+str(trainingDataset.__len__()))
logger.info("Test Image count: "+str(testDataset.__len__()))

trainingDataset.validate()
testDataset.validate()

logger.info("Pytorch model name " + config.getPytorchModelFileName())
logger.info("Onnx file name " + config.getOnnxFileName())

The datasets directory does not exist: /path/to/train_dataset
Please adjust the configured dataset directory in the configuration file!
The datasets directory does not exist: /path/to/test_dataset
Please adjust the configured dataset directory in the configuration file!
No files were found in the Training dataset
No files were found in the Test dataset


Device: cuda


In [17]:
imageNumber = 5
print(trainingDataset.getLabels(imageNumber))
drawImageAndFeatureMasks(config, trainingDataset, imageNumber)

IndexError: list index out of range

In [18]:
model = trainModel(config, trainingDataset, testDataset)
saveModel(config, model, path=config.getPytorchModelFileName())

No files were found in the Training dataset
Training dataset is invalid, please inspect the logs.


AttributeError: 'NoneType' object has no attribute 'state_dict'

In [None]:
model.eval()
testPrediction = testInference(config, model=model,
                               dataset=testDataset, imageNumber=88)

In [None]:
exportOnnxModel(config, model)

In [None]:
writeONNXMeta(config)

In [None]:
onnx_model = loadONNX(config)
print(f"metadata_props={onnx_model.metadata_props}")