Load data

In [9]:
import numpy as np
import pandas as pd
import nibabel as nib

coords =  pd.read_csv('coordinates_DK.tsv', sep='\t')
coords = np.delete(coords.to_numpy(), [3, 37], axis=0)
names = coords[:,0]
coords = np.delete(coords, [0,4], axis=1)

FCcorr = np.load('epoched_correlation.npy')

signal = np.load('CC110033.npy')
epoch, n_regions, n_timepoints = signal.shape
new_sig=np.moveaxis(signal, 0, 1)
new_sig=np.reshape(new_sig, (n_regions, epoch*n_timepoints))
signal = new_sig

#normalize the signal between 0 and 1
signal = (signal - np.min(signal)) / (np.max(signal) - np.min(signal))
print(signal.shape)

# Load the .gii file to extract coordinates and other relevant data
def load_gii_file(gii_file_path):
    # Load the GIFTI file
    gii_data = nib.load(gii_file_path)
    
    # Extract vertex coordinates (assuming 'vertices' are stored in the GIFTI file)
    coords = gii_data.darrays[0].data  # First data array typically contains vertex coordinates (x, y, z)
    
    # Extract faces (the triangles connecting vertices to form the mesh surface)
    faces = gii_data.darrays[1].data  # Second data array typically contains the faces (triangular mesh)
    
    return coords, faces

C_coords, faces = load_gii_file('Cortex.surf.gii')

(68, 51414)


In [10]:
import plotly.graph_objects as go
from IPython.display import display
import ipywidgets as widgets
import asyncio

