In [18]:
# visit http://127.0.0.1:8050/ in your web browser.
# Imports
import dash
from dash import html, dcc, Dash, html, dcc, Input, Output, callback, State
from dash.dependencies import Input, Output, MATCH
import plotly.express as px
import json

import torch
from torch.autograd.functional import jacobian
import torch.nn.functional as F
import cv2

# Original Images
import os
from PIL import Image
import numpy as np

import sys

sys.path.append("../data/")
from vess_map_dataset_loader import vess_map_dataloader

sys.path.append("../models/")
from vess_map_custom_cnn import CustomResNet

torch.cuda.empty_cache()
device = torch.device("cuda")

In [19]:
# Load the pre-trained model
model_weighted = CustomResNet(num_classes=2).cuda()
model_weighted.load_state_dict(
    torch.load(f"../models/trained-models/vess_map_custom_cnn.pth")
)
model = model_weighted.eval()

In [20]:
def load_images_from_directory(directory_name):
    # Get the list of image file names in sorted order
    image_files = sorted(os.listdir(directory_name))

    # Load and store the images in a list
    images = []
    for file_name in image_files:
        if file_name.endswith('.png'):
            img_path = os.path.join(directory_name, file_name)
            img = Image.open(img_path)
            img_array = np.array(img)
            images.append(img_array)

    return images

# Load images from both directories
original_images = load_images_from_directory('../data/cropped_images')

# Convert the lists to arrays if needed
original_images = torch.tensor(np.array(original_images) / 255.0, dtype=torch.float).to('cuda')
original_images.shape

torch.Size([5, 64, 64])

In [21]:
gradient_path = f'../gradient-extraction/gradients'

# thresholded_gradients images
thresholded_gradients = []
for i in range(original_images.shape[0]):
  thresholded_gradients.append(np.load(f'../gradient-extraction/thresholded_gradients/image_{i}.npy'))

thresholded_gradients = np.array(thresholded_gradients)

thresholded_gradients.shape

(5, 64, 64)

In [22]:
def plot_gradients_with_bounding_box(gradient, model_name, threshold=0.01):
    gradient = gradient.squeeze()
    mask = np.abs(gradient) > threshold
    non_zero_coords = np.nonzero(mask)
    
    if len(non_zero_coords[0]) > 0:
        y_min, y_max = non_zero_coords[0].min(), non_zero_coords[0].max()
        x_min, x_max = non_zero_coords[1].min(), non_zero_coords[1].max()
        num_pixels_above_threshold = np.sum(mask)
        bounding_box_area = (y_max - y_min + 1) * (x_max - x_min + 1)
        fulfillment = num_pixels_above_threshold / bounding_box_area

        # Create the figure using Plotly Express
        fig = px.imshow(
            gradient[y_min:y_max+1, x_min:x_max+1],
            title=f'Gradient Analysis for {model_name}',
            labels={'x': 'x-axis', 'y': 'y-axis'},
            color_continuous_scale=px.colors.diverging.RdYlGn,
            range_color=[-np.abs(gradient).max(), np.abs(gradient).max()],
        )

        fig.update_layout(
            annotations=[{
                'text': f"Pixels: {num_pixels_above_threshold}<br>"
                        f"X(min, max): ({x_min}, {x_max})<br>"
                        f"Y(min, max): ({y_min}, {y_max})<br>"
                        f"Area: {bounding_box_area}<br>"
                        f"Fulfillment: {fulfillment:.2f}<br>"
                        f"Threshold: {threshold:.6f}<br>",
                        
                'showarrow': False,
                'xref': 'paper',
                'yref': 'paper',
                'x': 0, 'y': 1,
                'xanchor': 'left', 'yanchor': 'top',
                'font': {'size': 12, 'color': 'black'},
                'bgcolor': 'white',
                'opacity': 0.7
            }],
            xaxis={'visible': False},
            yaxis={'visible': False},
        )
        return fig
    else:
        return px.imshow(torch.zeros(128, 128), color_continuous_scale=px.colors.sequential.gray)

In [23]:
app = Dash(__name__)

