In [None]:
from torchvision import transforms
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau

from ml_tools.ML_vision_datasetmaster import DragonDatasetVision
from ml_tools.ML_trainer import DragonTrainer
from ml_tools.ML_callbacks import DragonModelCheckpoint, DragonEarlyStopping, DragonLRScheduler
from ml_tools.ML_utilities import inspect_model_architecture
from ml_tools.ML_configuration import (BinaryImageClassificationMetricsFormat, 
                                       FinalizeBinaryImageClassification,  
                                       MultiClassImageClassificationMetricsFormat, 
                                       FinalizeMultiClassImageClassification, 
                                       DragonTrainingConfig)
from ml_tools.IO_tools import train_logger

from rootpaths import PM
from visual_ccc.gradcam import custom_alexnet, SIZE_REQUIREMENT
from visual_ccc.visualcnn_model import VisualCNN

In [None]:
train_config = DragonTrainingConfig(validation_size=0.2,
                                    test_size=0.1,
                                    initial_learning_rate=0.0002,
                                    batch_size=2,
                                    early_stop_patience=25,
                                    scheduler_patience=4,
                                    scheduler_lr_factor=0.5,
                                    classes="2-class",  # "2-class" or "3-class"
                                    model="visualcnn",   # "alexnet" or "visualcnn"
                                    finalized_file="V-DendritesSpheroids") #AlloysDendritesSpheroids

## Binary classification: Dendrites, Spheroids

In [None]:
if train_config.classes == "2-class": # type: ignore
    vision_dataset = DragonDatasetVision.from_folder(PM.two_dataset)
    TASK = "binary image classification"
else:
    vision_dataset = DragonDatasetVision.from_folder(PM.three_dataset)
    TASK = "multiclass image classification"

vision_dataset.split_data(val_size=train_config.validation_size, 
                          test_size=train_config.test_size,
                          random_state=train_config.random_state)

vision_dataset.configure_transforms(resize_size=int(1.2*SIZE_REQUIREMENT),
                                    crop_size=SIZE_REQUIREMENT,
                                    mean=None, std=None,
                                    pre_transforms=[transforms.Grayscale(num_output_channels=1)])

In [None]:
class_map = vision_dataset.save_class_map(save_dir=PM.artifacts)

vision_dataset.save_transform_recipe(filepath=PM.transform_recipe_file)

In [None]:
# Model
if train_config.model == "alexnet": # type: ignore
    model = custom_alexnet(classes=train_config.classes) # type: ignore
else:
    model = VisualCNN(classes=train_config.classes) # type: ignore

inspect_model_architecture(model=model, save_dir=PM.artifacts)

# Optimizer
optimizer = AdamW(params=model.parameters(), lr=train_config.initial_learning_rate)

# Trainer
trainer = DragonTrainer(model=model,
                    train_dataset=vision_dataset.train_dataset,
                    validation_dataset=vision_dataset.validation_dataset,
                    kind=TASK,
                    optimizer=optimizer,
                    criterion="auto",
                    device="cuda:0",
                    checkpoint_callback=DragonModelCheckpoint(save_dir=PM.checkpoints, mode="min"),
                    early_stopping_callback=DragonEarlyStopping(patience=train_config.early_stop_patience, mode="min"), # type: ignore
                    lr_scheduler_callback=DragonLRScheduler(scheduler=ReduceLROnPlateau(optimizer=optimizer, 
                                                                                        mode="min",
                                                                                        factor=train_config.scheduler_lr_factor, # type: ignore
                                                                                        patience=train_config.scheduler_patience)) # type: ignore
                    )

In [None]:
history = trainer.fit(save_dir=PM.artifacts, epochs=100, batch_size=train_config.batch_size)

In [None]:
train_logger(train_config=train_config,
             model_parameters={"Task": TASK},
             train_history=history,
             save_directory=PM.results)

In [None]:
# Configurations
if train_config.classes == "2-class": # type: ignore
    validation_configuration = BinaryImageClassificationMetricsFormat(cmap='BuGn', ROC_PR_line="darkorange")
    test_configuration = BinaryImageClassificationMetricsFormat(cmap='BuPu', ROC_PR_line="forestgreen")
else:
    validation_configuration = MultiClassImageClassificationMetricsFormat(cmap='YlGn', ROC_PR_line="darkorange")
    test_configuration = MultiClassImageClassificationMetricsFormat(cmap='Oranges', ROC_PR_line="forestgreen")

trainer.evaluate(save_dir=PM.metrics, 
                 model_checkpoint="latest",
                 classification_threshold=0.5,
                 test_data=vision_dataset.test_dataset,
                 val_format_configuration=validation_configuration,
                 test_format_configuration=test_configuration
                 )

In [None]:
# Finalizer
if train_config.classes == "2-class":  # type: ignore
    finalizer = FinalizeBinaryImageClassification(filename=train_config.finalized_file, # type: ignore
                                                classification_threshold=0.5,
                                                class_map=class_map)
else:
    finalizer = FinalizeMultiClassImageClassification(filename=train_config.finalized_file,  # type: ignore
                                                      class_map=class_map)

trainer.finalize_model_training(model_checkpoint="current",
                                save_dir=PM.artifacts,
                                finalize_config=finalizer)