# Chapter 2 - Inference Model on Scale Images
Here we will show how to load a pretrianed model on inference on a folder of images.

## Dataset and Dataloader
First we define a dataset that is used to find and load the images.  This dataloader only returns the images and the image names and not any labels or other information since this is used for inference only.

In [1]:
import os
from os import listdir
from os.path import isfile, join
import numpy as np
from PIL import Image
from torchvision import transforms
from torchvision.io import read_image
from torch.utils.data.dataset import Dataset  # For custom datasets
from torchvision import datasets
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader

class FishTestDataset(Dataset):
    def __init__(self, image_dir, transform=None):

        # Get the directory dataset images
        self.image_dir = image_dir

        # Get the transform methods
        self.transforms = transform


        # Image Name
        self.image_name = [f for f in listdir(image_dir) if isfile(join(image_dir, f))]


    def __len__(self):
        return len(self.image_name)

    def __getitem__(self, index):
        img_path = os.path.join(self.image_dir, str(self.image_name[index]))
        image = Image.open(img_path)

        if self.transforms:
            image = self.transforms(image)

        return image, self.image_name[index]
        
data_dir = 'cropped'
data_transforms = transforms.Compose(
        [
            transforms.Resize(224),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ]
    )
test_dataset = FishTestDataset( data_dir, data_transforms)
test_loader = DataLoader(test_dataset, batch_size=24, shuffle=False, drop_last=False)


## Create and load model
Here we use a simple resnet18 classification architecture from pytorch for our model.  We have to change provide the desired number of classes since the default is 1000 classes (for imagenet).

In [4]:
from torchvision.models import resnet18, ResNet18_Weights
from tqdm import tqdm
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = resnet18(num_classes = 5)

# Load model - TODO
model.load_state_dict(torch.load("best_model.pth"))

model.eval()    
model.to(device)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

## Load images and Inference

Here we use our dataloader to load the images by batches (of size 24) and input them into the model for inference.  Also, we output the results to a csv file.

In [5]:
import torch

output_path = "inference_results.csv"
file = open(output_path, 'w')
file.write("Image Name, Predicted Age\n")

for images, img_path in tqdm(test_loader):
    images = images.to(device)
    outputs = model(images)
    outputs = torch.squeeze(outputs)
    _, preds = torch.max(outputs, 1)
    preds = preds.cpu().detach().numpy()
    for i in range(preds.shape[0]):
        age = str(preds[i])
        if(preds[i] ==4):
            age = "4+"
        file.write("%s,%s\n"%(img_path[i],age))
file.close()

100%|████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:00<00:00,  8.47it/s]


The resulting csv file should look something like this.

![image.png](attachment:6a5d61e3-2d68-4cd4-a960-d300f9c886be.png)