### Import Required Modules and Functions

In [34]:
import numpy as np

import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as T
from torchvision.transforms import ToPILImage

from torch.utils.data import Dataset, DataLoader
from PIL import Image

import os
import json

import random
import matplotlib.pyplot as plt

### Set Device to GPU

In [13]:
USE_GPU = True
dtype = torch.float32 

if USE_GPU and torch.cuda.is_available(): 
    device = torch.device('cuda')
else:
    device = torch.device('cpu')


### Prepare Data Loaders
##### Ensure That WildCam_3classes is in the correct location
##### Run Brightness_subset_maker.ipynb to create "brightest" image folder

In [16]:
class WildCamDataset(Dataset):
    def __init__(self, img_paths, annotations, transform=T.ToTensor(), directory='WildCam_3classes/train'):
        self.img_paths = img_paths
        self.annotations = annotations
        self.transform = transform
        self.dir = directory

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, index):
        ID = '{}/{}'.format(self.dir, self.img_paths[index])
        img = Image.open(ID).convert('RGB')
        X = self.transform(img)             
        y = self.annotations['labels'][self.img_paths[index]]
        loc = self.annotations['locations'][self.img_paths[index]]
        return X, y, loc
    
normalize = T.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
transform = T.Compose([
            T.Resize((112,112)),
            T.ToTensor(),
            normalize
])

param_train = {
    'batch_size': 256,       
    'shuffle': True
    }

param_valtest = {
    'batch_size': 256,
    'shuffle': False
    }

annotations = json.load(open('WildCam_3classes/annotations.json'))

train_images = sorted(os.listdir('WildCam_3classes/train'))
train_dset = WildCamDataset(train_images, annotations, transform, directory='WildCam_3classes/train')
train_loader = DataLoader(train_dset, **param_train)

val_images = sorted(os.listdir('WildCam_3classes/val'))
val_dset = WildCamDataset(val_images, annotations, transform, directory="WildCam_3classes/val")
val_loader = DataLoader(val_dset, **param_valtest)

test_images = sorted(os.listdir('WildCam_3classes/test'))
test_dset = WildCamDataset(test_images, annotations, transform, directory="WildCam_3classes/test")
test_loader = DataLoader(test_dset, **param_valtest)

brightest_labels = json.load(open('WildCam_3classes/brightest_labels.json'))

bright_images = sorted(os.listdir('WildCam_3classes/brightest'))
bright_dset = WildCamDataset(bright_images, brightest_labels, transform, directory="WildCam_3classes/brightest")
bright_loader = DataLoader(bright_dset, **param_valtest)

### Write out functions for ResNet+ and BrightNet in order to load the models

In [17]:
#ResNet+ model
# Hyperparameters
channel_1 = 64
channel_2 = 128
channel_3 = 256
hidden_layer_1 = 256
hidden_layer_2 = 128
learning_rate = 1e-3
epochs = 5
dropout_rate = 0.4

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1, downsample=None, use_se=False):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.downsample = downsample
        self.use_se = use_se
        if self.use_se:
            self.se_block = SEBlock(out_channels)

    def forward(self, x):
        identity = x
        if self.downsample is not None:
            identity = self.downsample(x)

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        out = self.conv2(out)
        out = self.bn2(out)
        if self.use_se:
            out = self.se_block(out)
        out += identity
        out = self.relu(out)
        return out

