In [None]:
import os
import pandas as pd
from tqdm import tqdm
import sys
import csv

import matplotlib.pyplot as plt

import warnings
warnings.simplefilter("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=FutureWarning)

In [None]:
# program_prefix = "/scratch1/lczhu/OP_CV/"
program_prefix = "./"
program_id = "species_net_test_multiclass_test"
num_epochs = 2

In [None]:
image_class_labels_path = "./image_class_labels_filtered.csv"
data = []

with open(image_class_labels_path, newline='') as csvfile:
    reader = csv.DictReader(csvfile)
    for row in reader:
        filepath = row["filepath"]
        class_id = int(row["class_id"])
        data.append((filepath, class_id))

print(data[:10])
print(len(data))

In [None]:
# create and write column headers to csv files for logging purposes
import csv
log_folder_path = program_prefix + "/log/" + program_id + "/"
os.makedirs(log_folder_path, exist_ok=True)

train_csv_file_path = log_folder_path + "train_scores.csv"
val_csv_file_path = log_folder_path + "val_scores.csv"
time_csv_file_path = log_folder_path + "time_per_epoch.csv"
test_csv_file_path = log_folder_path + "test_scores.csv"
loss_csv_file_path = log_folder_path + "loss_per_epoch.csv"

train_csv = open(train_csv_file_path, mode='w', newline='')
val_csv = open(val_csv_file_path, mode='w', newline='')
time_csv = open(time_csv_file_path, mode='w', newline='')
test_csv = open(test_csv_file_path, mode='w', newline='')
loss_csv = open(loss_csv_file_path, mode='w', newline='')

train_writer = csv.writer(train_csv)
val_writer = csv.writer(val_csv)
time_writer = csv.writer(time_csv)
test_writer = csv.writer(test_csv)
loss_writer = csv.writer(loss_csv)

train_writer.writerow([
    'fold', 'epoch',
    'accuracy_pangolin', 'accuracy_other',
    'recall_pangolin', 'recall_other',
    'precision_pangolin', 'precision_other',
    'f1_pangolin', 'f1_other',
    'auc_pangolin', 'auc_other'
])
val_writer.writerow([
    'fold', 'epoch',
    'accuracy_pangolin', 'accuracy_other',
    'recall_pangolin', 'recall_other',
    'precision_pangolin', 'precision_other',
    'f1_pangolin', 'f1_other',
    'auc_pangolin', 'auc_other'
])
time_writer.writerow(['fold', 'epoch', 'time'])
test_writer.writerow([
    'fold',
    'accuracy_pangolin', 'accuracy_other',
    'recall_pangolin', 'recall_other',
    'precision_pangolin', 'precision_other',
    'f1_pangolin', 'f1_other',
    'auc_pangolin', 'auc_other'
])
loss_writer.writerow(['fold', 'epoch', 'train_loss', 'val_loss'])

In [None]:
import torch.optim as optim
from torch import nn
from speciesnet import SpeciesNet
from sklearn.metrics import ConfusionMatrixDisplay
import time
import random

speciesnet_model_path = program_prefix + 'speciesnet_model'
model_folder_path = program_prefix + "models/" + program_id + "/"                 # create folder to store optimal model per fold
feature_map_folder_path = program_prefix + "feature_maps/" + program_id + "/"
confusion_matrix_folder_path = program_prefix + "cm/" + program_id + "/"
misclassified_folder_path = program_prefix + "misclassified/" + program_id + "/"
os.makedirs(model_folder_path, exist_ok=True)
os.makedirs(feature_map_folder_path, exist_ok=True)
os.makedirs(confusion_matrix_folder_path, exist_ok=True)
os.makedirs(misclassified_folder_path, exist_ok=True)

In [None]:
pangolin_indices = []
with open("./speciesnet_model/always_crop_99710272_22x8_v12_epoch_00148.labels.txt", "r") as f:
    for idx, line in enumerate(f):
        if "pangolin" in line.lower():
            pangolin_indices.append(idx)

# lines 882, 1405, 1406
# 2f336cdf-a62f-4587-a516-6e6c74d07353;mammalia;pholidota;manidae;phataginus;tricuspis;white-bellied pangolin
# b1cefdc9-af34-4f28-b077-1186dd6b5072;mammalia;pholidota;manidae;;;pangolin family
# ade3ecab-c110-429a-849e-b6afdb290219;mammalia;pholidota;manidae;manis;;pangolin species

# pangolin_indices = [1404] # family level
# pangolin_indices = [1405] # species level

print(pangolin_indices)

In [None]:
import torch
from torch.utils.data import Dataset
from torchvision import transforms
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
)

