In [None]:
from torchvision import transforms
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau

from ml_tools.ML_vision_datasetmaster import DragonDatasetVision
from ml_tools.ML_trainer import DragonTrainer
from ml_tools.ML_callbacks import DragonModelCheckpoint, DragonEarlyStopping, DragonLRScheduler
from ml_tools.ML_utilities import inspect_model_architecture
from ml_tools.ML_configuration import BinaryImageClassificationMetricsFormat, FinalizeBinaryImageClassification
from ml_tools import custom_logger

from rootpaths import PM
from visual_ccc.gradcam import custom_alexnet, SIZE_REQUIREMENT

## Binary classification: Dendrites, Spheroids

In [None]:
VAL_SIZE = 0.2
TEST_SIZE = 0.1
RANDOM_STATE = 101

vision_dataset = DragonDatasetVision.from_folder(PM.dataset)

vision_dataset.split_data(val_size=VAL_SIZE, 
                          test_size=TEST_SIZE,
                          random_state=RANDOM_STATE)

vision_dataset.configure_transforms(resize_size=int(1.2*SIZE_REQUIREMENT),
                                    crop_size=SIZE_REQUIREMENT,
                                    mean=None, std=None,
                                    pre_transforms=[transforms.Grayscale(num_output_channels=1)])

train_dataset, validation_dataset, test_dataset = vision_dataset.get_datasets()

In [None]:
class_map = vision_dataset.save_class_map(save_dir=PM.artifacts)

vision_dataset.save_transform_recipe(filepath=PM.transform_recipe_file)

In [None]:
INITIAL_LR = 0.0002
SCHEDULER_PATIENCE = 2
STOP_PATIENCE = 12

# Model
model = custom_alexnet()

inspect_model_architecture(model=model, save_dir=PM.artifacts)

# Optimizer
optimizer = AdamW(params=model.parameters(), lr=INITIAL_LR)

# Trainer
trainer = DragonTrainer(model=model,
                    train_dataset=train_dataset,
                    validation_dataset=validation_dataset,
                    kind="binary image classification",
                    optimizer=optimizer,
                    criterion="auto",
                    device="cuda:0",
                    checkpoint_callback=DragonModelCheckpoint(save_dir=PM.checkpoints, mode="min"),
                    early_stopping_callback=DragonEarlyStopping(patience=STOP_PATIENCE, mode="min"),
                    lr_scheduler_callback=DragonLRScheduler(scheduler=ReduceLROnPlateau(optimizer=optimizer, 
                                                                                        mode="min",
                                                                                        factor=0.7,
                                                                                        patience=SCHEDULER_PATIENCE))
                    )

In [None]:
BATCH_SIZE = 2

history = trainer.fit(save_dir=PM.artifacts, epochs=100, batch_size=BATCH_SIZE)

In [None]:
train_log = {
    "validation size": VAL_SIZE,
    "test size": TEST_SIZE,
    "images per dataset": vision_dataset.images_per_dataset(),
    "random state": RANDOM_STATE,
    "initial lr": INITIAL_LR,
    "scheduler patience": SCHEDULER_PATIENCE,
    "stop patience": STOP_PATIENCE,
    "batch size": BATCH_SIZE,
    "history": history
}

custom_logger(data=train_log,
              save_directory=PM.results,
              log_name="train_log",
              dict_as="json")

In [None]:
trainer.evaluate(
    save_dir=PM.metrics, 
    # model_checkpoint=loaded_best_path,
    model_checkpoint="latest",
    classification_threshold=0.442814,
    test_data=test_dataset,
    val_format_configuration=BinaryImageClassificationMetricsFormat(cmap='BuGn',
                                                                    ROC_PR_line="darkorange"),
    test_format_configuration=BinaryImageClassificationMetricsFormat(cmap='BuPu',
                                                                     ROC_PR_line="forestgreen")
)

In [None]:
finalizer = FinalizeBinaryImageClassification(filename="DendritesSpheroids",
                                              classification_threshold=0.442814,
                                              class_map=class_map)

trainer.finalize_model_training(model_checkpoint="current",
                                save_dir=PM.artifacts,
                                finalize_config=finalizer)