In [1]:
import os
import pandas as pd
from tqdm import tqdm
import sys

import matplotlib.pyplot as plt

import warnings
warnings.simplefilter("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

In [2]:
program_prefix = "./"
program_id = "species_net_bounded_no_finetuning"

In [3]:
# get image filepaths
pangolin_path = program_prefix + "[6-3] Pangolin Images"
other_path = program_prefix + "[6-3] Other Animal Images"

pangolin_images = os.listdir(pangolin_path)
other_images = os.listdir(other_path)
print(len(pangolin_images), pangolin_images)
print(len(other_images), other_images)

206 ['IMG_0006.JPG', 'IMG_0015.JPG', 'IMG_0016.JPG', 'IMG_0017.JPG', 'IMG_0018.JPG', 'IMG_0153.JPG', 'IMG_0155.JPG', 'IMG_0156.JPG', 'IMG_0157.JPG', 'IMG_0158.JPG', 'IMG_0159.JPG', 'IMG_0160.JPG', 'IMG_0161.JPG', 'IMG_0162.JPG', 'IMG_0179.JPG', 'IMG_0180.JPG', 'IMG_0181.JPG', 'IMG_0182.JPG', 'IMG_0236.JPG', 'IMG_0247.JPG', 'IMG_0251.JPG', 'IMG_0252.JPG', 'IMG_0253.JPG', 'IMG_0254.JPG', 'IMG_0255.JPG', 'IMG_0256.JPG', 'IMG_0257.JPG', 'IMG_0258.JPG', 'IMG_0259.JPG', 'IMG_0271.JPG', 'IMG_0416.JPG', 'IMG_0417.JPG', 'IMG_0418.JPG', 'IMG_0420.JPG', 'IMG_0422.JPG', 'IMG_0423.JPG', 'IMG_0424.JPG', 'IMG_0425.JPG', 'IMG_0426.JPG', 'IMG_0549.JPG', 'IMG_0550.JPG', 'IMG_0551.JPG', 'IMG_0552.JPG', 'IMG_0553.JPG', 'IMG_0554.JPG', 'IMG_0555.JPG', 'IMG_0556.JPG', 'IMG_0557.JPG', 'IMG_0558.JPG', 'IMG_0559.JPG', 'IMG_0560.JPG', 'IMG_0561.JPG', 'IMG_0562.JPG', 'IMG_0563.JPG', 'IMG_0564.JPG', 'IMG_0577.JPG', 'IMG_0578.JPG', 'IMG_0579.JPG', 'IMG_0581.JPG', 'IMG_0582.JPG', 'IMG_0583.JPG', 'IMG_0584.JPG', 'IM

In [4]:
# combine filepath and label
data = ([(os.path.join(pangolin_path, img), 1) for img in pangolin_images if img != ".DS_Store"] +
    [(os.path.join(other_path, img), 0) for img in other_images])

print(data)
print(len(data))

[('./[6-3] Pangolin Images\\IMG_0006.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0015.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0016.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0017.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0018.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0153.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0155.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0156.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0157.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0158.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0159.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0160.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0161.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0162.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0179.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0180.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0181.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0182.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0236.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0247.JPG', 1), ('./[6-3] Pangolin Images\\IMG_0251.JPG', 1), ('./[6-3] Pangolin Images\\IMG_02

In [5]:
# create and write column headers to csv files for logging purposes
import csv
log_folder_path = program_prefix + "log/" + program_id + "/"
os.makedirs(log_folder_path, exist_ok=True)

test_csv_file_path = log_folder_path + "test_scores.csv"
test_csv = open(test_csv_file_path, mode='w', newline='')
test_writer = csv.writer(test_csv)
test_writer.writerow([
    'fold',
    'accuracy_pangolin', 'accuracy_other',
    'recall_pangolin', 'recall_other',
    'precision_pangolin', 'precision_other',
    'f1_pangolin', 'f1_other',
    'auc_pangolin', 'auc_other'
])

147

In [6]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
)

from PIL import Image
import numpy as np
from speciesnet.utils import BBox

# plot image of misclassifications
def plot_misclassified(fold, pred, data):
    true = [label for _, label in data]
    fn_indices = [i for i, (p, t) in enumerate(zip(pred, true)) if p == 0 and t == 1]   # false negatives
    fp_indices = [i for i, (p, t) in enumerate(zip(pred, true)) if p == 1 and t == 0]   # false positives

    # helper method to plot image
    def plot_images(indices, title):
        if not indices:
            print(f"No Misclassified Samples Found: " + title)
            return
        
        print(f"Misclassified Samples Found: " + str(len(indices)))
        indices = indices[:5]
        
        fig, axes = plt.subplots(1, len(indices), figsize=(15, 5))
        if len(indices) == 1:
            axes = [axes]
        
        for ax, idx in zip(axes, indices):
            filename, true_label = data[idx]
            img = Image.open(filename).convert("L")            
            ax.imshow(img, cmap="gray")
            ax.axis("off")
            ax.set_title(f"{title}\nPred: {pred[idx]}, True: {true_label}", fontsize=10)
            
        misclassified_image_path  = misclassified_folder_path + "misclassified_fold_" + str(fold + 1) + "/"     # path to store the misclassified images, name-specific to fold 
        misclassified_image = misclassified_image_path + str(idx) + ".png"

        os.makedirs(misclassified_image_path, exist_ok=True)
        plt.savefig(misclassified_image, bbox_inches='tight')
        plt.close()

    # plot images
    plot_images(fn_indices, "Fold " + str(fold + 1) + "- False Negatives")
    plot_images(fp_indices, "Fold " + str(fold + 1) + "- False Positives")

###############################################################################################################################################
# printing/logging methods (different for validation, training, and testing because they have different ways to log to csv files)
###############################################################################################################################################

# prints score for validation
def val_print_and_log_scores(fold, epoch, true, pred, prob):
    # calculating scores for pangolin class
    accuracy = accuracy_score(true, pred)
    pangolin_precision = precision_score(true, pred, pos_label=1, zero_division=0)
    pangolin_recall = recall_score(true, pred, pos_label=1, zero_division=0)
    pangolin_f1 = f1_score(true, pred, pos_label=1, zero_division=0)
    pangolin_prob = prob
    pangolin_auc = roc_auc_score(true, pangolin_prob)

    print(f"\tpangolin accuracy: {accuracy:.4f}")
    print(f"\tpangolin precision: {pangolin_precision:.4f}")
    print(f"\tpangolin recall: {pangolin_recall:.4f}")
    print(f"\tpangolin f1-score: {pangolin_f1:.4f}")
    print(f"\tpangolin auc: {pangolin_auc:.4f}")

    # calculating scores for other class
    other_precision = precision_score(true, pred, pos_label=0, zero_division=0)
    other_recall = recall_score(true, pred, pos_label=0, zero_division=0)
    other_f1 = f1_score(true, pred, pos_label=0, zero_division=0)

    other_prob = prob
    true_inverted = [1 if t != 1 else 0 for t in true]
    other_auc = roc_auc_score(true_inverted, other_prob)

    print(f"\tother accuracy: {accuracy:.4f}")
    print(f"\tother precision: {other_precision:.4f}")
    print(f"\tother recall: {other_recall:.4f}")
    print(f"\tother f1-score: {other_f1:.4f}")
    print(f"\tother auc: {other_auc:.4f}")

    # writing to csv file
    val_writer.writerow([
            fold, epoch,
            accuracy, accuracy,
            pangolin_recall, other_recall,
            pangolin_precision, other_precision,
            pangolin_f1, other_f1,
            pangolin_auc, other_auc
    ])


# print and log score for training
def train_print_and_log_scores(fold, epoch, true, pred, prob):
    # calculating scores for pangolin class
    accuracy = accuracy_score(true, pred)
    pangolin_precision = precision_score(true, pred, pos_label=1, zero_division=0)
    pangolin_recall = recall_score(true, pred, pos_label=1, zero_division=0)
    pangolin_f1 = f1_score(true, pred, pos_label=1, zero_division=0)
    pangolin_prob = prob.detach().numpy()
    pangolin_auc = roc_auc_score(true, pangolin_prob)

    print(f"\tpangolin accuracy: {accuracy:.4f}")
    print(f"\tpangolin precision: {pangolin_precision:.4f}")
    print(f"\tpangolin recall: {pangolin_recall:.4f}")
    print(f"\tpangolin f1-score: {pangolin_f1:.4f}")
    print(f"\tpangolin auc: {pangolin_auc:.4f}")

    # calculating scores for other class
    other_precision = precision_score(true, pred, pos_label=0, zero_division=0)
    other_recall = recall_score(true, pred, pos_label=0, zero_division=0)
    other_f1 = f1_score(true, pred, pos_label=0, zero_division=0)
    other_prob = prob.detach().numpy()
    true_inverted = (true != 1).detach().numpy().astype(int)  
    other_auc = roc_auc_score(true_inverted, other_prob)

    print(f"\tother accuracy: {accuracy:.4f}")
    print(f"\tother precision: {other_precision:.4f}")
    print(f"\tother recall: {other_recall:.4f}")
    print(f"\tother f1-score: {other_f1:.4f}")
    print(f"\tother auc: {other_auc:.4f}")

    # writing to csv file
    train_writer.writerow([
        fold, epoch,
        accuracy, accuracy,
        pangolin_recall, other_recall,
        pangolin_precision, other_precision,
        pangolin_f1, other_f1,
        pangolin_auc, other_auc
    ])

# print and log scores for testing
def test_print_and_log_scores(fold, true, pred, prob):
    # calculating scores for pangolin class
    accuracy = accuracy_score(true, pred)
    pangolin_precision = precision_score(true, pred, pos_label=1, zero_division=0)
    pangolin_recall = recall_score(true, pred, pos_label=1, zero_division=0)
    pangolin_f1 = f1_score(true, pred, pos_label=1, zero_division=0)
    pangolin_prob = prob
    pangolin_auc = roc_auc_score(true, pangolin_prob)

    print(f"\tpangolin accuracy: {accuracy:.4f}")
    print(f"\tpangolin precision: {pangolin_precision:.4f}")
    print(f"\tpangolin recall: {pangolin_recall:.4f}")
    print(f"\tpangolin f1-score: {pangolin_f1:.4f}")
    print(f"\tpangolin auc: {pangolin_auc:.4f}")

    # calculating scores for other class
    other_precision = precision_score(true, pred, pos_label=0, zero_division=0)
    other_recall = recall_score(true, pred, pos_label=0, zero_division=0)
    other_f1 = f1_score(true, pred, pos_label=0, zero_division=0)
    other_prob = prob
    true_inverted = [1 if t != 1 else 0 for t in true]
    other_auc = roc_auc_score(true_inverted, other_prob)

    print(f"\tother accuracy: {accuracy:.4f}")
    print(f"\tother precision: {other_precision:.4f}")
    print(f"\tother recall: {other_recall:.4f}")
    print(f"\tother f1-score: {other_f1:.4f}")
    print(f"\tother auc: {other_auc:.4f}")

    # writing to csv file
    test_writer.writerow([
        fold,
        accuracy, accuracy,
        pangolin_recall, other_recall,
        pangolin_precision, other_precision,
        pangolin_f1, other_f1,
        pangolin_auc, other_auc
    ])

# custom collate function for the data loader
def collate_fn(batch):
    images = torch.stack([item["image"] for item in batch])
    labels = torch.tensor([item["label"] for item in batch], dtype=torch.long)
    image_ids = [item["image_id"] for item in batch]
    detections = [item["detections"] for item in batch]

    return {
        "image": images,
        "label": labels,
        "image_id": image_ids,
        "detections": detections
    }

# custom image dataset
class ImagesDataset(Dataset):
    def __init__(self, data, detector):
        self.data = data
        self.detector = detector
        self.transform = transforms.Compose(
            [
                transforms.Resize((480, 480)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
                )
            ]
        )
        
    def __getitem__(self, index):
        filepath, label = self.data[index]
        image = Image.open(filepath).convert("RGB")

        # preprocessing: crop out bottom bar and convert to greyscale
        image = self._crop_bottom_bar(image)
        image = image.convert("L").convert("RGB") # converts back to rgb to preserve 3 channels

        label = torch.tensor(label, dtype=torch.long)

        # preprocessing with MD
        image_MD = self.detector.preprocess(image)
        detection_result = self.detector.predict(filepath, image_MD)
        detections = detection_result["detections"]

        image_bbox = image
        # use bounding box to crop image 
        if detections:
            # get detection with highest confidence
            bbox = BBox(*detections[0]["bbox"])

            x0 = bbox.xmin * image.width
            y0 = bbox.ymin * image.height
            x1 = (bbox.xmin + bbox.width) * image.width
            y1 = (bbox.ymin + bbox.height) * image.height

            image_bbox = image.crop((x0, y0, x1, y1))

            # extend cropped image for square dimensions
            width = x1 - x0
            height = y1 - y0
            diff = abs(width - height)

            if width > height:
                y0 = max(0, y0 - diff // 2)
                y1 = min(image.height, y1 + (diff - diff // 2))
            elif height > width:
                x0 = max(0, x0 - diff // 2)
                x1 = min(image.width, x1 + (diff - diff // 2))

            image_bbox = image.crop((x0, y0, x1, y1))

        # ensure final image is square
        image_bbox = self._pad_to_square(image_bbox)

        # transform image
        image = self.transform(image_bbox)

        sample = {"image": image, "label": label, "image_id": index, "detections": detections}
        return sample

    def __len__(self):
        return len(self.data)
        
    # crops out WI info bar at bottom of image 
    def _crop_bottom_bar(self, image):
        width, height = image.size
        if (width != 5376 and height != 3024): return image
        bar_height = int(height * 0.05)
        cropped_image = image.crop((0, 0, width, height - bar_height))
        return cropped_image

    # pads non square images with black pixels
    def _pad_to_square(self, image):
        width, height = image.size
        if width == height:
            return image

        max_side = max(width, height)
        image_padded = Image.new("RGB", (max_side, max_side), color=(0, 0, 0))
        x = (max_side - width) // 2
        y = (max_side - height) // 2
        image_padded.paste(image, (x, y))
        return image_padded

# store feature maps for each layer
feature_maps = {}

# hook function to save feature maps
def hook_fn(module, input, output, name):
    feature_maps[name] = output.detach()


layers_to_visualize = [
    'SpeciesNet/efficientnetv2-m/stem_conv/Conv2D.1',               # very early
    'SpeciesNet/efficientnetv2-m/block2a_expand_conv/Conv2D',       # early mid
    'SpeciesNet/efficientnetv2-m/block4a_project_conv/Conv2D',      # mid
    'SpeciesNet/efficientnetv2-m/block6b_project_conv/Conv2D',      # late mid
    'SpeciesNet/efficientnetv2-m/top_conv/Conv2D'                   # final conv
]



  from .autonotebook import tqdm as notebook_tqdm


In [7]:
import torch.optim as optim
from torch import nn
from speciesnet import SpeciesNet
from sklearn.metrics import ConfusionMatrixDisplay
import time
import random

speciesnet_model_path = program_prefix + 'speciesnet_model'
# model_folder_path = program_prefix + "models/" + program_id + "/"                 # create folder to store optimal model per fold
# feature_map_folder_path = program_prefix + "feature_maps/" + program_id + "/"
confusion_matrix_folder_path = program_prefix + "cm/" + program_id + "/"
misclassified_folder_path = program_prefix + "misclassified/" + program_id + "/"
# os.makedirs(model_folder_path, exist_ok=True)
# os.makedirs(feature_map_folder_path, exist_ok=True)
os.makedirs(confusion_matrix_folder_path, exist_ok=True)
os.makedirs(misclassified_folder_path, exist_ok=True)

In [8]:
pangolin_indices = []
with open("./speciesnet_model/always_crop_99710272_22x8_v12_epoch_00148.labels.txt", "r") as f:
    for idx, line in enumerate(f):
        if "pangolin" in line.lower():
            pangolin_indices.append(idx)

# lines 882, 1405, 1406
# 2f336cdf-a62f-4587-a516-6e6c74d07353;mammalia;pholidota;manidae;phataginus;tricuspis;white-bellied pangolin
# b1cefdc9-af34-4f28-b077-1186dd6b5072;mammalia;pholidota;manidae;;;pangolin family
# ade3ecab-c110-429a-849e-b6afdb290219;mammalia;pholidota;manidae;manis;;pangolin species

# pangolin_indices = [1404] # family level
# pangolin_indices = [1405] # species level

print(pangolin_indices)

[70, 881, 951, 1044, 1404, 1405, 2330]


In [9]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
%matplotlib inline

# load model
speciesnet_model = SpeciesNet(speciesnet_model_path)
speciesnet_classifier = speciesnet_model.classifier
speciesnet_detector = speciesnet_model.detector
speciesnet_ensembler = speciesnet_model.ensemble

base_model = speciesnet_classifier.model
base_model.fc = nn.Sequential(
    nn.Linear(2048, 100),
    nn.ReLU(inplace=True),
    nn.Dropout(0.1),
    nn.Linear(100, 1)
)

model = base_model.to(device)

full_data = data
full_dataset = ImagesDataset(full_data, speciesnet_detector)
full_dataloader = DataLoader(full_dataset, batch_size=32, collate_fn=collate_fn)

classifier_results = {}
detector_results = {}
geolocation_results = {}
filepaths = []

model.eval()

# iterate through dataloader and run the model
with torch.no_grad():
    for batch in tqdm(full_dataloader, total=len(full_dataloader)):
        batch["image"] = batch["image"].permute(0, 2, 3, 1)      

        images = batch['image'].to(device)
        labels = batch['label'].float().to(device)    
        image_ids = batch['image_id']

        logits = model.forward(images)
        
        pangolin_logits = logits[:, pangolin_indices].logsumexp(dim=1)
        pangolin_prob = torch.sigmoid(pangolin_logits)

        preds_df = pd.DataFrame(
            pangolin_prob.detach().cpu().numpy(),
            index=batch["image_id"],
            columns=["prob"]
        )
        # test_preds_collector.append(preds_df)

        # storing information for ensembling method
        for i in range(len(pangolin_prob)):
            prob = pangolin_prob[i].item()

            image_id = image_ids[i]
            filepath, _ = full_data[image_id]  # get filepath from index

            filepaths.append(filepath)

            # classifier prediction
            classifier_results[filepath] = {
                "classifications": {
                    "classes": [
                        "Animalia;Chordata;Mammalia;Pholidota;Manidae;Manis;pangolin",
                        "Unknown;Unknown;Unknown;Unknown;Unknown;Unknown;other"
                    ],
                    "scores": [prob, 1 - prob],
                }
            }

            # detector result from batch
            detector_result = batch["detections"][i] if "detections" in batch else []
            detector_results[filepath] = {
                "detections": detector_result
            }

ensemble_outputs = speciesnet_ensembler.combine(
    filepaths=filepaths,
    classifier_results=classifier_results,
    detector_results=detector_results,
    geolocation_results=geolocation_results,
    partial_predictions={},
)

ensemble_labels = [1 if "pangolin" in o["prediction"] else 0 for o in ensemble_outputs]
ensemble_scores = [o["prediction_score"] for o in ensemble_outputs]

ensemble_pred = torch.tensor(ensemble_labels)
ensemble_prob = torch.tensor(ensemble_scores)

test_true = [label for _, label in full_data]                               # get the true labels
test_print_and_log_scores(0, test_true, ensemble_pred, ensemble_prob)

# error analysis on misclassified images
plot_misclassified(0, ensemble_pred, full_data)

# display confusion matrix
fig, ax = plt.subplots(figsize=(10, 10))
cm = ConfusionMatrixDisplay.from_predictions(
    test_true,
    ensemble_pred,
    ax=ax,
    xticks_rotation=90,
    colorbar=True,
    normalize='true'
)

cm_image_path  = confusion_matrix_folder_path      # path to store the confusion matrix, name-specific to fold 
cm_image = cm_image_path + "cm.png"

os.makedirs(os.path.dirname(cm_image), exist_ok=True)
plt.savefig(cm_image, bbox_inches='tight')
plt.close()

100%|██████████| 21/21 [21:17<00:00, 60.81s/it]


	pangolin accuracy: 0.7749
	pangolin precision: 0.9155
	pangolin recall: 0.3155
	pangolin f1-score: 0.4693
	pangolin auc: 0.3501
	other accuracy: 0.7749
	other precision: 0.7577
	other recall: 0.9866
	other f1-score: 0.8571
	other auc: 0.6499
Misclassified Samples Found: 141
Misclassified Samples Found: 6


In [11]:
test_csv.flush()