# Correlation plot combine and compare

#### Note: Use the epi-paint kernel only.

In [None]:
# Import Dependencies

import os as _os
import os.path as _ospath
import numpy as _np
import pandas as _pd
import matplotlib.pyplot as _plt
from matplotlib.colors import LogNorm
import seaborn as _sns
from scipy.stats import spearmanr

In [None]:
# Build and/or define output folder
def output_folder_path(folder):
    output_folder = _ospath.join(folder, 'Analysis', 'Correlation')
    if not _ospath.exists(output_folder):
        _os.makedirs(output_folder)
    return output_folder

# Data combiner
def combine_data(folder, file_name):
    files_list = []
    for root, dirs, files in _os.walk(folder):
        for file in files:
            if file == file_name:
                data_file = _ospath.join(root, file)
                files_list.append(data_file)

    if len(files_list) == 0:
        raise FileNotFoundError(f'No files with name {file_name} found in {folder}.')

    files_list = [file for file in files_list if file.endswith('.csv')] # Important in MacOS as MacOS directories have a .DS_Store file that messes things up without this. 
    print(f'Found {len(files_list)} CSV files with name {file_name} in {folder}.')
    for file in files_list:
        print(file)
    data_list = [_pd.read_csv(file, index_col=0) for file in files_list]
    data = _pd.concat(data_list, ignore_index=True)
    data = data.dropna(how='all')
    data = data.fillna(0)
    return data

def combine_and_correlate(folder, file_name):
    files_list = []
    for root, dirs, files in _os.walk(folder):
        for file in files:
            if file == file_name:
                data_file = _ospath.join(root, file)
                files_list.append(data_file)

    if len(files_list) == 0:
        raise FileNotFoundError(f'No files with name {file_name} found in {folder}.')

    # Filter only CSV files
    files_list = [file for file in files_list if file.endswith('.csv')]
    print(f'Found {len(files_list)} CSV files with name {file_name} in {folder}.')
    
    for file in files_list:
        print(file)

    # Read all datasets
    data_list = [_pd.read_csv(file, index_col=0) for file in files_list]

    # Clean each dataset
    data_list = [df.dropna(how='all').fillna(0) for df in data_list]

    # Compute correlation matrices
    corr_mats = [df.corr() for df in data_list]

    # Compute average correlation matrix
    avg_corr = sum(corr_mats) / len(corr_mats)

    return data_list, corr_mats, avg_corr


def plot_correlation_matrices(corr_mats, avg_corr, labels=None, cmap='coolwarm'):
    n = len(corr_mats)
    # fig, axes = _plt.subplots(1, n + 1, figsize=(4*(n+1), 4))
    
    # # Plot each dataset correlation
    # for i, corr in enumerate(corr_mats):
    #     _sns.heatmap(corr, ax=axes[i], cmap=cmap, vmin=-1, vmax=1, cbar=False)
    #     axes[i].set_title(labels[i] if labels else f'Dataset {i+1}')
    #     _plt.show()
    
    # Plot average correlation
    _sns.heatmap(avg_corr, cmap=cmap, vmin=-1, vmax=1)

    _plt.tight_layout()
    _plt.show()

# Correlation Plot for Each Data Type
def correlationplot(data1, data2):
    if data1.shape != data2.shape:
        raise ValueError("Dataframes must have the same dimensions.")
    num_datasets = data1.shape[1]
    columns = data1.columns[:num_datasets]
    correlation_matrix = _np.ones((num_datasets,num_datasets))

    for i in range(num_datasets):
        for j in range(num_datasets):
            # Extract the two datasets
            data_i = data1.iloc[:, i]
            data_j = data2.iloc[:, j]

            # Find valid indices where both datasets are nonzero
            valid_idx = (data_i != 0) & (data_j != 0) & (~data_i.isna()) & (~data_j.isna())

            # Compute correlation only if enough nonzero values exist
            if valid_idx.sum() > 1:  # At least 2 points needed for correlation
                corr = _np.corrcoef(data_i[valid_idx], data_j[valid_idx])[0, 1] # Computes Pearson's Correlation
                # corr = spearmanr(data_i[valid_idx], data_j[valid_idx])[0] # Computes Spearman's Correlation
            else:
                corr = _np.nan  # Not enough data to compute correlation

            # Store in the matrix (symmetric)
            correlation_matrix[i, j] = corr

    correlation_matrix = _pd.DataFrame(correlation_matrix, index=columns, columns=columns)

    return correlation_matrix

