### Import Packages and Dependencies

In [1]:
import os
import sys
from pathlib import Path

os.chdir(Path(os.getcwd()).parents[0])
sys.path.append(os.getcwd())

from pytorch_lightning import LightningDataModule
from matplotlib.colors import LinearSegmentedColormap
import torch
import matplotlib.pyplot as plt
import seaborn as sns
import glob
import numpy as np
import ast
import fnmatch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from src.data.modelnet40_datamodule import ModelNet40DataModule
from src.data.shapenet_datamodule import ShapeNetDataModule

print(sorted(glob.glob("data/attribution_maps/Point_Cloud/*"), key=os.path.getmtime))

file = sorted(glob.glob("data/attribution_maps/Point_Cloud/*"), key=os.path.getmtime)[
    -1
]  # selects always the newsest
array = np.load(file)
data = [array["arr_0"], #array["arr_1"], array["arr_2"]
        ]


if fnmatch.fnmatch(file, "*modelnet40*"):
    datamodule = ModelNet40DataModule(data_dir="data/datasets/", batch_size=20)
    classes = [
        "airplane",
        "bathtub",
        "bed",
        "bench",
        "bookshelf",
        "bottle",
        "bowl",
        "car",
        "chair",
        "cone",
        "cup",
        "curtain",
        "desk",
        "door",
        "dresser",
        "flower_pot",
        "glass_box",
        "guitar",
        "keyboard",
        "lamp",
        "laptop",
        "mantel",
        "monitor",
        "night_stand",
        "person",
        "piano",
        "plant",
        "radio",
        "range_hood",
        "sink",
        "sofa",
        "stairs",
        "stool",
        "table",
        "tent",
        "toilet",
        "tv_stand",
        "vase",
        "wardrobe",
        "xbox",
    ]

if fnmatch.fnmatch(file, "*shapenet*"):
    datamodule = ShapeNetDataModule(data_dir="data/datasets/", batch_size=20)
    classes = [
        "Airplane",
        "Bag",
        "Cap",
        "Car",
        "Chair",
        "Earphone",
        "Guitar",
        "Knife",
        "Lamp",
        "Laptop",
        "Motorbike",
        "Mug",
        "Pistol",
        "Rocket",
        "Skateboard",
        "Table",
    ]

dataloader = datamodule.dataloader()

with torch.no_grad():
    x_batch, y_batch = next(iter(dataloader))


['data/attribution_maps/Point_Cloud/attr_modelnet40_dataset_14_methods_2023-05-23_10-08-03.npz', 'data/attribution_maps/Point_Cloud/attr_modelnet40_dataset_14_methods_2023-05-23_10-48-25.npz', 'data/attribution_maps/Point_Cloud/attr_modelnet40_dataset_14_methods_2023-05-23_10-56-00.npz', 'data/attribution_maps/Point_Cloud/attr_modelnet40_dataset_14_methods_2023-05-23_11-05-45.npz']


In [28]:
methods = ["Occlusion","LIME (Mask)","Kernel SHAP (Mask)","Saliency","Input x Gradient", "Guided Backprob","GradCAM","ScoreCAM","GradCAM++", "IG", "EG", "Deeplift", "Deeplift SHAP", "LRP", "Raw Attention", "Rollout Attention", "LRP Attention"]
models = ["PointNet","DGCNN","Pointcloud Transformer"]
n = 0
model = 0
img = x_batch[n].detach().numpy()

methods.append("Orignal Class: " + str(classes[int(y_batch[n])]).title())

In [29]:
str(classes[int(y_batch[n])])

'night_stand'

In [40]:
import plotly.graph_objects as go

def NormalizeData(data):
    return (data - np.min(data)) / ((np.max(data) - np.min(data)) + 0.00000000001)

fig = go.Figure(data=[go.Scatter3d(
    x=img[0],
    y=img[1],
    z=img[2],
    mode='markers',
    marker=dict(
        size=4,
        color=NormalizeData(np.abs(data[model][n, 1, :, :])).flatten(),                # set color to an array/list of desired values
        colorscale='Viridis',   # choose a colorscale
        opacity=0.9
    )
)])

fig.update_scenes(
    xaxis_showticklabels=False,
    yaxis_showticklabels=False,
    zaxis_showticklabels=False,
    xaxis_title=" ",
    yaxis_title=" ",
    zaxis_title="",
    aspectmode="cube",
    camera=dict(eye=dict(x=1.5, y=1., z=0.5))
)

# tight layout
fig.update_layout(
    scene = dict(
        xaxis = dict(range=[-1,1],),
                     yaxis = dict(range=[-1,1],),
                     zaxis = dict(range=[-1,1],),),)
fig.show()

In [57]:
import plotly.graph_objects as go
import numpy as np
from plotly.subplots import make_subplots

cmap = [[0, 'white'], [0.5, 'red'], [1, 'red']]
titles = methods if model == 2 else methods[0:14] + [methods[17]]

def NormalizeData(data):
    return (data - np.min(data)) / ((np.max(data) - np.min(data)) + 0.00000000001)

X, Y, Z = np.mgrid[-1:1:100j, -1:1:100j, -1:1:100j]
colorbar = dict(tickfont=dict(family="Helvetica", size=18), outlinewidth  = 0, thickness = 20, len = 0.8)

