In [None]:
import sys
import torch
sys.path.append('../')
from experiment import OCTClassification
sys.path.append('../Transformer-Explainability')
from baselines.ViT.ViT_LRP import VisionTransformer
import matplotlib.pyplot as plt

In [None]:
trained_savedict = '/home/clement/Documents/Clement/runs/mlruns/1/8cde2d9e96fb475681128bce91f90c07/artifacts/iteration_100000_mIoU_0901.pth'

state_dict = torch.load(trained_savedict)
state_dict = state_dict['model_state_dict']
reformated_state_dict = {}
for k in state_dict:
    reformated_state_dict[k.split('network.')[1]] = state_dict[k]

In [None]:
from timm.models.vision_transformer import vit_base_patch32_384

In [None]:
from baselines.ViT.ViT_explanation_generator import LRP

In [None]:
model = VisionTransformer(patch_size=32, embed_dim=768, depth=12, 
                          num_heads=12, img_size=384, num_classes=4, qkv_bias=True)
model.load_state_dict(reformated_state_dict, strict=True)
model = model.cuda()
model.eval()
attribution_generator = LRP(model)
import cv2
import numpy as np
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap*0.25 + 0.75*np.float32(img)
    cam = cam / np.max(cam)
    return cam

In [None]:
config_path = '../config.yaml'
from nntools.utils import load_yaml
config = load_yaml(config_path)
experiment = OCTClassification(config)
dataset = experiment.test_dataset


def generate_visualization(original_image, class_index=None):
    transformer_attribution = attribution_generator.generate_LRP(original_image, method="transformer_attribution", index=class_index).detach()
    transformer_attribution = transformer_attribution.reshape(1, 1, 12, 12)
    transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=32, mode='bilinear')
    transformer_attribution = transformer_attribution.cuda().data.cpu().numpy()
    transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
    
    transformer_attribution = np.squeeze(transformer_attribution)
    image_transformer_attribution = original_image.squeeze(0).permute(1, 2, 0).data.cpu().numpy()
    image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
    vis = show_cam_on_image(image_transformer_attribution, transformer_attribution)
    vis =  np.uint8(255 * vis)
    vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return image_transformer_attribution, vis

In [None]:
preds = []
gts = []
index = np.arange(1000)
np.random.shuffle(index)


labels = {} 
for k, v in dataset.map_class.items():
    labels[v] = k
    
import os
output = 'ViT-LRP-Explained_Kermany/'

if not os.path.exists(output):
    os.makedirs(output)
for i in index:
    img, gt = dataset[i]
    filename = dataset.filename(i)
    img = img.unsqueeze(0).cuda()
    pred = model(img)
    prob = torch.softmax(pred, dim=1)
    pred = torch.argmax(prob, dim=1)
    argsort = torch.argsort(prob, dim=1, descending=True)
    preds.append(pred)
    gts.append(gt)    
    inp, out = generate_visualization(img, pred)
    fig, axs = plt.subplots(1,2)
    fig.set_size_inches(17,9)
    axs[0].imshow(inp)
    axs[0].set_title('Groundtruth: '+ labels[int(gt)] )
    title = ''
    for arg in argsort.squeeze(0)[:2]:
        title = title + ' ' + labels[int(arg)]+', Pr: {:.2%}'.format((prob[0][arg]))

    axs[1].set_title('Predicted: '+ title)
    axs[1].imshow(out)
    
    for ax in axs:
        # Hide grid lines
        ax.grid(False)
        # Hide axes ticks
        ax.set_xticks([])
        ax.set_yticks([])
    plt.tight_layout()
    filepath = os.path.join(output, filename)
    plt.savefig(filepath)
    plt.show()

In [None]:
gts = np.asarray([int(_) for _ in gts])

In [None]:
preds = np.asarray([int(_) for _ in preds ])

In [None]:
print(np.sum(gts==preds)/1000)