In [None]:
import sys
import os
sys.path.insert(0, './../../Research')

from PIL import Image
import numpy as np
from networks import FashionCNN
from utils import get_classes
from tqdm.notebook import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from time import time

import torchvision.transforms as transforms

## Inference Dataset

In [None]:
class Inference(Dataset):
    """Pytorch Dataset for Inference dataset"""
    
    def __init__(self, holdout_folder):
        """Initialising the Dataset
        
        Args:
            holdout_folder: Folder with images for inferencing
        """
        
        self.images = [os.path.join(holdout_folder, img_file) for img_file in os.listdir(holdout_folder)]
        self.transforms = transforms.ToTensor()
        self.classes = get_classes()
    
    def __len__(self):
        """Number of images"""
        
        return len(self.images)
    
    def __getitem__(self, idx):
        """Get the item corresponding to the index
        
        Args:
            idx: Index of batch
        """
        
        img_name = self.images[idx]
        img_PIL = Image.open(img_name)
        img_label = torch.tensor(int(os.path.basename(img_name).split('_')[0]))
        img = transforms.ToTensor()(img_PIL)
        
        return img, img_label

In [None]:
infer_dataset = Inference(holdout_folder = "./../data/fashionmnist/images/")
infer_loader = DataLoader(infer_dataset, batch_size=1, num_workers=4, shuffle=False)

# class labels
classes = get_classes()

## Model loading

In [None]:
# Pytorch Model
model = FashionCNN()
model.load_state_dict(torch.load('./../../Models/fashionNet.pth'));
model.eval();

## Inference

In [None]:
correct_predictions = 0
total_predictions = 0

In [None]:
time_per_image = []
time_start = time()

# Inferencing Images one by one
with torch.no_grad():
    for data, label in tqdm(infer_loader):
        time_img_st = time()
        outs = model(data)
        preds = torch.argmax(outs, dim=1)
        for pred, label in zip(preds, label):
            pred_label = classes[pred]
            img_label = classes[label]
            total_predictions += 1
            if pred_label == img_label:
                correct_predictions += 1
        time_img_en = time()
        time_per_image.append(time_img_en - time_img_st)
time_end = time()

## Printing stats

In [None]:
print("Accuracy         = ", correct_predictions*100/total_predictions)
print("Total time (sec) = ", time_end - time_start)
print("Latency          = ", np.mean(time_per_image))

## Stats from my Run

In [None]:
# FashionCNN

# Accuracy         =  92.33844103930713
# Total time (sec) =  10.233047008514404
# Latency          =  0.001875086834556178

In [None]:
# torNet

# Accuracy         =  91.87208527648235
# Total time (sec) =  117.98538613319397
# Latency          =  0.03715702940987238

In [None]:
# efficientNet

# Accuracy         =  89.44037308461026
# Total time (sec) =  166.0295968055725
# Latency          =  0.05312728746822085