class create_filtered_connectome_with_edges():
    def __init__(self, adj_matrix, coords, edge_threshold="99%", top_n=3):
        
        node_numbers = [i for i in range(1,69)]
        self.node_names = [f"{node_numbers[i]} {j}" for i,j in enumerate(names)]  # Replace with real names if available
        self.value = 10000
        self.current_node = None

        self.mean_adj_matrix = np.mean(adj_matrix, axis=0)
        self.current_adj_matrix = self.mean_adj_matrix
        

        # Filter edges by edge threshold
        if isinstance(edge_threshold, str) and "%" in edge_threshold:
            percentile = 100 - float(edge_threshold.strip('%'))
            sorted_edges = np.sort(self.mean_adj_matrix.flatten())
            n_lines = len(sorted_edges) - int((percentile / 100) * len(sorted_edges))
            threshold_value = sorted_edges[n_lines]
        else:
            threshold_value = edge_threshold

        # Extract edges above the threshold
        edge_x, edge_y, edge_z, edge_colors = [], [], [], []
        for i in range(len(self.mean_adj_matrix)):
            for j in range(i + 1, len(self.mean_adj_matrix)):
                if self.mean_adj_matrix[i, j] >= threshold_value:
                    edge_x.extend([coords[i][0], coords[j][0], None])
                    edge_y.extend([coords[i][1], coords[j][1], None])
                    edge_z.extend([coords[i][2], coords[j][2], None])
                    edge_colors.extend([self.mean_adj_matrix[i, j]] * 3)

        # Normalize edge weights
        visible_edges = np.array(edge_colors)
        cmin, cmax = visible_edges.min(), visible_edges.max()

        # Create edge trace
        edge_trace = go.Scatter3d(
            x=edge_x, y=edge_y, z=edge_z,
            mode='lines',
            line=dict(
                width=6,
                color=edge_colors,
                colorscale='Viridis',
                cmin=cmin, cmax=cmax,
                colorbar=dict(title="Edge Strength", thickness=15, x=0.85)
            ),
            hoverinfo='none'
        )

        # Create node trace
        x, y, z = coords[:, 0], coords[:, 1], coords[:, 2]
        node_trace = go.Scatter3d(
            x=x, y=y, z=z,
            mode='markers',
            marker=dict(size=4, color='black', opacity=1, colorscale='Viridis'),
            text=self.node_names,
            hoverinfo='text',
        )

        surface_trace = go.Mesh3d(
            x=C_coords[:, 0], y=C_coords[:, 1], z=C_coords[:, 2],
            i=faces[:, 0], j=faces[:, 1], k=faces[:, 2],
            opacity=0.15,  # Make the surface semi-transparent
            color='lightgray',  # Surface color
            showscale=False,
            hoverinfo='skip',
            flatshading=True
        )

        # Combine traces
        fig = go.FigureWidget(data=[edge_trace, node_trace, surface_trace])
        fig.update_layout(
        title="Filtered Interactive Connectome",
        #hide surface mesh

        scene=dict(
            xaxis=dict(visible=True),
            yaxis=dict(visible=True),
            zaxis=dict(visible=True),
            bgcolor="white"
        ),
        width=1000,  # Increase the width of the figure
        height=800,  # Increase the height of the figure
        dragmode='turntable',
        showlegend=False,
        updatemenus=[
                {
                    'buttons': [
                        {
                            'args': [{'visible': [True, True, False]}, {'title': 'Show Mesh'}],
                            'label': 'Hide Mesh',
                            'method': 'update'
                        },
                        {
                            'args': [{'visible': [True, True, True]}, {'title': 'Hide Mesh'}],
                            'label': 'Show Mesh',
                            'method': 'update'
                        }
                    ],
                    'direction': 'down',
                    'pad': {'r': 10, 't': 10},
                    'showactive': True,
                    'x': 0.17,
                    'xanchor': 'left',
                    'y': 1.15,
                    'yanchor': 'top'
                }
            ]
        )
        fig.update_legends(font=dict(size=15))
        # Output widget for connection details
        self.playS_button = widgets.Button(description="Play Signal")
        self.pauseS_button = widgets.Button(description="Pause Signal")
        self.slider = widgets.IntSlider(min=0, max=len(signal[0])-1, step=1, description="Time")
        self.playFC_button = widgets.Button(description="Play FC")
        self.pauseFC_button = widgets.Button(description="Pause FC")
        self.sliderFC = widgets.IntSlider(min=0, max=epoch-1, step=1, description="Time")
        self.back_Mean_button = widgets.Button(description="Back to Mean")
        self.output = widgets.Output()
        self.animation_running = [False]
        self.animationFC_running = [False]
        self.top_n = top_n
        self.edge_threshold = edge_threshold
        self.edge_trace = edge_trace
        self.node_trace = node_trace
        self.surface_trace = surface_trace
        self.coords = coords
        self.C_coords = C_coords
        self.faces = faces
        self.adj_matrix = adj_matrix
        self.fig = fig

        self.output_lock = asyncio.Lock()  # Create a lock

        @fig.data[1].on_click
        def handle_click(trace, points, selector):
            """Handle click events on nodes."""
            if points.point_inds:  # Ensure a node is clicked
                node_idx = points.point_inds[0]
                self.current_node = node_idx
                self.show_top_connections()

                # Update node color
                node_colors = ['black'] * len(coords)  # Reset all colors
                node_colors[node_idx] = 'red'  # Highlight the clicked node
                node_size = [8] * len(coords)
                node_size[node_idx] = 16
                
                # Update the node trace
                self.fig.data[1].marker.color = node_colors
                self.fig.data[1].marker.size = node_size
        
        # Function to show top connections for a clicked node
    def show_top_connections(self):
        """Print top connections for the given node."""
        #update the so that the node stands out
        #node_trace.marker[node_idx].size = 8
        
        if self.current_node is not None:
            connections = self.current_adj_matrix[self.current_node]
            top_indices = np.argsort(connections)[::-1][:self.top_n]
            print(f"  -> Node {self.node_names[top_indices[0]]}: {connections[top_indices[0]]:.2f}")
            with self.output:
                self.output.clear_output()
                print(f"Top {self.top_n} connections for Node {self.node_names[self.current_node]}:")
                for idx in top_indices:
                    print(f"  -> Node {self.node_names[idx]}: {connections[idx]:.2f}")

    # Handle node click event

    

    async def update_output(self):
        async with self.output_lock:  # Ensure only one update happens at a time
            self.output.clear_output(wait=True)
            connections = self.current_adj_matrix[self.current_node]
            top_indices = np.argsort(connections)[::-1][:self.top_n]
            print(f"  -> Node {self.node_names[top_indices[0]]}: {connections[top_indices[0]]:.2f}")
            with self.output:
                self.output.clear_output()
                print(f"Top {self.top_n} connections for Node {self.node_names[self.current_node]}:")
                for idx in top_indices:
                    print(f"  -> Node {self.node_names[idx]}: {connections[idx]:.2f}")

    
    def change_edges(self, activity):
        
        """Change the edges based on the slider value."""
        # Update the current adjacency matrix based on the slider value
        if activity < 0:
            self.current_adj_matrix = self.mean_adj_matrix
        else:
            self.current_adj_matrix = self.adj_matrix[activity]
        
        
        
        # Update edge colors
        if isinstance(self.edge_threshold, str) and "%" in self.edge_threshold:
            percentile = 100 - float(self.edge_threshold.strip('%'))
            sorted_edges = np.sort(self.current_adj_matrix.flatten())
            n_lines = len(sorted_edges) - int((percentile / 100) * len(sorted_edges))
            threshold_value = sorted_edges[n_lines]

        edge_x, edge_y, edge_z, edge_colors = [], [], [], []
        for i in range(len(self.current_adj_matrix)):
            for j in range(i+1, len(self.current_adj_matrix)):
                if self.current_adj_matrix[i, j] >= threshold_value:
                    edge_x.extend([coords[i][0], coords[j][0], None])
                    edge_y.extend([coords[i][1], coords[j][1], None])
                    edge_z.extend([coords[i][2], coords[j][2], None])
                    edge_colors.extend([self.current_adj_matrix[i, j]] * 3)
        
        self.fig.data[0].x = edge_x
        self.fig.data[0].y = edge_y
        self.fig.data[0].z = edge_z

        # Normalize edge weights
        visible_edges = np.array(edge_colors)
        if activity < 0:
            cmin, cmax = visible_edges.min(), visible_edges.max()
        else:
            cmin, cmax = 0.65, 0.98

        # Update edge trace
        self.fig.data[0].line.cmin = cmin
        self.fig.data[0].line.cmax = cmax
        self.fig.data[0].line.color = visible_edges

        
        
        
    
    async def animate(self):
        """Run the animation asynchronously."""
        while self.animation_running[0] and self.slider.value < signal.shape[1] - 1:
            self.slider.value += 1  # Advance the slider
            await asyncio.sleep(0.1)  # Non-blocking wait

    def play_animation(self, _):
        """Start the animation."""
        if not self.animation_running[0]:  # Prevent multiple tasks
            self.animation_running[0] = True
            self.animationFC_running[0] = False
            asyncio.create_task(self.animate())  # Start the animation taskn

    def pause_animation(self, _):
        """Pause the animation."""
        self.animation_running[0] = False

    def on_slider_change(self, change):
        """Update visualization when slider value changes."""
        frame = change['new']
        
        node_size = 20 * signal[:, frame]
        self.fig.data[1].marker.size = node_size
        if self.value != self.slider.value//n_timepoints:   #no change when already on the right frame
                self.value = self.slider.value//n_timepoints
                self.change_edges(self.value)
                 

    async def animateFC(self):
        """Run the animation asynchronously."""
        while self.animationFC_running[0] and self.sliderFC.value < epoch - 1:
            self.sliderFC.value += 1
            #change_edges(sliderFC.value)
            await self.update_output()
            await asyncio.sleep(3)
            
    def play_animationFC(self, _):
        """Start the animation."""
        if not self.animationFC_running[0]:  # Prevent multiple tasks
            self.animationFC_running[0] = True
            self.animation_running[0] = False
            asyncio.create_task(self.animateFC())

    def pause_animationFC(self, _):
        """Pause the animation."""
        self.animationFC_running[0] = False

    def on_sliderFC_change(self, change):
        """Update visualization when slider value changes."""
        frame = change['new']
        self.slider.value = self.sliderFC.value * n_timepoints
        #change_edges(frame)
        

    def back_to_mean(self, _):
        """Reset the visualization to the mean connectivity matrix."""
        self.slider.value = 0
        self.change_edges(-1)
        node_size = 4
        self.fig.data[1].marker.size = node_size
        node_colors = ['black'] * len(coords)
        self.fig.data[1].marker.color = node_colors
    
    def show_all(self):
        self.playS_button.on_click(self.play_animation)
        self.pauseS_button.on_click(self.pause_animation)
        self.slider.observe(self.on_slider_change, names='value')

        self.playFC_button.on_click(self.play_animationFC)
        self.pauseFC_button.on_click(self.pause_animationFC)
        self.sliderFC.observe(self.on_sliderFC_change, names='value')

        self.back_Mean_button.on_click(self.back_to_mean)
        



        # Display the figure and output widget
        display(self.fig)
        display(widgets.HBox([widgets.VBox([widgets.HBox([self.playS_button, self.pauseS_button]), self.slider, widgets.HBox([self.playFC_button, self.pauseFC_button]), self.sliderFC]), self.output, self.back_Mean_button]))


