<a href="https://colab.research.google.com/github/adubowski/redi-xai/blob/main/classifier/compare_predictions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Load Libraries and Initialise Parameters

In [None]:
import torch
from torch.utils.data import TensorDataset
from torchvision import models

from PIL import Image
import numpy as np
from numpy.random import randint

from google.colab import drive
import os
from os.path import join as oj
from tqdm import tqdm
import gc

from sklearn.metrics import recall_score
from matplotlib import pyplot as plt
plt.style.use('seaborn')

##### Mount Google Drive and create & store various directory paths

In [None]:
drive.mount("/content/drive")
dir_path = "/content/drive/MyDrive/redi-detecting-cheating"

model_path = oj(dir_path, "models", "initial_classifier")
data_path = oj(dir_path, "data")

test_path             = oj(dir_path, "models", "test_files_224.txt")
test_path_patches     = oj(dir_path, "data", "test", "inpainted_patches")
test_path_no_patches  = oj(dir_path, "data", "test", "inpainted_no_patches")

##### Parameters for standardising the data 

In [None]:
mean = np.asarray([0.485, 0.456, 0.406]) 
std = np.asarray([0.229, 0.224, 0.225])

##### Function to plot individual image.

In [None]:
def plot_lesion(dataset, idx):
  """ 
  Input:    dataset   -   Tensor dataset of images. Contained in each element of the dataset is the image and label
            idx       -   The image index to plot.
  Returns:  None      -   Plots the original lesion image to screen.
  """
  plt.style.use('default')   # Reset the style to avoid plotting axes on the plot.

  # The axes have been swapped in the tensor dataset so that the colour channels are the first axis. Undo this with permute().
  # The image has been standardised, so multiply by the std and add to the mean to reverse this.
  orig_img = datalist[idx].permute(1, 2, 0).numpy()*std + mean
  plt.imshow(orig_img)

### Load Data
Load dataset from file to test effect of altering the patches.
The dataset has subsequently been saved as a tensor so it is quicker to read it in from this format. This section can be skipped.

##### Function to read the relevant files.

In [None]:
def extract_filenames(dataset_path):
  """ Extracts the paths, names and root directory for the image files in a given directory 
          or given a file containing image filepaths.
      Returns:
        filenames     filenames sorted alphabetically
        file_list     if dataset_path is a directory, then file_list is equivalent to filenames. 
                      If dataset_path is a file containing filepaths, then file_list is a list of these paths. 
        root_dir      If dataset_path is a directory, then root_dir=dataset_path, otherwise root_dir=''.
  """
  if os.path.isfile(dataset_path):
    file_list = open(dataset_path, 'rt').read().splitlines()
    filenames = [os.path.basename(file) for file in file_list] # Extract the filename from the full filepath.
    root_dir = ''
  elif os.path.isdir(dataset_path):
    file_list = os.listdir(dataset_path)
    filenames = file_list
    root_dir = dataset_path
  else:
    print('Invalid testing data file/folder path.')
    exit(1)

  # Sort alphabetically based on the filename.
  zip_sorted = sorted(zip(filenames, file_list), key=lambda tup: tup[0])

  filenames, file_list = zip(*zip_sorted)   # Unzip the sorted results.

  return filenames, file_list, root_dir


def load_files(dataset_path, imsize = (224,224)):
  """ filelist should be either a text file containing the full paths of the relevant image files,
         or a path to the directory containing the images.
      Returns the images as a numpy array with float values between 0 and 1."""
  filenames, file_list, root_dir = extract_filenames(dataset_path)

  num_files = len(file_list)
  imgs_np = np.empty((num_files,  imsize[0], imsize[1], 3))
  for i in tqdm(range(num_files)): 
    try:
      img = Image.open(oj(root_dir, file_list[i]))
      imgs_np[i] = np.asarray(img)/255.0              # Transform to float between 0 and 1 from integer between 0-255

      img.close()
    except:
      print(i)
  return imgs_np, filenames


