# Notebook to verify Mlflow

## Import mlflow library

In [1]:
import numpy as np
from sklearn.tree import DecisionTreeClassifier
import mlflow
import os

import mlflow_util

## Prepare training data

In [2]:
X = np.array([[1, 1], [1, 2], [2, 2], [2, 3]])
y = np.dot(X, np.array([1, 2])) + 3

## Setup Mlflow
You will put the location of Mlflow server. Note that becuase your notebook and the Mlflow server is running on Kubernetes, we just put the location of Kubernetes Service. 

We are using our local Minio server as the S3 storage and therefore pass the variables named AWS_SECRET_ACCESS_KEY containing the password.

### Experiment Name
This is one important variable via which all of your experiment runs will be stored in the Mlflow server

In [3]:
HOST = "http://mlflow:5500"

EXPERIMENT_NAME = "HelloMlFlowCustom"

os.environ['MLFLOW_S3_ENDPOINT_URL']='http://minio-ml-workshop:9000'
os.environ['AWS_ACCESS_KEY_ID']='minio'
# os.environ['AWS_SECRET_ACCESS_KEY']='minio123'
os.environ['AWS_REGION']='us-east-1'
os.environ['AWS_BUCKET_NAME']='mlflow'

# Connect to local MLflow tracking server
mlflow.set_tracking_uri(HOST)

# Set the experiment name through which you will label all your exerpiments runs
mlflow.set_experiment(EXPERIMENT_NAME)

# enable autologging for scikit
mlflow.sklearn.autolog()

## Perform training as usual

In [4]:

model = DecisionTreeClassifier(max_depth=5, criterion='gini',min_samples_leaf = 3 ,min_samples_split = 10)


# Adding custom tags to the run
Mlflow api allows to associate the custom tags as shown below. 

record_libraries is a custom function which runs the pip freeze command and store it as a file to the mlflow run. You can find this function in the associated mlflow_util class in this repo.

In [6]:
with mlflow.start_run(tags={
    "mlflow.source.git.commit" : get_git_revision_hash() ,
    "mlflow.source.git.branch": get_git_branch(),
    "mlflow.source.type": "NOTEBOOK",
    "mlflow.source.git.repoURL": get_git_remote()
    }) as run:
    
    model.fit(X, y)
    record_libraries(mlflow)
    

    


NameError: name 'get_git_revision_hash' is not defined