<img align="left" src="https://panoptes-uploads.zooniverse.org/project_avatar/86c23ca7-bbaa-4e84-8d8a-876819551431.png" type="image/png" height=100 width=100>
</img>
<h1 align="right">KSO Tutorials #5: Train ML models</h1>
<h3 align="right">Written by the KSO Team</h3>

# 1. Set up and requirements

### Install and import Python packages

In [None]:
from IPython.display import clear_output

try:
    import google.colab
    import os

    IN_COLAB = True
    print("Running in Colab...")

    # Clone repo
    !git clone --recurse-submodules https://github.com/ocean-data-factory-sweden/koster_yolov4.git
    !pip install -q --upgrade pip
    !pip install -qr koster_yolov4/requirements.txt
    !pip install -qr koster_yolov4/yolov5_tracker/requirements.txt

    # Fix libmagic issue
    !apt-get -qq update && apt-get -qq install -y libmagic-dev > /dev/null

    # Replace upsampling script with custom version
    os.chdir("koster_yolov4/tutorials")
    !mv ../src/upsampling.py /usr/local/lib/python3.7/dist-packages/torch/nn/modules/upsampling.py

    # Replace nearest neighbours script with custom version (due to relative path issue)
    !cp ../src/multi_tracker_zoo.py ../yolov5_tracker/trackers/strong_sort/multi_tracker_zoo.py

    # Enable external widgets
    from google.colab import output

    output.enable_custom_widget_manager()

    # Ensure widgets are shown properly
    !jupyter nbextension enable --user --py widgetsnbextension
    !jupyter nbextension enable --user --py jupyter_bbox_widget

    print("All packages are installed and ready to go!")
    try:
        clear_output()
        print("All packages are installed and ready to go!")
    except:
        clear_output()
        print("There have been some issues installing the packages!")
except:
    IN_COLAB = False
    import sys
    import pkgutil

    if pkgutil.find_loader("torch") is None:
        !pip install -q --upgrade pip
        !pip install -q torch==1.8.0 torchvision==0.9.0

    # Replace nearest neighbours script with custom version (due to relative path issue)
    !cp ../src/multi_tracker_zoo.py ../yolov5_tracker/trackers/strong_sort/multi_tracker_zoo.py
    # Ensure widgets are shown properly
    !jupyter nbextension enable --user --py widgetsnbextension
    !jupyter nbextension enable --user --py jupyter_bbox_widget
    clear_output()
    print("Running locally... you're good to go!")

In [None]:
# Set the directory of the libraries
import sys, os

sys.path.append("..")

# Enables testing changes in utils
%load_ext autoreload
%autoreload 2

# Import required modules
from pathlib import Path
from ipyfilechooser import FileChooser
import kso_utils.tutorials_utils as t_utils
import kso_utils.project_utils as p_utils
import kso_utils.server_utils as s_utils
import kso_utils.t5_utils as t5
import wandb

clear_output()
print("Packages loaded successfully")

# Select the model type for training
model_type = t5.choose_model_type()

In [None]:
# Model-specific imports
if model_type.value == 1:
    import yolov5.train as train
    import yolov5.detect as detect
    import yolov5.val as val

    print("Object detection model loaded")
elif model_type.value == 2:
    import yolov5.classify.train as train
    import yolov5.classify.predict as detect
    import yolov5.classify.val as val

    print("Image classification model loaded")
elif model_type.value == 3:
    import yolov5.segment.train as train
    import yolov5.segment.predict as detect
    import yolov5.segment.val as val

    print("Image segmentation model loaded")
else:
    print("Invalid model specification")

# 2. Train the model

🔴 <span style="color:red">&nbsp;NOTE: To be able to train your own models, you will need access to the Koster WANDB group. You may request this access by contacting jurie.germishuys@combine.se. </span>

### Choose your project

In [None]:
project_name = t_utils.choose_project()

In [None]:
project = p_utils.find_project(project_name=project_name.value)

In [None]:
# Only for Template Project (downloading prepared data)
s_utils.get_ml_data(project)

### Configure data paths

