### Import Required Modules and Functions

In [1]:
import numpy as np

import torch
import torchvision
import torchvision.transforms as T

from torch.utils.data import Dataset, DataLoader
from PIL import Image

import os
import json

import matplotlib.pyplot as plt

### Set Device to GPU

In [2]:
USE_GPU = True
dtype = torch.float32 

if USE_GPU and torch.cuda.is_available(): 
    device = torch.device('cuda')
else:
    device = torch.device('cpu')


### Prepare Data Loaders
##### Ensure That WildCam_3classes is in the correct location
##### Run Brightness_subset_maker.ipynb to create "brightest" image folder

In [4]:
class WildCamDataset(Dataset):
    def __init__(self, img_paths, annotations, transform=T.ToTensor(), directory='WildCam_3classes/train'):
        self.img_paths = img_paths
        self.annotations = annotations
        self.transform = transform
        self.dir = directory

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):
        ID = '{}/{}'.format(self.dir, self.img_paths[index])
        img = Image.open(ID).convert('RGB')
        X = self.transform(img)             
        y = self.annotations['labels'][self.img_paths[index]]
        loc = self.annotations['locations'][self.img_paths[index]]
        return X, y, loc
    
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
transform = T.Compose([
            T.Resize((112,112)),
            T.ToTensor(),
            normalize
])

param_train = {
    'batch_size': 256,       
    'shuffle': True
    }

param_valtest = {
    'batch_size': 256,
    'shuffle': False
    }

annotations = json.load(open('WildCam_3classes/annotations.json'))

train_images = sorted(os.listdir('WildCam_3classes/train'))
train_dset = WildCamDataset(train_images, annotations, transform, directory='WildCam_3classes/train/')
train_loader = DataLoader(train_dset, **param_train)

val_images = sorted(os.listdir('WildCam_3classes/val'))
val_dset = WildCamDataset(val_images, annotations, transform, directory="WildCam_3classes/val/")
val_loader = DataLoader(val_dset, **param_valtest)

test_images = sorted(os.listdir('WildCam_3classes/test'))
test_dset = WildCamDataset(test_images, annotations, transform, directory="WildCam_3classes/test/")
test_loader = DataLoader(test_dset, **param_valtest)

### Grab the Brightest Images and Save Them

In [3]:
def get_brightest_images(loader, percentage=20, num_to_display=5):
    all_brightness = []
    all_images = []
    all_labels = []
    all_locations = []

    for images, labels, locations in loader:
        images = images.numpy() 
        brightness = images.mean(axis=(1, 2, 3))  
        all_brightness.extend(brightness)
        all_images.extend(images)
        all_labels.extend(labels.numpy())
        all_locations.extend(locations.numpy())

    all_brightness = np.array(all_brightness)
    all_images = np.array(all_images)
    all_labels = np.array(all_labels)
    all_locations = np.array(all_locations)

    brightest_images = []
    brightest_labels = []
    brightest_locations = []

    unique_classes = np.unique(all_labels)

    for cls in unique_classes:
        class_indices = np.where(all_labels == cls)[0]

        class_brightness = all_brightness[class_indices]
        class_images = all_images[class_indices]
        class_locations = all_locations[class_indices]
        class_labels = all_labels[class_indices]

        threshold = np.percentile(class_brightness, 100 - percentage)

        class_selected_indices = np.where(class_brightness >= threshold)[0]
        class_selected_images = class_images[class_selected_indices]
        class_selected_labels = class_labels[class_selected_indices]
        class_selected_locations = class_locations[class_selected_indices]

        brightest_images.extend(class_selected_images)
        brightest_labels.extend(class_selected_labels)
        brightest_locations.extend(class_selected_locations)

    print(f"Displaying {min(num_to_display, len(brightest_images))} images per class from the top {percentage}% brightest images:")
    for i in range(min(num_to_display * len(unique_classes), len(brightest_images))):
        plt.figure(figsize=(3, 3))
        image = brightest_images[i].transpose(1, 2, 0) 
        plt.imshow(np.clip(image, 0, 1)) 
        plt.title(f"Label: {brightest_labels[i]} Location: {brightest_locations[i]}")
        plt.axis('off')
        plt.show()

    return brightest_images, brightest_labels, brightest_locations

In [None]:
brightest_images, brightest_labels, brightest_locations = get_brightest_images(test_loader, 20, 5)

In [4]:
def save_brightest_images_and_labels(images, labels, locations, output_dir="WildCam_3classes/brightest"):
    os.makedirs(output_dir, exist_ok=True)

    labels_dict = {}
    locations_dict = {}

    for i, (image, label, location) in enumerate(zip(images, labels, locations)):
        image = np.clip(image.transpose(1, 2, 0), 0, 1) * 255  
        image = Image.fromarray(image.astype(np.uint8))  

        image_filename = f"image_{i}.png"
        image_path = os.path.join(output_dir, image_filename)

        image.save(image_path)

        labels_dict[image_filename] = int(label)
        locations_dict[image_filename] = int(location)

    labels_json_path = os.path.join("WildCam_3classes", "brightest_labels.json")
    with open(labels_json_path, "w") as json_file:
        json.dump(labels_dict, json_file, indent=4)
        json.dump(locations_dict, json_file, indent=4)

    print(f"Saved {len(images)} images and labels to {output_dir}.")

In [None]:
save_brightest_images_and_labels(brightest_images, brightest_labels, brightest_locations, "WildCam_3classes/brightest")