# Setting up and training a SimpleVGG16 model

This notebook assumes, that your images are neatly segregeted in subdirectories, for example:
```
├── training
|   ├── cat
|   |   ├──cat0001.jpg
.   .   .
.   .   .
|   |   └──cat0999.jpg
|   ├── drop_bear
|   |   ├──dropbear01.jpg
.   .   .
.   .   .
|   |   └──dr0p-b34r.jpg
.   .
.   .
|   └── zebra
|       ├──zebra.jpg
.       .
.       .
|       └──z12.jpg
└── validation
    ├── cat
    |   ├──cat1000.jpg
    .   .
    .   .
    |   └──cat1200.jpg
    ├── drop_bear
    |   ├──koala1.jpg
    .   .
    .   .
    |   └──kbear.jpg
    .
    .
    └── zebra
        ├──z1.jpg
        .
        .
        └──striped_horse.jpg
```

In [1]:
import json

### Manage configuration file inside the notebook
def backup_config():
    i = 1
    while os.path.exists("config.json.bak%d" % i):
        i += 1
    os.rename("config.json", "config.json.bak%d" % i)

def write_config(updated_config: dict):
    with open("config.json", "w") as cf:
        json.dump(CONFIG, cf, indent=4)

CONFIG = {}
def reload_config():
    global CONFIG
    with open("config.json", "r") as cf:
        CONFIG = json.load(cf)
        
### Load constants from configuration file
reload_config()

## Removing invalid images
The purpose of the script below is to remove damaged files and non-RGB files.

In [2]:
from PIL import Image
from itertools import chain
import os
import json

from os import path

training_dir = CONFIG["training_dir"]
validation_dir = CONFIG["training_dir"]

to_remove = []
    

for dirpath, _, filenames in chain(os.walk(training_dir), os.walk(validation_dir)):
    for filename in filenames:
        full_path = path.join(dirpath, filename)
        try:
            img = Image.open(full_path)
        except IOError as e:
            print("Bad file:", str(e))
            to_remove.append(full_path)
        finally:
            if img.mode != "RGB":
                print("Image mode is not RGB", full_path)
                to_remove.append(full_path)             

if not to_remove:
    print("No bad images found, nice!")

No bad images found, nice!


In [3]:
### The files displayed above will be removed from the disk ###
for file_path in to_remove:
    os.remove(file_path)
print(f"Removed {len(to_remove)} bad images")

Removed 0 bad images


## Adding classes to config file
All subdirectories from the training_dir will be stored as class names

In [4]:
print("Classes in config:", CONFIG["classes"])

_, dirs, _ = next(os.walk(CONFIG["training_dir"]))
print("Classes after update:", dirs)

CONFIG["classes"] = dirs
backup_config()
write_config(CONFIG)

Classes in config: ['slav', 'wagon']
Classes after update: ['slav', 'wagon']


## Review important settings
If the notebook crashes due to lack of memory, halve the batch_size and restart the notebook
To change any of these values, update them in `config.json` and then reload_config() in this notebook

In [5]:
print("train_epochs:", CONFIG["train_epochs"]) # How long should the model be trained for
print("batch_size", CONFIG["batch_size"]) # decrease if you run out of memory
print("model_name", CONFIG["model_name"]) # check the model directory for more models
print("input_width", CONFIG["input_width"]) # Downscale training images to this size
print("input_height", CONFIG["input_height"]) # Should be the same as input_width
print("patch_size", CONFIG["patch_size"]) # The size of the generated patch (cannot be changed later!)
print("target_class", CONFIG["target_class"]) # The class to generate patch for (can be changed later)

train_epochs: 40
batch_size 1
model_name SimpleVGG16
input_width 224
input_height 224
patch_size [224, 224]
target_class slav


## Training the model
For better performance use `python train.py

In [None]:
from train import train

train(weights_output="network_weights.h5")