In [1]:
import sys
import torch
sys.path.append('../')
from experiment import OCTClassification
from nntools.utils import Config
sys.path.append('../Transformer-Explainability')
from baselines.ViT.ViT_LRP import VisionTransformer
import matplotlib.pyplot as plt
import cv2
import numpy as np
import os


In [2]:
trained_savedict = '/home/clement/Documents/Clement/runs/mlruns/1/505896c1301b4f6790efd15dca57b8f9/artifacts/iteration_25000_mIoU_0920.pth'

state_dict = torch.load(trained_savedict)
state_dict = state_dict['model_state_dict']
reformated_state_dict = {}
for k in state_dict:
    reformated_state_dict[k.split('network.')[1]] = state_dict[k]

In [3]:
from timm.models.vision_transformer import vit_base_patch32_384

In [4]:
from baselines.ViT.ViT_explanation_generator import LRP

In [5]:
model = VisionTransformer(patch_size=32, embed_dim=768, depth=12, 
                          num_heads=12, img_size=384, num_classes=4, qkv_bias=True)
model.load_state_dict(reformated_state_dict, strict=True)
model = model.cuda()
model.eval()
attribution_generator = LRP(model)


In [15]:
config_path = '../config.yaml'
config = Config(config_path)
experiment = OCTClassification(config)
dataset = experiment.test_dataset
labels = {} 
for k, v in dataset.map_class.items():
    labels[v] = k

def show_cam_on_image(img, mask, alpha):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap*alpha + (1-alpha)*np.float32(img)
    cam = cam / np.max(cam)
    return cam


def generate_visualization(original_image, class_index=None, alpha=0.5):
    transformer_attribution = attribution_generator.generate_LRP(original_image, method="transformer_attribution", index=class_index).detach()
    transformer_attribution = transformer_attribution.reshape(1, 1, 12, 12)
    transformer_attribution = torch.nn.functional.interpolate(transformer_attribution, scale_factor=32, mode='bilinear')
    transformer_attribution = transformer_attribution.cuda().data.cpu().numpy()
    transformer_attribution = (transformer_attribution - transformer_attribution.min()) / (transformer_attribution.max() - transformer_attribution.min())
    
    transformer_attribution = np.squeeze(transformer_attribution)
    image_transformer_attribution = original_image.squeeze(0).permute(1, 2, 0).data.cpu().numpy()
    image_transformer_attribution = (image_transformer_attribution - image_transformer_attribution.min()) / (image_transformer_attribution.max() - image_transformer_attribution.min())
    vis = show_cam_on_image(image_transformer_attribution, transformer_attribution, alpha)
    vis =  np.uint8(255 * vis)
#     vis = cv2.cvtColor(np.array(vis), cv2.COLOR_RGB2BGR)
    return image_transformer_attribution, vis

Traceback (most recent call last):
  File "/home/clement/anaconda3/envs/py37/lib/python3.7/site-packages/mlflow/store/tracking/file_store.py", line 237, in list_experiments
    experiment = self._get_experiment(exp_id, view_type)
  File "/home/clement/anaconda3/envs/py37/lib/python3.7/site-packages/mlflow/store/tracking/file_store.py", line 311, in _get_experiment
    meta = read_yaml(experiment_dir, FileStore.META_DATA_FILE_NAME)
  File "/home/clement/anaconda3/envs/py37/lib/python3.7/site-packages/mlflow/utils/file_utils.py", line 170, in read_yaml
    raise MissingConfigException("Yaml file '%s' does not exist." % file_path)
mlflow.exceptions.MissingConfigException: Yaml file '/home/clement/Documents/Clement/runs/OCT-Classification/meta.yaml' does not exist.
Traceback (most recent call last):
  File "/home/clement/anaconda3/envs/py37/lib/python3.7/site-packages/mlflow/store/tracking/file_store.py", line 237, in list_experiments
    experiment = self._get_experiment(exp_id, view_type

In [32]:
preds = []
gts = []
index = np.arange(1000)
np.random.shuffle(index)    
output_vit = 'Explained-ViT-LRP/'
output_img = 'KermanyTestSet/'

if not os.path.exists(output_vit):
    os.makedirs(output_vit)

if not os.path.exists(output_img):
    os.makedirs(output_img)
    
def plot():
    for i in index:
        img, gt = dataset[i]
        filename = dataset.filename(i)
        img = img.unsqueeze(0).cuda()
        pred = model(img)
        prob = torch.softmax(pred, dim=1)
        pred = torch.argmax(prob, dim=1)
        argsort = torch.argsort(prob, dim=1, descending=True)
        preds.append(pred)
        gts.append(gt)    
        inp, out = generate_visualization(img, pred)
        fig, ax = plt.subplots()
        fig.set_size_inches(9,9)
        ax.grid(False)
        # Hide axes ticks
        ax.set_xticks([])
        ax.set_yticks([])
        ax.imshow(inp)
        ax.set_title('Groundtruth: '+ labels[int(gt)] )
        plt.tight_layout()    
        plt.savefig(os.path.join(output_img,filename))    

        plt.close(fig)
        title = ''
        for arg in argsort.squeeze(0)[:2]:
            title = title + ' ' + labels[int(arg)]+', Pr: {:.2%}'.format((prob[0][arg]))

        fig, ax = plt.subplots()
        ax.grid(False)
        # Hide axes ticks
        ax.set_xticks([])
        ax.set_yticks([])
        fig.set_size_inches(9,9)

        ax.set_title('Predicted: '+ title)
        ax.imshow(out)
        plt.tight_layout()    
    #     plt.savefig(os.path.join(output_vit,filename))
    #     plt.close(fig)
        plt.show()
        break

def explain_map():
    for i in index:
        img, gt = dataset[i]
        filename = dataset.filename(i)
        img = img.unsqueeze(0).cuda()
        pred = model(img)
        prob = torch.softmax(pred, dim=1)
        pred = torch.argmax(prob, dim=1)
        argsort = torch.argsort(prob, dim=1, descending=True)
        preds.append(pred)
        gts.append(gt)    
        inp, out = generate_visualization(img, pred, alpha=1)
        
#         cv2.imwrite(os.path.join(output_vit, filename))
        filepath = dataset.img_filepath[i] 
        inp = cv2.imread(filepath)
        
        out = cv2.resize(out, dsize=inp.shape[:2][::-1])
        
        cv2.imwrite(os.path.join(output_vit, filename), out)

        
explain_map()

        


    


In [12]:
gts = np.asarray([int(_) for _ in gts])
preds = np.asarray([int(_) for _ in preds ])
print(np.sum(gts==preds)/1000)

0.001
