# Install the required packages
[![Open In Collab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Bergschaf/visualime_guide/blob/master/Get_Started.ipynb)


In [None]:
!pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
!pip3 install numpy
!pip3 install matplotlib

# Import the required packages

In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
from torchvision import datasets, transforms

# Define the transformations to prepare the data

In [None]:
transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5,), (0.5,)),
                                ])

1. ```transforms.ToTensor()``` converts the image to a tensor
2. ```transforms.Normalize((0.5,), (0.5,))``` normalizes the image

# Download the dataset

In [None]:
!mkdir -p data

testset = datasets.MNIST('data', download=True, train=False, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)

# Analyze the dataset

In [None]:
dataiter = iter(testloader)
images, labels = next(dataiter)
images, labels = next(dataiter)

print(images.shape)
print(labels.shape)

The batch size is 64 and the image size is 28x28 and the number of channels is 1 (grayscale)
The labels are the corresponding numbers for the images

# Download the model

In [None]:
!wget https://github.com/Bergschaf/visualime_guide/raw/master/models/mnist_model.pt
model = torch.load("mnist_model.pt")

# Test the model on a single image

In [None]:
img = images[0]
img = img.view(1, 784)
with torch.no_grad():
    logps = model(img)

ps = torch.exp(logps)
probab = list(ps.numpy()[0])
print("Predicted Digit =", probab.index(max(probab)))

plt.imshow(img.resize_(1, 28, 28).numpy().squeeze(), cmap='Greys_r')

# Explain the classification with visualime

## Install and import visuallime

In [None]:
!pip3 install visualime
from visualime.explain import explain_classification, render_explanation

## Define helper Functions

In [None]:
def to_visualime(image: np.ndarray):
    """
     Converts an image of the shape [1,28,28] to the shape [28,28,3]
    """
    image = image.squeeze()
    image = np.stack((image, image, image), axis=2)
    return image

```to_visualime()``` converts the image to the shape [28,28,3] which is required by visualime

In [None]:
def predict(imgs: np.ndarray):
    """
    :param image: visualime RGB image of the shape [num_samples, 28,28,3]
    :return:
    """
    imgs = imgs[:, :, :, 0]
    predictions = np.zeros((imgs.shape[0], 10))
    for i in range(imgs.shape[0]):
        image = imgs[i]
        # convert the image to a tensor
        image = torch.from_numpy(image)
        image = image.view(1, 784)
        with torch.no_grad():
            prediction = model(image)
        predictions[i] = torch.exp(prediction).numpy()[0]
    return predictions


```predict()``` takes an image of the shape [num_samples, 28,28,3] (an array of visualime images) and returns the predictions of the model
This is required to explain the classification with visualime

## Explain the classification

In [None]:
img = images[5] # Choose an image

In [None]:
print("The network predicts: ", np.argmax(predict(np.array([to_visualime(img)]))))

segment_mask, segment_weights = explain_classification(image=to_visualime(img), predict_fn=predict, num_of_samples=512)

explanation = render_explanation(
        to_visualime(img),
        segment_mask,
        segment_weights,
        positive="green",
        negative="red",
        coverage=0.5,
        opacity=1,
    )

plt.imshow(explanation)
