In [1]:
# Imports
import torch
import torch.nn.functional as F

import matplotlib.pyplot as plt

import sys
# Add the upper directory to the path
sys.path.append("../../models/")
from CustomCNNVessel import CustomResNet
sys.path.append("../../data/")
from VessMapDatasetLoader import vess_map_dataloader

from dash import html, dcc, Dash, html, dcc, Input, Output, callback, State
from dash.dependencies import Input, Output
import plotly.express as px
import json

torch.cuda.empty_cache()
device = torch.device("cpu")

In [2]:
# Dataloaders
image_dir = '/home/fonta42/Desktop/interpretacao-redes-neurais/data/VessMap/images'
mask_dir = '/home/fonta42/Desktop/interpretacao-redes-neurais/data/VessMap/labels'
skeleton_dir = '/home/fonta42/Desktop/interpretacao-redes-neurais/data/VessMap/skeletons'

batch_size = 10
train_size = 0.8

train_loader, test_loader = vess_map_dataloader(image_dir, 
                                  mask_dir, 
                                  skeleton_dir, 
                                  batch_size,
                                  train_size = train_size)

In [3]:
# Concating images
all_images = []
all_masks = []
all_skeletons = []

# Iterate through the entire train_loader
for batch in train_loader:
    images, masks, skeletons = batch
    images, masks, skeletons = images.to(device), masks.to(device), skeletons.to(device)

    all_images.extend(images)
    all_masks.extend(masks)
    all_skeletons.extend(skeletons)
    
for batch in test_loader:
    images, masks, skeletons = batch
    images, masks, skeletons = images.to(device), masks.to(device), skeletons.to(device)

    all_images.extend(images)
    all_masks.extend(masks)
    all_skeletons.extend(skeletons)

In [4]:
# Models
model = CustomResNet(num_classes=2).to(device)
# Load the weights
model.load_state_dict(torch.load(f"../../models/vess_map_regularized_none_200.pth"))
model = model.eval()

In [5]:
import plotly.express as px
import numpy as np

def plot_gradients_with_bounding_box(gradient, model_name, threshold=0.01):
    gradient = gradient.squeeze()
    mask = np.abs(gradient) > threshold
    non_zero_coords = np.nonzero(mask)
    
    if len(non_zero_coords[0]) > 0:
        y_min, y_max = non_zero_coords[0].min(), non_zero_coords[0].max()
        x_min, x_max = non_zero_coords[1].min(), non_zero_coords[1].max()
        num_pixels_above_threshold = np.sum(mask)
        bounding_box_area = (y_max - y_min + 1) * (x_max - x_min + 1)
        fulfillment = num_pixels_above_threshold / bounding_box_area

        # Create the figure using Plotly Express
        fig = px.imshow(
            gradient[y_min:y_max+1, x_min:x_max+1],
            title=f'Gradient Analysis for {model_name}',
            labels={'x': 'x-axis', 'y': 'y-axis'},
            color_continuous_scale=diverging_colorscale,
            range_color=[-np.abs(gradient).max(), np.abs(gradient).max()],
        )

        fig.update_layout(
            annotations=[{
                'text': f"Pixels: {num_pixels_above_threshold}<br>"
                        f"Area: {bounding_box_area}<br>"
                        f"Fulfillment: {fulfillment:.2f}<br>"
                        f"Threshold: {threshold:.4f}<br>",
                'showarrow': False,
                'xref': 'paper',
                'yref': 'paper',
                'x': 0, 'y': 1,
                'xanchor': 'left', 'yanchor': 'top',
                'font': {'size': 12, 'color': 'black'},
                'bgcolor': 'white',
                'opacity': 0.7
            }],
            xaxis={'visible': False},
            yaxis={'visible': False},
        )

        return fig
    else:
        return px.imshow(torch.zeros(224, 224), color_continuous_scale=diverging_colorscale)



In [13]:
from dash import Dash, dcc, html, Input, Output
import plotly.express as px
import json
import numpy as np
import plotly.graph_objects as go

image_data = all_images[0].squeeze()

# Define the gray color map for the original plot
gray_colorscale = [
    [0, 'rgb(0, 0, 0)'],
    [1, 'rgb(255, 255, 255)']
]

# Define the diverging color map for the result plot
diverging_colorscale = [
    [0, 'rgb(255, 0, 0)'],  # Red for negative values
    [0.5, 'rgb(255, 255, 0)'],  # White for zero
    [1, 'rgb(0, 128, 0)']  # Green for positive values
]

app = Dash(__name__)

