In [2]:
import io

import torchvision.transforms as transforms
from PIL import Image

In [3]:
def transform_image(image_bytes):
    my_transforms = transforms.Compose([transforms.Resize(255),
                                        transforms.CenterCrop(244),
                                        transforms.ToTensor(),
                                        transforms.Normalize(
                                            [0.485, 0.456, 0.406],
                                            [0.229, 0.224, 0.225])])
    image = Image.open(io.BytesIO(image_bytes))
    return my_transforms(image).unsqueeze(0)

In [5]:
with open('../static/img/cat.jpg', 'rb') as f:
    image_bytes = f.read()
    tensor = transform_image(image_bytes=image_bytes)
    print(tensor.shape)

torch.Size([1, 3, 244, 244])


In [6]:
from torchvision import models

model = models.densenet121(pretrained=True)
model.eval();

In [7]:
def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    return y_hat

In [8]:
with open('../static/img/cat.jpg', 'rb') as f:
    image_bytes = f.read()
    y_hat = get_prediction(image_bytes)
    print(y_hat)

tensor([281])


In [9]:
import json

imagenet_class_index = json.load(open('../static/imagenet_class_index.json'))

def get_prediction(image_bytes):
    tensor = transform_image(image_bytes=image_bytes)
    outputs = model.forward(tensor)
    _, y_hat = outputs.max(1)
    predicted_idx = str(y_hat.item())
    return imagenet_class_index[predicted_idx]

In [10]:
with open('../static/img/cat.jpg', 'rb') as f:
    image_bytes = f.read()
    print(get_prediction(image_bytes=image_bytes))

['n02123045', 'tabby']
