In [1]:
import connectiviz
import pandas as pd
from connectiviz.load_networks import create_mapped_timeseries, load_network_names
from connectiviz.calculating_connectivity import calculate_intra_network_connectivity, calculate_inter_network_connectivity
from connectiviz.plots import plot_inter_intra_network_connectivity, plot_correlation_matrix


import numpy as np
import pandas as pd
import os
import seaborn as sns
import scipy
import matplotlib.pyplot as plt
import json

TypeError: count_inter_network_connectivity() missing 4 required positional arguments: 'original_timeseries', 'new_timeseries', 'network_json_path', and 'network_names'

In [19]:
# Single subject parameters and data locations
parent_dir = '/Users/molly/Documents/code'
timeseries_csv = ('/Users/molly/Documents/code/13281_2016-10-14_rois_timeseries.csv')
network_json_path = '/Users/molly/Documents/code/connectiviz/data/Yeo_7network_names.json'
subregions_csv = '/Users/molly/Documents/code/connectiviz/data/subregions_Yeo7networks.csv'

In [None]:
original_timeseries = pd.read_csv(timeseries_csv)
original_timeseries = original_timeseries.drop(original_timeseries.columns[233], axis=1)

networks = pd.read_csv(subregions_csv)
network_names = load_network_names(network_json_path)
ordered_networks = network_names.keys()
network_mapping = networks.set_index('Label')['Yeo_7network'].to_dict()
mapped_timeseries = original_timeseries.iloc[:, 0].map(network_mapping)
new_timeseries = original_timeseries.iloc[:, 4:]
new_timeseries['Yeo_7network'] = mapped_timeseries
original_timeseries = original_timeseries.iloc[:, 4:]

In [None]:
# Intra-Network Connectivity
intra_network_connectivity = calculate_intra_network_connectivity(original_timeseries, new_timeseries, network_json_path, network_names)
# Inter-Network Connectivity
inter_network_connectivity = calculate_inter_network_connectivity(original_timeseries, new_timeseries, network_json_path, network_names)
# Both plots
corr_matrix, plot_labels = plot_inter_intra_network_connectivity(intra_network_connectivity, inter_network_connectivity, ordered_networks, network_json_path, network_names)
# Plot correlation matrix with circles
# new_matrix = corr_matrix.to_numpy()
# plot_correlation_matrix(new_matrix, plot_labels)


In [None]:
def plot_correlation_matrix(matrix_df,network_labels, max_size=2000):
    # Calculate the size and color of each dot based on the correlation value
    x, y = np.meshgrid(np.arange(matrix_df.shape[0]), np.arange(matrix_df.shape[1]))
    x = x.flatten()
    y = y.flatten()
    size = matrix_df.flatten() * max_size
    color = matrix_df.flatten()

    # Makes a mask to show only bottom triangle of data
    mask = np.tril(np.ones_like(matrix_df, dtype=bool))
    x_masked = x[mask.flatten()]
    y_masked = y[mask.flatten()]
    size_masked = size[mask.flatten()]
    color_masked = color[mask.flatten()]
    plt.scatter(x_masked+0.5, y_masked+0.5, s=size_masked, c=color_masked, cmap='coolwarm',vmin=0.1)

    # Add labels and title
    plt.title("Inter-Network Connectivity")
    #plt.xticks(np.arange(len(network_labels)), labels=network_labels, rotation=45)
    #plt.yticks(np.arange(len(network_labels)), labels=network_labels)
    plt.ylim(len(network_labels), -1)

    x_ticks = np.unique(x_masked)
    y_ticks = np.unique(y_masked)

    # Add grid lines at the midpoints between circle positions
    plt.xticks(np.arange(len(network_labels)) +0.5, labels=network_labels, rotation=90, ha='center', va='top')
    plt.yticks(np.arange(len(network_labels)) +0.5, labels=network_labels, va='center')

    #plt.grid(True)
    plt.colorbar()
    plt.show()

new_matrix = corr_matrix.to_numpy()
plot_correlation_matrix(new_matrix, plot_labels)

In [23]:
#Inter-Network Connectivity Counter 
def count_inter_network_connectivity(original_timeseries, new_timeseries, network_json_path, network_names):
    """
    Inter-Network Connectivity
    
    Inputs:
    - orginal single-subject timeseries
    - new timeseries
    - path to the network names
    - network names

    Parameters:
    - Key of network numbers mapped to network names (ordered_networks)
    - Iterates over pairs of networks (network_i, network_j)
    - For each network pair, it selects the relevant data from original_timeseries based on the network labels in new_timeseries['Yeo_7network']
    - Calculates the correlation between each pair of rows (one from each network) and then gets the average correlations

    Returns:
    - Inter-network connectivity average correlation matrix
    """
    inter_network_connectivity = {}
    inter_network_counter = {}  # Counter for the number of correlation values
    ordered_networks = network_names.keys()
    for i, network_i in enumerate(ordered_networks):
        for j, network_j in enumerate(ordered_networks):
            if network_i != network_j and network_i != 0 and network_j != 0:
                data_i = original_timeseries[new_timeseries['Yeo_7network'] == network_i]
                data_j = original_timeseries[new_timeseries['Yeo_7network'] == network_j]
                correlation_values = []

                for index_i, row_i in data_i.iterrows():
                    for index_j, row_j in data_j.iterrows():
                        correlation_values.append(row_i.corr(row_j))

                for index_i, row_i in data_i.iterrows():
                    for index_j, row_j in data_j.iterrows():
                        correlation_values.append(row_i.corr(row_j))
                average_corr = np.mean(correlation_values)
                key = (network_i, network_j)
                inter_network_connectivity[key] = average_corr
                # Update the counter
                if key in inter_network_counter:
                    inter_network_counter[key] += len(correlation_values)
                else:
                    inter_network_counter[key] = len(correlation_values)
    return inter_network_counter

connectivity_count = count_inter_network_connectivity(original_timeseries, new_timeseries, network_json_path, network_names)
connectivity_count

{(1, 2): 2244,
 (1, 3): 2040,
 (1, 4): 1496,
 (1, 5): 1768,
 (1, 6): 1700,
 (1, 7): 2448,
 (2, 1): 2244,
 (2, 3): 1980,
 (2, 4): 1452,
 (2, 5): 1716,
 (2, 6): 1650,
 (2, 7): 2376,
 (3, 1): 2040,
 (3, 2): 1980,
 (3, 4): 1320,
 (3, 5): 1560,
 (3, 6): 1500,
 (3, 7): 2160,
 (4, 1): 1496,
 (4, 2): 1452,
 (4, 3): 1320,
 (4, 5): 1144,
 (4, 6): 1100,
 (4, 7): 1584,
 (5, 1): 1768,
 (5, 2): 1716,
 (5, 3): 1560,
 (5, 4): 1144,
 (5, 6): 1300,
 (5, 7): 1872,
 (6, 1): 1700,
 (6, 2): 1650,
 (6, 3): 1500,
 (6, 4): 1100,
 (6, 5): 1300,
 (6, 7): 1800,
 (7, 1): 2448,
 (7, 2): 2376,
 (7, 3): 2160,
 (7, 4): 1584,
 (7, 5): 1872,
 (7, 6): 1800}