# Model Inference with Alluxio

This notebook shows an example of model inference with Alluxio by classifying some images.

Before running this notebook, we need to run the `AI-training-demo.ipynb` file first to train a model and save it to Alluxio.

In [None]:
import torch
import torchvision.transforms as transforms
import torchvision.models as models
import matplotlib.pyplot as plt

from PIL import Image

Confirming that the model has been written to the Alluxio FUSE folder.

In [None]:
!ls /mnt/alluxio/fuse/models/demo

## Model Loading

Here, we load the Pytorch model from Allxuio.

As we use FUSE to mount the Alluxio model into the node, users can conveniently load models from Alluxio as the models are on the local disk.

If it outputs "All keys matched successfully", it means the model has been loaded successfully.

In [None]:
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

model = models.resnet50(pretrained=False)
model_path = "/mnt/alluxio/fuse/models/demo/ai-demo.pth"
model.load_state_dict(torch.load(model_path))

## Model Inference

We prepare some images and classify them via the model trained and loaded.

In [None]:
image_paths = ['/mnt/alluxio/fuse/imagenet-mini/val/n01818515/ILSVRC2012_val_00007081.JPEG', 
               '/mnt/alluxio/fuse/imagenet-mini/val/n02088238/ILSVRC2012_val_00024881.JPEG',
               '/mnt/alluxio/fuse/imagenet-mini/val/n02123045/ILSVRC2012_val_00016389.JPEG',
               '/mnt/alluxio/fuse/imagenet-mini/val/n01855032/ILSVRC2012_val_00011488.JPEG']

images = []
for image_path in image_paths:
    image = Image.open(image_path).convert('RGB')
    image = transform(image)
    images.append(image)

with torch.no_grad():
    inputs = torch.stack(images)
    outputs = model(inputs)

_, predicted_labels = torch.max(outputs, 1)

## Plotting Results

Here, we load the human-readable lables and plot the sample images with predicted labels.

In [None]:
with open('imagenet_classes.txt') as f:
    class_labels = [line.strip() for line in f.readlines()]

In [None]:
fig, axs = plt.subplots(1, len(image_paths), figsize=(12, 4))
for i, image_path in enumerate(image_paths):
    image = Image.open(image_path)
    label = class_labels[predicted_labels[i]]
    axs[i].imshow(image)
    axs[i].set_title(label)
    axs[i].axis('off')
plt.show()