In [None]:
import os
import numpy as np
from tqdm.autonotebook import tqdm
from torch.utils.data import DataLoader
import torch
import matplotlib.pyplot as plt

from model_gray import *
import gray_dataClass

np.random.seed(100)
torch.manual_seed(100)

In [None]:
# load data
batch_size = 1
testdata = gray_dataClass.create_dataset(datadir='path/to/gray_images')
all_dl = DataLoader(testdata, batch_size=batch_size, shuffle=True)
progress = tqdm(enumerate(all_dl), total=len(all_dl))

# load model
model.load_state_dict(torch.load('grayclassification.model', map_location=torch.device('cpu')))
model.eval()

# implant hooks for resnet layers
activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

model.relu.register_forward_hook(get_activation('conv1'))
model.layer1.register_forward_hook(get_activation('layer1'))
model.layer2.register_forward_hook(get_activation('layer2'))
model.layer3.register_forward_hook(get_activation('layer3'))
model.layer4.register_forward_hook(get_activation('layer4'))

# List to store the images
images = []
true = 0
false = 0

for i, batch in progress:
    x, y = batch['img'].float().to(device), batch['lbl'].float().to(device)

    output = model(x)
    prediction = 1 if output[0] > 0 else 0

    if prediction == 1 and y[0] == 1:
        res = 'true_pos'
        true += 1
    elif prediction == 0 and y[0] == 0:
        res = 'true_neg'
        true += 1
    elif prediction == 0 and y[0] == 1:
        res = 'false_neg'
        false += 1
    elif prediction == 1 and y[0] == 0:
        res = 'false_pos'
        false += 1

    if batch_size == 1:
        # Create plot
        f, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(1, 3))

        # Grayscale plot
        ax1.imshow(x[0][0], cmap='gray', origin='upper')
        ax1.set_title({'true_pos': 'True Positive', 'true_neg': 'True Negative',
                       'false_pos': 'False Positive', 'false_neg': 'False Negative'}[res], fontsize=8)
        ax1.set_ylabel('Grayscale', fontsize=8)
        ax1.set_xticks([])
        ax1.set_yticks([])

        # False color plot
        ax2.imshow(0.2 + (np.dstack([x[0][0], x[0][0], x[0][0]]) -
                         np.min([x[0][0].numpy()] )) /
                   (np.max([x[0][0].numpy()]) -
                    np.min([x[0][0].numpy()])),
                   cmap='hot', origin='upper')
        ax2.set_ylabel('False Color', fontsize=8)
        ax2.set_xticks([])
        ax2.set_yticks([])

        # Layer2 activations plot
        map_layer2 = ax3.imshow(activation['layer2'].sum(axis=(0, 1)),
                                vmin=50, vmax=150)
        ax3.set_ylabel('Layer2', fontsize=8)
        ax3.set_xticks([])
        ax3.set_yticks([])

        f.subplots_adjust(0.05, 0.02, 0.95, 0.9, 0.05, 0.05)

        # Construct the filename
        filename = os.path.split(batch['imgfile'][0])[1].replace('.png', '_eval.png').replace(':', '_')

        # Set the complete save path including the filename
        complete_path = os.path.join(save_path, filename)

        # Save the image at the specified path
        plt.savefig(complete_path, dpi=200)

        # Append the image to the list
        images.append(f)

        # Display the image
        plt.close()

# Display all the images together
for image in images:
    image.show()