def get_dataset(dataset_path, save_path, imsize = (224,224)):
  if os.path.isfile(save_path):        # If the dataset has previously been saved as a tensor, load this for efficiency.
    dataset = torch.load(save_path)

    filenames, _,_ = extract_filenames(dataset_path)   # Get the associated image filenames.
  
  else:
    ims, filenames = load_files(dataset_path, imsize)   # Load in all of the images.

    ims -= mean[None, None, :]    # Standardise the images as expected by the VGG16 model.
    ims /= std[None, None, :]

    # Check if the image comes from the 'no_cancer' directory or the 'cancer' directory. Cancer images have target=1.
    targets = [0 if "no_cancer" in file else 1 for file in filenames]  
    targets = np.array(targets).astype(np.int8)

    # Create a tensor dataset with the images and targets.
    dataset = TensorDataset(torch.from_numpy(ims.swapaxes(1,3).swapaxes(2,2)).float(), torch.from_numpy(targets))

    torch.save(dataset, save_path)    # save for more efficient loading the next time.

  return dataset, filenames    # Return the TensorDataset and list of image filenames

##### Read the test datasets.
Or if they have been previously been saved as a tensor dataset then load these for efficiency.

In [None]:
test_dataset, test_files                    = get_dataset(test_path, oj(data_path, 'saved-tensors', 'test_dataset.pt'))

inpainted_patch_dataset, patch_files        = get_dataset(test_path_patches, oj(data_path, 'saved-tensors', 'inpainted_patch_dataset.pt'))
inpainted_no_patch_dataset, no_patch_files  = get_dataset(test_path_no_patches, oj(data_path, 'saved-tensors', 'inpainted_no_patch_dataset.pt'))

### Load Model
Take the most recent trained classifier. The weights of the classification section of the model have been saved, and the weights of the feature extraction part are taken from the original pretrained VGG16 model.

In [None]:
# Get a list of the models in the directory and their modified times 
model_list = [(f, os.path.getmtime(oj(model_path,f))) for f in os.listdir(model_path) if f.endswith('.pt')]
model_list.sort(key=lambda tup: tup[1], reverse=True)  # sorts in place from most to least recent.

model_name = model_list[0][0]                       # Take the most recent model.

model_dict = torch.load(oj(model_path, model_name)) # Read the paramater dict from file.

model = models.vgg16(pretrained=True)               # Read in the original VGG16 pretrained model.
model.classifier[-1] = torch.nn.Linear(4096, 2)           # Set the final classification layer to have only 2 output nodes.

model.classifier.load_state_dict(model_dict)        # Use the saved model parameters.

device = torch.device(0)
model = model.to(device)

Free up space (model file is approx. 0.5GB)

In [None]:
del model_dict, model_list
gc.collect()

##### Get and save predictions for the original images

In [None]:
from sklearn.metrics import auc, roc_auc_score, f1_score

def get_output(model, dataset):
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=16,
                                             shuffle=False, num_workers=4)
    model = model.eval()
    y = []
    y_hat = []
    softmax= torch.nn.Softmax()
    with torch.no_grad() :
        for inputs, labels in data_loader:
          y.append((labels).cpu().numpy())
          y_hat.append(torch.nn.Softmax(dim=1)( model(inputs.cuda()))[:,1].detach().cpu().numpy()) # take the probability for cancer
    y = np.concatenate( y, axis=0 )
    y_hat = np.concatenate( y_hat, axis=0 )
    return y, y_hat 

def get_auc_f1(model, dataset,fname = None, ):
    if fname !=None:
        with open(fname, 'rb') as f:
            weights = torch.load(f)
        if "classifier.0.weight" in weights.keys():
            model.load_state_dict(weights)
        else:
            model.classifier.load_state_dict(weights)
        y, y_hat = get_output(model.classifier, dataset)
    else:   
        y, y_hat = get_output(model, dataset)
    auc = roc_auc_score(y, y_hat)
    f1 = np.asarray([f1_score(y, y_hat > x) for x in np.linspace(0.1,1, num = 10) if (y_hat >x).any() and (y_hat<x).any()]).max()
    return auc, f1

In [None]:
preds_test_unaltered, targets_test_unaltered = get_output(model, test_dataset)
# np.savez(oj(data_path, 'saved-tensors', 'preds_test_unaltered.npz'), test_targets=targets_test_unaltered, test_preds=preds_test_unaltered)

