In [2]:
# visit http://127.0.0.1:8050/ in your web browser.
# Imports
import torch
import torch.nn.functional as F

import matplotlib.pyplot as plt

import sys
# Add the upper directory to the path
sys.path.append("../../models/")
from CustomCNNVessel import CustomResNet
sys.path.append("../../data/")
from VessMapDatasetLoader import vess_map_dataloader

from dash import html, dcc, Dash, html, dcc, Input, Output, callback, State
from dash.dependencies import Input, Output
import plotly.express as px
import json

torch.cuda.empty_cache()
device = torch.device("cpu")

In [3]:
# Dataloaders
image_dir = '/home/fonta42/Desktop/interpretacao-redes-neurais/data/VessMap/images'
mask_dir = '/home/fonta42/Desktop/interpretacao-redes-neurais/data/VessMap/labels'
skeleton_dir = '/home/fonta42/Desktop/interpretacao-redes-neurais/data/VessMap/skeletons'

batch_size = 10
train_size = 0.8

train_loader, test_loader = vess_map_dataloader(image_dir, 
                                  mask_dir, 
                                  skeleton_dir, 
                                  batch_size,
                                  train_size = train_size)

In [4]:
# Concating images
all_images = []
all_masks = []
all_skeletons = []

# Iterate through the entire train_loader
for batch in train_loader:
    images, masks, skeletons = batch
    images, masks, skeletons = images.to(device), masks.to(device), skeletons.to(device)

    all_images.extend(images)
    all_masks.extend(masks)
    all_skeletons.extend(skeletons)
    
for batch in test_loader:
    images, masks, skeletons = batch
    images, masks, skeletons = images.to(device), masks.to(device), skeletons.to(device)

    all_images.extend(images)
    all_masks.extend(masks)
    all_skeletons.extend(skeletons)

In [5]:
# Models
model = CustomResNet(num_classes=2).to(device)
# Load the weights
model.load_state_dict(torch.load(f"../../models/vess_map_regularized_none_200.pth"))
model = model.eval()

In [6]:
def get_all_gradients(model, image, sampling_rate = 10, device = "cuda", vectorize = True):
  model.to(device)
  image = image.to(device).requires_grad()
  
  sampled_image = image[:,:,::sampling_rate,::sampling_rate]
  jacobian_gradient = torch.autograd.functional.jacobian(model, 
                                                         sampled_image,
                                                         vectorize = vectorize)
  jacobian_gradient = jacobian_gradient.squeeze()
  
  return jacobian_gradient

In [7]:
import plotly.express as px
import numpy as np

def plot_gradients_with_bounding_box(gradient, model_name, threshold=0.01):
    gradient = gradient.squeeze()
    mask = np.abs(gradient) > threshold
    non_zero_coords = np.nonzero(mask)
    
    if len(non_zero_coords[0]) > 0:
        y_min, y_max = non_zero_coords[0].min(), non_zero_coords[0].max()
        x_min, x_max = non_zero_coords[1].min(), non_zero_coords[1].max()
        num_pixels_above_threshold = np.sum(mask)
        bounding_box_area = (y_max - y_min + 1) * (x_max - x_min + 1)
        fulfillment = num_pixels_above_threshold / bounding_box_area

        # Create the figure using Plotly Express
        fig = px.imshow(
            gradient[y_min:y_max+1, x_min:x_max+1],
            title=f'Gradient Analysis for {model_name}',
            labels={'x': 'x-axis', 'y': 'y-axis'},
            color_continuous_scale=diverging_colorscale,
            range_color=[-np.abs(gradient).max(), np.abs(gradient).max()],
        )

        fig.update_layout(
            annotations=[{
                'text': f"Pixels: {num_pixels_above_threshold}<br>"
                        f"Area: {bounding_box_area}<br>"
                        f"Fulfillment: {fulfillment:.2f}<br>"
                        f"Threshold: {threshold:.4f}<br>",
                'showarrow': False,
                'xref': 'paper',
                'yref': 'paper',
                'x': 0, 'y': 1,
                'xanchor': 'left', 'yanchor': 'top',
                'font': {'size': 12, 'color': 'black'},
                'bgcolor': 'white',
                'opacity': 0.7
            }],
            xaxis={'visible': False},
            yaxis={'visible': False},
        )
        return fig
    else:
        return px.imshow(torch.zeros(224, 224), color_continuous_scale=gray_colorscale)