from PIL import Image
import numpy as np
from speciesnet.utils import BBox

# plot image of misclassifications
def plot_misclassified(fold, pred, data):
    true = [label for _, label in data]
    binary_true = np.isin(true, pangolin_indices).astype(int)
    fn_indices = [i for i, (p, t) in enumerate(zip(pred, binary_true)) if p == 0 and t == 1]   # false negatives
    fp_indices = [i for i, (p, t) in enumerate(zip(pred, binary_true)) if p == 1 and t == 0]   # false positives

    # helper method to plot image
    def plot_images(indices, title):
        if not indices:
            print(f"No Misclassified Samples Found: " + title)
            return
        
        print(f"Misclassified Samples Found: " + str(len(indices)))
        indices = indices[:5]
        
        fig, axes = plt.subplots(1, len(indices), figsize=(15, 5))
        if len(indices) == 1:
            axes = [axes]
        
        for ax, idx in zip(axes, indices):
            filename, true_label = data[idx]
            img = Image.open(filename).convert("L")            
            ax.imshow(img, cmap="gray")
            ax.axis("off")
            ax.set_title(f"{title}\nPred: {pred[idx]}, True: {true_label}", fontsize=10)
            
        misclassified_image_path  = misclassified_folder_path + "misclassified_fold_" + str(fold + 1) + "/"     # path to store the misclassified images, name-specific to fold 
        misclassified_image = misclassified_image_path + str(idx) + ".png"

        os.makedirs(misclassified_image_path, exist_ok=True)
        plt.savefig(misclassified_image, bbox_inches='tight')
        plt.close()

    # plot images
    plot_images(fn_indices, "Fold " + str(fold + 1) + "- False Negatives")
    plot_images(fp_indices, "Fold " + str(fold + 1) + "- False Positives")

###############################################################################################################################################
# printing/logging methods (different for validation, training, and testing because they have different ways to log to csv files)
###############################################################################################################################################

# prints score for validation
def val_print_and_log_scores(fold, epoch, true, pred, prob):
    # calculating scores for pangolin class
    accuracy = accuracy_score(true, pred)
    pangolin_precision = precision_score(true, pred, pos_label=1, zero_division=0)
    pangolin_recall = recall_score(true, pred, pos_label=1, zero_division=0)
    pangolin_f1 = f1_score(true, pred, pos_label=1, zero_division=0)
    pangolin_prob = prob
    pangolin_auc = roc_auc_score(true, pangolin_prob)

    print(f"\tpangolin accuracy: {accuracy:.4f}")
    print(f"\tpangolin precision: {pangolin_precision:.4f}")
    print(f"\tpangolin recall: {pangolin_recall:.4f}")
    print(f"\tpangolin f1-score: {pangolin_f1:.4f}")
    print(f"\tpangolin auc: {pangolin_auc:.4f}")

    # calculating scores for other class
    other_precision = precision_score(true, pred, pos_label=0, zero_division=0)
    other_recall = recall_score(true, pred, pos_label=0, zero_division=0)
    other_f1 = f1_score(true, pred, pos_label=0, zero_division=0)

    other_prob = prob
    true_inverted = [1 if t != 1 else 0 for t in true]
    other_auc = roc_auc_score(true_inverted, other_prob)

    print(f"\tother accuracy: {accuracy:.4f}")
    print(f"\tother precision: {other_precision:.4f}")
    print(f"\tother recall: {other_recall:.4f}")
    print(f"\tother f1-score: {other_f1:.4f}")
    print(f"\tother auc: {other_auc:.4f}")

    # writing to csv file
    val_writer.writerow([
            fold, epoch,
            accuracy, accuracy,
            pangolin_recall, other_recall,
            pangolin_precision, other_precision,
            pangolin_f1, other_f1,
            pangolin_auc, other_auc
    ])


