In [None]:
import os
import numpy as np
from tqdm.autonotebook import tqdm
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt

import sys
sys.path.append('/path/to/classification')
from model_classification import *
from data_classification import create_dataset

np.random.seed(100)
torch.manual_seed(100)

In [None]:
import warnings

# Suppress warnings
warnings.filterwarnings("ignore", category=UserWarning, message=".*Clipping.*")

# Set the custom user-defined save path
save_path = '/path/to/classification/plotImageFin/'


# load data
batch_size = 1 # 1 to create diagnostic images, any value otherwise
testdata = create_dataset(datadir='/path/to/val')
all_dl = DataLoader(testdata, batch_size=batch_size, shuffle=True)
progress = tqdm(enumerate(all_dl), total=len(all_dl))

# load model
model.load_state_dict(torch.load('/path/to/classification/classification.model', map_location=torch.device('cpu')))
model.eval()

# implant hooks for resnet layers
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

model.relu.register_forward_hook(get_activation('conv1'))
model.layer1.register_forward_hook(get_activation('layer1'))
model.layer2.register_forward_hook(get_activation('layer2'))
model.layer3.register_forward_hook(get_activation('layer3'))
model.layer4.register_forward_hook(get_activation('layer4'))


# List to store the images
images = []
# run through test data set
true = 0
false = 0
for i, batch in progress:
    x, y = batch['img'].float().to(device), batch['lbl'].float().to(device)

    output = model(x)
    prediction = 1 if output[0] > 0 else 0

    if prediction == 1 and y[0] == 1:
        res = 'true_pos'
        true += 1
    elif prediction == 0 and y[0] == 0:
        res = 'true_neg'
        true += 1
    elif prediction == 0 and y[0] == 1:
        res = 'false_neg'
        false += 1
    elif prediction == 1 and y[0] == 0:
        res = 'false_pos'
        false += 1

    if batch_size == 1:
        # Create plot
        f, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(1, 3))

        rgb_img = 0.2 + 1.5 * (np.dstack([x[0][3], x[0][2], x[0][1]]) - np.min([x[0][3].numpy(),x[0][2].numpy(),x[0][1].numpy()])) / (np.max([x[0][3].numpy(),x[0][2].numpy(),x[0][1].numpy()]) -np.min([x[0][3].numpy(),x[0][2].numpy(),x[0][1].numpy()]))
        rgb_img = np.clip(rgb_img, 0.0, 1.0)  # Clip values outside [0, 1]
        ax1.imshow(rgb_img, origin='upper')
        ax1.set_title({'true_pos': 'True Positive','true_neg': 'True Negative','false_pos': 'False Positive','false_neg': 'False Negative'}[res],
                      fontsize=8)
        ax1.set_ylabel('RGB', fontsize=8)
        ax1.set_xticks([])
        ax1.set_yticks([])

        false_color_img = 0.2 + (np.dstack([x[0][0], x[0][9], x[0][10]]) - np.min([x[0][0].numpy(), x[0][9].numpy(), x[0][10].numpy()])) / (np.max([x[0][0].numpy(), x[0][9].numpy(), x[0][10].numpy()]) - np.min([x[0][0].numpy(), x[0][9].numpy(), x[0][10].numpy()]))
        false_color_img = np.clip(false_color_img, 0.0, 1.0)  # Clip values outside [0, 1]
        ax2.imshow(false_color_img, origin='upper')
        ax2.set_ylabel('False Color', fontsize=8)
        ax2.set_xticks([])
        ax2.set_yticks([])

        # Layer2 activations plot
        map_layer2 = ax3.imshow(activation['layer2'].sum(axis=(0, 1)),
                                vmin=50, vmax=150)
        ax3.set_ylabel('Layer2', fontsize=8)
        ax3.set_xticks([])
        ax3.set_yticks([])

        f.subplots_adjust(0.18, 0.02, 0.85, 0.9, 0.05, 0.05)

        # Construct the filename
        filename = os.path.split(batch['imgfile'][0])[1].replace('.tif', '_eval.png').replace(':', '_')

        # Set the complete save path including the filename
        complete_path = os.path.join(save_path, filename)

        # Save the image at the specified path
        plt.savefig(complete_path, dpi=200)

        # Append the image to the list
        images.append(f)

        # Display the image
        plt.close()

# Display all the images together
for image in images:
    image.show()
print('test set accuracy:', true / (true + false))