In [None]:
graph = create_filtered_connectome_with_edges(FCcorr, coords, edge_threshold="99%", top_n=5)
graph.show_all()

FigureWidget({
    'data': [{'hoverinfo': 'none',
              'line': {'cmax': 0.3554867716096943,
                       'cmin': 0.3379759721356241,
                       'color': [0.3379759721356241, 0.3379759721356241,
                                 0.3379759721356241, 0.34450979411231536,
                                 0.34450979411231536, 0.34450979411231536,
                                 0.34053183790428054, 0.34053183790428054,
                                 0.34053183790428054, 0.3412089457874777,
                                 0.3412089457874777, 0.3412089457874777,
                                 0.34097238031408134, 0.34097238031408134,
                                 0.34097238031408134, 0.3382855374166086,
                                 0.3382855374166086, 0.3382855374166086,
                                 0.3500425788010897, 0.3500425788010897,
                                 0.3500425788010897, 0.33883213085080444,
                                 0.

HBox(children=(VBox(children=(HBox(children=(Button(description='Play Signal', style=ButtonStyle()), Button(de…

  -> Node 17 inferiorparietal_rh: 0.33
  -> Node 4 caudalmiddlefrontal_lh: 0.72
  -> Node 40 parsorbitalis_rh: 0.76
  -> Node 16 inferiorparietal_lh: 0.62
  -> Node 15 fusiform_rh: 0.65
  -> Node 33 middletemporal_rh: 0.66
  -> Node 6 corpuscallosum_lh: 0.60
  -> Node 6 corpuscallosum_lh: 0.60
  -> Node 6 corpuscallosum_lh: 0.60
  -> Node 32 middletemporal_lh: 0.65
  -> Node 23 isthmuscingulate_rh: 0.58
  -> Node 23 isthmuscingulate_rh: 0.58