app.layout = html.Div(
    style={
        'display': 'flex',
        'flex-direction': 'column',
        'align-items': 'center',
        'height': '95vh',  # Reduced height to account for borders
        'width': '95vw',   # Reduced width to account for borders
        'padding': '10px',  # Add some padding
        'margin': 'auto',  # Center the div
        'box-shadow': '0 4px 8px 0 rgba(0, 0, 0, 0.2)',  # Add a shadow for better visuals
        'background-color': '#f9f9f9',  # Change background color for better contrast
        'border-radius': '15px',  # Add border radius
    },
    children=[
        html.H1("Dynamic Gradient Plots", style={'margin-bottom': '20px'}),
        html.Div(id='hover-data', style={'display': 'none'}),
        html.Div(
            style={
                'display': 'flex',
                'flex-direction': 'row',
                'justify-content': 'space-around',
                'width': '95vw',
                'margin-bottom': '20px',
            },
            children=[
                dcc.Dropdown(
                    id='image-dropdown',
                    options=[{'label': f'Image {i}', 'value': i} for i in range(len(original_images))],
                    value=0,
                    style={
                        'width': '30vw',
                        'margin-right': '10px',
                        'background-color': '#fff',  # White background for better visibility
                        'border-radius': '15px',
                    }
                ),
                dcc.Input(
                    id='threshold-input',
                    type='number',
                    placeholder='Porcentage of max value of gradient as threshold',
                    min=0.001,
                    max=100,
                    style={
                        'width': '30vw',
                        'margin-right': '10px',
                        'padding': '5px',  # Add some padding for better appearance
                        'border-radius': '15px',
                    }
                ),
                html.Button(
                    'Update',
                    id='update-button',
                    n_clicks=0,
                    style={
                        'width': '30vw',
                        'padding': '5px',  # Add some padding for better appearance
                        'background-color': '#007bff',  # Bootstrap primary color
                        'color': '#fff',  # White text color
                        'border': 'none',  # Remove default border
                        'cursor': 'pointer',  # Change cursor to pointer on hover
                        'border-radius': '15px',
                    }
                )
            ]
        ),
        html.Div(
            style={
                'display': 'flex',
                'flex-direction': 'row',
                'width': '95vw',  # Use 100% of the width
                'height': '95vh',  # Adjust the height as needed
            },
            children=[
                html.Div(
                    style={
                        'display': 'flex',
                        'flex-direction': 'column',
                        'align-items': 'center',
                        'flex': '1',
                        'margin-right': '2%',
                        'border-radius': '15px'
                    },
                    children=[
                        html.H3("Original Image", style={'padding':'10px'}),
                        dcc.Graph(
                            id='image-display',
                            config={'displayModeBar': True},
                            style={'height': '80%', 'width': '100%'}
                        )
                    ]
                ),
                html.Div(
                    style={
                        'display': 'flex',
                        'flex-direction': 'column',
                        'align-items': 'center',
                        'flex': '1',
                        'margin-right': '2%',
                        'border-radius': '15px'
                    },
                    children=[
                        html.H3("Normalized quantity of pixels above 1% of max value for gradient at (x,y)"),
                        dcc.Graph(
                            id='fulfillment-image-display',
                            config={'displayModeBar': True},
                            style={'height': '80%', 'width': '100%'}
                        )
                    ]
                ),
                html.Div(
                    style={
                        'display': 'flex',
                        'flex-direction': 'column',
                        'align-items': 'center',
                        'flex': '1',
                        'margin-right': '2%',
                        'border-radius': '15px'
                    },
                    children=[
                        html.H3("Model Mask", style={'padding':'10px'}),
                        dcc.Graph(
                            id='model-mask-display',
                            config={'displayModeBar': True},
                            style={'height': '80%', 'width': '100%'}
                        )
                    ]
                ),
                html.Div(
                    style={
                        'display': 'flex',
                        'flex-direction': 'column',
                        'align-items': 'center',
                        'flex': '1',
                        'margin-right': '2%',
                        'border-radius': '15px'
                    },
                    children=[
                        html.H3(id='hover-coordinates', children='Coordinates: (x, y)', style={'padding':'10px'}),
                        dcc.Graph(
                            id='result-image-display',
                            config={'displayModeBar': True},
                            style={'height': '80%', 'width': '100%'}
                        )
                    ]
                ),
                html.Div(
                    style={
                        'display': 'flex',
                        'flex-direction': 'column',
                        'align-items': 'center',
                        'flex': '1',
                        'border-radius': '15px'
                    },
                    children=[
                        html.H3(id='threshold-title', children="Pixels of gradient above Threshold: 0)"),
                        dcc.Graph(
                            id='gradient-display',
                            config={'displayModeBar': True},
                            style={'height': '80%', 'width': '100%'}
                        )
                    ]
                )
            ]
        )
    ]
)

