# LRP on hornet classifier model

Identify local directories and imports.

In [None]:
import os
import sys
import glob
import matplotlib.pyplot as plt
import torch
import numpy as np
import torchvision
from skimage import io
from skimage import transform

# Experiment name
exp_name = 'resnet18-lrp'

# Root directory
root_dir = '/Users/Holmes/Research/Projects/vespai'
sys.path.insert(0, root_dir)
data_dir = os.path.join(root_dir, 'datasets/extracts-21')
weights_dir = os.path.join(
    root_dir, 'models/classifier-runs/' + exp_name + '/weights',
)

# Add local lrp-resnet to path and import LRP package
lrp_dir = os.path.join(root_dir, 'explanation/lrp-resnet')
resnet_dir = os.path.join(root_dir, 'explanation/lrp-resnet/notebooks')
os.chdir(lrp_dir)
from LRP import LRP
os.chdir(resnet_dir)
from resnet import resnet18

# Automatically reload imported programmes
%load_ext autoreload
%autoreload 2

## ResNet and LRP models


In [None]:
# Load binary ResNet classifier
model = resnet18(num_classes=2)

# Import pretrained weights
model.load_state_dict(torch.load(os.path.join(weights_dir, 'best.pt')))
model = model.eval()

# Instantiate LRP object
lrp_model = LRP(model, 'z_rule')

## Test on data

In [None]:
os.chdir(root_dir)
os.chdir(os.path.join(root_dir, 'models'))
from classifier.loader import get_hornet_loader
from classifier.utils import show_batch


# Get data loader
test_files = glob.glob(os.path.join(data_dir, 'test/*.jpeg'))
test_loader = get_hornet_loader(test_files, batch_size=1, augment=False)
test_list = list(test_loader)

# Print samples of batches
for idx, sample_batch in enumerate(test_loader):
    print('Batch number ', idx, ', batch size: ', sample_batch[0].size())

    # Observe 4th batch and stop.
    if idx == 3:
        plt.figure(figsize=(12, 12))
        show_batch(sample_batch)
        plt.axis('off')
        plt.ioff()
        # plt.savefig(os.path.join(fig_dir, 'augs.png'))
        plt.show()
        break


### Inference

In [None]:
data = test_list[0]
image = data[0]
mpl_image = image.squeeze().permute(1, 2, 0)
plt.imshow(mpl_image)
plt.show()

# Run through model
print('Image shape: ', image.shape)
image_output = lrp_model.forward(image)

print('Prediction: ', torch.softmax(image_output, dim=-1).squeeze().detach())
print('Target: ', data[1])

### Relevance

In [None]:
image_lrp = lrp_model(image)
plt.imshow(image_lrp.squeeze().permute(1, 2, 0)[:,:,0])

In [None]:
image_lrp.squeeze().permute(1, 2, 0)[:,:,0]

## Random predictions

In [None]:
random_image = torch.rand(1, 3, 256, 256)
random_output = model(random_image)

plt.imshow(random_image.squeeze().permute(1, 2, 0))
plt.show()
print('Random predictions: ', torch.softmax(random_output, dim=-1).squeeze().detach())


## Relevance

In [None]:
random_lrp = lrp_model.relprop(random_image)
plt.imshow(random_lrp.squeeze().permute(1,2,0)[:, :, 0])

In [None]:
random_lrp.squeeze().permute(1,2,0)[:, :, 0]