Skip to content

Commit

Permalink
revised model calibration and its unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
Adeemy committed Jun 9, 2024
1 parent 065effb commit 56f9826
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 42 deletions.
17 changes: 10 additions & 7 deletions src/training/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ def main(
comparison_metric_name = config.params["train"]["comparison_metric"]
exp_keys_file_name = config.params["files"]["experiments_keys_file_name"]
train_set_file_name = config.params["files"]["train_set_file_name"]
valid_set_file_name = config.params["files"]["valid_set_file_name"]
test_set_file_name = config.params["files"]["test_set_file_name"]
ve_registered_model_name = config.params["modelregistry"][
"voting_ensemble_registered_model_name"
Expand All @@ -76,6 +77,10 @@ def main(
data_dir / train_set_file_name,
)

valid_set = pd.read_parquet(
data_dir / valid_set_file_name,
)

test_set = pd.read_parquet(
data_dir / test_set_file_name,
)
Expand Down Expand Up @@ -150,14 +155,12 @@ def main(
best_model_exp_obj.log_metrics(test_scores)

# Calibrate champ model before deployment
training_features = train_set.drop(class_col_name, axis=1)
training_class = np.array(train_set[class_col_name])
valid_features = valid_set.drop(class_col_name, axis=1)
valid_class = np.array(valid_set[class_col_name])
calib_pipeline = champ_model_manager.calibrate_pipeline(
train_features=training_features,
train_class=training_class,
preprocessor_step=best_model_pipeline.named_steps["preprocessor"],
selector_step=best_model_pipeline.named_steps["selector"],
model=best_model_pipeline.named_steps["classifier"],
valid_features=valid_features,
valid_class=valid_class,
fitted_pipeline=best_model_pipeline,
cv_folds=calib_cv_folds,
)

Expand Down
59 changes: 31 additions & 28 deletions src/training/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,49 +1212,52 @@ def select_best_performer(

return best_challenger_name

@staticmethod
def calibrate_pipeline(
self,
train_features: pd.DataFrame,
train_class: np.ndarray,
preprocessor_step: ColumnTransformer,
selector_step: VarianceThreshold,
model: Callable,
valid_features: pd.DataFrame,
valid_class: np.ndarray,
fitted_pipeline: Pipeline,
cv_folds: int = 5,
) -> Pipeline:
"""Takes a fitted pipeline and returns a calibrated pipeline.
Args:
train_features (pd.DataFrame): train features.
train_class (np.ndarray): train class labels.
preprocessor_step (ColumnTransformer): data preprocessing step.
selector_step (VarianceThreshold): feature selection step.
model (Callable): model object.
cv_folds (int): number of cross-validation
folds for calibration.
valid_features (np.ndarray): Validation features.
valid_class (np.ndarray): Validation class labels.
fitted_pipeline (Pipeline): Fitted pipeline on the training set.
cv_folds (int): Number of cross-validation folds for calibration.
Returns:
calib_pipeline (Pipeline): calibrated pipeline.
calib_pipeline (Pipeline): Calibrated pipeline.
"""

# Fit a pipeline with a calibrated model
# Extract preprocessor, selector, and classifier from the fitted pipeline
preprocessor = fitted_pipeline.named_steps.get("preprocessor")
selector = fitted_pipeline.named_steps.get("selector")
model = fitted_pipeline.named_steps.get("classifier")

if not hasattr(model, "classes_"):
raise ValueError("The classifier in the fitted pipeline is not fitted.")

# Calibrate the newly fitted model using the validation set
calibrator = CalibratedClassifierCV(
base_estimator=model,
method=("isotonic" if len(valid_class) > 1000 else "sigmoid"),
cv=cv_folds, # Indicate that the model is already fitted
)

# Fit the calibrator on the validation set
calibrator.fit(valid_features, valid_class)

# Create a new pipeline with the calibrated classifier
calib_pipeline = Pipeline(
steps=[
("preprocessor", preprocessor_step),
("selector", selector_step),
(
"classifier",
CalibratedClassifierCV(
estimator=model,
method="isotonic" if len(train_class) > 1000 else "sigmoid",
cv=cv_folds,
),
),
("preprocessor", preprocessor),
("selector", selector),
("classifier", calibrator),
]
)

# Fit pipelines
calib_pipeline.fit(train_features, train_class)

return calib_pipeline

def log_and_register_champ_model(
Expand Down
41 changes: 34 additions & 7 deletions tests/test_training/test_training_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -994,23 +994,50 @@ def test_calibrate_pipeline():
mock_self = ModelChampionManager(champ_model_name="champion_model")

# Set up the test data
train_features = pd.DataFrame(
valid_features = pd.DataFrame(
np.random.randint(0, 100, size=(100, 4)), columns=list("ABCD")
)
train_class = np.random.randint(2, size=100)
valid_class = np.random.randint(2, size=100)
preprocessor_step = ColumnTransformer(
transformers=[("scaler", StandardScaler(), ["A", "B", "C", "D"])]
)
selector_step = VarianceThreshold()
model = LogisticRegression()
cv_folds = 5

# Create a new pipeline with the calibrated classifier
unfitted_pipeline = Pipeline(
steps=[
("preprocessor", preprocessor_step),
("selector", selector_step),
("classifier", model),
]
)

# Test that the function raises a ValueError if the supplied pipeline is not fitted yet
with pytest.raises(
ValueError, match="The classifier in the fitted pipeline is not fitted."
):
_ = mock_self.calibrate_pipeline(
valid_features=valid_features,
valid_class=valid_class,
fitted_pipeline=unfitted_pipeline,
cv_folds=cv_folds,
)

fitted_pipeline = Pipeline(
steps=[
("preprocessor", preprocessor_step),
("selector", selector_step),
("classifier", model),
]
)
fitted_pipeline = fitted_pipeline.fit(valid_features, valid_class)

calib_pipeline = mock_self.calibrate_pipeline(
train_features=train_features,
train_class=train_class,
preprocessor_step=preprocessor_step,
selector_step=selector_step,
model=model,
valid_features=valid_features,
valid_class=valid_class,
fitted_pipeline=fitted_pipeline,
cv_folds=cv_folds,
)

Expand Down

0 comments on commit 56f9826

Please sign in to comment.