# print and log score for training
def train_print_and_log_scores(fold, epoch, true, pred, prob):
    # calculating scores for pangolin class
    accuracy = accuracy_score(true, pred)
    pangolin_precision = precision_score(true, pred, pos_label=1, zero_division=0)
    pangolin_recall = recall_score(true, pred, pos_label=1, zero_division=0)
    pangolin_f1 = f1_score(true, pred, pos_label=1, zero_division=0)
    pangolin_prob = prob
    pangolin_auc = roc_auc_score(true, pangolin_prob)

    print(f"\tpangolin accuracy: {accuracy:.4f}")
    print(f"\tpangolin precision: {pangolin_precision:.4f}")
    print(f"\tpangolin recall: {pangolin_recall:.4f}")
    print(f"\tpangolin f1-score: {pangolin_f1:.4f}")
    print(f"\tpangolin auc: {pangolin_auc:.4f}")

    # calculating scores for other class
    other_precision = precision_score(true, pred, pos_label=0, zero_division=0)
    other_recall = recall_score(true, pred, pos_label=0, zero_division=0)
    other_f1 = f1_score(true, pred, pos_label=0, zero_division=0)
    other_prob = prob
    true_inverted = (true != 1) 
    other_auc = roc_auc_score(true_inverted, other_prob)

    print(f"\tother accuracy: {accuracy:.4f}")
    print(f"\tother precision: {other_precision:.4f}")
    print(f"\tother recall: {other_recall:.4f}")
    print(f"\tother f1-score: {other_f1:.4f}")
    print(f"\tother auc: {other_auc:.4f}")

    # writing to csv file
    train_writer.writerow([
        fold, epoch,
        accuracy, accuracy,
        pangolin_recall, other_recall,
        pangolin_precision, other_precision,
        pangolin_f1, other_f1,
        pangolin_auc, other_auc
    ])

# print and log scores for testing
def test_print_and_log_scores(fold, true, pred, prob):
    # calculating scores for pangolin class
    accuracy = accuracy_score(true, pred)
    pangolin_precision = precision_score(true, pred, pos_label=1, zero_division=0)
    pangolin_recall = recall_score(true, pred, pos_label=1, zero_division=0)
    pangolin_f1 = f1_score(true, pred, pos_label=1, zero_division=0)
    pangolin_prob = prob
    pangolin_auc = roc_auc_score(true, pangolin_prob)

    print(f"\tpangolin accuracy: {accuracy:.4f}")
    print(f"\tpangolin precision: {pangolin_precision:.4f}")
    print(f"\tpangolin recall: {pangolin_recall:.4f}")
    print(f"\tpangolin f1-score: {pangolin_f1:.4f}")
    print(f"\tpangolin auc: {pangolin_auc:.4f}")

    # calculating scores for other class
    other_precision = precision_score(true, pred, pos_label=0, zero_division=0)
    other_recall = recall_score(true, pred, pos_label=0, zero_division=0)
    other_f1 = f1_score(true, pred, pos_label=0, zero_division=0)
    other_prob = prob
    true_inverted = [1 if t != 1 else 0 for t in true]
    other_auc = roc_auc_score(true_inverted, other_prob)

    print(f"\tother accuracy: {accuracy:.4f}")
    print(f"\tother precision: {other_precision:.4f}")
    print(f"\tother recall: {other_recall:.4f}")
    print(f"\tother f1-score: {other_f1:.4f}")
    print(f"\tother auc: {other_auc:.4f}")

    # writing to csv file
    test_writer.writerow([
        fold,
        accuracy, accuracy,
        pangolin_recall, other_recall,
        pangolin_precision, other_precision,
        pangolin_f1, other_f1,
        pangolin_auc, other_auc
    ])

