# Image Classification with PyTorch
Pytorch has been both researcher's and engineer's preferred choice of framework for DL development but when it comes to productionizing pytorch models, there still hasn't been a consensus on what to use. This guide run you through building a simple image classification model using Pytorch and then deploying that to RedisAI

In [12]:
import torchvision.models as models
import torch

In [13]:
model = models.resnet50(pretrained=True)
model.eval()

scripted_model = torch.jit.script(model)
torch.jit.save(scripted_model, 'resnet50.pt')

In [14]:
import json
import time
from redisai import Client
import ml2rt
from skimage import io

In [15]:
import os
from redisai import Client

REDIS_HOST = os.getenv("REDIS_HOST", "localhost")
REDIS_PORT = int(os.getenv("REDIS_PORT", 6379))

In [16]:
con = Client(host=REDIS_HOST, port=REDIS_PORT)

In [17]:
con.ping()

True

In [18]:
model = ml2rt.load_model("resnet50.pt")
con.modelstore("pytorch_model", backend="TORCH", device="CPU", data=model)

'OK'

In [28]:
script = """
def pre_process(tensors: List[Tensor], keys: List[str], args: List[str]):
    image = tensors[0]
    mean = torch.zeros(3).float().to(image.device)
    std = torch.zeros(3).float().to(image.device)
    mean[0], mean[1], mean[2] = 0.485, 0.456, 0.406
    std[0], std[1], std[2] = 0.229, 0.224, 0.225
    mean = mean.unsqueeze(1).unsqueeze(1)
    std = std.unsqueeze(1).unsqueeze(1)
    temp = image.float().div(255).permute(2, 0, 1)
    return temp.sub(mean).div(std).unsqueeze(0)


def post_process(tensors: List[Tensor], keys: List[str], args: List[str]):
    output = tensors[0]
    return output.max(1)[1]
"""
con.scriptstore("processing_script", device="CPU", script=script, entry_points=("pre_process", "post_process"))

'OK'

In [29]:
image = io.imread("../data/cat.jpg")

In [30]:
con.tensorset('image', image)

'OK'

In [31]:
con.scriptexecute('processing_script', 'pre_process', 'image', 'processed')

ResponseError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last):   File "<string>", line 10, in pre_process     mean = mean.unsqueeze(1).unsqueeze(1)     std = std.unsqueeze(1).unsqueeze(1)     temp = image.float().div(255).permute(2, 0, 1)            ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE     return temp.sub(mean).div(std).unsqueeze(0) RuntimeError: number of dims don't match in permute 