In [1]:
from misc import sample_images_with_labels, LoadLungImages, get_probabilities
from train import CNNClassification, ShallowResNetClassification, ShallowMobileNetClassification
from Dataloaders import CreateDataLoaders
from sklearn.metrics import roc_curve, auc
import matplotlib.pyplot as plt
from visualisation import visualize_attributions_panel, visualize_attributions_layer_panel
import numpy as np
import os

In [3]:
# Load data - please change the path to your local path
path = "/Users/antanas/GitRepo/PnuPred/PnuPred/PnuData/chest_xray/"

X_train, y_train = LoadLungImages(os.path.join(path, "train"))
X_test, y_test = LoadLungImages(os.path.join(path, "test"))
X_val, y_val = LoadLungImages(os.path.join(path, "val"))

X_train = X_train + X_val
y_train = y_train + y_val

In [4]:
# Train CNN and ShallowResNet models
cnn_model, _, _ = CNNClassification(X_train, y_train, X_test, y_test, num_epochs=10, learning_rate=0.001, batch_size=32, preprocess=False)
resnet_model, _, _ = ShallowResNetClassification(X_train, y_train, X_test, y_test, num_epochs=10, learning_rate=0.001, batch_size=32, preprocess=False)
mobilenet_model, _, _ = ShallowMobileNetClassification(X_train, y_train, X_test, y_test, num_epochs=10, learning_rate=0.001, batch_size=32, preprocess=False)

In [None]:
# Create test dataloader
_, test_loader = CreateDataLoaders(X_train, y_train, X_test, y_test, batch_size=50, preprocess=True)

# Get probabilities for both models
cnn_probs = get_probabilities(cnn_model, test_loader)
resnet_probs = get_probabilities(resnet_model, test_loader)
mobilenet_probs = get_probabilities(mobilenet_model, test_loader)

# Calculate ROC curve for both models
cnn_fpr, cnn_tpr, _ = roc_curve(y_test, cnn_probs)
cnn_auc = auc(cnn_fpr, cnn_tpr)
resnet_fpr, resnet_tpr, _ = roc_curve(y_test, resnet_probs)
resnet_auc = auc(resnet_fpr, resnet_tpr)
mobilenet_fpr, mobilenet_tpr, _ = roc_curve(y_test, mobilenet_probs)
mobilenet_auc = auc(mobilenet_fpr, mobilenet_tpr)


# Plot the ROC curve
plt.figure()
plt.plot(cnn_fpr, cnn_tpr, label='CNN (area = {:.2f})'.format(cnn_auc))
plt.plot(resnet_fpr, resnet_tpr, label='Shallow ResNet (area = {:.2f})'.format(resnet_auc))
plt.plot(mobilenet_fpr, mobilenet_tpr, label='Shallow MobileNet (area = {:.2f})'.format(mobilenet_auc))
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')
plt.legend(loc="lower right")
plt.show()

In [None]:
pos_img, pos_label = sample_images_with_labels(X_test, y_test, label = 1, num_samples = 5)
neg_img, neg_label = sample_images_with_labels(X_test, y_test, label = 0, num_samples = 5)
images = pos_img + neg_img  # define your image data here
labels = pos_label + neg_label  # define your labels here

### Integrated-Gradients

In [None]:
figure, axis = visualize_attributions_panel(mobilenet_model, images, labels, sign = 'all')

In [None]:
figure, axis = visualize_attributions_panel(cnn_model, images, labels, sign = 'all')

In [None]:
figure, axis = visualize_attributions_panel(resnet_model, images, labels, sign = 'all')

### Grad-CAM

In [None]:
fig, axs = visualize_attributions_layer_panel(mobilenet_model, images, labels, ['conv_dw1', 'conv_dw2'])

In [None]:
fig, axs = visualize_attributions_layer_panel(cnn_model, images, labels, ['conv1', 'conv2'])

In [None]:
fig, axs = visualize_attributions_layer_panel(resnet_model, images, labels, ['layer1', 'layer2'])

### Sanity Check

In [None]:
sample = np.random.choice(y_train, size=len(y_train), replace=True)
model, loss_list, acc_list = ShallowResNetClassification(X_train, sample, X_test, y_test, num_epochs = 1, learning_rate = 0.001, batch_size = 10, preprocess = False)