-
Notifications
You must be signed in to change notification settings - Fork 0
/
app.py
74 lines (58 loc) · 2.35 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
'''
The Flask app that loads the model and generates predictions
'''
import os
import sys
import json
from flask import Flask, request, jsonify
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
import pickle
# Initializing the flask app
app = Flask(__name__)
# Loading the model
os.chdir(sys.path[0]) # Setting the directory to the same directory as the script
model_file_path = './models/model.pkl'
model = pickle.load(open(model_file_path, 'rb'))
print('Model Loaded')
# Name of the species to return
class_dict = {0:'Setosa', 1:'Versicolour', 2: 'Virginica'}
@app.route('/test/',methods=['GET','POST'])
def test():
responses = jsonify(predictions=json.dumps('The flask app is working'))
responses.status_code = 200
return responses
@app.route('/predict', methods=['POST'])
def predict():
'''
API call to make predictions on received data
Receives a pandas dataframe that was sent as a payload and generates batch predictions
'''
try:
json_data = request.get_json()
data = pd.read_json(json_data, orient='split')
# TODO: Put any preprocessing steps here
except Exception as e:
raise e
if data.empty:
return bad_request()
else:
# Creating predictions
print('Creating predictions for {0} records'.format(data.shape[0]))
class_probabilities = model.predict_proba(data)
predictions = np.argmax(class_probabilities, axis=1)
predictions = [class_dict[prediction] for prediction in predictions]
predictions = pd.DataFrame({'prediction': predictions,
'probability_setosa': class_probabilities[:, 0],
'probability_versicolour': class_probabilities[:, 1],
'probability_virginica': class_probabilities[:, 2]})
predictions['index'] = range(predictions.shape[0])
# Returning the responses
responses = jsonify(predictions=predictions.to_json(orient='records'))
responses.status_code = 200
return responses
if __name__ == "__main__":
port = int(os.environ.get('PORT', 5000))
app.run(host='localhost', port=port, debug=True)