In [None]:

import matplotlib.pyplot as plt
import numpy as np
import matplotlib.colors as mcolors
import os, sys, random
import pickle as pkl
from os.path import exists, join

sys.path.append('../dependencies/')

# Figure dir
figDir = os.path.join(os.getcwd(), 'figures_output')
if not os.path.isdir(figDir):
    os.makedirs(figDir)

# Define the target directory
targDir = "C:\OneDrive\KwanLab\Lightsheet_cFos_Pipeline\\1.scaled_Output\\"

# Define a list of substrings to filter directories
tagList = ['data=count_norm-', 'PowerTrans_RobScal_fSel_BorFS_clf_LogReg(multinom)_CV100']

# Call the function and get the list of paths based on the tagList
score_dict_paths = []

# Walk through the directory and its subdirectories
for root, dirs, files in os.walk(targDir):
    # Check if 'scoreDict.pkl' is present in the files
    if 'scoreDict_Real.pkl' in files:
        if all(tag in root for tag in tagList):
            score_dict_paths.append(os.path.join(root, 'scoreDict_Real.pkl'))

# Each directory name will be used to generate a label, based on the sequence between the strings in the directory name below
startStr = 'count_norm-'
endStr = '\PowerTrans'
featureLists, countNames  = [], []

# Print the result
print(f"Found 'scoreDict.pkl' files in directories containing {tagList}:")
for path in score_dict_paths:
    print(path)

    # Load the scoreDict.pkl file and extract desired accuracy.
    with open(path, 'rb') as f:                 
        featureDict = pkl.load(f)
        featureLists.append(featureDict['featuresPerModel'])

    # Extract the label for the entry
    start_index = path.find(startStr)
    end_index = path.find(endStr)
    # scoreNames.append(path[start_index + len(startStr):end_index] if start_index != -1 and end_index != -1 else None)
    countNames.append(featureDict['compLabel'])

In [None]:
# Modifying Score names
def replace_strings_with_dict(input_strings, translate_dict):
    replaced_strings = []

    for string in input_strings:
        for key, value in translate_dict.items():
            string = string.replace(key, value)
        replaced_strings.append(string)

    return replaced_strings

translateDict = dict()
translateDict['Ag_5HT2A'] = '5HT2A Agonists'
translateDict['Acute SSRI'] = 'A-SSRI'
translateDict['Chronic SSRI'] = 'C-SSRI'
translateDict['Psilocybin'] = 'PSI'
translateDict['Ketamine'] = 'KET'
translateDict['H_Trypt'] = 'H.Trypt'
translateDict['Non Halluc Trypt'] = 'Non-H.Trypt'
translateDict['6-Fluoro-DET'] = '6-F-DET'

scoreNames = replace_strings_with_dict(countNames, translateDict)

## Create Violin Plot for Feature counts

In [None]:
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

swapDict = dict()
swapDict['MDMA vs PSI/5MEO'] = 'PSI/5MEO vs MDMA'
swapDict['6-F-DET vs PSI/5MEO'] = 'PSI/5MEO vs 6-F-DET'

desiredNames = ['6-F-DET vs PSI', '5MEO vs 6-F-DET', 'PSI/5MEO vs 6-F-DET', 'PSI/5MEO vs MDMA', 'KET vs PSI', 'A-SSRI vs C-SSRI', 'A-SSRI vs PSI', 'MDMA vs PSI', '5MEO vs PSI']
origNames = ['6-F-DET vs PSI', '5MEO vs 6-F-DET', '6-F-DET vs PSI/5MEO',    'MDMA vs PSI/5MEO', 'KET vs PSI', 'A-SSRI vs C-SSRI', 'A-SSRI vs PSI', 'MDMA vs PSI', '5MEO vs PSI']

# Your list of lists (sublists with numbers)
data = [[len(sublist) for sublist in inner_list] for inner_list in featureLists]
# colors = ['red', 'green', 'blue', 'purple'] * 2

# Use numpy.argsort to obtain the indices that would sort the original list
sort_indices = [list(origNames).index(name) for name in scoreNames]

sort_indices = np.array(sort_indices[::-1])
sort_indices = np.array([4, 6, 7, 8, 5, 0, 1, 2, 3])

# Use the sorted indices to reorder the original list
names = np.array(scoreNames)[sort_indices]
data = np.array(data)[sort_indices]

for i in range(len(names)):
    if names[i] in swapDict:
        names[i] = swapDict[names[i]]

# Create a data frame with melted data
flat_data = [item for sublist in data for item in sublist]

df = pd.melt(pd.DataFrame(data, index=names).T, var_name='Category', value_name='Values')

# Create horizontally oriented violin plot
colorsList = [[82, 211, 216], [56, 135, 190]]
colorsList = np.array(colorsList)/256