# custom collate function for the data loader
def collate_fn(batch):
    images = torch.stack([item["image"] for item in batch])
    labels = torch.tensor([item["label"] for item in batch], dtype=torch.long)
    image_ids = [item["image_id"] for item in batch]
    detections = [item["detections"] for item in batch]

    return {
        "image": images,
        "label": labels,
        "image_id": image_ids,
        "detections": detections
    }

# custom image dataset
class ImagesDataset(Dataset):
    def __init__(self, data, detector):
        self.data = data
        self.detector = detector
        self.transform = transforms.Compose(
            [
                transforms.Resize((480, 480)),
                transforms.ToTensor(),
                transforms.Normalize(
                    mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)
                )
            ]
        )
        
    def __getitem__(self, index):
        filepath, label = self.data[index]
        image = Image.open(filepath).convert("RGB")

        # preprocessing: crop out bottom bar and convert to greyscale
        image = self._crop_bottom_bar(image)
        image = image.convert("L").convert("RGB") # converts back to rgb to preserve 3 channels

        label = torch.tensor(label, dtype=torch.long)

        # preprocessing with MD
        image_MD = self.detector.preprocess(image)
        detection_result = self.detector.predict(filepath, image_MD)
        detections = detection_result["detections"]

        image_bbox = image
        # use bounding box to crop image 
        if detections:
            # get detection with highest confidence
            bbox = BBox(*detections[0]["bbox"])

            x0 = bbox.xmin * image.width
            y0 = bbox.ymin * image.height
            x1 = (bbox.xmin + bbox.width) * image.width
            y1 = (bbox.ymin + bbox.height) * image.height

            image_bbox = image.crop((x0, y0, x1, y1))

            # extend cropped image for square dimensions
            width = x1 - x0
            height = y1 - y0
            diff = abs(width - height)

            if width > height:
                y0 = max(0, y0 - diff // 2)
                y1 = min(image.height, y1 + (diff - diff // 2))
            elif height > width:
                x0 = max(0, x0 - diff // 2)
                x1 = min(image.width, x1 + (diff - diff // 2))

            image_bbox = image.crop((x0, y0, x1, y1))

        # ensure final image is square
        image_bbox = self._pad_to_square(image_bbox)

        # transform image
        image = self.transform(image_bbox)

        sample = {"image": image, "label": label, "image_id": index, "detections": detections}
        return sample

    def __len__(self):
        return len(self.data)
        
    # crops out WI info bar at bottom of image 
    def _crop_bottom_bar(self, image):
        width, height = image.size
        if (width != 5376 and height != 3024): return image
        bar_height = int(height * 0.05)
        cropped_image = image.crop((0, 0, width, height - bar_height))
        return cropped_image

    # pads non square images with black pixels
    def _pad_to_square(self, image):
        width, height = image.size
        if width == height:
            return image

        max_side = max(width, height)
        image_padded = Image.new("RGB", (max_side, max_side), color=(0, 0, 0))
        x = (max_side - width) // 2
        y = (max_side - height) // 2
        image_padded.paste(image, (x, y))
        return image_padded

# store feature maps for each layer
feature_maps = {}

# hook function to save feature maps
def hook_fn(module, input, output, name):
    feature_maps[name] = output.detach()


layers_to_visualize = [
    'SpeciesNet/efficientnetv2-m/stem_conv/Conv2D.1',               # very early
    'SpeciesNet/efficientnetv2-m/block2a_expand_conv/Conv2D',       # early mid
    'SpeciesNet/efficientnetv2-m/block4a_project_conv/Conv2D',      # mid
    'SpeciesNet/efficientnetv2-m/block6b_project_conv/Conv2D',      # late mid
    'SpeciesNet/efficientnetv2-m/top_conv/Conv2D'                   # final conv
]



In [None]:
# read cross validation indices from stored csv file
df_splits = pd.read_csv("v2-multiclass_splits.csv")
df_splits = df_splits.iloc[::100] ####### for testing purposes
print(df_splits)

In [None]:
label_file = "./speciesnet_model/always_crop_99710272_22x8_v12_epoch_00148.labels.txt"

with open(label_file, "r") as f:
    class_names = [line.strip() for line in f]

len(class_names)

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
class_weights = torch.ones(len(class_names))

for idx in pangolin_indices:
    class_weights[idx] = 5.0

class_weights = class_weights.to(device)

In [None]:
%matplotlib inline

for fold in df_splits["fold"].unique():   
    image_index = 0     # indices for feature maps
     
    # get the stored indices 
    train_index = df_splits[(df_splits["fold"] == fold) & (df_splits["type"] == "train")]["index"].values
    val_index = df_splits[(df_splits["fold"] == fold) & (df_splits["type"] == "val")]["index"].values
    test_index = df_splits[(df_splits["fold"] == fold) & (df_splits["type"] == "test")]["index"].values

    print(f"\n\nfold {fold + 1} -------------------------------------------------------------------------------------------------")

    min_val_loss = sys.float_info.max
    model_name = program_prefix + "model_fold" + str(fold + 1) + ".pth"
    model_path = model_folder_path + "/" + model_name       # path to store the model checkpoints, name-specific to fold

    # load model
    speciesnet_model = SpeciesNet(speciesnet_model_path)
    speciesnet_classifier = speciesnet_model.classifier
    speciesnet_detector = speciesnet_model.detector
    speciesnet_ensembler = speciesnet_model.ensemble

    base_model = speciesnet_classifier.model
    base_model.fc = nn.Sequential(
        nn.Linear(2048, 100),
        nn.ReLU(inplace=True),
        nn.Dropout(0.1),
        nn.Linear(100, 2498),  # output a single value for binary classification
        # nn.Sigmoid()
    )

    model = base_model.to(device)
    
    criterion = nn.CrossEntropyLoss(weight=class_weights)
    optimizer = optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

    for name, param in model.named_parameters():
        if 'block7' in name or 'top_conv' in name or 'fc' in name:
            param.requires_grad = True
        else:
            param.requires_grad = False

    # split data and get dataloaders
    train_data = [data[i] for i in train_index]
    val_data = [data[i] for i in val_index]
    test_data = [data[i] for i in test_index]

    train_dataset = ImagesDataset(train_data, speciesnet_detector)
    val_dataset = ImagesDataset(val_data, speciesnet_detector)
    test_dataset = ImagesDataset(test_data, speciesnet_detector)

    train_dataloader = DataLoader(train_dataset, batch_size=32, collate_fn=collate_fn)
    val_dataloader = DataLoader(val_dataset, batch_size=32, collate_fn=collate_fn)
    test_dataloader = DataLoader(test_dataset, batch_size=32, collate_fn=collate_fn)

    # register hooks for each layer
    hooks = []
    for layer_name in layers_to_visualize:
        layer = dict([*model.named_modules()])[layer_name]  # Get the layer by name
        hook = layer.register_forward_hook(lambda module, input, output, name=layer_name: hook_fn(module, input, output, name))
        hooks.append(hook)

    for epoch in range(1, num_epochs + 1):
        start_time = time.time()

        print(f"\nepoch {epoch}")

        #############################################################################################################################################
        # training
        #############################################################################################################################################
        print("TRAINING")

        model.train()

        tracking_loss = {}
        training_loss = 0
        training_num_loss = 0

        # all_outputs = []
        all_labels = []

        classifier_results = {}
        detector_results = {}
        geolocation_results = {}
        filepaths = []

        # iterate through the dataloader batches. tqdm keeps track of progress.
        for batch_n, batch in tqdm(
            enumerate(train_dataloader), total=len(train_dataloader)
        ):
            
            feature_map_images = batch['image']            

            batch["image"] = batch["image"].permute(0, 2, 3, 1)     
    
            images = batch['image'].to(device)
            labels = batch['label'].long().to(device)

            image_ids = batch['image_id']

            # 1) zero out the parameter gradients so that gradients from previous batches are not used in this step
            optimizer.zero_grad()

            # 2) run the foward step on this batch of images
            logits = model(images)
            probs = torch.softmax(logits, dim=1)

            # 3) compute the loss
            loss = criterion(logits, labels)

            training_loss += loss.item()
            training_num_loss += 1

            # let's keep track of the loss by epoch
            tracking_loss[epoch] = loss.item()

            # 4) compute our gradients
            loss.backward()

            # update our weights
            optimizer.step()

            all_labels.append(labels.detach().cpu())

            # storing information for ensembling method
            for i in range(len(probs)):
                # prob = pangolin_prob[i].item()
                # label = "pangolin" if prob > 0.5 else "other"

                image_id = image_ids[i]
                filepath, _ = train_data[image_id]  # get filepath from index

                filepaths.append(filepath)

                class_probs = probs[i].cpu().tolist()

                classifier_results[filepath] = {
                    "classifications": {
                        "classes": class_names,
                        "scores": class_probs,
                    }
                }

                # detector result from batch
                detector_result = batch["detections"][i] if "detections" in batch else []
                detector_results[filepath] = {
                    "detections": detector_result
                }

        # print and log scores
        training_loss /= training_num_loss
        print(f"\ttraining loss: {training_loss:.4f}")

        all_labels = torch.cat(all_labels, dim=0)
        train_true = all_labels             # get the true labels
        binary_train_true = np.isin(train_true, pangolin_indices).astype(int)
        
        ensemble_outputs = speciesnet_ensembler.combine(
            filepaths=filepaths,
            classifier_results=classifier_results,
            detector_results=detector_results,
            geolocation_results=geolocation_results,
            partial_predictions={},
        )

        ensemble_labels = []
        ensemble_scores = []

        for o in ensemble_outputs:
            pred_class = o["prediction"]
            class_scores = o["prediction_score"]
            
            if pred_class in pangolin_indices and o["prediction_score"] >= 0.5:
                ensemble_labels.append(1)  # pangolin
            else:
                ensemble_labels.append(0)  # other

            ensemble_scores.append(o["prediction_score"])


        ensemble_labels = [1 if "pangolin" in o["prediction"] else 0 for o in ensemble_outputs]
        ensemble_scores = [o["prediction_score"] for o in ensemble_outputs]

        ensemble_pred = torch.tensor(ensemble_labels)
        ensemble_prob = torch.tensor(ensemble_scores)

        train_print_and_log_scores(fold, epoch, binary_train_true, ensemble_pred, ensemble_prob)

        # outputting images for feature extraction
        # num_images = 10
        # random_indices = np.random.choice(len(feature_map_images), num_images, replace=False)

        # mean = [0.485, 0.456, 0.406]
        # std = [0.229, 0.224, 0.225]

        # label_1_indices = [i for i in range(len(image_ids)) if train_data[image_ids[i]][1] == 1]
        # label_0_indices = [i for i in range(len(image_ids)) if train_data[image_ids[i]][1] == 0]

        # sampled_indices = random.sample(label_1_indices, 3) + random.sample(label_0_indices, 7)

        # for img_idx in sampled_indices:
        #     file_name, label = train_data[image_ids[img_idx]]
        #     fig, axes = plt.subplots(1, len(layers_to_visualize) + 1, figsize=(30, 5))
        #     fig.suptitle(f"Label: {label}", fontsize=20)

        #     image = (feature_map_images[img_idx].cpu().numpy().transpose(1, 2, 0)) * std + mean
        #     image = np.clip(image, 0, 1)

        #     # image
        #     axes[0].imshow(image, cmap='gray')
        #     axes[0].set_title("original")
        #     axes[0].axis("off")

        #     for i, layer_name in enumerate(layers_to_visualize):
        #         fmap = feature_maps[layer_name][img_idx]
        #         fmap = fmap.mean(dim=0)
        #         fmap = (fmap - fmap.min()) / (fmap.max() - fmap.min())

        #         if fmap.ndimension() == 3:
        #             num_channels = fmap.shape[0]
        #             random_channel = np.random.randint(num_channels)
        #             fmap_to_show = fmap[random_channel].cpu().numpy()
        #         else:
        #             random_channel = 0
        #             fmap_to_show = fmap.cpu().numpy()

        #         fmap_to_show = (fmap_to_show - fmap_to_show.min()) / (fmap_to_show.max() - fmap_to_show.min() + 1e-5)
        #         axes[i + 1].imshow(fmap_to_show, cmap='viridis')

        #         title = '/'.join(layer_name.split('/')[-2:])

        #         axes[i + 1].set_title(f"{title}")
        #         axes[i + 1].axis("off")

        #     feature_map_image_path  = feature_map_folder_path + "feature_map_fold_" + str(fold + 1) + "/epoch_" + str(epoch) + "/"      # path to store the feature map images, name-specific to fold and epoch 
        #     feature_map_image =  feature_map_image_path + str(image_index) + ".png"

        #     image_index += 1

        #     os.makedirs(os.path.dirname(feature_map_image), exist_ok=True)
        #     plt.savefig(feature_map_image, bbox_inches='tight')
        #     plt.close()

        #############################################################################################################################################
        # validation
        #############################################################################################################################################
        print("VALIDATION")

        # val_preds_collector = []

        val_loss = 0.0
        val_num_loss = 0

        classifier_results = {}
        detector_results = {}
        geolocation_results = {}
        filepaths = []

        model.eval()

        # iterate through dataloader and run the model
        with torch.no_grad():
            for batch in tqdm(val_dataloader, total=len(val_dataloader)):
                batch["image"] = batch["image"].permute(0, 2, 3, 1)     
        
                images = batch['image'].to(device)
                labels = batch['label'].long().to(device)

                image_ids = batch['image_id']
                logits = model.forward(images)
                
                probs = torch.softmax(logits, dim=1)
                loss = criterion(logits, labels)

                val_loss += loss.item()
                val_num_loss += 1

                # storing information for ensembling method
                for i in range(len(probs)):
                    image_id = image_ids[i]
                    filepath, _ = val_data[image_id]  # get filepath from index

                    filepaths.append(filepath)

                    class_probs = probs[i].cpu().tolist()

                    classifier_results[filepath] = {
                        "classifications": {
                            "classes": class_names,
                            "scores": class_probs,
                        }
                    }

                    # detector result from batch
                    detector_result = batch["detections"][i] if "detections" in batch else []
                    detector_results[filepath] = {
                        "detections": detector_result
                    }

        # print scores
        val_loss /= val_num_loss
        print(f"\tvalidation loss: {val_loss:.4f}")

        val_true = [label for _, label in val_data]          # get the true labels
        binary_val_true = np.isin(val_true, pangolin_indices).astype(int)
        
        ensemble_outputs = speciesnet_ensembler.combine(
            filepaths=filepaths,
            classifier_results=classifier_results,
            detector_results=detector_results,
            geolocation_results=geolocation_results,
            partial_predictions={},
        )

        ensemble_labels = []
        ensemble_scores = []

        for o in ensemble_outputs:
            pred_class = o["prediction"]
            class_scores = o["prediction_score"]
            
            if pred_class in pangolin_indices and o["prediction_score"] >= 0.5:
                ensemble_labels.append(1)  # pangolin
            else:
                ensemble_labels.append(0)  # other

            ensemble_scores.append(o["prediction_score"])

        ensemble_labels = [1 if "pangolin" in o["prediction"] else 0 for o in ensemble_outputs]
        ensemble_scores = [o["prediction_score"] for o in ensemble_outputs]

        ensemble_pred = torch.tensor(ensemble_labels)
        ensemble_prob = torch.tensor(ensemble_scores)

        val_print_and_log_scores(fold, epoch, binary_val_true, ensemble_pred, ensemble_prob)

        # log train and val loss
        loss_writer.writerow([fold, epoch, training_loss, val_loss])

        # check if the current epoch is most optimal based on the current minimum validation loss
        if val_loss < min_val_loss:
            # save model checkpoint to folder
            model_path = os.path.join(model_folder_path, model_name)
            torch.save(model.state_dict(), model_path)

            # update min validation loss value
            min_val_loss = val_loss

        # get and log the elapsed time for current epoch
        epoch_time = time.time() - start_time
        time_writer.writerow([fold, epoch, epoch_time])

    #############################################################################################################################################
    # testing
    #############################################################################################################################################
    print("TESTING")

    speciesnet_model = SpeciesNet(speciesnet_model_path)
    base_model = speciesnet_model.classifier.model
    base_model.fc = nn.Sequential(
        nn.Linear(2048, 100),
        nn.ReLU(inplace=True),
        nn.Dropout(0.1),
        nn.Linear(100, 2498),
    )

    base_model.load_state_dict(torch.load(model_path))
    model = base_model.to(device)

    # test_preds_collector = []

    classifier_results = {}
    detector_results = {}
    geolocation_results = {}
    filepaths = []

    model.eval()

    # iterate through dataloader and run the model
    with torch.no_grad():
        for batch in tqdm(test_dataloader, total=len(test_dataloader)):        
            batch["image"] = batch["image"].permute(0, 2, 3, 1)      

            images = batch['image'].to(device)
            labels = batch['label'].long().to(device)    
            image_ids = batch['image_id']

            logits = model.forward(images)
            probs = torch.softmax(logits, dim=1)

            # storing information for ensembling method
            for i in range(len(probs)):
                image_id = image_ids[i]
                filepath, _ = test_data[image_id]  # get filepath from index

                filepaths.append(filepath)

                class_probs = probs[i].cpu().tolist()

                classifier_results[filepath] = {
                    "classifications": {
                        "classes": class_names,
                        "scores": class_probs,
                    }
                }

                # detector result from batch
                detector_result = batch["detections"][i] if "detections" in batch else []
                detector_results[filepath] = {
                    "detections": detector_result
                }

    # print and log scores
    test_true = [label for _, label in test_data]              # get the true labels
    binary_test_true = np.isin(test_true, pangolin_indices).astype(int)
    
    ensemble_outputs = speciesnet_ensembler.combine(
        filepaths=filepaths,
        classifier_results=classifier_results,
        detector_results=detector_results,
        geolocation_results=geolocation_results,
        partial_predictions={},
    )

    ensemble_labels = []
    ensemble_scores = []

    for o in ensemble_outputs:
        pred_class = o["prediction"]
        class_scores = o["prediction_score"]
        
        if pred_class in pangolin_indices and o["prediction_score"] >= 0.5:
            ensemble_labels.append(1)  # pangolin
        else:
            ensemble_labels.append(0)  # other

        ensemble_scores.append(o["prediction_score"])

    ensemble_labels = [1 if "pangolin" in o["prediction"] else 0 for o in ensemble_outputs]
    ensemble_scores = [o["prediction_score"] for o in ensemble_outputs]

    ensemble_pred = torch.tensor(ensemble_labels)
    ensemble_prob = torch.tensor(ensemble_scores)

    test_print_and_log_scores(fold, binary_test_true, ensemble_pred, ensemble_prob)

    # error analysis on misclassified images
    plot_misclassified(fold, ensemble_pred, test_data)

    # display confusion matrix
    fig, ax = plt.subplots(figsize=(10, 10))
    plt.title("confusion matrix - fold " + str(fold + 1))
    cm = ConfusionMatrixDisplay.from_predictions(
        test_true,
        ensemble_pred,
        ax=ax,
        xticks_rotation=90,
        colorbar=True,
        normalize='true'
    )

    cm_image_path  = confusion_matrix_folder_path      # path to store the confusion matrix, name-specific to fold 
    cm_image = cm_image_path + "cm_fold_" + str(fold + 1) + ".png"

    os.makedirs(os.path.dirname(cm_image), exist_ok=True)
    plt.savefig(cm_image, bbox_inches='tight')
    plt.close()


    # flush writes to logging files
    train_csv.flush()
    time_csv.flush()
    test_csv.flush()
    loss_csv.flush()


In [None]:
train_csv.flush()
val_csv.flush()
time_csv.flush()
test_csv.flush()
loss_csv.flush()

train_csv.close()
val_csv.close()
time_csv.close()
test_csv.close()
loss_csv.close()