In [1]:

import torchvision.transforms as transforms
import torch
import engine
from model import CaptchaModel
import config
from train import decode_predictions
# Flask
from flask import Flask, redirect, url_for, request, render_template, Response, jsonify, redirect
from werkzeug.utils import secure_filename
from gevent.pywsgi import WSGIServer

import numpy as np
from util import base64_to_pil

In [2]:

app = Flask(__name__)

In [3]:
model_path = 'Models/Model'

In [4]:
def preproc_image(image):
    """
    :param image_path: path to the test image
    :return: {'images': image tensor}
    """
    transformer = transforms.Compose([
        transforms.Resize((config.image_height, config.image_width)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ])

    image = torch.as_tensor(transformer(image), dtype=torch.float)
    image = torch.unsqueeze(image, dim=0)
    image = image.to(config.DEVICE)
    return {'images': image}

In [5]:
def remove_duplicates(x):
    if len(x) < 2:
        return x
    fin = ""
    for j in x:
        if fin == "":
            fin = j
        else:
            if j == fin[-1]:
                continue
            else:
                fin = fin + j
    return fin

In [6]:
def get_predictions(image, model_path):
    le =  np.load('Data/lbl_enc.npy',allow_pickle=True)
    classs = np.load('Data/num_class.npy',allow_pickle=True)
    model = CaptchaModel(num_chars=len(classs))
    model.to(config.DEVICE)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    data = preproc_image(image)

    with torch.no_grad():
        preds, _ = model(**data)

    # Now decode the preds
    preds = decode_predictions(preds, le)
    preds = remove_blanks(preds)
    return preds

In [7]:
# home page
@app.route("/")
def home():
    return render_template("base.html")


@app.route("/predict", methods=['GET', 'POST'])
def predict():
    if request.method == 'POST':
        # Get the image from post request
        img = base64_to_pil(request.json)
        prediction = get_predictions(img, model_path)      

        # Serialize the result, you can add additional fields
        return jsonify(result=prediction)


In [None]:
if __name__ == "__main__":
    app.run(port=5002,debug=False)
 
    # Serve the app with gevent
    http_server = WSGIServer(('0.0.0.0', 5000), app)
    http_server.serve_forever()

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
   Use a production WSGI server instead.
 * Debug mode: off


 * Running on http://127.0.0.1:5002/ (Press CTRL+C to quit)
127.0.0.1 - - [22/Mar/2021 10:18:01] "[37mGET / HTTP/1.1[0m" 200 -
[2021-03-22 10:18:11,208] ERROR in app: Exception on /predict [POST]
Traceback (most recent call last):
  File "C:\Users\ramji\anaconda3\lib\site-packages\flask\app.py", line 2447, in wsgi_app
    response = self.full_dispatch_request()
  File "C:\Users\ramji\anaconda3\lib\site-packages\flask\app.py", line 1952, in full_dispatch_request
    rv = self.handle_user_exception(e)
  File "C:\Users\ramji\anaconda3\lib\site-packages\flask\app.py", line 1821, in handle_user_exception
    reraise(exc_type, exc_value, tb)
  File "C:\Users\ramji\anaconda3\lib\site-packages\flask\_compat.py", line 39, in reraise
    raise value
  File "C:\Users\ramji\anaconda3\lib\site-packages\flask\app.py", line 1950, in full_dispatch_request
    rv = self.dispatch_request()
  File "C:\Users\ramji\anaconda3\lib\site-packages\flask\app.py", line 1936, in dispatch_request
    return self.

In [8]:
import numpy
print(numpy.__version__)

1.19.2
