In [1]:
from flask import Flask
app = Flask(__name__)


@app.route('/')
def hello():
    return 'Hello World!'

In [2]:
from flask import Flask, jsonify
app = Flask(__name__)

# @app.route('/predict', methods=['POST'])
# def predict():
#     return jsonify({'class_id': 'IMAGE_NET_XXX', 'class_name': 'Cat'})

In [3]:
import io

import torchvision.transforms as transforms
from PIL import Image

def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(224),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

In [4]:
with open("./img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    tensor = transform_image(image_bytes=image_bytes)
    print(tensor)

tensor([[[[ 0.4508,  0.4166,  0.3994,  ..., -1.3473, -1.3473, -1.3473],
          [ 0.5364,  0.4851,  0.4508,  ..., -1.2959, -1.3130, -1.3302],
          [ 0.7248,  0.6392,  0.6049,  ..., -1.2959, -1.3302, -1.3644],
          ...,
          [ 1.3584,  1.3755,  1.4098,  ...,  1.1700,  1.3584,  1.6667],
          [ 1.8893,  1.7694,  1.4440,  ...,  1.2899,  1.4783,  1.5468],
          [ 1.6324,  1.8379,  1.8379,  ...,  1.4783,  1.7352,  1.4612]],

         [[ 0.5728,  0.5378,  0.5203,  ..., -1.3529, -1.3529, -1.3529],
          [ 0.6604,  0.6078,  0.5728,  ..., -1.3004, -1.3354, -1.3529],
          [ 0.8529,  0.7654,  0.7304,  ..., -1.3004, -1.3354, -1.3704],
          ...,
          [ 1.4657,  1.4657,  1.4832,  ...,  1.3256,  1.5357,  1.8508],
          [ 2.0084,  1.8683,  1.5182,  ...,  1.4657,  1.6583,  1.7283],
          [ 1.7458,  1.9384,  1.9209,  ...,  1.6583,  1.9209,  1.6408]],

         [[ 0.7228,  0.6879,  0.6531,  ..., -1.6476, -1.6302, -1.6302],
          [ 0.8099,  0.7576,  

In [5]:
from torchvision import models

# Make sure to set `weights` as `'IMAGENET1K_V1'` to use the pretrained weights:
model = models.densenet121(weights='IMAGENET1K_V1')
# Since we are using our model only for inference, switch to `eval` mode:
model.eval()


def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return y_hat

In [6]:
import json

imagenet_class_index = json.load(open('./imagenet_class_index.json'))

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

In [7]:
with open("./img/sample_file.jpeg", 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes=image_bytes))

['n02124075', 'Egyptian_cat']


In [8]:
from flask import request

@app.route('/predict', methods=['POST'])
def predict():
    if request.method == 'POST':
        # we will get the file from the request
        file = request.files['file']
        # convert that to bytes
        img_bytes = file.read()
        class_id, class_name = get_prediction(image_bytes=img_bytes)
        return jsonify({'class_id': class_id, 'class_name': class_name})