# Training notebook

In this notebook, we will show how to train a UNet network to segment defects in hazelnut images from the MVTec Anomaly Detection [dataset](https://www.mvtec.com/company/research/datasets/mvtec-ad).


First, import all the necessary libraries:

In [None]:
import os
import tensorflow as tf

from omegaconf import OmegaConf

import namegenerator as namegen
from models.model_functions import create_model, load_model
from models.train_test.train import train_model
from models.saving import (load_params, save_losses, save_model, save_params)
from utils.utils import (create_metadata, save_metadata, create_color_map, create_category_dict)

from processing.preprocessing import preprocess_data_from_images
from sklearn.model_selection import train_test_split

All configuration parameters are defined in `env.yaml` file. These parameters include information about the location of the dataset and results folders, the architecture of the network, and its training parameters.

In [None]:
# load environment variables
cfg = OmegaConf.load('configs/env.yaml')

# used for inference only
categ_dict = create_category_dict(cfg.MODEL.categories)
color_map = create_color_map(cfg.MODEL.categories)

print('Class categories', categ_dict)

num_classes = len(cfg.MODEL.categories) + 1

We can either load existing model and use it for transfer learning, or create a new one. 

In [None]:
if cfg.MODEL.model_name == '': 
    # generate a new model name
    model_name = namegen.gen(separator='_')

    model = create_model(eval(cfg.DATA.img_dims), num_classes, cfg.TRAINING.filters)
else:
    # use existing model with its proper name
    model_name = cfg.MODEL.model_name

    model = load_model(model_name)
    load_params(model_name)

print(f'Model name is {model_name}')    
    
metadata = create_metadata(model_name)

model.summary()

Here we load and preprocess images and masks from the dataset. Further we split it into training and testing sets. We will use only the training subset for training the model. 

In [None]:
# === DATASET LOADING AND PREPROCESSING === #
X, y = preprocess_data_from_images(data_path = cfg.DIRS.data, 
                                   shape = eval(cfg.DATA.img_dims),
                                   categories = cfg.MODEL.categories)

#=== TRAIN/TEST SPLIT === #
X_train, X_test, y_train, y_test = train_test_split(X, y,
                                                    test_size=cfg.DATA.test_split,
                                                    shuffle=cfg.TRAINING.shuffle,
                                                    random_state=cfg.DATA.seed)

print(f'Number of TRAIN images: {len(X_train)}')
print(f'Number of TEST images: {len(X_test)}')

Now it's time to train our model

In [None]:
history = train_model(model, model_name, Xs=X_train, Ys=y_train, cfg=cfg)

After training, we save the model as well as loss function and metric evolution plots to files.

In [None]:
# == Saving model informations == #
PATH_RESULTS = os.path.join(cfg.DIRS.results, model_name)
PATH_LOG = os.path.join(cfg.DIRS.history, model_name)

os.makedirs(PATH_LOG, exist_ok=True)

save_losses(history, PATH_RESULTS)
save_model(model, PATH_LOG)
save_metadata(metadata, PATH_LOG)
save_params(PATH_LOG, cfg)

Now you can find the results in the folders provided in `env.yaml`. By default, the results are saved in `results` and model checkpoints are saved in `model_versioning` respectively.

We show how to perform inference (make predictions) in a separate notebook.