# Region Selector Script
- Retrieves featureDicts generated during classifySamples analyses
- Retrieves lightsheet data
- For each included region returned in the contrast, Show means across regions for each element + saline
- Sort the table of regions based on greatest relative difference b/t regions

In [6]:
def process_lightSheet(lightsheet_data, regionList, drug, salineScale, acaNorm):
    # Use the class to extract data from lightsheetData
    ls_data_comp = lightsheet_data.copy()
    classifyDict = dict(label = f"class_{drug}") 

    conv_dict = hf.create_drugClass_dict(classifyDict)
    conv_dict['SAL'] = 'Saline'

    ls_data_comp[classifyDict['label']] = ls_data_comp['drug'].map(conv_dict)
    ls_data_comp = ls_data_comp.dropna(subset=[classifyDict['label']])

    # If scaling by ACA
    if acaNorm:
        ls_data_sums = ls_data_comp.groupby(['abbreviation', "dataset"])['count'].sum().reset_index()
        ls_data_sums_ACA = ls_data_sums[ls_data_sums.abbreviation.isin(['ACAv'])]
        ls_data_sums_ACA = ls_data_sums_ACA.groupby(['dataset'])['count'].sum()

        # Scaling factor of ACAv in mouse * mean ACAv count in all treated mice.
        ACA_drug = ls_data_sums_ACA[~ls_data_sums_ACA.index.str.contains('SAL')]
        ACA_drug_mean = ACA_drug.mean()

        for dSet in ACA_drug.index:
            # in each dataset, scale the entire dataset to the mean.
            ls_data_comp.loc[ls_data_comp['dataset'] == dSet, 'count'] = ls_data_comp.loc[ls_data_comp['dataset'] == dSet, 'count'] * ACA_drug_mean/ACA_drug[dSet]

    # Filter the table based on regions of interest
    ls_data_comp_z = ls_data_comp.groupby(['abbreviation', f"class_{drug}"])['count'].mean().reset_index()
    ls_data_comp_piv = ls_data_comp_z.pivot(index='abbreviation', columns=f"class_{drug}", values='count')
    ls_data_comp_regions = ls_data_comp_piv.loc[ls_data_comp_piv.index.isin(regionList), :]

    # Scale columns by by Saline
    fmtStr = '.0f'

    if salineScale:
        lsdc_piv_datas = ls_data_comp_regions.values
        class_names = list(ls_data_comp_regions.columns)
        sal_idx = [x == 'Saline' for x in class_names]
        other_idx = [not elem for elem in sal_idx]

        remaining_classes = list(compress(class_names, other_idx))
        remaining_data_salScale = (lsdc_piv_datas[:, other_idx] - lsdc_piv_datas[:, sal_idx])/lsdc_piv_datas[:, sal_idx]
        lsdc_piv_scaled = pd.DataFrame(remaining_data_salScale, index=ls_data_comp_regions.index, columns=remaining_classes)
        lsdc_plot = lsdc_piv_scaled
        fmtStr = '.0%'

    else: 
        lsdc_plot = ls_data_comp_regions

    # Sort the table
    lsdc_plot = lsdc_plot.copy()
    lsdc_plot['Diff'] = lsdc_plot.iloc[:, 0] - lsdc_plot.iloc[:, 1]
    lsdc_plot = lsdc_plot.sort_values('Diff')
    hvAxBool = list(lsdc_plot['Diff'] > 0)
    if any(hvAxBool):
        hvAxIdx = hvAxBool.index(True)
    else:
        hvAxIdx = 0

    return lsdc_plot, fmtStr, hvAxIdx

def find_feature_dict_files(root_dir, tags):
    feature_dict_files = []

    # Traverse the directory structure
    for foldername, subfolders, filenames in os.walk(root_dir):
        for filename in filenames:
            # Check if the file is named 'featureDict.pkl'
            if filename == 'scoreDict_Real.pkl':
                file_path = os.path.join(foldername, filename)

                # Check if the file contains all specified substrings in the 'tags' list
                if all(tag in file_path for tag in tags):
                    feature_dict_files.append(file_path)

    return feature_dict_files

