In [None]:
def train():
    import os.path as path
    from pickle import load, dump
    import warnings

    import pandas as pd
    import numpy as np

    from sklearn import metrics
    from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier
    from xgboost import XGBClassifier

    from urllib.parse import urlparse
    import mlflow
    import mlflow.sklearn

    import logging

    logging.basicConfig(level=logging.WARN)
    logger = logging.getLogger(__name__)


    def eval_score(X_test, y_test, 'xgBoost.pkl'): # fix the 'xgBoost.pkl'
        preds = make_prediction(X_test, )
        clf = load_model()
        return metrics.accuracy_score(y_test, preds)

    def make_prediction(X_test, model_name):
        clf = load_model(model_name)
        return clf.predict(X_test)

    def load_model(model_name):
        return load(open(path.join('6_models', "{}").format(model_name), 'rb'))


    if __name__ == "__main__":

        warnings.filterwarnings("ignore")
        np.random.seed(40)

        print("\nInitializing the program\n...")

        X_test = pd.read_csv(path.join("3_X_fitted_dataframe", "X_test_scaled.csv"))
        X_train = pd.read_csv(path.join("3_X_fitted_dataframe", "X_train_scaled.csv"))
        y_train = pd.read_csv(path.join('4_y_dataframe', 'y_train.csv'))   
        y_test = pd.read_csv(path.join('4_y_dataframe', 'y_test.csv'))

        try:
            print("\nData loaded\n...")
        except Exception as e:
            logger.exception(
                "Unable to download training & test CSV, check your internet connection. Error: %s",
            )

        # start the mlflow run, to fit, and score
        with mlflow.start_run():
            clf = load_model()
            clf.fit(X_train, y_train)

            score = eval_score(X_test, y_test, 'xgBoost.pkl')
            mlflow.log_metric("accuracy", score)
            print("\nLogged accuracy\n...")

            tracking_url_type_store = urlparse(mlflow.get_tracking_uri()).scheme
            print("\nTracking to MLflow UI\n...")

            mlflow.sklearn.log_model(clf, "model")
            print("\nModel registered <3\n...")


In [None]:
train()