plt.figure(figsize=(5, 5))  # Adjust the width and height as needed

ax = sns.violinplot(x='Values', y='Category', bw_adjust=.5, data=df, orient='h', color=colorsList[0])  #, palette=colors)  # Remove inner bars and set color
# for violin in ax.collections:
#     violin.set_alpha(1)

# Overlay individual data points
# sns.stripplot(x='Values', y='Category', data=df, orient='h', edgecolor='black', linewidth=0.3, jitter=True, size=4, palette=colors, alpha=.5)  # Adjust size and color

# Set plot labels and title
plt.xlabel('Feature Count')
plt.ylabel('Classifier')
plt.title('Feature Count Per Split')
plt.savefig("FeatureCountPerSplit_violin.svg", format='svg', bbox_inches='tight')     

# Show the plot
plt.show()


# Create Distance matricies to compare features across comparisons

In [None]:
from collections import Counter


def jaccard(u, v):
    u,v = set(u), set(v) # pdist will pass 2D data [[a,b,c]], so we need to slice
    return 1-len(u.intersection(v))/len(u.union(v))

def weighted_jaccard_similarity(u, v):
    counter_u, counter_v = Counter(u), Counter(v)
    intersection = sum((counter_u & counter_v).values())
    union = sum((counter_u | counter_v).values())

    # Using the modified Jaccard similarity with frequency
    similarity = intersection / union if union != 0 else 0

    return similarity

modelCount = len(featureLists)
# regionDict = dict(Counter(featureLists[0]))
# labels, counts = list(regionDict.keys()), list(regionDict.values())

# Initialize a grid
grid = [[0 for _ in range(modelCount)] for _ in range(modelCount)]

weightedList = True
# compare the mean distances across items of the list
for idx_a, listA in enumerate(featureLists):
    for idx_b, listB in enumerate(featureLists):
        distances = []
        if weightedList:
            listA_flat = [item for sublist in listA for item in sublist]
            listB_flat = [item for sublist in listB for item in sublist]
            grid[idx_a][idx_b] = weighted_jaccard_similarity(listA_flat, listB_flat)

        else:
            # # compute the distance between each individual list
            for i in range(len(listA)):
                for j in range(len(listB)):
                    if i != j and idx_b != idx_a:
                        distances.append(jaccard(listA[i], listB[j]))
            # compute the mean distance
            grid[idx_a][idx_b] = np.mean(distances)





## Plot

In [None]:
from matplotlib import cm

# Plot the grid
fig, ax = plt.subplots(figsize=(10,10))
# plt.figure(figsize=(10,10)) cm.plasma.reversed()
im = ax.imshow(grid, cmap='Blues', interpolation='nearest')
cbar = plt.colorbar(im, shrink=0.8)
cbar.ax.tick_params(labelsize=15)  # Adjust the font size for colorbar ticks

# Set font size for x-axis ticks and labels
ax.tick_params(axis='x', labelsize=12)

# Set font size for y-axis ticks and labels
ax.tick_params(axis='y', labelsize=12)

plt.xticks(range(modelCount), scoreNames, rotation=90, ha='right', rotation_mode='anchor')
ax.set_xticklabels(ax.get_xticklabels(), rotation=45)
plt.yticks(range(modelCount), scoreNames)
plt.title('Similarity Between Feature Lists', fontdict={'fontsize': 25})
plt.savefig("MeanSimilarity_heatmap.svg", format='svg', bbox_inches='tight')     
plt.show()


In [None]:
import matplotlib.pyplot as plt
from matplotlib_venn import venn2

# Flatten each list
featureLists2 = [set([item for sublist in inner_list for item in sublist]) for inner_list in featureLists]

# Sample data
for list1 in featureLists2:
    for list2 in featureLists2:
        # Calculate the intersection and differences
        intersection = list1.intersection(list2)
        only_list1 = list1 - list2
        only_list2 = list2 - list1

        if only_list1 == set() and only_list2 == set():
            continue

        # Create a Venn diagram
        venn_labels = {'100': only_list1, '010': only_list2, '110': intersection}
        venn_diagram = venn2(subsets=(len(only_list1), len(only_list2), len(intersection)),
                            set_labels=('List 1', 'List 2'))

        # # Adjust positions for a more spread-out appearance
        for idx, (text, pos) in enumerate(venn_labels.items()):
            venn_diagram.get_label_by_id(text).set_text('\n'.join(pos))
            venn_diagram.get_label_by_id(text).set_fontsize(8)  # Adjust font size if needed

        # Customize the size of the Venn diagram
        plt.gcf().set_size_inches(8, 8)

        # Display the plot
        plt.show()