Téléchargement base de données + préprocessing

Pour Linux

In [None]:
!wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz > /dev/null 2>&1
!tar zxvf imagenette2.tgz > /dev/null 2>&1

Pour MacOS

In [None]:
# test if file exists
import os
if not os.path.exists('imagenette2.tgz'):
    !curl -O -# https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
if not os.path.exists('imagenette2'):
    !tar -zxvf imagenette2.tgz > /dev/null 2>&1

In [None]:
from google.colab import drive

drive.mount("/content/drive")

In [None]:
import torchvision
import torchvision.transforms as transforms
import torch
import os
from torch.utils.data import Dataset

means, stds = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
train_transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(means, stds),
    ]
)

test_transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(means, stds),
    ]
)


def get_imagenette2_loaders(root_path="./imagenette2", **kwargs):

    trainset = torchvision.datasets.ImageFolder(
        os.path.join(root_path, "train"), transform=train_transform
    )
    trainloader = torch.utils.data.DataLoader(trainset, **kwargs)
    testset = torchvision.datasets.ImageFolder(
        os.path.join(root_path, "val"), transform=test_transform
    )
    testloader = torch.utils.data.DataLoader(testset, **kwargs)
    return trainloader, testloader


trainloader, testloader = get_imagenette2_loaders(
    batch_size=64, shuffle=True, num_workers=2
)

labels = [
    "tench",
    "English springer",
    "cassette player",
    "chain saw",
    "church",
    "French horn",
    "garbage truck",
    "gas pump",
    "golf ball",
    "parachute",
]

In [None]:
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

inv_normalize = transforms.Normalize(
    mean=[-m / s for m, s in zip(means, stds)], std=[1 / s for s in stds]
)

x, y = next(iter(trainloader))
img_grid = make_grid(x[:16])
img_grid = inv_normalize(img_grid)
plt.figure(figsize=(20, 15))
plt.imshow(img_grid.permute(1, 2, 0))
plt.axis("off")

Modèle 1 : pr-trained VGG11

In [None]:
import torch.nn as nn
import torch

model_vgg11 = torchvision.models.vgg11(pretrained=True)
for param in model_vgg11.features:
    param.requires_grad = False

model_vgg11.classifier = nn.Sequential(
    nn.Linear(in_features=25088, out_features=4096, bias=True),
    nn.ReLU(inplace=True),
    nn.Linear(in_features=4096, out_features=4096, bias=True),
    nn.ReLU(inplace=True),
    nn.Linear(in_features=4096, out_features=10, bias=True),
)
if torch.cuda.is_available():
    model_vgg11 = model_vgg11.cuda()

Modèle 2 : Resnet

In [None]:
model_resnet18 = torchvision.models.resnet18(pretrained=True)
for param in model_resnet18.parameters():
    param.requires_grad = False

model_resnet18.fc = nn.Linear(model_resnet18.fc.in_features, 10)

if torch.cuda.is_available():
    model_resnet18 = model_resnet18.cuda()

Entraînement, test et évaluation

In [None]:
from tqdm import tqdm

criterion_classifier = nn.CrossEntropyLoss(reduction="mean")


def train(model, optimizer, trainloader, epochs=30):
    t = tqdm(range(epochs))
    for epoch in t:
        corrects = 0
        total = 0
        for x, y in trainloader:
            loss = 0
            if torch.cuda.is_available():
                x = x.cuda()
                y = y.cuda()
            y_hat = model(x)

            loss += criterion_classifier(y_hat, y)
            _, predicted = y_hat.max(1)
            corrects += predicted.eq(y).sum().item()
            total += y.size(0)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            t.set_description(
                f"epoch: {epoch}; current accuracy: {round(corrects / total * 100, 2)}%  "
            )
    return corrects / total

In [None]:
learning_rate = 5e-3
epochs = 1

In [None]:
# vgg11
optimizer = torch.optim.Adam(model_vgg11.classifier.parameters(), lr=learning_rate)
train(model_vgg11, optimizer, trainloader, epochs=epochs)