In [None]:
# Specify path containing the images and labels folders.
output_folder = t_utils.choose_folder(
    project.photo_folder if not project.photo_folder == "None" else ".", "output"
)

🔴 <span style="color:red">&nbsp;NOTE: Each model type requires a specific folder structure to be in place. To be able to train your own Object Detection models, your data_path must contain a yml file for data and hyperparameters. See https://github.com/ultralytics/yolov5/wiki/Train-Custom-Data#11-create-datasetyaml. For image classification models, there should be 3 folders (train, val, test) each containing images in class_name folders. For segmentation models, polygon coordinates are also required. </span>

In [None]:
# Fix important paths
data_path, hyps_path = t5.setup_paths(output_folder.selected, model_type.value)
project_path = str(Path(output_folder.selected, project.Project_name.lower()))

### Choose a suitable experiment name

In [None]:
exp_name = t5.choose_experiment_name()

### Choose model to use for training

In [None]:
# Specify path to download baseline model
download_folder = t_utils.choose_folder(
    project.photo_folder if not project.photo_folder == "None" else ".",
    "model download",
)

In [None]:
weights = t5.choose_baseline_model(download_folder.value)

### Train model with given configuration

In [None]:
batch_size, epochs = t5.choose_train_params()

In [None]:
if model_type.value == 1:
    train.run(
        entity="koster",
        data=data_path,
        hyp=hyps_path,
        weights=weights.artifact_path,
        project=os.path.basename(project_path).replace(" ", "_").lower(),
        name=exp_name.value,
        img_size=[720, 540],
        batch_size=int(batch_size.value),
        epochs=epochs.value,
        workers=1,
        single_cls=False,
        cache_images=True,
    )
elif model_type.value == 2:
    train.run(
        entity="koster",
        data=data_path,
        model=weights.artifact_path,
        project=os.path.basename(project_path).replace(" ", "_").lower(),
        name=exp_name.value,
        img_size=224,
        batch_size=int(batch_size.value),
        epochs=epochs.value,
        workers=1,
    )
else:
    print("Segmentation model training not yet supported.")

# 3. Evaluate model performance

In [None]:
conf_thres = t5.choose_eval_params()

In [None]:
# Choose model
eval_model = FileChooser(project_path)
display(eval_model)

In [None]:
# Find trained model weights
tuned_weights = f"{Path(project_path, eval_model.selected, 'weights', 'best.pt')}"

In [None]:
# Evaluate YOLO Model on Unseen Test data
val.run(
    data=data_path,
    weights=tuned_weights,
    conf_thres=conf_thres.value,
    imgsz=640 if model_type.value == 1 else 224,
    half=False,
)

# (Optional) : 4. Enhance annotations using trained model

Enhancement uses the trained model to increase the amount of annotations in the training data. This should only be done in cases where it is absolutely necessary as bad predictions lead to worse predictions when used to train the next iteration of the model. 


🔴 <span style="color:red">&nbsp;NOTE: We recommend using a relatively high confidence threshold when enhancing trained models as low confidence predictions could significantly impact the quality of your annotated data. This is currently only available for object detection models.  </span>

In [None]:
eh_conf_thres = t5.choose_eval_params()

In [None]:
if model_type.value == 1:
    detect.run(
        weights=tuned_weights,
        source=output_folder.selected + "/images",
        imgsz=[640, 640],
        conf_thres=eh_conf_thres.value,
        save_txt=True,
    )
elif model_type.value == 2:
    print("Enhancements not supported for image classification models at this time.")
else:
    print("Enhancements not supported for segmentation models at this time.")

### Choose run to use as enhanced annotations

In [None]:
runs = FileChooser(".")
display(runs)

In [None]:
if model_type.value == 1:
    !mv {output_folder}"/labels" {output_folder}"/labels_org"
    !mv {runs.selected}"/labels" {output_folder}"/labels"

#### Once you have moved the new labels to the original label location, you can return to Step 2 and train your model again. 

🔴 <span style="color:red">&nbsp;NOTE: Run this cell to complete WANDB run, OR else artifacts will not be shown.

In [None]:
wandb.finish()

In [None]:
# END