fig = make_subplots(
    rows=3,
    cols=7,
    specs=[
        [
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
        ],
        [
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
        ],
        [
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            None,
            None,
            None,
        ]
    ] if model == 2 else
    [
        [
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
        ],
        [
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
            {"type": "scene"},
        ],
        [
            {"type": "scene"},
            None,
            None,
            None,
            None,
            None,
            None,
        ]
    ],
    subplot_titles=titles,
    vertical_spacing=0.05,
)

for i in range(7):
    fig.add_trace(
        go.Scatter3d(
        x=img[0],
        y=img[1],
        z=img[2],
        mode='markers',
        showlegend=False,
        marker=dict(
            size=4,
            color=NormalizeData(np.abs(data[model][n, i, :, :])).flatten(),
            colorscale='Viridis',   # choose a colorscale
            opacity=0.9,
            colorbar= colorbar
        )
    ),
        row=1,
        col=i + 1,
    )

for i in range(7):
    fig.add_trace(
        go.Scatter3d(
            x=img[0],
            y=img[1],
            z=img[2],
            mode='markers',
            showlegend=False,
            marker=dict(
                size=4,
                color=NormalizeData(np.abs(data[model][n, i + 7, :, :])).flatten(),
                colorscale='Viridis',   # choose a colorscale
                opacity=0.9,
                colorbar= colorbar
            )
        ),
        row=2,
        col=i + 1,
    )

if model == 2:
    for i in range(3):
        fig.add_trace(
            go.Scatter3d(
                x=img[0],
                y=img[1],
                z=img[2],
                mode='markers',
                showlegend=False,
                marker=dict(
                    size=4,
                    color=NormalizeData(np.abs(data[model][n, i + 14, :, :])).flatten(),
                    colorscale='Viridis',   # choose a colorscale
                    opacity=0.9,
                    colorbar= colorbar
                )
            ),
            row=3,
            col=i + 1,
        )

fig.add_trace(
        go.Scatter3d(
        x=img[0],
        y=img[1],
        z=img[2],
        mode='markers',
        showlegend=False,
        marker=dict(
            size=4,
            opacity=0.9,
            color = '#8c564b'
        )
    ),
    row=3,
    col=4 if model == 2 else 1,
)

fig.update_scenes(
    xaxis_showticklabels=False,
    yaxis_showticklabels=False,
    zaxis_showticklabels=False,
    xaxis_title=" ",
    yaxis_title=" ",
    zaxis_title="",
    aspectmode="cube",
    camera=dict(eye=dict(x=1.5, y=1., z=0.5)),
    xaxis = dict(range=[-1,1],),
    yaxis = dict(range=[-1,1],),
    zaxis = dict(range=[-1,1],)
)

fig.update_annotations(font=dict(family="Helvetica", size=22))

fig.update_layout(
    title=dict(
        text="<b>3D Attribution and Attention for " + models[model] + " Model</b>",
        font=dict(family="Helvetica", size=28), x = 0.03
    ),
    height=1200,
    width=2500,
    font=dict(
        family="Helvetica",
        color="#000000",
    ),

)


fig.write_image("data/figures/3DPC_"+ str(model) +"_Importance.png", scale=2)
#fig.show()


In [33]:
x_eye = -1.25
y_eye = 2
z_eye = 0.5


def rotate_z(x, y, z, theta):
    w = x+1j*y
    return np.real(np.exp(1j*theta)*w), np.imag(np.exp(1j*theta)*w), z


for t in np.arange(0, 6.26, 0.2):
    xe, ye, ze = rotate_z(x_eye, y_eye, z_eye, -t)

    fig.update_scenes(camera_eye=dict(x=xe, y=ye, z=ze))

    fig.write_image("data/figures/gif/frame_" + str(t) + "_.png", scale=1)


In [34]:
from PIL import Image
imgs = (
    Image.open(f)
    for f in sorted(glob.glob("data/figures/gif/frame_*"), key=os.path.getmtime)
)
img = next(imgs)  # extract first image from iterator
img.save(
    fp="data/figures/gif/3D_" + str(model) + "_.gif",
    format="GIF",
    append_images=imgs,
    save_all=True,
    duration=120,
    loop=0,
)

In [35]:
for i in glob.glob('data/figures/gif/frame_*'):
    os.remove(i)

In [None]:
fig, axes = plt.subplots(1, 8, figsize=(20, 7), sharey=True, sharex = True)


xai = 16

for i in range(8):
    idx = np.min([(i*4), 27])
    mask = np.abs(data[model][n,xai,:,:,:,idx]) # obs , XAI, c, w, h, z
    axes[i].imshow(img[:,:,idx,:],cmap='gray')
    sns.heatmap(ax =  axes[i], data = mask[0,:,:],cbar=False, cmap="viridis", alpha=0.5)
    axes[i].axis('off')
    axes[i].axes.set_title(str((i*4)),fontsize=8)

axes[0].axes.set_title("Class: " + str(classes[int(y_batch[n])]+ "\nXAI: " + methods[xai] + "\nModel: " + models[model]),fontsize=8)