In [1]:
"""
Author       : Aditya Jain
Date Started : May 21, 2022
About        : This file does DL-based localization and classification on raw images and saves annotation information
"""

import torch
import torchvision.models as torchmodels
import torchvision
import os
import numpy as np
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
from torchvision import transforms
from PIL import Image
import matplotlib.pyplot as plt
import cv2
import json
import timm

#### User-defined variables

In [2]:
data_dir     = '/home/mila/a/aditya.jain/scratch/TrapData_QuebecVermont_2022/Vermont/'
image_folder = '2022_05_13'

In [3]:
data_path  = data_dir + image_folder + '/'
save_path  = data_dir
annot_file = 'localize_classify_annotation-' + image_folder + '.json'

In [4]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


#### Loading Localization Model

In [6]:
# load a model pre-trained pre-trained on COCO
model_localize = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
num_classes    = 2  # 1 class (person) + background
in_features    = model_localize.roi_heads.box_predictor.cls_score.in_features
model_localize.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)


model_path  = '/home/mila/a/aditya.jain/logs/v1_localizmodel_2021-08-17-12-06.pt'
checkpoint  = torch.load(model_path, map_location=device)
model_localize.load_state_dict(checkpoint['model_state_dict'])

<All keys matched successfully>

#### Model Class Definition

In [7]:
class ModelInference:
    def __init__(self, model_path, category_map_json, device, input_size=300):
        self.device = device
        self.input_size = input_size
        self.id2categ = self._load_category_map(category_map_json)
        self.transforms = self._get_transforms()
        self.model = self._load_model(model_path, num_classes=len(self.id2categ))
        self.model.eval()

    def _load_category_map(self, category_map_json):
        with open(category_map_json, 'r') as f:
            categories_map = json.load(f)

        id2categ = {categories_map[categ]: categ for categ in categories_map}

        return id2categ

    def _get_transforms(self):
        mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5]

        return transforms.Compose([
          transforms.Resize((self.input_size, self.input_size)),
          transforms.ToTensor(),
          transforms.Normalize(mean, std),
          ])

    def _load_model(self, model_path, num_classes):
        model = timm.create_model('tf_efficientnetv2_b3',
                              pretrained=False,
                              num_classes=num_classes)
        model = model.to(self.device)
        model.load_state_dict(torch.load(model_path,
                                     map_location=torch.device(self.device)))

        return model

    def predict(self, image, confidence=False):
        with torch.no_grad():
            image = self.transforms(image)
            image = image.to(self.device) 
            image = image.unsqueeze_(0)

            predictions = self.model(image)
            predictions = torch.nn.functional.softmax(predictions, dim=1)
            predictions = predictions.cpu().numpy()

            categ = predictions.argmax(axis=1)[0]
            categ = self.id2categ[categ]

            if confidence:
                return categ, predictions.max(axis=1)[0]
            else:
                return categ

#### Loading Binary Classification Model (moth / non-moth)

In [10]:
category_map_json = '/home/mila/a/aditya.jain/logs/05-moth-nonmoth_category_map.json'
model_path        = '/home/mila/a/aditya.jain/logs/moth-nonmoth-effv2b3_20220506_061527_30.pth'
   
model_binary      = ModelInference(model_path, category_map_json, device) 

image = Image.open('ant.jpeg')
print(np.shape(image))

categ, conf = model_binary.predict(image, confidence=True)
print(f'Prediction: {categ}, Confidence: {conf}')
    

(246, 250, 3)
Prediction: nonmoth, Confidence: 0.9491909742355347


#### Loading Moth Classification Model 

In [8]:
category_map_json = '/home/mila/a/aditya.jain/logs/03-mothsv2_category_map.json'
model_path        = '/home/mila/a/aditya.jain/logs/mothsv2_20220421_110638_30.pth'
   
model_moth        = ModelInference(model_path, category_map_json, device) 

image = Image.open('Orthopygia_glaucinalis.jpg')

categ, conf = model_moth.predict(image, confidence=True)
print(f'Prediction: {categ}, Confidence: {conf}')

Prediction: Chlorochlamys chloroleucaria, Confidence: 0.7936117649078369


#### Prediction on data

In [None]:
model_localize  = model_localize.to(device)
model_localize.eval()

annot_data = {}
SCORE_THR  = 0.99
image_list = os.listdir(data_path)
# image_list.sort()

transform  = transforms.Compose([              
            transforms.ToTensor()])

for img in image_list:
    image_path = data_path + img
    raw_image  = Image.open(image_path)
    image      = transform(raw_image)
    image_pred = torch.unsqueeze(image, 0).to(device)
    output     = model_localize(image_pred)
    
    bboxes     = output[0]['boxes'][output[0]['scores'] > SCORE_THR]  
    
    bbox_list     = []
    label_list    = []
    class_list    = []   # moth / non-moth
    subclass_list = []   # moth species / non-moth
    conf_list     = []   # confidence list
    
    for box in bboxes:
        box_numpy = box.detach().cpu().numpy() 
        bbox_list.append([int(box_numpy[0]), int(box_numpy[1]), \
                          int(box_numpy[2]), int(box_numpy[3])])
        label_list.append(1)
        
        cropped_image    = image[:,int(box_numpy[1]):int(box_numpy[3]), 
                                     int(box_numpy[0]):int(box_numpy[2])]
        transform_to_PIL = transforms.ToPILImage()
        cropped_image    = transform_to_PIL(cropped_image)
        
        # prediction for moth / non-moth
        categ, conf = model_binary.predict(cropped_image, confidence=True)
#         plt.figure()
#         plt.imshow(np.transpose(image[:,int(box_numpy[1]):int(box_numpy[3]), 
#                                      int(box_numpy[0]):int(box_numpy[2])]))
#         print(categ)
        if categ == 'nonmoth':
            class_list.append('nonmoth')
            subclass_list.append('nonmoth')
            conf_list.append(int(conf*100))
        else:
            categ, conf = model_moth.predict(cropped_image, confidence=True)
            class_list.append('moth')
            subclass_list.append(categ)
            conf_list.append(int(conf*100))       
        
    annot_data[img] = [bbox_list, label_list, class_list, subclass_list, conf_list]

with open(save_path + annot_file , 'w') as outfile:
    json.dump(annot_data, outfile)    