In [6]:
import numpy as np
from flask_restful import reqparse
from torch.utils.data import DataLoader
from torch import nn
import torch
import torch.nn.functional as F
import cv2
from matplotlib import pyplot as plt

from flask import Flask

app = Flask(__name__)


In [7]:

class NeuralNetwork_OL_v2(nn.Module):
  
    def __init__(self):
        super(NeuralNetwork_OL_v2, self).__init__()
        
        self.conv0 = nn.Conv2d(1, 16, 3, padding=(2,2)) # 3x3 filters w/ same padding
        self.pool0 = nn.MaxPool2d(2, stride=2)
        self.conv1 = nn.Conv2d(16, 16, 3, padding=(3,3)) # 3x3 filters w/ same padding
        self.pool1 = nn.MaxPool2d(2, stride=2)
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(16*25*25, 256), 
             nn.ReLU(),
        )
        self.linear = nn.Linear(256, 10)
        self.linear_x = nn.Linear(256, 1)
        self.linear_y = nn.Linear(256, 1)
        self.linear_all = nn.Linear(256, 2)
        
    def forward(self, x):
        x = self.conv0(x)
        x = F.relu(self.pool0(x))
        x = self.conv1(x)
        x = F.relu(self.pool1(x))
        x = self.flatten(x)
        x = self.linear_relu_stack(x)
        logits = self.linear(x)
        centr = self.linear_all(x)
        return logits, centr

        
model = NeuralNetwork_OL_v2()

model.load_state_dict(torch.load('/home/george/new_model2'))
model.eval()



NeuralNetwork_OL_v2(
  (conv0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(2, 2))
  (pool0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv1): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(3, 3))
  (pool1): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (linear_relu_stack): Sequential(
    (0): Linear(in_features=10000, out_features=256, bias=True)
    (1): ReLU()
  )
  (linear): Linear(in_features=256, out_features=10, bias=True)
  (linear_x): Linear(in_features=256, out_features=1, bias=True)
  (linear_y): Linear(in_features=256, out_features=1, bias=True)
  (linear_all): Linear(in_features=256, out_features=2, bias=True)
)

In [8]:


@app.route('/predict', methods=['POST'])
def predict():
    img = reqparse.request.files['file'].read()
    npimg = np.fromstring(img, np.uint8)
    img = cv2.imdecode(npimg, cv2.IMREAD_UNCHANGED)
    plt.imshow(img)
    plt.show()
    img = np.reshape(img, (1, 1, img.shape[0], img.shape[1]))

    img_dataloader = DataLoader(img, batch_size=64, shuffle=False)

    X = next(iter(img_dataloader))

    digit_pred, center_pred = model(X.float())
    center_pred = center_pred.tolist()
    predicted_digit = np.argmax(digit_pred[0].cpu().detach().numpy())
    
    return {'prediction': int(predicted_digit), 'y_center': center_pred[0][0], 'x_center': center_pred[0][1]}


if __name__ == '__main__':
    app.run(port=8090)


 * Serving Flask app '__main__' (lazy loading)
 * Environment: production
[2m   Use a production WSGI server instead.[0m
 * Debug mode: off


 * Running on http://127.0.0.1:8090/ (Press CTRL+C to quit)