def extract_substrings_between_a_and_b(string_list, a, b):
    extracted_substrings = []

    for input_string in string_list:
        # Define the regular expression pattern to match substrings between A and B
        pattern = re.compile(f'{re.escape(a)}(.*?){re.escape(b)}', re.DOTALL)
        
        # Find all matches in the input string
        matches = pattern.findall(input_string)

        # Append the matched substrings to the result list
        extracted_substrings.extend(matches)

    return extracted_substrings

In [8]:
# Retrieve featureDicts
import pandas as pd
import pickle as pkl
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import os, re, sys
from collections import Counter
sys.path.append('..//../functionScripts/')
import helperFunctions as hf

# Load lightsheetData
lightsheet_data = pd.read_pickle(r"C:\\OneDrive\\KwanLab\\Lightsheet_cFos_Pipeline\\figureScripts\\lightsheet_data.pkl")

filtList = True
filtThresh = 75
salineScale = False
acaNorm = False
plotData = False

# Example usage:
root_directory = "C:\OneDrive\KwanLab\Lightsheet_cFos_Pipeline\\1.scaled_Output\classif\\"
tags_to_match = ['data=count_norm', 'PowerTrans_RobScal_fSel_BorFS']

matching_files = find_feature_dict_files(root_directory, tags_to_match)
drugList = extract_substrings_between_a_and_b(matching_files, 'count_norm-', '\PowerTrans')

for drug, dictPath in zip(drugList, matching_files):

    # Retrieve both the relevant dictionaries
    allRegion, allCount = [], []

    with open(dictPath, 'rb') as f:                 
        scoreDict = pkl.load(f)
    
    # Extract relevant values
    pltTitle = scoreDict['compLabel'].replace(' ', '_')
    pltTitle = scoreDict['compLabel'].replace('/', '+')
    regionList = np.concatenate(scoreDict['featuresPerModel'])
    regionDict = dict(Counter(regionList))

    regionList = np.array(list(regionDict.keys()))
    featureCount = np.array(list(regionDict.values()))
    
    if filtList:
        keepInd = (featureCount >= filtThresh)
        regionList = regionList[keepInd]
        featureCount = featureCount[keepInd]
    
    allRegion.append(regionList)
    allCount.append(featureCount)
    
    # Use the class to extract data from lightsheetData
    lsdc_plot, fmtStr, hvAxIdx = process_lightSheet(lightsheet_data, regionList, drug, salineScale, acaNorm)
    
    # Save this for importing into brainrender (which has to run older versions of python/packages)
    fullSavePath = f'{os.getcwd()}\\br_{pltTitle}.csv'
    lsdc_plot.to_csv(fullSavePath)

    if plotData:
        # Keep only first two columns
        if salineScale:
            lsdc_plot = lsdc_plot.iloc[:, 0:2]
        else:
            lsdc_plot = lsdc_plot.iloc[:, 0:3]

        # Plotting
        plt.figure(figsize=(1.5 * len(lsdc_plot.columns), .6 * len(lsdc_plot.index)))
        
        sns.heatmap(lsdc_plot, annot=True, fmt=fmtStr)  #, cmap='RdBu', center=0, vmin=-1, vmax=1, annot=True, fmt=fmtStr
        titleStr = f"{pltTitle}"

        if acaNorm:
            titleStr += ' ACA norm'

        if salineScale:
            titleStr += ' (% Change vs Saline)'

        # Draw a line for the 0 cross point
        plt.axhline(y=hvAxIdx, color='white', linewidth=2)

        plt.ylabel("Region Abbv.")
        plt.xlabel("Treatment", fontsize=12)
        plt.tick_params(axis='x', which='both', length=0)
        plt.title(titleStr, fontsize=10)
        plt.yticks(rotation=0)

        # plt.savefig(dirDict['classifyDir'] + titleStr + '.png', dpi=300, format='png', bbox_inches='tight')
        plt.show()