# Run analysis on one Image

In [None]:
import os
import csv
import torch
import torchvision
import torchvision.transforms as transforms
from PIL import Image, ImageDraw
import pandas as pd
import timm
import numpy as np

device = torch.device('cpu')

In [None]:
# load in the localization model
weights_path = "/bask/homes/f/fspo1218/amber/data/mila_models/v1_localizmodel_2021-08-17-12-06.pt"

model_loc = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None)
num_classes = 2  # 1 class (object) + background
in_features = model_loc.roi_heads.box_predictor.cls_score.in_features
model_loc.roi_heads.box_predictor = (
    torchvision.models.detection.faster_rcnn.FastRCNNPredictor(
        in_features, num_classes
    )
)

checkpoint = torch.load(weights_path, map_location=device)
state_dict = checkpoint.get("model_state_dict") or checkpoint
model_loc.load_state_dict(state_dict)
model_loc = model_loc.to(device)
model_loc.eval()

print('localisation model loaded')

In [None]:
weights_path = "/bask/homes/f/fspo1218/amber/data/mila_models/moth-nonmoth-effv2b3_20220506_061527_30.pth"
labels_path = "/bask/homes/f/fspo1218/amber/data/mila_models/05-moth-nonmoth_category_map.json"

num_classes=2
classification_model = timm.create_model(
            "tf_efficientnetv2_b3",
            num_classes=num_classes,
            weights=None,
        )
classification_model = classification_model.to(device)
# state_dict = torch.hub.load_state_dict_from_url(weights_url)
checkpoint = torch.load(weights_path, map_location=device)
# The model state dict is nested in some checkpoints, and not in others
state_dict = checkpoint.get("model_state_dict") or checkpoint
classification_model.load_state_dict(state_dict)
classification_model.eval()

print('binary classifier model loaded')

