In [None]:
import os
import sys

current_dir = os.getcwd()

parent_dir = os.path.abspath(os.path.join(current_dir, ".."))

if parent_dir not in sys.path:
    sys.path.append(parent_dir)

In [None]:
from tools.config import setup

from treeseg.data.dataset import prepare_datasets
from treeseg.data.imagepaths import get_image_label_paths

from treeseg.modeling.trainer import resume_or_load, trainer
from treeseg.evaluation.evaluator import evaluator

In [None]:
def run(conf, eval_only):
    assert os.path.exists(
        conf.data_folder
    ), f"Data folder {conf.data_folder} does not exist."

    train_images, train_labels, test_images, test_labels = get_image_label_paths(
        conf.data_folder
    )

    train_dataset, val_dataset, test_dataset = prepare_datasets(
        train_images, train_labels, test_images, test_labels, conf
    )

    model = resume_or_load(conf)

    if eval_only:
        print("Evaluation only mode")
        evaluator(model, test_dataset, len(test_images), conf.test_batch_size)

    else:
        print("Training mode")
        trainer(model, train_dataset, val_dataset, len(train_images), conf)

    return model


if __name__ == "__main__":
    config_file_path = "../configs/kokonet_bs8_cs256.txt"

    conf = setup(config_file_path)

    # Modified Config Variables for Local Execution
    conf.data_folder = "/Users/anis/Documents/AerialImageModel_ITD"
    conf.output_dir = os.path.join("..", conf.output_dir)

    eval_only = True

    run(conf, eval_only)