# Callback to update the original image based on the dropdown selection
@app.callback(
    Output('image-display', 'figure'),
    Output('fulfillment-image-display', 'figure'),
    Output('model-mask-display', 'figure'),
    Input('image-dropdown', 'value'),
)
def update_images(selected_index):
    original_image_data = original_images[selected_index].cpu().numpy()
    
    thresholded_gradients_data = thresholded_gradients[selected_index]
    
    original_fig = px.imshow(original_image_data, color_continuous_scale=px.colors.sequential.gray)
    
    model_mask = model(original_images[selected_index].unsqueeze(0).unsqueeze(0))
    thresholded_gradients_data *= model_mask.argmax(dim=1).squeeze(0).detach().cpu().numpy()
    fulfillment_fig = px.imshow(thresholded_gradients_data, color_continuous_scale=px.colors.diverging.RdYlGn)
    
    softmax_probs = F.softmax(model_mask, dim=1)
    class_one_probs = softmax_probs[0, 1, :, :].detach().cpu().numpy()

    model_fig = px.imshow(class_one_probs, color_continuous_scale=px.colors.diverging.RdYlGn)
    
    return original_fig, fulfillment_fig,model_fig

# Define callback to update the hover data
@app.callback(
    Output('hover-data', 'children'),
    Input('image-display', 'hoverData'),
    Input('fulfillment-image-display', 'hoverData')
)
def store_hover_data(hover_data_original, hover_data_fulfillment):
    hover_data = hover_data_original or hover_data_fulfillment
    if hover_data:
        return json.dumps({'x': hover_data['points'][0]['x'], 'y': hover_data['points'][0]['y']})
    return "{}"

# Define callback to update the hover coordinates display
@app.callback(
    Output('hover-coordinates', 'children'),
    Input('hover-data', 'children')
)
def update_hover_coordinates(hover_data_json):
    hover_data = json.loads(hover_data_json)
    if hover_data:
        x, y = hover_data['x'], hover_data['y']
        return f'Full Gradient, at Coordinates: ({x}, {y})'
    return 'Full Gradient, Coordinates: (x, y)'

# Define callback to update the result image
@app.callback(
    Output('result-image-display', 'figure'),
    Input('update-button', 'n_clicks'),
    Input('hover-data', 'children'),
    Input('image-dropdown', 'value'),
)
def update_image(n_clicks, hover_data_json, selected_index):
    hover_data = json.loads(hover_data_json)
    if hover_data:
        x, y = hover_data['x'], hover_data['y']
        loaded_gradient = torch.load(f'{gradient_path}/jacobian_gradient_{selected_index}.pt')
        
        max_val = torch.max(torch.abs(loaded_gradient)).cpu().item()
        #max_val = np.max([np.max(gradient) for gradient in gradient])
        fig = px.imshow(loaded_gradient[y,x].to('cpu'), color_continuous_scale=px.colors.diverging.RdYlGn, range_color=[-max_val, max_val])
        return fig
    return px.imshow(torch.zeros(128, 128), color_continuous_scale=px.colors.diverging.RdYlGn)

# Define callback to update the gradient display and the threshold title
@app.callback(
    Output('gradient-display', 'figure'),
    Output('threshold-title', 'children'),
    Input('hover-data', 'children'),
    Input('threshold-input', 'value'),
    Input('image-dropdown', 'value')
)
def update_gradient_display(hover_data_json, threshold, selected_index):
    hover_data = json.loads(hover_data_json)
    if hover_data:
        x, y = hover_data['x'], hover_data['y']
        loaded_gradient = torch.load(f'{gradient_path}/jacobian_gradient_{selected_index}.pt')
        max_val = torch.max(torch.abs(loaded_gradient)).cpu().item()
        
        threshold = threshold if threshold is not None else 0.01
        threshold_val = (threshold / 100) * max_val

        # Return the gradient with the diverging color map
        fig = plot_gradients_with_bounding_box(loaded_gradient[x,y].to('cpu').numpy(),"Model", threshold=threshold_val)
        fig.update_layout(
            title=f"Pixels of gradient above Threshold: {threshold_val:.4f}"
        )
        return fig, f"Pixels of gradient above Threshold: {threshold_val:.4f}"
    return px.imshow(torch.zeros(128, 128), color_continuous_scale=px.colors.diverging.RdYlGn), "Pixels of gradient above Threshold: 0"

if __name__ == '__main__':
    app.run_server(debug=True)