# resnet18
optimizer = torch.optim.Adam(model_resnet18.fc.parameters(), lr=learning_rate)
train(model_resnet18, optimizer, trainloader, epochs=epochs)

In [None]:
def test(model, dataloader):
    test_corrects = 0
    total = 0
    with torch.no_grad():
        for x, y in dataloader:
            if torch.cuda.is_available():
                x = x.cuda()
                y = y.cuda()
            y_hat = model(x).argmax(1)
            test_corrects += y_hat.eq(y).sum().item()
            total += y.size(0)
    return test_corrects / total

In [None]:
model_vgg11.eval()
test_acc = test(model_vgg11, testloader) * 100
print(f"Test accuracy vgg11: {test_acc:.2f} %")

model_resnet18.eval()
test_acc = test(model_resnet18, testloader) * 100
print(f"Test accuracy resnet18: {test_acc:.2f} %")

Méthode 1 : RISE

In [None]:
import numpy as np
from skimage.transform import resize


def generate_masks(N, s, p1, image_size):
    cell_size = np.ceil(np.array(image_size) / s)
    up_size = (s + 1) * cell_size

    grid = np.random.rand(N, s, s) < p1
    grid = grid.astype("float32")

    masks = np.empty((N, *image_size))

    for i in range(N):
        # Random shifts
        x = np.random.randint(0, cell_size[0])
        y = np.random.randint(0, cell_size[1])
        # Linear interpolation
        masks[i, :, :] = resize(
            grid[i], up_size, order=1, mode="reflect", anti_aliasing=False
        )[x : x + image_size[0], y : y + image_size[1]]

    masks = masks.reshape(-1, 1, *image_size)
    masks = torch.from_numpy(masks).float()

    return masks

In [None]:
def explain(model, N, p1, img, masks):
    img = img.unsqueeze(0).to("cpu")
    _, _, H, W = img.size()
    print(type(img), type(masks))
    stack = torch.mul(masks, img)

    p = []

    for i in tqdm(range(N)):
        with torch.no_grad():
            input = stack[i : i + 1]
            if torch.cuda.is_available():
                input = input.cuda()
            output = model(input)
            p.append(output.to("cpu"))

    p = torch.cat(p)

    CL = p.size(1)
    sal = torch.matmul(p.data.transpose(0, 1), masks.view(N, H * W))
    sal = sal.view((CL, H, W))
    sal = sal / N / p1
    return sal

In [None]:
# Visualisation de la importance map pour la prédiction des modèles sur l'image suivante

idx = 0

img = inv_normalize(x[idx])
np_img = np.transpose(img.cpu().detach().numpy(), (1, 2, 0)) * 255
np_img = np_img.astype(np.uint8)
plt.imshow(np_img)
plt.axis("off")
input = x[idx].unsqueeze(0)
if torch.cuda.is_available():
    input = input.cuda()

In [None]:
output = model_vgg11(input)
_, prediction_vgg11 = torch.topk(output, 1)
print(f"VGG11 prediction: {labels[prediction_vgg11.item()]} (item number {prediction_vgg11.item()})")

output = model_resnet18(input)
_, prediction_resnet18 = torch.topk(output, 1)
print(f"Resnet18 prediction: {labels[prediction_resnet18.item()]} (item number {prediction_resnet18.item()})")

In [None]:
N = 10000  # Number of masks
s = 8  # Size of grid
p1 = 0.1  # Probability of inclusion

masks = generate_masks(N, s, p1, img.shape[-2:])

In [None]:
saliency_maps_vgg11 = explain(model_vgg11, N, p1, img, masks)
saliency_maps_resnet18 = explain(model_resnet18, N, p1, img, masks)

In [None]:
# vgg11
plt.imshow(np_img)

plt.imshow(saliency_maps_vgg11[prediction_vgg11.item()], cmap="turbo", alpha=0.5)
plt.colorbar()
plt.show()