In [52]:
app = Dash(__name__)

app.layout = html.Div(
    style={
        'display': 'flex',
        'flex-direction': 'column',
        'align-items': 'center',
        'height': '100vh',  # Corrected to 'vh' for viewport height
        'width': '100vw'  # Corrected to 'vw' for viewport width
    },
     children=[
        html.H1("Dynamic Gradient Plots"),
        html.Div(
            style={
                'display': 'flex',
                'flex-direction': 'row',
                'justify-content': 'space-around',
                'width': '90%',
                'margin-bottom': '20px',
                'align-items': 'center'  # Add this to align items vertically
            },
            children=[
                dcc.Dropdown(
                    id='image-dropdown',
                    options=[{'label': f'Image {i}', 'value': i} for i in range(len(all_images))],
                    value=0,  # Default value
                    style={'width': '80%', 'flex': '1'}  # Adjusted for equal width and flex
                ),
                dcc.Input(
                    id='threshold-input',
                    type='number',
                    placeholder='Enter the proportion of max gradient value',
                    debounce=True,
                    min=0,
                    step=0.001,
                    value=0.005,
                    style={'width': '80%', 'flex': '1'}  # Adjusted for equal width and flex
                ),
                html.Button(
                    'Update',
                    id='update-button',
                    n_clicks=0,
                    style={'width': '80%', 'flex': '1'}  # Adjusted for equal width and flex
                )
            ]
        ),
        html.Div(
            style={
                'display': 'flex',
                'flex-direction': 'row',  # Changed to 'row' for horizontal layout
                'justify-content': 'space-between',
                'width': '95vw',
                'height': '80vh',  # Adjusted to use remaining viewport height
            },
            children=[
                html.Div(
                    [
                        html.H2("Original Image"),
                        dcc.Graph(
                            id='image-display',
                            config={'displayModeBar': True},
                            style={'height': '100%'}  # Adjusted to fill the div height
                        ),
                        html.Div(id='hover-data', style={'display': 'none'})
                    ],
                    style={'flex': '1', 'margin-right': '10px'}  # Adjusted for equal spacing and margin
                ),
                html.Div(
                    [
                        html.H2(id='hover-coordinates', children='Coordinates: (x, y)'),
                        dcc.Graph(
                            id='result-image-display',
                            config={'displayModeBar': True},
                            style={'height': '100%'}  # Adjusted to fill the div height
                        )
                    ],
                    style={'flex': '1', 'margin-right': '10px'}  # Adjusted for equal spacing and margin
                ),
                html.Div(
                    [
                        html.H2(id='threshold-title', children="Scale Delimited Gradient (Threshold: 0)"),
                        dcc.Graph(
                            id='gradient-display',
                            config={'displayModeBar': True},
                            style={'height': '100%'}  # Adjusted to fill the div height
                        )
                    ],
                    style={'flex': '1'}  # Adjusted for equal spacing
                )
            ]
        )
    ]
)


# Callback to update the original image based on the dropdown selection
@app.callback(
    Output('image-display', 'figure'),
    Input('image-dropdown', 'value')
)
def update_original_image(selected_index):
    image_data = all_images[selected_index].squeeze()
    return px.imshow(image_data, color_continuous_scale=px.colors.sequential.gray)

# Define callback to update the hover data
@app.callback(
    Output('hover-data', 'children'),
    Input('image-display', 'hoverData')
)
def store_hover_data(hover_data):
    if hover_data:
        return json.dumps({'x': hover_data['points'][0]['x'], 'y': hover_data['points'][0]['y']})
    return "{}"