preds_inpainted_patches, targets_inpainted_patches = get_output(model, inpainted_patch_dataset)
# np.savez(oj(data_path, 'saved-tensors', 'preds_inpainted_patches.npz'), test_targets=targets_inpainted_patches, test_preds=preds_inpainted_patches)

preds_inpainted_no_patches, targets_inpainted_no_patches = get_output(model, inpainted_no_patch_dataset)
# np.savez(oj(data_path, 'saved-tensors', 'preds_inpainted_no_patches.npz'), test_targets=targets_inpainted_no_patches, test_preds=preds_inpainted_no_patches)

#### Compare probabilities for original and altered images

In [None]:
def plot_compare_probs(probs_original, probs_altered, output_dir = oj(dir_path, 'plots'), output_add = None):
  """ Plots and saves three plots to compare the predicted probabilities before & after altering the images.
  Input: 
    probs_original, probs_altered     The output probabilities of the classification model for the original & altered image, as a Tensor, numpy array or list.
    output_dir                        The path to the directory for saving the plots.
    output_add                        An ID to add to the output filename. If left blank, then a random 10 digit ID is created.
  Returns:
    None    The three plots are saved to the relevant directory and also printed to screen.
  """
  if output_add is None:
    output_add = ''.join(["%s" % randint(0, 9) for num in range(0, 10)]) # Create a random ID to avoid overwriting previous files.

  ## Plot Histogram Comparison
  fig, ax = plt.subplots(2, 1, figsize = (10,8))

  ax[0].hist(probs_original)
  ax[0].set_title('Original')
  ax[0].set_xlabel('Predicted Probability')
  ax[0].set_ylabel('Number of Samples')
  ax[1].hist(probs_altered)
  ax[1].set_title('Patches Altered')
  ax[1].set_xlabel('Predicted Probability')
  ax[1].set_ylabel('Number of Samples')

  fig.suptitle('Predicted Probabilities Before & After Altering Patches', fontsize=16)
  fig.tight_layout(rect=[0, 0.03, 1, 0.95])

  fig.savefig(oj(output_dir, 'Probs Comparison Hist ' + output_add + '.png'))
  plt.show()

  ## Plot Scatterplot Comparison.
  fig, ax = plt.subplots(1, 1, figsize = (8,8))

  ax.scatter(probs_original, probs_altered, alpha=0.4)
  ax.set_xlabel('Original', fontsize=12)
  ax.set_ylabel('Altered',fontsize=12)
  ax.set_title('Predicted Probs for Original Images & After Altering Patches', fontsize=16)

  fig.savefig(oj(output_dir, 'Probs Comparison Scatter ' + output_add + '.png'))
  plt.show()

  ## Calculate differences and plot histogram
  diff_preds = probs_altered - probs_original

  fig, ax = plt.subplots(1, 1, figsize = (8,8))

  hist_range = (-max(abs(diff_preds)), max(abs(diff_preds)))  # Make sure histogram is symmetric about zero.

  ax.hist(diff_preds, range = hist_range, bins = 20)
  ax.set_title('Histogram of differences in predicted probs after altering patches', fontsize=16)
  ax.set_ylabel('Number of Samples')
  ax.set_xlabel('Difference in Predicted Probability')

  fig.savefig(oj(output_dir, 'Diff in Predicted Probs Hist ' + output_add + '.png'))
  plt.show()

##### Get results

In [None]:
patch_ind = [file in patch_files for file in test_filenames]   # Get a boolean list of whether the test file has a patch.

# Compare the outputs for the inpainted images vs. the originals.
plot_compare_probs(preds_test_unaltered[patch_ind], preds_inpainted_patches, output_add='(inpainted patches)')
plot_compare_probs(preds_test_unaltered[~patch_ind], preds_inpainted_no_patches, output_add='(inpainted no patches)')

In [None]:
print("Specificity for images with patches is: {:.2f}".format( \
    recall_score(preds_test_unaltered[patch_ind] > 0.5, targets_test_unaltered[patch_ind], pos_label=0))

print("Specificity for images without a patch is:\t {:.2f}".format( \
    recall_score(preds_test_unaltered[~patch_ind] > 0.5, targets_test_unaltered[~patch_ind], pos_label=0))