# Plotter for KDE plots all targets vs all targets.

### Note: Use the Epi-PAINT kernel.

In [None]:
import os as _os
import os.path as _ospath
import numpy as _np
import pandas as _pd
import h5py as _h5py
import yaml as _yaml
from PyQt5.QtWidgets import QMessageBox as _QMessageBox
import matplotlib.pyplot as _plt
import seaborn as _sns
import itertools
from tqdm import tqdm
from matplotlib.colors import LogNorm
from matplotlib.colors import LinearSegmentedColormap, to_hex

In [None]:
folder = ''
maximum_threshold = 0.7
minimum_threshold = -0.7
min_radius = 100
step_size = 100
maximum_radius = 1000
correlation_data_folder = _ospath.join(folder, 'Analysis', 'Correlations', str(min_radius) + '_' + str(step_size) + '_' +str(maximum_radius))
correlation_data_files = [f for f in _os.listdir(correlation_data_folder) if f.endswith('.csv')]

output_folder = _ospath.join(folder, 'Analysis', 'Correlations', 'Plots' + '_' + str(min_radius) + '_' + str(step_size) + '_' +str(maximum_radius))
if not _ospath.exists(output_folder):
    _os.makedirs(output_folder)

# Define the proteins in the data. The plotting will be in this order. 
proteins = ['S2P', 'S5P', 'SC35', 'H3K4me3', 'H3K27ac', 'CTCF', 'H3K27me3', 'H3K9me3', 'LaminB1']

In [None]:
cmap_proteins_white = {
    'S2P': LinearSegmentedColormap.from_list('S2P', ['#FFFFFF', '#FF0000']),
    'S5P': LinearSegmentedColormap.from_list('S5P', ['#FFFFFF', '#FFAA00']),
    'SC35': LinearSegmentedColormap.from_list('SC35', ['#FFFFFF', '#AAFF00']),
    'H3K4me3': LinearSegmentedColormap.from_list('H3K4me3', ['#FFFFFF', '#00FF00']),
    'H3K27ac': LinearSegmentedColormap.from_list('H3K27ac', ['#FFFFFF', '#00FFAA']),
    'CTCF': LinearSegmentedColormap.from_list('CTCF', ['#FFFFFF', '#00AAFF']),
    'H3K27me3': LinearSegmentedColormap.from_list('H3K27me3', ['#FFFFFF', '#0000FF']),
    'H3K9me3': LinearSegmentedColormap.from_list('H3K9me3', ['#FFFFFF', '#AA00FF']),
    'LaminB1': LinearSegmentedColormap.from_list('LaminB1', ['#FFFFFF', '#FF00AA']),
}

In [None]:
fig, axes = _plt.subplots(9, 9, figsize = (20,20), sharex = True, sharey = True)

diff_matrix = _np.full((len(proteins), len(proteins)), _np.nan)

for row_id, row_protein in enumerate(proteins):
    for column_id, column_protein in enumerate(proteins):
        if row_protein == column_protein:
            ax = axes[row_id, column_id]
            ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False, labelbottom=False, labelleft=False)
            ax.set_aspect('equal', adjustable='box')
            if column_id == 0:
                ax.set_ylabel(row_protein)
            if row_id == len(proteins) - 1:
                ax.set_xlabel(column_protein)
            continue
        else:
            corr_file_name = row_protein + '_vs_' + column_protein + '.csv'
            if corr_file_name not in correlation_data_files:
                print(f'Data for {row_protein}_vs_{column_protein} not found')
                continue
            corr_file_path = _ospath.join(correlation_data_folder, corr_file_name)
            corr = _np.loadtxt(corr_file_path, delimiter = ',')
            ax = axes[row_id, column_id]
            _sns.kdeplot(corr, ax=ax, color = to_hex(cmap_proteins_white[row_protein](1.0)), linewidth = 2.5)
            # File the top and bottom above threshold.
            corr_above_threshold = corr[corr > maximum_threshold]
            corr_below_threshold = corr[corr < minimum_threshold]
            per_above = len(corr_above_threshold) / len(corr) * 100
            per_below = len(corr_below_threshold) / len(corr) * 100
            
            diff_matrix[row_id, column_id] = per_above - per_below

            line = ax.lines[0]
            x, y = line.get_data()
            mask = x > maximum_threshold
            ax.fill_between(x[mask], y[mask], color = to_hex(cmap_proteins_white[row_protein](1.0)), alpha = 0.5, edgecolor = None)
            ax.text(0.95, 1.15, f'{len(corr_above_threshold)/len(corr)*100:.2f}%', ha = 'right', va = 'top', transform=ax.transAxes, fontsize=10, color = to_hex(cmap_proteins_white[row_protein](1.0)))
            mask = x < minimum_threshold
            ax.fill_between(x[mask], y[mask], color = to_hex(cmap_proteins_white[row_protein](1.0)), alpha = 0.5, edgecolor = None)
            ax.text(0.05, 1.15, f'{len(corr_below_threshold)/len(corr)*100:.2f}%', ha = 'left', va = 'top', transform=ax.transAxes, fontsize=10, color = to_hex(cmap_proteins_white[row_protein](1.0)))
            ax.set_ylim(0, 2)

            # Apply tick styling
            ax.tick_params(axis = 'both', which = 'both', bottom = True, top = False, left = False, right = False, labelleft = False, direction = 'in')

            # # Hide tick labels except for bottom row and left column
            # if row_id != len(proteins) - 1:
            #     ax.set_xticklabels([])  # Hide x tick labels
            # else:
            #     ax.tick_params(labelbottom=True)

            if column_id != 0:
                ax.set_yticklabels([])  # Hide y tick labels

            # Set aspect and axis labels
            ax.set_aspect('equal', adjustable='box')
            if column_id == 0:
                ax.set_ylabel(row_protein)
            if row_id == len(proteins) - 1:
                ax.set_xlabel(column_protein)

            # ax.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False, labelbottom=False, labelleft=False)
            # ax.set_aspect('equal', adjustable='box')
            # if column_id == 0:
            #     ax.set_ylabel(row_protein)
            # if row_id == len(proteins) - 1:
            #     ax.set_xlabel(column_protein)

_plt.savefig(_ospath.join(output_folder, 'DoC_KDE_All.svg'), format = 'svg', bbox_inches = 'tight')
_plt.show()
    