In [None]:
from ml_tools.ML_datasetmaster import DragonDataset
from ml_tools.ML_models import DragonGateModel
from ml_tools.ML_configuration import (
    FormatRegressionMetrics,
    FinalizeRegression, 
    DragonGateParams,
    DragonTrainingConfig
)

from ml_tools.ML_trainer import DragonTrainer
from ml_tools.ML_callbacks import DragonModelCheckpoint, DragonPatienceEarlyStopping, DragonPlateauScheduler
from ml_tools.ML_utilities import build_optimizer_params
from ml_tools.ML_utilities import inspect_model_architecture
from ml_tools.utilities import load_dataframe_with_schema
from ml_tools.IO_tools import train_logger
from ml_tools.schema import FeatureSchema
from ml_tools.keys import TaskKeys

from torch.optim import AdamW

from paths import PM

## 1. Config

In [None]:
train_config = DragonTrainingConfig(
    validation_size=0.2,
    test_size=0.1,
    initial_learning_rate=0.001,
    batch_size=64,
    task = TaskKeys.REGRESSION,
    device = "cuda",
    finalized_filename = "gate_step2.pth",
    random_state=101,
    
    early_stop_patience=20,
    scheduler_patience=3,
    scheduler_lr_factor=0.5,
    weight_decay=0.01,
)

## 2. Load Schema and Dataframe

In [None]:
schema = FeatureSchema.from_json(PM.engineering_artifacts_2)

df, _ = load_dataframe_with_schema(df_path=PM.step2_data_file, schema=schema)

## 3. Make Datasets

In [None]:
dataset = DragonDataset(pandas_df=df,
                        schema=schema,
                        kind=train_config.task,
                        feature_scaler="fit",
                        target_scaler="fit",
                        validation_size=train_config.validation_size,
                        test_size=train_config.test_size,
                        random_state=train_config.random_state,
                        )

## 4. Model and Trainer

In [None]:
model_params = DragonGateParams(
    schema=schema,
    out_targets=dataset.number_of_targets,
    embedding_dim=16,
    gflu_stages = 6,
    gflu_dropout = 0.1,
    num_trees= 20,
    tree_depth= 4,
    tree_dropout= 0.1,
    chain_trees= False,
    tree_wise_attention= True,
    tree_wise_attention_dropout= 0.1,
    binning_activation= "entmoid",
    feature_mask_function= "entmax",
    share_head_weights= True,
    batch_norm_continuous= True
)

model = DragonGateModel(**model_params)

# Initialize decision thresholds before training.
model.data_aware_initialization(train_dataset=dataset.train_dataset, num_samples=1000)

# optimizer
optim_params = build_optimizer_params(model=model, weight_decay=train_config.weight_decay)
optimizer = AdamW(params=optim_params, lr=train_config.initial_learning_rate)

trainer = DragonTrainer(model=model,
                        train_dataset=dataset.train_dataset,
                        validation_dataset=dataset.validation_dataset,
                        kind=train_config.task,
                        optimizer=optimizer,
                        device=train_config.device,
                        checkpoint_callback=DragonModelCheckpoint(save_dir=PM.train_checkpoints_2, 
                                                                  monitor="Validation Loss"),
                        early_stopping_callback=DragonPatienceEarlyStopping(patience=train_config.early_stop_patience, 
                                                                            monitor="Validation Loss"),
                        lr_scheduler_callback=DragonPlateauScheduler(monitor="Validation Loss",
                                                                     patience=train_config.scheduler_patience,
                                                                     factor=train_config.scheduler_lr_factor),  
                        )

## 5. Training

In [None]:
history = trainer.fit(save_dir=PM.train_artifacts_2, epochs=500, batch_size=train_config.batch_size)

## 6. Evaluation

In [None]:
trainer.evaluate(save_dir=PM.train_evaluation_2,
                model_checkpoint="best",
                test_data=dataset.test_dataset,
                val_format_configuration=FormatRegressionMetrics(scatter_color='mediumspringgreen'),
                test_format_configuration=FormatRegressionMetrics(scatter_color='darkmagenta'),
                )

## 7. Explanation

In [None]:
trainer.explain_captum(save_dir=PM.train_evaluation_2,
                       n_samples=1000,
                       n_steps=500)

## 8. Save artifacts

In [None]:
# Dataset artifacts
dataset.save_artifacts(PM.train_artifacts_2)

# Model artifacts
model.save_architecture(PM.train_artifacts_2)
inspect_model_architecture(model=model, save_dir=PM.train_artifacts_2)

# FeatureSchema
schema.to_json(PM.train_artifacts_2)

# Train log
train_logger(train_config=train_config,
             model_parameters=model_params,
             train_history=history,
             save_directory=PM.train_metrics_2)

## 9. Finalize Deep Learning

In [None]:
trainer.finalize_model_training(model_checkpoint='current',
                                save_dir=PM.train_artifacts_2,
                                finalize_config=FinalizeRegression(filename=train_config.finalized_filename,
                                                                    target_name=dataset.target_names[0]))