In [None]:
!pip install albumentations==0.4.6

In [None]:
!pip install captum

In [None]:
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn

from PIL import Image

import pandas as pd

import os
import json
import numpy as np
from matplotlib.colors import LinearSegmentedColormap

import torchvision
from torchvision import models
from torchvision import transforms

from captum.attr import GradientShap
from captum.attr import visualization as viz

import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2

In [None]:
if(torch.cuda.is_available()):
    print("SI")
else:
    print("NO")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = models.resnet34(pretrained = True).to(device)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 5)

checkpoint = torch.load("/content/drive/MyDrive/run3/checkpoint.pth.tar")
model.load_state_dict(checkpoint['state_dict'])

model = model.eval()
model = model.to(device)

In [None]:
class Diabetic_Retinopathy_Dataset(Dataset): 
  def __init__(self, transform=None):
    self.image_data = pd.read_csv("/content/drive/MyDrive/Projecte_RD/dades_imatges_act3.csv", sep = ';', encoding="latin-1")
    self.transform = transform 

  def __getitem__(self, i):
    ID_imatge, label = self.image_data.loc[i,['ID_Imatge','Retinopatia']]
    path = f"/content/drive/MyDrive/Projecte_RD/dataset_processat/classe_{label}/{ID_imatge}"
    image = cv2.imdecode(np.fromfile(path, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = self.transform(image=image)["image"]
    return image, label-1

  def __len__(self):
    return len(self.image_data)

In [None]:
transform = A.Compose(
    [
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ]
)

transform_simple = transforms.Compose([
 transforms.ToTensor()
])

In [None]:
#6311, 6312, 6906, 48770, 64488
i = 6311
image_data = pd.read_csv("/content/drive/MyDrive/Projecte_RD/dades_imatges_act3.csv", sep = ';', encoding="latin-1")
ID_imatge, label = image_data.loc[i,['ID_Imatge','Retinopatia']]
path = f"/content/drive/MyDrive/Projecte_RD/dataset_processat/classe_{label}/{ID_imatge}"
image = cv2.imdecode(np.fromfile(path, dtype=np.uint8), cv2.IMREAD_UNCHANGED)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
print("Class:",label)
Image.fromarray(image)

image_transform = transform_simple(image)

In [None]:
dataset = Diabetic_Retinopathy_Dataset(transform = transform)
dataset = torch.utils.data.Subset(dataset, [i])
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

In [None]:
for index_i, (images, target) in enumerate(dataloader):
    images = images.to(device)
    output = model(images)
    output = (torch.max(torch.exp(output), 1)[1])
    print(output)

In [None]:
default_cmap = LinearSegmentedColormap.from_list('custom blue', 
                                                 [(0, '#ffffff'),
                                                  (0.25, '#000000'),
                                                  (1, '#000000')], N=256)

In [None]:
torch.manual_seed(0)
np.random.seed(0)

gradient_shap = GradientShap(model)

# Defining baseline distribution of images
rand_img_dist = torch.cat([images * 0, images * 1])

attributions_gs = gradient_shap.attribute(images,
                                          n_samples=50,
                                          stdevs=0.0001,
                                          baselines=rand_img_dist,
                                          target=output)
_ = viz.visualize_image_attr_multiple(np.transpose(attributions_gs.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      np.transpose(image_transform.squeeze().cpu().detach().numpy(), (1,2,0)),
                                      ["original_image", "heat_map"],
                                      ["all", "absolute_value"],
                                      cmap=default_cmap,
                                      show_colorbar=True)