Skip to content

Commit

Permalink
[Metrics] Cleaned code
Browse files Browse the repository at this point in the history
  • Loading branch information
YanSte committed Sep 9, 2023
1 parent 5fc96f5 commit eb3987e
Showing 1 changed file with 79 additions and 11 deletions.
90 changes: 79 additions & 11 deletions src/skit/ModelMetrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,16 @@ class Metric(Enum):

@property
def train_metric_key(self):
"""
Get the training metric key corresponding to the Metric enum value.
"""
return self.value

@property
def val_metric_key(self):
"""
Get the validation metric key corresponding to the Metric enum value.
"""
if self == Metric.ACCURACY:
return "val_accuracy"
elif self == Metric.AUC:
Expand All @@ -56,12 +62,6 @@ def val_metric_key(self):
def plot_labels(self):
"""
Get the curve labels corresponding to the given Metric enum.
Parameters:
metric_enum (Metric): The Metric enum value.
Returns:
dict: A dictionary mapping curve labels to metric names.
"""
if self == Metric.ACCURACY or self == Metric.VAL_ACCURACY:
return {
Expand All @@ -88,6 +88,16 @@ def plot_labels(self):

class ModelMetrics:
def __init__(self, versions, metric_to_monitor=Metric.ACCURACY):
"""
Initialize ModelMetrics class.
Parameters
----------
versions : list
List of model versions to track.
metric_to_monitor : Metric
The metric to monitor (default is Accuracy).
"""
self.output = {}
self.metric_to_monitor = metric_to_monitor
for version in versions:
Expand All @@ -99,6 +109,14 @@ def __init__(self, versions, metric_to_monitor=Metric.ACCURACY):
}

def reset(self, version=None):
"""
Reset the tracking for a specific version or all versions.
Parameters
----------
version : str, optional
The specific version to reset. If None, reset all versions.
"""
default_dict = {
"history": None,
"duration": None,
Expand All @@ -114,6 +132,17 @@ def reset(self, version=None):
self.output[version] = default_dict.copy()

def get_best_metric(self, version):
"""
Get the best training and validation metrics for a specific model version.
Parameters
----------
Args:
version (str): The model version to retrieve metrics for.
Returns:
dict: Dictionary containing best training and validation metrics.
"""
history = self.output[version]['history'].history

train_metric_key = self.metric_to_monitor.train_metric_key
Expand All @@ -137,11 +166,9 @@ def get_best_report(self, version):
version : str
The model version for which to get the best model report.
Returns
-------
dict or None
The best model report containing training and validation metrics, duration, and paths.
Returns None if the specified version is not found in the output.
Returns:
dict or None: The best model report containing training and validation metrics,
duration, and paths. Returns None if the specified version is not found in the output.
"""
if version not in self.output:
return None
Expand Down Expand Up @@ -287,6 +314,47 @@ def show_history(
version,
figsize=(8,6)
):
"""
Visualizes the training and validation metrics from the model's history using matplotlib.
The function generates separate plots for each main category (like 'Accuracy' and 'Loss')
defined in the `plot` parameter. For each main category, multiple curves (like 'Training Accuracy'
and 'Validation Accuracy') can be plotted based on the nested dictionary values.
Parameters:
-----------
history : dict
The history object typically returned from the .fit() method of a Keras model. It should
have a 'history' attribute containing the training and validation metrics.
figsize : tuple, optional
The width and height in inches for the figure. Defaults to (8,6).
plot : dict, optional
A nested dictionary defining the metrics to be plotted.
- The top-level key corresponds to the main category (e.g., 'Accuracy' or 'Loss').
- The associated nested dictionary's keys are the curve labels (e.g., 'Training Accuracy')
and the values are the corresponding metric names in the 'history' object (e.g., 'accuracy').
Defaults to plotting both training and validation accuracy and loss.
Example:
--------
show_history(
model_history,
figsize=(10,8),
plot={
"Title A": {
"Legend Title 1": "metric_name_1",
"Legend Title 2": "metric_name_2"
}
}
)
Note:
-----
The `plot` parameter allows you to customize which metrics to plot and how they are labeled
in the generated visualization.
"""
history = self.output[version]['history']
plot = self.metric_to_monitor.plot_labels
display(show_history(history, figsize=figsize, plot=plot))
Expand Down

0 comments on commit eb3987e

Please sign in to comment.