# Torch Script

In [None]:
import torch
import torchvision

model = torchvision.models.vgg16(pretrained=True)
example_input = torch.rand(1, 3, 224, 224)
torch_script_model = torch.jit.trace(model, example_input)
torch_script_model.save("vgg16.pt")

In [None]:
import torch.nn as nn


class ControlFlowModel(nn.Module):
    def __init__(self, N):
        super().__init__()
        self.fc = nn.Linear(N, 100)

    def forward(self, x):
        if x.sum() > 0:
            return x
        else:
            return -x


model = ControlFlowModel(10)
torch_script_model = torch.jit.script(model)
torch_script_model.save("control.pt")

# Torch Serve

In [None]:
!pip install torchserve torch-model-archiver

In [None]:
!torch-model-archiver --model-name vgg16 --version 1.0 --serialized-file model.pt --handler

In [None]:
!torchserve --model-store /models --start --models all

In [None]:
!curl http://localhost:8080/predictions/vgg16 -T hot_dog.jpg

# ONNX

In [None]:
model = torchvision.models.vgg16(pretrained=True)
example_input = torch.randn(1, 3, 224, 224)
onnx_model = torch.onnx.export(model, example_input, "vgg16.onnx")

In [None]:
import onnx

model = onnx.load("vgg16.onnx")
onnx.checker.check_model(model)
onnx.helper.printable_graph(model.graph)

# Flask

In [None]:
import io
import json

from torchvision import models
import torchvision.transforms as transforms
from PIL import Image
from flask import jsonify, Flask, request

In [None]:
import json

imagenet_class_index = json.load(open("./imagenet_class_index.json"))

model = models.vgg16(pretrained=True)

image_transforms = transforms.Compose(
    [
        transforms.Resize(255),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
    ]
)


def get_prediction(image_bytes):
    image = Image.open(io.BytesIO(image_bytes))
    tensor = image_transforms(image)
    outputs = model(tensor)
    _, y = outputs.max(1)
    predicted_idx = str(y.item())
    return imagenet_class_index[predicted_idx]

In [None]:
app = Flask(__name__)


@app.route("predict", methods=["POST"])
def predict():
    if request.method == "POST":
        file = request.files["file"]
    img_bytes = file.read()
    class_id, class_name = get_prediction(img_bytes)
    return jsonify({"class_id": class_id, "class_name": class_name})


if __name__ == "__main__":
    app.run()