class SEBlock(nn.Module):
    def __init__(self, channels, reduction=16):
        super(SEBlock, self).__init__()
        self.fc1 = nn.Linear(channels, channels // reduction)
        self.fc2 = nn.Linear(channels // reduction, channels)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        batch, channels, _, _ = x.size()
        y = x.mean((2, 3))  
        y = self.fc1(y)
        y = self.relu(y)
        y = self.fc2(y)
        y = self.sigmoid(y).view(batch, channels, 1, 1)
        return x * y

class BrightResNet18(nn.Module):
    def __init__(self, num_classes=3):
        super(BrightResNet18, self).__init__()
        self.in_channels = 64  
        self.conv1 = nn.Conv2d(3, 64, kernel_size=9, stride=2, padding=4, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(BasicBlock, 64, 2)
        self.layer2 = self._make_layer(BasicBlock, 128, 2, stride=2)
        self.layer3 = self._make_layer(BasicBlock, 256, 2, stride=2, use_se=True)
        self.layer4 = self._make_layer(BasicBlock, 512, 2, stride=2, use_se=True)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * BasicBlock.expansion, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride=1, use_se=False):
        downsample = None
        if stride != 1 or self.in_channels != out_channels * block.expansion:
            downsample = nn.Sequential(
                nn.Conv2d(self.in_channels, out_channels * block.expansion, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels * block.expansion),
            )
        layers = []
        layers.append(block(self.in_channels, out_channels, stride, downsample, use_se=use_se))
        self.in_channels = out_channels * block.expansion
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels, use_se=use_se))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x


model = BrightResNet18(3)  

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [18]:
#BrightNet model
# Hyperparameters
channel_1 = 64
channel_2 = 128
channel_3 = 256
hidden_layer_1 = 256
hidden_layer_2 = 128
learning_rate = 1e-3
epochs = 5
dropout_rate = 0.4

class BrightFeatureBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(BrightFeatureBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.spatial_attention = nn.Sequential(
            nn.Conv2d(out_channels, 1, kernel_size=7, stride=1, padding=3, bias=False),
            nn.Sigmoid()
        )
        
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels),
            )

    def forward(self, x):
        identity = self.shortcut(x)
        
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)
        
        out = self.conv2(out)
        out = self.bn2(out)
        
        attention = self.spatial_attention(out)
        out = out * attention 
        
        out += identity
        out = self.relu(out)
        return out

