# Image Classification on the AMD Ryzen™ AI using Resnet-50

In [None]:
import requests
from PIL import Image

In [None]:
url = "https://datasets-server.huggingface.co/assets/beans/--/default/train/0/image/image.jpg"
image = Image.open(requests.get(url, stream=True).raw)
image

## Export the model to ONNX

In [None]:
from pathlib import Path
from optimum.exporters.onnx import main_export
from optimum.amd.ryzenai import RyzenAIModelForImageClassification

In [None]:
task = "image-classification"
model_id = "eugenecamus/resnet-50-base-beans-demo"

onnx_dir = "demo_resnet_onnx"

In [None]:
main_export(
    model_id,
    onnx_dir,
    task=task
)

static_onnx_path = RyzenAIModelForImageClassification.reshape(
    Path(onnx_dir) / "model.onnx",
    input_shape_dict={"pixel_values": [1, 3, 224, 224]},
    output_shape_dict={"logits": [1, 3]},
)


## Quantize the model

In [None]:
from optimum.amd.ryzenai import RyzenAIOnnxQuantizer

quantizer = RyzenAIOnnxQuantizer.from_pretrained(onnx_dir, file_name=static_onnx_path.name)

In [None]:
from functools import partial
from transformers import AutoFeatureExtractor
from optimum.amd.ryzenai.configuration import QuantizationConfig

# Create the quantization configuration containing all the quantization parameters
qconfig = QuantizationConfig()

feature_extractor = AutoFeatureExtractor.from_pretrained(model_id)

def preprocess_fn(ex, feature_extractor):
    return feature_extractor(ex["image"])

# Create the calibration dataset used for the calibration step
calibration_dataset = quantizer.get_calibration_dataset(
    "beans",
    preprocess_function=partial(preprocess_fn, feature_extractor=feature_extractor),
    num_samples=128,
    dataset_split="test",
)

In [None]:
output_dir = "demo_resnet_onnx_quantized"

quantizer.quantize(
    save_dir=output_dir,
    quantization_config=qconfig,
    dataset=calibration_dataset,
)

## Run inference using the quantized model

In [None]:
vaip_config = "vaip_config.json"
model = RyzenAIModelForImageClassification.from_pretrained(output_dir, vaip_config=vaip_config)

In [None]:
from transformers import pipeline

cls_pipe = pipeline("image-classification", model=model, feature_extractor=feature_extractor)
cls_pipe(image)