# EXAMPLE: Updating model to pickle model store

> Example of creating, training and testing a machine learning model and storing the trained model to pickle model store.

In [18]:
import os
import sys

import logging
import pickle
from typing import List
import pandas as pd
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier
import datetime as dt

Local imports:

In [19]:
current = os.path.abspath('')
parent_directory = os.path.dirname(current)
sys.path.append(parent_directory)
from model_store import PickleModelStore, ModelSchemaContainer

Data:

In [20]:
# Load sample dataset
df = pd.read_csv('iris_dataset.csv')
y = df.pop('variety')
X = df

# Split data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2)

Model:

In [21]:
# Create linear regression object
classifier = DecisionTreeClassifier(criterion="entropy")

# Train the model using the training sets
classifier.fit(X_train, y_train)
y_pred = classifier.predict(X_test)

Train, test:

In [22]:
# Metrics
# Train/test/val metrics to be passed on to api & monitoring
metrics_raw = classification_report(y_test, y_pred, output_dict=True)

# reformat metrics to api/metrics format
metrics_parsed = {}
for feature in list(metrics_raw.keys()):
    if isinstance(metrics_raw[feature], dict):
        for metricname in list(metrics_raw[feature].keys()):
            metrics_parsed[feature + "_" + metricname] = {
                "value": metrics_raw[feature][metricname],
                "description": "",
                "type": "numeric",
            }

# we can also pass metadata
metrics_parsed["model_update_time"] = {
    "value": dt.datetime.now(),
    "description": "",
    "type": "numeric",
}

Schema:

In [23]:
# Use dtypes to determine api request and response models
dtypes_x = [{"name": c, "type": X[c].dtype.type} for c in X.columns]
dtypes_y = [{"name": y.name, "type": y.dtype.type}]


Save to model store:

In [25]:
MODEL_PATH = parent_directory + "/local_data/bundle_latest.pickle"
# Pickle all in single file
model_store = PickleModelStore()
model_store.persist(classifier, MODEL_PATH, dtypes_x, dtypes_y, metrics_parsed)

Now if you should be able to run the api with `MODEL_STORE=pickle` environment variable set and load the just-pickled model.