Téléchargement base de données + préprocessing

Pour Linux

In [None]:
!wget https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz > /dev/null 2>&1
!tar zxvf imagenette2.tgz > /dev/null 2>&1

Pour MacOS

In [None]:
!curl -O -# https://s3.amazonaws.com/fast-ai-imageclas/imagenette2.tgz
!tar -zxvf imagenette2.tgz > /dev/null 2>&1

In [None]:
from google.colab import drive

drive.mount("/content/drive")

In [None]:
import torchvision
import torchvision.transforms as transforms
import torch
import os
from torch.utils.data import Dataset

means, stds = (0.485, 0.456, 0.406), (0.229, 0.224, 0.225)
train_transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(means, stds),
    ]
)

test_transform = transforms.Compose(
    [
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(means, stds),
    ]
)


def get_imagenette2_loaders(root_path="./imagenette2", **kwargs):

    trainset = torchvision.datasets.ImageFolder(
        os.path.join(root_path, "train"), transform=train_transform
    )
    trainloader = torch.utils.data.DataLoader(trainset, **kwargs)
    testset = torchvision.datasets.ImageFolder(
        os.path.join(root_path, "val"), transform=test_transform
    )
    testloader = torch.utils.data.DataLoader(testset, **kwargs)
    return trainloader, testloader


trainloader, testloader = get_imagenette2_loaders(
    batch_size=64, shuffle=True, num_workers=2
)

labels = [
    "tench",
    "English springer",
    "cassette player",
    "chain saw",
    "church",
    "French horn",
    "garbage truck",
    "gas pump",
    "golf ball",
    "parachute",
]

In [None]:
from torchvision.utils import make_grid
import matplotlib.pyplot as plt

inv_normalize = transforms.Normalize(
    mean=[-m / s for m, s in zip(means, stds)], std=[1 / s for s in stds]
)

x, _ = next(iter(trainloader))
img_grid = make_grid(x[:16])
img_grid = inv_normalize(img_grid)
plt.figure(figsize=(20, 15))
plt.imshow(img_grid.permute(1, 2, 0))
plt.axis("off")

Modèle 1 : pr-trained VGG11

In [None]:
import torch.nn as nn
import torch

model = torchvision.models.vgg11(pretrained=True)
for param in model.features:
    param.requires_grad = False

model.classifier = nn.Sequential(
    nn.Linear(in_features=25088, out_features=4096, bias=True),
    nn.ReLU(inplace=True),
    nn.Linear(in_features=4096, out_features=4096, bias=True),
    nn.ReLU(inplace=True),
    nn.Linear(in_features=4096, out_features=10, bias=True),
)
if torch.cuda.is_available():
    model = model.cuda()

In [None]:
from tqdm import tqdm

criterion_classifier = nn.CrossEntropyLoss(reduction="mean")


def train(model, optimizer, trainloader, epochs=30):
    t = tqdm(range(epochs))
    for epoch in t:
        corrects = 0
        total = 0
        for x, y in trainloader:
            loss = 0
            if torch.cuda.is_available():
                x = x.cuda()
                y = y.cuda()
            y_hat = model(x)

            loss += criterion_classifier(y_hat, y)
            _, predicted = y_hat.max(1)
            corrects += predicted.eq(y).sum().item()
            total += y.size(0)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            t.set_description(
                f"epoch:{epoch} current accuracy:{round(corrects / total * 100, 2)}%"
            )
    return corrects / total

In [None]:
learning_rate = 5e-3
epochs = 1
optimizer = torch.optim.Adam(model.classifier.parameters(), lr=learning_rate)
train(model, optimizer, trainloader, epochs=epochs)

In [None]:
def test(model, dataloader):
    test_corrects = 0
    total = 0
    with torch.no_grad():
        for x, y in dataloader:
            if torch.cuda.is_available():
                x = x.cuda()
                y = y.cuda()
            y_hat = model(x).argmax(1)
            test_corrects += y_hat.eq(y).sum().item()
            total += y.size(0)
    return test_corrects / total


model.eval()
test_acc = test(model, testloader)
print(f"Test accuracy: {test_acc:.2f} %")

Méthode 1 : RISE

In [None]:
import numpy as np
from skimage.transform import resize


def generate_masks(N, s, p1, image_size):
    cell_size = np.ceil(np.array(image_size) / s)
    up_size = (s + 1) * cell_size

    grid = np.random.rand(N, s, s) < p1
    grid = grid.astype("float32")

    masks = np.empty((N, *image_size))

    for i in range(N):
        # Random shifts
        x = np.random.randint(0, cell_size[0])
        y = np.random.randint(0, cell_size[1])
        # Linear interpolation
        masks[i, :, :] = resize(
            grid[i], up_size, order=1, mode="reflect", anti_aliasing=False
        )[x : x + image_size[0], y : y + image_size[1]]

    masks = masks.reshape(-1, 1, *image_size)
    masks = torch.from_numpy(masks).float()
    # masks = masks.cuda()

    return masks

In [None]:
def explain(model, N, p1, img, masks, gpu_batch=2):
    img = img.unsqueeze(0).to("cpu")
    _, _, H, W = img.size()
    print(type(img), type(masks))
    stack = torch.mul(masks, img)

    p = []

    """for i in range(0, N, gpu_batch):
         output = model(stack[i:min(i + gpu_batch, N)].to('cuda'))
         p.append(output.to('cpu'))"""

    for i in tqdm(range(N)):
        with torch.no_grad():
            output = model(stack[i : i + 1].to("cuda"))
            p.append(output.to("cpu"))

    p = torch.cat(p)

    CL = p.size(1)
    sal = torch.matmul(p.data.transpose(0, 1), masks.view(N, H * W))
    sal = sal.view((CL, H, W))
    sal = sal / N / p1
    return sal

In [None]:
# Visualisation de la importance map pour la prédiction du modèle VGG11 sur l'image suivante

idx = 7

img = inv_normalize(x[idx])
np_img = np.transpose(img.cpu().detach().numpy(), (1, 2, 0)) * 255
np_img = np_img.astype(np.uint8)
plt.imshow(np_img)
plt.axis("off")
input = x[idx].unsqueeze(0)
if torch.cuda.is_available():
    input = input.cuda()
output = model(input)
_, prediction = torch.topk(output, 1)
print(prediction)
print(f"Model's prediction: {labels[prediction.item()]}")

In [None]:
print(img.shape[-2:])

In [None]:
N = 12000  # Number of masks
s = 8  # Size of grid
p1 = 0.1  # Probability of inclusion

masks = generate_masks(N, s, p1, img.shape[-2:])

In [None]:
# saliency_maps = explain(model, N, p1, img.cuda(), masks)
saliency_maps = explain(model, N, p1, img, masks)

plt.imshow(np_img)

plt.imshow(saliency_maps[0], cmap="turbo", alpha=0.5)
plt.colorbar()
plt.show()