# Define callback to update the hover coordinates display
@app.callback(
    Output('hover-coordinates', 'children'),
    Input('hover-data', 'children')
)
def update_hover_coordinates(hover_data_json):
    hover_data = json.loads(hover_data_json)
    if hover_data:
        x, y = hover_data['x'], hover_data['y']
        return f'Full Gradient, at Coordinates: ({x}, {y})'
    return 'Full Gradient, Coordinates: (x, y)'

# Define callback to update the result image
@app.callback(
    Output('result-image-display', 'figure'),
    Input('update-button', 'n_clicks'),
    Input('hover-data', 'children'),
    Input('image-dropdown', 'value'),
    State('image-display', 'relayoutData')
)
def update_image(n_clicks, hover_data_json, selected_index, relayout_data):
    hover_data = json.loads(hover_data_json)
    if hover_data:
        x, y = hover_data['x'], hover_data['y']
        image_data = all_images[selected_index].squeeze()
        # Ensure image requires gradient
        image = image_data.unsqueeze(0).unsqueeze(0).requires_grad_()

        # Forward pass
        out = model(image)

        # Apply softmax to the output to get class probabilities
        probabilities = F.softmax(out, dim=1)

        score = probabilities[0, 1, x, y]  # Probability of class 1 at (x, y)

        # Compute gradients
        score.backward()

        fig = px.imshow(image.grad.squeeze(), color_continuous_scale=px.colors.diverging.RdYlGn)
        if relayout_data and 'xaxis.range[0]' in relayout_data:
            fig.update_layout(
                xaxis_range=[relayout_data['xaxis.range[0]'], relayout_data['xaxis.range[1]']],
                yaxis_range=[relayout_data['yaxis.range[0]'], relayout_data['yaxis.range[1]']]
            )
        return fig
    return px.imshow(torch.zeros(224, 224), color_continuous_scale=px.colors.diverging.RdYlGn)

# Define callback to update the gradient display and the threshold title
@app.callback(
    Output('gradient-display', 'figure'),
    Output('threshold-title', 'children'),
    Input('hover-data', 'children'),
    Input('threshold-input', 'value'),
    Input('image-dropdown', 'value')
)
def update_gradient_display(hover_data_json, threshold, selected_index):
    hover_data = json.loads(hover_data_json)
    if hover_data:
        x, y = hover_data['x'], hover_data['y']
        image_data = all_images[selected_index].squeeze()
        # Ensure image requires gradient
        image = image_data.unsqueeze(0).unsqueeze(0).requires_grad_()

        # Forward pass
        out = model(image)

        # Apply softmax to the output to get class probabilities
        probabilities = F.softmax(out, dim=1)

        score = probabilities[0, 1, x, y]  # Probability of class 1 at (x, y)

        # Compute gradients
        score.backward()

        # Update the threshold if provided
        if threshold is not None:
            gradients = [image.grad.squeeze().numpy()]
            gradients = [np.abs(gradient) for gradient in gradients]
            max_val = np.max([np.max(gradient) for gradient in gradients])
            threshold = threshold * max_val
            print(threshold)
            print(max_val)
            
        else:
            threshold = 0.001

        # Return the gradient with the diverging color map
        fig = plot_gradients_with_bounding_box(image.grad.squeeze().numpy(),"Model", threshold=threshold)
        fig.update_layout(
            title=f"Scale Delimited Gradient (Threshold: {threshold:.3f})"
        )
        return fig, f"Scale Delimited Gradient (Threshold: {threshold:.3f})"
    return px.imshow(torch.zeros(224, 224), color_continuous_scale=px.colors.diverging.RdYlGn), "Scale Delimited Gradient (Threshold: 0)"

if __name__ == '__main__':
    app.run_server(debug=True)

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
AttributeError: module 'plotly.express.colors' has no attribute 'converging'

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
AttributeError: module 'plotly.express.colors' has no attribute 'converging'