# Plotting the correlation matrix
def plot_correlation_confusion_matrix(correlation_matrix, output_folder, output_file_name):
    # mask = _np.triu(_np.ones_like(correlation_matrix, dtype=bool), k=0)
    # correlation_matrix = correlation_matrix.iloc[1:, :-1]
    _np.fill_diagonal(correlation_matrix.values, 1)
    max_value = _np.max(correlation_matrix[correlation_matrix != 1].abs())
    _np.fill_diagonal(correlation_matrix.values, _np.nan)
    cmap = _sns.color_palette('PuOr_r', as_cmap=True)
    # ax = _sns.heatmap(correlation_matrix, annot=False, linewidths=0.5, square=True, cmap=cmap, vmin = (-1 * max_value), vmax = max_value, mask = correlation_matrix.isna(), center=0, )
    ax = _sns.heatmap(correlation_matrix, annot=False, linewidths=0.5, square=True, cmap=cmap, vmin = -0.1, vmax = 0.1, mask = correlation_matrix.isna(), center=0, )
    # ax.xaxis.tick_top()
    _plt.savefig(_ospath.join(output_folder, output_file_name), format = 'svg', bbox_inches='tight')
    _plt.show()

def plot_correlation_confusion_matrix_two_conditions(correlation_matrix_bottom_left, correlation_matrix_top_right, output_folder_1, output_folder_2, output_file_name):
    assert correlation_matrix_bottom_left.shape == correlation_matrix_top_right.shape, "Matrices must be the same shape"
    assert (correlation_matrix_bottom_left.index == correlation_matrix_bottom_left.columns).all(), "Bottom left must be a square matrix with matching index/columns"
    assert (correlation_matrix_top_right.index == correlation_matrix_top_right.columns).all(), "Top right must be a square matrix with matching index/columns"
    
    combined_matrix = correlation_matrix_bottom_left.copy() * 0
    
    for i in range(len(correlation_matrix_bottom_left)):
        for j in range(i):
            combined_matrix.iloc[i, j] = correlation_matrix_bottom_left.iloc[i, j]

    for i in range(len(correlation_matrix_top_right)):
        for j in range(i + 1, len(correlation_matrix_top_right)):
            combined_matrix.iloc[i, j] = correlation_matrix_top_right.iloc[i, j]

    max_value = _np.max(combined_matrix[combined_matrix != 1].abs())
    cmap = _sns.color_palette('PuOr_r', as_cmap=True)
    
    # ax = _sns.heatmap(combined_matrix, annot=False, linewidths=0.5, square=True, cmap=cmap, vmin = (-1 * max_value), vmax = max_value, mask = combined_matrix.isna(), center=0, )
    ax = _sns.heatmap(combined_matrix, annot=False, linewidths=0.5, square=True, cmap=cmap, vmin = -0.11, vmax = 0.11, mask = combined_matrix.isna(), center=0, )
    _plt.savefig(_os.path.join(output_folder_1, output_file_name), format = 'svg', bbox_inches='tight')
    _plt.savefig(_os.path.join(output_folder_2, output_file_name), format = 'svg', bbox_inches='tight')
    _plt.show()

# Correlation Difference Matrix
def plot_difference_matrix(correlation_condition_bottom_left, correlation_condition_top_right, output_folder_1, output_folder_2, output_file_name):
    difference_matrix = correlation_condition_top_right - correlation_condition_bottom_left
    mask = _np.triu(_np.ones_like(difference_matrix, dtype=bool), k=0)
    difference_matrix = difference_matrix.iloc[1:, :-1]
    max_value = _np.max(difference_matrix[difference_matrix != 1].abs())
    _sns.heatmap(difference_matrix, mask=mask[1:, :-1], annot=False, linewidths=0.5, square=True, cmap='PuOr_r', vmin = (-1 * max_value), vmax = max_value, center = 0, )
    _plt.savefig(_os.path.join(output_folder_1, output_file_name), format = 'svg', bbox_inches='tight')
    _plt.savefig(_os.path.join(output_folder_2, output_file_name), format = 'svg', bbox_inches='tight')
    _plt.show()

In [None]:
# Average Correlation Matrix

folder = ''  # <<< Set your folder path here
output_folder = output_folder_path(folder)
file_name = 'data_1.csv'

data_list, corr_mats, avg_corr = combine_and_correlate(folder, file_name)
plot_correlation_confusion_matrix(avg_corr, output_folder, output_file_name = 'correlation_matrix_mock.svg')