In [None]:
# resnet18
plt.imshow(np_img)

plt.imshow(saliency_maps_resnet18[prediction_resnet18.item()], cmap="turbo", alpha=0.5)
plt.colorbar()
plt.show()

Méthode 2 : Vanilla gradient back-propagation

In [None]:
img = img.unsqueeze(0).cuda() # we need to set the input on GPU before the requires_grad operation!
img.requires_grad_();

In [None]:
# VGG11
img.grad = None

output = model_vgg11(img)
output_idx = output.argmax()
output_max = output[0, output_idx]

output_max.backward()

In [None]:
saliency_vgg11, _ = torch.max(img.grad.data.abs(), dim=1)
saliency_vgg11 = saliency_vgg11.squeeze(0)


plt.figure(figsize=(15,10))
plt.subplot(1,2,1)
plt.imshow(np_img)
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(saliency_vgg11.cpu(), cmap='hot')
plt.axis('off')


In [None]:
# Resnet18
img.grad = None

output = model_resnet18(img)
output_idx = output.argmax()
output_max = output[0, output_idx]

output_max.backward()

In [None]:
saliency_resnet18, _ = torch.max(img.grad.data.abs(), dim=1)
saliency_resnet18 = saliency_resnet18.squeeze(0)


plt.figure(figsize=(15,10))
plt.subplot(1,2,1)
plt.imshow(np_img)
plt.axis('off')
plt.subplot(1,2,2)
plt.imshow(saliency_resnet18.cpu(), cmap='hot')
plt.axis('off')

Métriques d'évaluation

In [None]:
from torch.nn import functional as F

def deletion(model, img, saliency_map, target_class):
    scores = []
    modifiable_img = img.clone().detach()
    modifiable_img = modifiable_img.squeeze(0)
    C, H, W = modifiable_img.shape

    # Créer un saliency_map appliqué à chaque canal
    saliency_map = saliency_map.squeeze()
    expanded_saliency_map = saliency_map.repeat(C, 1, 1)  # Répétition du saliency_map pour chaque canal

    # Indices des pixels par ordre décroissant saliency
    _, indices = torch.sort(expanded_saliency_map.view(-1), descending=True)

    num_pixels_per_step = max(1, len(indices) // 100)  # 1% de l'image totale à chaque fois
    num_steps = len(indices) // num_pixels_per_step

    for step in range(num_steps):
        if step == num_steps - 1:
            indices_to_zero = indices[step * num_pixels_per_step:]  # Prendre tous les pixels restants si dernière itération
        else:
            indices_to_zero = indices[step * num_pixels_per_step:(step + 1) * num_pixels_per_step]

        flat_img = modifiable_img.view(-1)
        flat_img[indices_to_zero] = 0
        modifiable_img = flat_img.view(C, H, W)

        with torch.no_grad():
            output = model(modifiable_img.unsqueeze(0))
            prob = F.softmax(output, dim=1)[0, target_class]
            scores.append(prob.item())

    return scores


In [None]:
scores_vgg11_rise_deletion = deletion(model_vgg11, img, saliency_maps_vgg11[prediction_vgg11.item()], prediction_vgg11.item())
scores_resnet18_rise_deletion = deletion(model_resnet18, img, saliency_maps_resnet18[prediction_resnet18.item()], prediction_resnet18.item())
scores_vgg11_vanilla_deletion = deletion(model_vgg11, img, saliency_vgg11, prediction_vgg11.item())
scores_resnet18_vanilla_deletion = deletion(model_resnet18, img, saliency_resnet18, prediction_resnet18.item())

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 10))

plt.subplot(2, 2, 1)
plt.plot(scores_vgg11_rise_deletion, color='blue')
plt.title('VGG11 RISE Deletion')
plt.xlabel('Pixel removed (%)')
plt.ylabel('Model Confidence')

plt.subplot(2, 2, 2)
plt.plot(scores_resnet18_rise_deletion, color='green')
plt.title('ResNet18 RISE Deletion')
plt.xlabel('Pixel removed (%)')
plt.ylabel('Model Confidence')