class BrightResNet(nn.Module):
    def __init__(self, num_classes=3):
        super(BrightResNet, self).__init__()
        self.in_channels = 64  
        
        self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size=9, stride=2, padding=4, bias=False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)

        self.layer1 = self._make_layer(BrightFeatureBlock, 64, 2)
        self.layer2 = self._make_layer(BrightFeatureBlock, 128, 2, stride=2)
        self.layer3 = self._make_layer(BrightFeatureBlock, 256, 2, stride=2)
        self.layer4 = self._make_layer(BrightFeatureBlock, 512, 2, stride=2)

        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def _make_layer(self, block, out_channels, blocks, stride=1):
        layers = []
        layers.append(block(self.in_channels, out_channels, stride))
        self.in_channels = out_channels
        for _ in range(1, blocks):
            layers.append(block(self.in_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = torch.flatten(x, 1)
        x = self.fc(x)
        return x

model = BrightResNet(num_classes=3)

optimizer = optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [None]:
resnet18_loaded = torch.load('resnet18_trained.pth')
resnet18_loaded = resnet18_loaded.to(device)

resnetplus_loaded = torch.load('resnet_plus.pth')
resnetplus_loaded = resnetplus_loaded.to(device)

bright_loaded = torch.load('bright3.0.pth')
bright_loaded = bright_loaded.to(device)

In [None]:
# Set models to evaluation mode
resnet18_loaded.eval()
resnetplus_loaded.eval()
bright_loaded.eval()

#### Grab visualizations

In [27]:
def visualize_pred(models, loader, device, class_labels, num_images_to_show=5):
    for model in models:
        model.eval()

    y_true = []
    y_pred_all_models = [[] for _ in models]  
    valid_images = []

    fig, axes = plt.subplots(3, num_images_to_show, figsize=(20, 15))

    with torch.no_grad():
        for x, y, _ in loader:
            x, y = x.to(device), y.to(device)

            predictions_all_models = []
            confidences_all_models = []

            for model in models:
                scores = model(x)
                _, preds = scores.max(1)
                predictions_all_models.append(preds.cpu().numpy())
                confidences_all_models.append(torch.softmax(scores, dim=1).cpu().numpy())

            y_true.extend(y.cpu().numpy())  

            for i in range(x.size(0)):
                img = x[i].cpu().numpy().transpose(1, 2, 0) 
                img = np.clip(img, 0, 1)  
                true_label = class_labels[y[i].item()]

                pred_labels = [class_labels[pred[i]] for pred in predictions_all_models]
                confidences = [conf[i][class_labels.index(pred_labels[j])] for j, conf in enumerate(confidences_all_models)]

                correct_preds = [y[i].item() == pred[i] for pred in predictions_all_models]
                correct_count = sum(correct_preds)

                if correct_count <= 2:  
                    valid_images.append((img, true_label, pred_labels, confidences, correct_preds))

            if len(valid_images) >= num_images_to_show:
                break

    selected_images = random.sample(valid_images, min(num_images_to_show, len(valid_images)))

    for i, (img, true_label, pred_labels, confidences, correct_preds) in enumerate(selected_images):
        for model_idx in range(3):
            ax = axes[model_idx, i]  
            ax.imshow(img)
            ax.set_title(f"True: {true_label}\nPred: {pred_labels[model_idx]} ({confidences[model_idx]:.2f})\nCorrect: {correct_preds[model_idx]}")
            ax.axis("off")

    plt.tight_layout()
    plt.show()

    return y_true, y_pred_all_models

In [None]:
#Prints from top row to bottom row
    #ResNet-18
    #ResNet+
    #BrightNet

class_labels = ['Rabbit', 'Bobcat', 'Cat']  

print("Visualizing Test Set Predictions")
visualize_pred([resnet18_loaded, resnetplus_loaded, bright_loaded], test_loader, device, class_labels, 5)

print("Visualizing Bright Set Predictions")
visualize_pred([resnet18_loaded, resnetplus_loaded, bright_loaded], bright_loader, device, class_labels, 5)

In [36]:
def visualize_fm(models, loader, device, model_names, num_images):
    for model, name in zip(models, model_names):
        model.eval()  
        y_true = []
        y_pred = []

        feature_maps = {}

        def hook_fn(module, input, output):
            feature_maps['layer1'] = output 

        hook = model.layer1[0].register_forward_hook(hook_fn)

        to_pil = ToPILImage()

        with torch.no_grad():
            for i, (x, y, _) in enumerate(loader): 
                x, y = x.to(device), y.to(device)

                scores = model(x)
                _, preds = scores.max(1)

                y_true.extend(y.cpu().numpy())  
                y_pred.extend(preds.cpu().numpy()) 

                if i < num_images:  
                    if 'layer1' in feature_maps:
                        feature_map = feature_maps['layer1']
                        original_image = to_pil(x[0])  
                        display_feature_maps(original_image, feature_map, name)

        hook.remove()
    return y_true, y_pred

def display_feature_maps(original_image, feature_map, model_name):
    num_feature_maps = feature_map.size(1)
    num_to_display = min(5, num_feature_maps) 

    fig, axes = plt.subplots(1, num_to_display + 1, figsize=(15, 5)) 
    axes[0].imshow(original_image)
    axes[0].axis('off')
    axes[0].set_title(f'Original Image\nModel: {model_name}')

    for i in range(num_to_display):
        ax = axes[i + 1]
        ax.imshow(feature_map[0, i].detach().cpu().numpy(), cmap='viridis')
        ax.axis('off')
        ax.set_title(f'Feature Map {i+1}')
    
    plt.tight_layout()
    plt.show()

In [None]:
print("Visualizing Test Set Feature Maps")
visualize_fm([resnetplus_loaded, bright_loaded], test_loader, device, ['ResNet+', 'BrightNet'], 5)

print("Visualizing Bright Set Feature Maps")
visualize_fm([resnetplus_loaded, bright_loaded], bright_loader, device, ['ResNet+', 'BrightNet'], 5)