In [1]:
from flask import Flask, request
from flask_cors import CORS
import PIL
import torch
from torchvision import datasets, transforms, models
from torch import nn
from collections import OrderedDict

In [2]:
app = Flask(__name__)
cors = CORS(app, resources={r"/*": {"origins": "*"}})

In [6]:
# Load saved model
def load_ckpt(ckpt_path):
    ckpt = torch.load(ckpt_path)
    model = models.resnet18(pretrained=True)
    model.fc = nn.Sequential(OrderedDict([
        ('fc1', nn.Linear(512, 400)),
        ('relu', nn.ReLU()),
        ('fc2', nn.Linear(400, 2)),
        ('output', nn.LogSoftmax(dim=1))
    ]))
    model.load_state_dict(ckpt, strict=False)
    return model

In [7]:
SAVE_PATH = 'res18_10.pth'

In [8]:
# load model
model = load_ckpt(SAVE_PATH)



RuntimeError: Error(s) in loading state_dict for ResNet:
	size mismatch for fc.fc2.weight: copying a param with shape torch.Size([3, 400]) from checkpoint, the shape in current model is torch.Size([2, 400]).
	size mismatch for fc.fc2.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([2]).

In [9]:
test_transforms = transforms.Compose([transforms.Resize(255),
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize([0.485, 0.456, 0.406],
                                                           [0.229, 0.224, 0.225])])

In [10]:
def process_image(image):
    ''' Scales, crops, and normalizes a PIL image for a PyTorch model,
        returns an torch Tensor
    '''
    im = PIL.Image.open(image)
    return test_transforms(im)

In [11]:
def predict(image_path, model):
    # Predict the class of an image using a trained deep learning model.
    model.eval()
    img_pros = process_image(image_path)
    img_pros = img_pros.view(1, 3, 224, 224)
    with torch.no_grad():
        output = model(img_pros)
    return output

In [15]:
@app.route('/', methods=['GET'])
def home():
    return 'API is not running'

AssertionError: View function mapping is overwriting an existing endpoint function: home

In [16]:
@app.route('/pred', methods=['POST'])
def pred():
    Img = request.files['Sq-Tri-Cir']
    ps = torch.exp(predict(Img, model))
    cls_score = int(torch.argmax(ps))
    if cls_score == 0:
        return 'Circle'
    elif cls_score == 1:
        return 'Square'
    else:
        return 'Triangle'

In [17]:
if __name__ == '__main__':
    app.run(port=8002, debug=True)

 * Serving Flask app "__main__" (lazy loading)
 * Environment: production
   Use a production WSGI server instead.
 * Debug mode: on


 * Restarting with windowsapi reloader


SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