app.layout = html.Div(
    style={
        'display': 'flex',
        'flex-direction': 'column',
        'align-items': 'center',
        'height': '100vh',  # Ensure the div takes up the full height of the viewport
        'width': '100vw'  # Ensure the div takes up the full width of the viewport
    },
    children=[
        html.H1("Dynamic Gradient Plots"),
        html.Div(
            style={
                'display': 'flex',
                'flex-direction': 'row',
                'justify-content': 'space-between',
                'width': '100%'
            },
            children=[
                html.Div(
                    [
                        html.H2("Original Image"),
                        dcc.Graph(id='image-display', figure=px.imshow(image_data, color_continuous_scale=gray_colorscale)),
                        html.Div(id='hover-data', style={'display': 'none'})
                    ],
                    style={'flex': '1', 'margin-right': '10px'}
                ),
                html.Div(
                    [
                        html.H2("Full Gradient"),
                        dcc.Graph(id='result-image-display'),
                        html.H3(id='hover-coordinates', children='Coordinates: (x, y)'),
                    ],
                    style={'flex': '1', 'margin-right': '10px'}
                ),
                html.Div(
                    [
                        html.H2("Scale Delimited Gradient"),
                        dcc.Graph(id='gradient-display'),
                        dcc.Input(id='threshold-input', type='number', placeholder='Enter proportion of max gradient value', debounce=True, min=0, step=0.001, style={'width': '90%', 'height': '5%', 'margin-bottom': '10px'}),
                        html.Button('Update', id='update-button', n_clicks=0, style={'width': '30%', 'height': '10%'})
                    ],
                    style={'flex': '1', 'margin-right': '10px'}
                )
            ]
        )
    ]
)




# Define callback to update the hover data
@app.callback(
    Output('hover-data', 'children'),
    Input('image-display', 'hoverData')
)
def store_hover_data(hover_data):
    if hover_data:
        return json.dumps({'x': hover_data['points'][0]['x'], 'y': hover_data['points'][0]['y']})
    return "{}"  # Return an empty JSON object if hover_data is None

# Define callback to update the hover coordinates display
@app.callback(
    Output('hover-coordinates', 'children'),
    Input('hover-data', 'children')
)
def update_hover_coordinates(hover_data_json):
    hover_data = json.loads(hover_data_json)
    if hover_data:
        x, y = hover_data['x'], hover_data['y']
        return f'Coordinates: ({x}, {y})'
    return 'Coordinates: (x, y)'

# Define callback to update the result image using update_image function
@app.callback(
    Output('result-image-display', 'figure'),
    Input('update-button', 'n_clicks'),
    Input('hover-data', 'children')
)
def update_image(n_clicks, hover_data_json):
    hover_data = json.loads(hover_data_json)
    if hover_data:
        x, y = hover_data['x'], hover_data['y']
        # Ensure image requires gradient
        image = image_data.unsqueeze(0).unsqueeze(0).requires_grad_()

        # Forward pass
        out = model(image)

        # Apply softmax to the output to get class probabilities
        probabilities = F.softmax(out, dim=1)

        score = probabilities[0, 1, x, y]  # Probability of class 1 at (x, y)

        # Compute gradients
        score.backward()

        return px.imshow(image.grad.squeeze(), color_continuous_scale=diverging_colorscale)
    return px.imshow(torch.zeros(224, 224), color_continuous_scale=diverging_colorscale)

# Define callback to update the gradient display using plot_gradients_with_bounding_box function
@app.callback(
    Output('gradient-display', 'figure'),
    Input('hover-data', 'children'),
    Input('threshold-input', 'value')
)
def update_gradient_display(hover_data_json, threshold):
    hover_data = json.loads(hover_data_json)
    if hover_data:
        x, y = hover_data['x'], hover_data['y']
        # Ensure image requires gradient
        image = image_data.unsqueeze(0).unsqueeze(0).requires_grad_()

        # Forward pass
        out = model(image)

        # Apply softmax to the output to get class probabilities
        probabilities = F.softmax(out, dim=1)

        score = probabilities[0, 1, x, y]  # Probability of class 1 at (x, y)

        # Compute gradients
        score.backward()

        # Update the threshold if provided
        if threshold is not None:
            gradients = [image.grad.squeeze().numpy()]
            gradients = [np.abs(gradient) for gradient in gradients]
            max_val = np.max([np.max(gradient) for gradient in gradients])
            threshold = threshold * max_val
        else:
            threshold = 0.001
        # Return the gradient with the diverging color map
        return  plot_gradients_with_bounding_box(image.grad.squeeze().numpy(), "Model", threshold=threshold)
    return px.imshow(torch.zeros(224, 224))

if __name__ == '__main__':
    app.run_server(debug=True)