# Class for predicting label given model and image

In [145]:
from torchvision import transforms
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torch


class Predict:
    image_size = (100, 100)
    def __init__(self, model, normalize=True):
        '''
        model:torchvision.models - pytorch model with loaded weights
        normalize:model - if True, normalize with mean=[0.485, 0.456, 0.406] and std=[0.229, 0.224, 0.225] for respective RGB channels
        '''
        self.model = model
        if normalize:
            normalizer = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
            self.transformer = transforms.Compose([transforms.Resize(self.image_size), transforms.ToTensor(), normalizer])
        else:
            self.transformer = transforms.Compose([transforms.Resize(self.image_size), transforms.ToTensor()])            
        self.normalize = normalize

    def predict(self, source, source_type="infer", file_type="infer", batch_size=32, outFile="predictions.csv"):
        if source_type == "infer":
            if "http" in source:
                source_type = "url"
                file_path = self.cache_file(source)
            else:
                source_type = "local"
                file_path = source
        elif source_type == "local":
            file_path = source
        else:
            raise NotImplementedError
        dataloader = self.build_loader(file_path)
        predictions = torch.empty(0, dtype=int)
        self.model.eval()
        for inputs in dataloader:
            with torch.set_grad_enabled(False):
                outputs = self.model.forward(inputs)
                _, preds = torch.max(outputs, 1)
                predictions = torch.cat((predictions, preds))
        df = pd.DataFrame(predictions, columns=["class_idx"])
        labels_mapping = pd.read_csv("labels_mapping.csv", header=None, index_col=0)
        df["class_name"] = df.class_idx.apply(lambda x: labels_mapping.loc[x])
        if len(df) <= 10:
            if input("Write predictions to file?(y/n)").lower() == "y":
                df.to_csv(outFile, header=True, index=False)
            else:
                print(df)
        else:
            df.to_csv(outFile, header=True, index=False)
            print(f"Written to '{outFile}'")

    def build_loader(self, path, batch_size=32):
        if '.csv' in path.lower():
            file_type = "csv"
        elif ".jpg" in path.lower() or ".jpeg" in path.lower() or ".png" in path.lower():
            #return [transform_single(Image.open(path))]
            # model expects 4 dimensions, hence unsqueeze to add dummy dimension
            return [self.transformer(Image.open(path).convert('RGB')).unsqueeze(0)]
        else:    # Assuming source to be a directory with all the images
            from glob import glob
            files = glob(f"{path}/*")
            return datagen(files, batch_size)

    def datagen(self, files, batch_size):
        for i in range(0, len(files), batch_size):
            yield torch.stack(tuple(map(lambda f: self.transformer(Image.open(f)), files)))    # Index exceeded is surprisingly handled automatically

    @staticmethod
    def transform_single(img, input_size=100):
        # Repeat mean and std to match image size
        # Check if means matches their respective channels
        m = np.repeat(np.repeat(np.array([0.485, 0.456, 0.406]).reshape(-1, 1), input_size, axis=1).reshape(3, input_size, 1), input_size, axis=2)
        s = np.repeat(np.repeat(np.array([0.229, 0.224, 0.225]).reshape(-1, 1), input_size, axis=1).reshape(3, input_size, 1), input_size, axis=2)
        # np.transpose since model expects (channel, height, width)
        return torch.from_numpy(((np.transpose(img.resize((input_size, input_size)), (2, 0, 1)) - m) / s).reshape(1, 3, input_size, input_size)).type(torch.float32)

    @staticmethod
    def cache_file(url, cachedir="cache"):
        import requests
        res = requests.get(url)
        res.raise_for_status()
        # Save to local
        local_file = f"{cachedir}/{url.rsplit('/', 1)[-1]}"
        with open(local_file, 'wb') as fw:
            fw.write(res.content)
        return local_file

# Load model

In [3]:
num_classes = 131

model_path = "weights/mobilenet"
weights = torch.load(model_path)
model = models.mobilenet_v2(pretrained=False)
num_ftrs = model.classifier[1].in_features
model.classifier[1] = torch.nn.Linear(num_ftrs, num_classes)
model.load_state_dict(weights)

<All keys matched successfully>

# Examples of using Predict

## Single image from local

In [130]:
Predict(model).predict('images/Test/Pineapple/3_100.jpg')

torch.Size([1, 3, 100, 100])
Write predictions to file?(y/n)n
   class_idx class_name
0         99  Pineapple


## Single image from URL

In [146]:
url = "https://5.imimg.com/data5/PW/ND/MY-46595757/fresh-pineapple-281kg-29-500x500.png"
Predict(model).predict(url)

Write predictions to file?(y/n)n
   class_idx        class_name
0        117  Strawberry Wedge


In [119]:
# Live webcam capture
import cv2
cap = cv2.VideoCapture(0)
cap.set(3, 640)    # Set Width
cap.set(4, 480)    # Set Height
labels_mapping = pd.read_csv("labels_mapping.csv", header=None, index_col=0)
try:
    while True:
        ret, frame = cap.read()
        #frame = torch.tensor(np.tile(frame, (1, 1, 1, 1)).transpose(0, 3, 1, 2))
        frame = transform(frame)
        with torch.set_grad_enabled(False):
            output = model.forward(frame)
            _, pred = torch.max(output, 1)

        print(labels_mapping.loc[int(pred), 1])
        k = cv2.waitKey(10) & 0xff # Press 'ESC' for exiting video
        if k == 27:
            break
except:
    cap.release()
    cv2.destroyAllWindows()

Pear 2
Pear 2
Pear 2
Pear 2
Pear 2
Pear 2
Pear 2
Pear 2
Pear 2
Peach Flat
Pear 2
Peach Flat
Pear 2
Pear 2
Pear 2
Pear 2
Pear 2
Pear 2
Pear 2
Pear 2
Pear 2
Pear 2
Pear 2
Pear 2
Pear 2
Pear 2
Pear 2
Pear 2
Pear 2