In [None]:
# Transformations for the images
transform = transforms.Compose([
    transforms.Resize((300, 300)),  # Assuming models require 300x300 input images
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

transform_species = transforms.Compose(
            [
                transforms.Resize((300, 300)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5,0.5, 0.5], std=[0.5,0.5, 0.5]),
            ]
        )

# Directory containing images
image_dir = '/bask/homes/f/fspo1218/amber/projects/object-store-scripts/data'

all_images = [os.path.join(dp, f) for dp, dn, fn in os.walk(os.path.expanduser(image_dir)) for f in fn]
all_images = [x for x in all_images if x.endswith('jpg')]


print(len(all_images))

# print(all_images)

# CSV file to save results
csv_file = '/bask/projects/v/vjgo8416-amber/projects/object-store-scripts/mila_results.csv'

In [None]:
all_images = all_images[0:5] 
all_images

In [None]:
class Resnet50(torch.nn.Module):
    def __init__(self, num_classes):
        """
        Args:
            config: provides parameters for model generation
        """
        super(Resnet50, self).__init__()
        self.num_classes = num_classes
        self.backbone = torchvision.models.resnet50(weights="DEFAULT")
        out_dim = self.backbone.fc.in_features

        self.backbone = torch.nn.Sequential(*list(self.backbone.children())[:-2])
        self.avgpool = torch.nn.AdaptiveAvgPool2d(output_size=(1, 1))
        self.classifier = torch.nn.Linear(out_dim, self.num_classes, bias=False)

    def forward(self, x):
        x = self.backbone(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.classifier(x)

        return x


In [None]:
import json

weights = '/bask/homes/f/fspo1218/amber/projects/species_classifier/outputs/turing-costarica_v03_resnet50_2024-06-04-16-17_state.pt'
category_map = json.load(open('/bask/homes/f/fspo1218/amber/data/gbif_costarica/03_costarica_data_category_map.json'))

num_classes = len(category_map)
species_model = Resnet50(num_classes=num_classes)
species_model = species_model.to(device)
checkpoint = torch.load(weights, map_location=device)
# The model state dict is nested in some checkpoints, and not in others
state_dict = checkpoint.get("model_state_dict") or checkpoint

species_model.load_state_dict(state_dict)
species_model.eval()


print('species classifier loaded')

In [None]:
import torch
import torch.nn as nn
from torchvision import models
#from torchvision.models import ResNet50_Weights

class ResNet502(nn.Module):
    '''ResNet-50 Architecture with pretrained weights
    '''

    def __init__(self, use_cbam=True, image_depth=3, num_classes=20):
        '''Params init and build arch.
        '''
        super(ResNet502, self).__init__()

        self.expansion = 4
        self.out_channels = 512
        
        #self.model_ft = models.resnet50(weights=ResNet50_Weights.IMAGENET1K_V2) # 80.86, 25.6M
        self.model_ft = models.resnet50(pretrained=True)
              
        # overwrite the 'fc' layer
        print("In features", self.model_ft.fc.in_features)
        #self.model_ft.fc = nn.Linear(self.model_ft.fc.in_features, 512*self.expansion) 
        self.model_ft.fc = nn.Identity() # Do nothing just pass input to output
        
        # At least one layer
        self.drop = nn.Dropout(p=0.5)
        self.linear_lvl1 = nn.Linear(self.out_channels*self.expansion, self.out_channels)
        self.relu_lv1 = nn.ReLU(inplace=False)
        self.softmax_reg1 = nn.Linear(self.out_channels, num_classes)

    def forward(self, x):
        '''Forward propagation of pretrained ResNet-50.
        '''
        x = self.model_ft(x)
        
        x = self.drop(x) # Dropout to add regularization

        level_1 = self.softmax_reg1(self.relu_lv1(self.linear_lvl1(x)))
        #level_1 = nn.Softmax(level_1)
                
        return level_1
    

In [None]:
savedWeights = '/bask/homes/f/fspo1218/amber/projects/MCC24-trap/model_order_060524/dhc_best_128.pth'
thresholdFile = '/bask/homes/f/fspo1218/amber/projects/MCC24-trap/model_order_060524/thresholdsTestTrain.csv'
device = 'cpu'
img_size = 128
        
print("Order classifier - threshold file", thresholdFile, "and weights", savedWeights, "of image size", img_size)

data_thresholds = pd.read_csv(thresholdFile)
order_labels = data_thresholds["ClassName"].to_list()
thresholds = data_thresholds["Threshold"].to_list()
means = data_thresholds["Mean"].to_list()
stds = data_thresholds["Std"].to_list()

img_depth = 3

num_classes=len(order_labels)
print("Use ResNet50 and load weights with num. classes", num_classes)

model_order = ResNet502(num_classes=num_classes) 
model_order.load_state_dict(torch.load(savedWeights, map_location=device))
model_order = model_order.to(device)
model_order.eval()

print('order classifier loaded')

In [None]:
from scipy.stats import norm

def classify_order(image_tensor): 
    augment=False
    visualize=False
    #visualize = increment_path(save_dir / Path(path).stem, mkdir=True) if visualize else False
    pred = model_order(image_tensor)
    

    predictions = pred.cpu().detach().numpy()
    predicted_label = np.argmax(predictions, axis=1)[0]
    print('preds:', predictions)
    print('pred labels:', predicted_label)
    
    label = order_labels[predicted_label]
    confidence_value = norm.cdf(predictions[0][predicted_label], 
                                data_thresholds['Mean'][predicted_label], 
                                data_thresholds['Std'][predicted_label])
    confidence_value = round(confidence_value*10000)/100

    return label, confidence_value
    


In [None]:
def classify_box(image_tensor):
    output = classification_model(image_tensor)

    predictions = torch.nn.functional.softmax(output, dim=1)

    predictions = predictions.detach().numpy()

    categories = predictions.argmax(axis=1)

    labels = {'moth': 0, 'nonmoth': 1}

    index_to_label = {index: label for label, index in labels.items()}

    label = [index_to_label[cat] for cat in categories][0]
    score = predictions.max(axis=1).astype(float)[0]
    return label, score

In [None]:
def classify_species(image_tensor):  
    output = species_model(image_tensor)

    predictions = torch.nn.functional.softmax(output, dim=1)

    predictions = predictions.detach().numpy()

    categories = predictions.argmax(axis=1)
    #print(categories)

    labels = category_map

    index_to_label = {index: label for label, index in labels.items()}

    label = [index_to_label[cat] for cat in categories][0]
    score = 1 - predictions.max(axis=1).astype(float)[0]
    return label, score

In [None]:
for im_index in range(len(all_images)):

    image_path = all_images[im_index]
    image = Image.open(image_path).convert('RGB')
    original_image = image.copy()
    original_width, original_height = image.size
    input_tensor = transform(image).unsqueeze(0).to(device)

    all_boxes = pd.DataFrame(columns=['image_path', 
                                      'box_score', 'x_min', 'y_min', 'x_max', 'y_max', #localisation info
                                      'class_name', 'class_confidence', # binary class info
                                      'order_name', 'order_confidence', # order info
                                      'species_name', 'species_confidence']) # species info

    # Perform object localization
    with torch.no_grad():
        localization_outputs = model_loc(input_tensor)

        print(image_path)
        print('Number of objects:', len(localization_outputs[0]['boxes']))

        # for each detection
        for i in range(len(localization_outputs[0]['boxes'])):
            x_min, y_min, x_max, y_max = localization_outputs[0]['boxes'][i]
            box_score = localization_outputs[0]['scores'].tolist()[i]

            x_min = int(int(x_min) * original_width / 300)
            y_min = int(int(y_min) * original_height / 300)
            x_max = int(int(x_max) * original_width / 300)
            y_max = int(int(y_max) * original_height / 300)

            box_width = x_max - x_min
            box_height = y_max - y_min

            # if box heigh or width > half the image, skip
            if box_width > original_width / 2 or box_height > original_height / 2:
                continue
                
            # if confidence below threshold
            if box_score <= 0.1:
                continue

            # Crop the detected region and perform classification
            cropped_image = original_image.crop((x_min, y_min, x_max, y_max))
            cropped_tensor = transform_species(cropped_image).unsqueeze(0)

            class_name, class_confidence = classify_box(cropped_tensor)            
            order_name, order_confidence = classify_order(cropped_tensor)  
            

            # Annotate image with bounding box and class
            if class_name == 'moth':
                # Perform the species classification
                print('...Performing the inference')
                species_name, species_confidence = classify_species(cropped_tensor)

                draw = ImageDraw.Draw(original_image)
                draw.rectangle([x_min, y_min, x_max, y_max], outline='green', width=3)
                draw.text((x_min, y_min - 10), species_name + " , %.3f " % species_confidence, fill='green')

            else:
                species_name, species_confidence = None, None
                draw = ImageDraw.Draw(original_image)
                draw.rectangle([x_min, y_min, x_max, y_max], outline='red', width=3)
                draw.text((x_min, y_min - 10), f'order: {order_name}, binary: {class_name}', fill='red')

            draw.text((x_min, y_max), str(box_score), fill='black')

            # append to csv with pandas
            df = pd.DataFrame([[image_path, 
                                box_score, x_min, y_min, x_max, y_max, 
                                class_name, class_confidence ,
                                order_name, order_confidence,
                                species_name, species_confidence]],
                              columns=['image_path', 
                                      'box_score', 'x_min', 'y_min', 'x_max', 'y_max', 
                                      'class_name', 'class_confidence', 
                                      'order_name', 'order_confidence',
                                      'species_name', 'species_confidence'])
            all_boxes = pd.concat([all_boxes, df])
            df.to_csv(csv_file, mode='a', header=False, index=False)


        # if (all_boxes['class'] == 'moth').any():
            # print('...Moth Detected')
        original_image.save(os.path.basename(image_path))