plt.subplot(2, 2, 3)
plt.plot(scores_vgg11_vanilla_deletion, color='red')
plt.title('VGG11 Vanilla Deletion')
plt.xlabel('Pixel removed (%)')
plt.ylabel('Model Confidence')

plt.subplot(2, 2, 4)
plt.plot(scores_resnet18_vanilla_deletion, color='purple')
plt.title('ResNet18 Vanilla Deletion')
plt.xlabel('Pixel removed (%)')
plt.ylabel('Model Confidence')

plt.tight_layout()
plt.show()


In [None]:
from torch.nn import functional as F

def insertion(model, img, saliency_map, target_class):
    scores = []
    modifiable_img = torch.zeros_like(img).detach()  # Commencer avec une image nulle
    modifiable_img = modifiable_img.squeeze(0)
    C, H, W = modifiable_img.shape

    # Créer un saliency_map appliqué à chaque canal
    saliency_map = saliency_map.squeeze()
    expanded_saliency_map = saliency_map.repeat(C, 1, 1)  # Répétition du saliency_map pour chaque canal

    # Indices des pixels par ordre décroissant de saliency
    _, indices = torch.sort(expanded_saliency_map.view(-1), descending=True)

    num_pixels_per_step = max(1, len(indices) // 100)  # 1% de l'image totale à chaque fois
    num_steps = len(indices) // num_pixels_per_step

    original_img = img.clone().detach().squeeze(0).view(-1)  # L'image originale en format plat

    for step in range(num_steps):
        if step == num_steps - 1:
            indices_to_add = indices[step * num_pixels_per_step:]  # Prendre tous les pixels restants si dernière itération
        else:
            indices_to_add = indices[step * num_pixels_per_step:(step + 1) * num_pixels_per_step]

        flat_img = modifiable_img.view(-1)
        flat_img[indices_to_add] = original_img[indices_to_add]  # Ajouter les pixels de l'image originale
        modifiable_img = flat_img.view(C, H, W)

        with torch.no_grad():
            output = model(modifiable_img.unsqueeze(0))
            prob = F.softmax(output, dim=1)[0, target_class]
            scores.append(prob.item())

    return scores


In [None]:
scores_vgg11_rise_insertion = insertion(model_vgg11, img, saliency_maps_vgg11[prediction_vgg11.item()], prediction_vgg11.item())
scores_resnet18_rise_insertion = insertion(model_resnet18, img, saliency_maps_resnet18[prediction_resnet18.item()], prediction_resnet18.item())
scores_vgg11_vanilla_insertion = insertion(model_vgg11, img, saliency_vgg11, prediction_vgg11.item())
scores_resnet18_vanilla_insertion = insertion(model_resnet18, img, saliency_resnet18, prediction_resnet18.item())

In [None]:
import matplotlib.pyplot as plt

plt.figure(figsize=(12, 10))

plt.subplot(2, 2, 1)
plt.plot(scores_vgg11_rise_insertion, color='blue')
plt.title('VGG11 RISE Insertion')
plt.xlabel('Pixel added (%)')
plt.ylabel('Model Confidence')

plt.subplot(2, 2, 2)
plt.plot(scores_resnet18_rise_insertion, color='green')
plt.title('ResNet18 RISE Insertion')
plt.xlabel('Pixel added (%)')
plt.ylabel('Model Confidence')

plt.subplot(2, 2, 3)
plt.plot(scores_vgg11_vanilla_insertion, color='red')
plt.title('VGG11 Vanilla Insertion')
plt.xlabel('Pixel added (%)')
plt.ylabel('Model Confidence')

plt.subplot(2, 2, 4)
plt.plot(scores_resnet18_vanilla_insertion, color='purple')
plt.title('ResNet18 Vanilla Insertion')
plt.xlabel('Pixel added (%)')
plt.ylabel('Model Confidence')

plt.tight_layout()
plt.show()
