In [None]:
#!/usr/bin/env python3
"""
Binaural Hearing Threshold Analysis with Genetic Markers

This script analyzes temporal interval discrimination thresholds under different conditions
(baseline, control, binaural) and examines their relationship with genetic variants.

Author: M. Mahdi Abolghsemi
Date: 7/7/2025
"""


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import os
from scipy.io import loadmat
from scipy.stats import wilcoxon, shapiro
from scipy.optimize import curve_fit
import warnings
warnings.filterwarnings("ignore")

# Data loading and processing functions
def load_data_from_path(path_base):
    """Load and process .mat files from directory"""
    files = os.listdir(path_base)
    sub_data = np.ndarray(len(files), dtype=object)
    
    for i, file in enumerate(files):
        print(f"{i}: {file}")
        tmp = loadmat(os.path.join(path_base, file))['Data']
        for trial_idx, trial in enumerate(tmp):
            for feat_idx, feature in enumerate(trial):
                tmp[trial_idx, feat_idx] = feature.squeeze()
        sub_data[i] = pd.DataFrame(tmp)
    return sub_data

def process_dataframes(data):
    """Clean dataframes by removing empty cells and NaN rows"""
    for i, df in enumerate(data):
        data[i] = df.mask(df.applymap(lambda x: 
            (isinstance(x, np.ndarray) and x.size == 0) or 
            (not isinstance(x, np.ndarray) and x in ['', [], None])
        )).dropna(how='all')
    return data

def get_correct_answers(data):
    """Filter for correct responses where columns 5 and 10 match"""
    return [df[df[5] == df[10]] for df in data]

def process_data(data, SI):
    """Process data for specific stimulus intensity (SI)"""
    result = []
    for score in data:
        filtered = pd.concat([score[score[6] == SI], score[score[8] == SI]]).sort_index()
        result.append((filtered[6] + filtered[8] - SI)[:])
    return result

def create_histogram_1000(data, bins=60, range_vals=(1000, 1700)):
    """Create histograms for response times"""
    data_df = pd.DataFrame(data)
    return [np.histogram(row.dropna().values, bins=bins, range=range_vals, density=False) 
            for _, row in data_df.iterrows()]

def create_histogram_300(data):
    """Create histograms for 300ms condition"""
    data_df = pd.DataFrame(data)
    return [np.histogram(row.dropna().values, bins=50, range=(300, 600)) 
            for _, row in data_df.iterrows()]

def fill_psychometric_data(row):
    """Fill NaN values in psychometric curves"""
    first_valid = row.first_valid_index()
    last_valid = row.last_valid_index()
    
    # Fill beginning NaNs with 0.5
    if first_valid is not None:
        row[:first_valid] = 0.5
        if row[first_valid] < 0.5:
            row[first_valid] = 0.5
    
    # Fill ending NaNs with 1
    if last_valid is not None:
        row[last_valid+1:] = 1
        if row[last_valid] != 1:
            row[last_valid] = 1
    
    # Interpolate middle values
    for i in range(len(row)):
        if pd.isna(row[i]) or row[i] == 0:
            prev_idx = row[:i][~row[:i].isna() & (row[:i] != 0)].last_valid_index()
            next_idx = row[i+1:][~row[i+1:].isna() & (row[i+1:] != 0)].first_valid_index()
            
            if prev_idx is not None and next_idx is not None:
                row[i] = (row[prev_idx] + row[next_idx]) / 2
            elif prev_idx is not None:
                row[i] = row[prev_idx]
            elif next_idx is not None:
                row[i] = row[next_idx]
        
        if row[i] < 0.5:
            row[i] = 0.5
    
    return row

# Psychometric function and threshold calculation
def psychometric_func(x, alpha, beta):
    """Psychometric function for curve fitting"""
    return 0.5 + (1 - 0.5) / (1 + np.exp(-(x - alpha) / beta))

def calculate_thresholds(data, num_samples=50, target_accuracy=0.75, SI=300):
    """Calculate thresholds using psychometric curve fitting"""
    x_data = np.linspace(0, num_samples - 1, num_samples)
    thresholds = []
    
    for subject in range(data.shape[0]):
        subject_data = data.iloc[subject, :]
        
        try:
            params, _ = curve_fit(psychometric_func, x_data, subject_data,
                                p0=[15, 1], bounds=([0, 0], [num_samples, 10]), maxfev=1000)
            alpha, beta = params
            threshold = alpha + beta * (target_accuracy - 0.5) / (1 - target_accuracy)
            threshold_scaled = float(threshold) * (SI / num_samples)
            thresholds.append(threshold_scaled)
        except:
            thresholds.append(np.nan)
    
    return pd.DataFrame(thresholds, columns=["Threshold"])

# Plotting functions
def plot_comparison_stats(groups, group_names, title="", save_path=None, ylim=None):
    """Plot comparison with statistical tests"""
    # Combine data
    for i, group in enumerate(groups):
        group['Genotypes'] = group_names[i]
    
    df = pd.concat(groups, axis=0)
    df_melted = df.melt(id_vars=['Genotypes'], var_name='Condition', value_name='Log Threshold (ms)')
    
    # Create plot
    plt.figure(figsize=(10, 6))
    sns.boxplot(data=df_melted, x="Genotypes", y="Log Threshold (ms)", hue="Condition", showfliers=False)
    
    # Statistical tests
    y_max = df_melted['Log Threshold (ms)'].max()
    comparisons = [('base', 'control'), ('base', 'binaural'), ('control', 'binaural')]
    condition_pos = {'base': -0.25, 'control': 0, 'binaural': 0.25}
    
    for i, group_name in enumerate(group_names):
        group_data = df[df['Genotypes'] == group_name]
        
        for j, (cond1, cond2) in enumerate(comparisons):
            try:
                # Use Wilcoxon signed-rank test
                _, p_value = wilcoxon(group_data[cond1], group_data[cond2])
                
                # Position annotations
                x1 = i + condition_pos[cond1]
                x2 = i + condition_pos[cond2]
                y = y_max + ((j + 1) + i * 0.5) * 0.05 * y_max
                h = 0.02 * y_max
                
                # Draw significance line
                plt.plot([x1, x1, x2, x2], [y, y + h, y + h, y], lw=1.5, c='black')
                
                # Add significance text
                if p_value < 0.001:
                    star = f"*** p={p_value:.3f}"
                elif p_value < 0.01:
                    star = f"** p={p_value:.3f}"
                elif p_value <= 0.05:
                    star = f"* p={p_value:.3f}"
                else:
                    star = f"p={p_value:.3f}"
                
                plt.text((x1 + x2) / 2, y + h, star, ha='center', va='bottom', 
                        fontsize=10, fontweight='bold')
            except:
                continue
    
    plt.ylabel("Log Threshold (ms)")
    plt.title(title)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    
    if ylim:
        plt.ylim(ylim)
    
    if save_path:
        plt.savefig(f"{save_path}.png", format='png', dpi=300, bbox_inches='tight')
    
    plt.tight_layout()
    plt.show()

# Main analysis pipeline
def run_analysis():
    # Load genetic data
    alleles = pd.read_excel('ID_data/all/genome.xlsx')
    alleles.columns.values[3] = 'COMT'
    
    # Load behavioral data
    paths = {
        'base': 'ID_data/all/base',
        'binaural': 'ID_data/all/binaural',
        'monaural': 'ID_data/all/monoral'
    }
    
    data = {}
    for condition, path in paths.items():
        data[condition] = load_data_from_path(path)
        data[condition] = process_dataframes(data[condition])
    
    # Process for correct responses
    correct_data = {k: get_correct_answers(v) for k, v in data.items()}
    
    # Calculate psychometric curves for both SI conditions
    conditions = {}
    for si, hist_func in [(1000, create_histogram_1000), (300, create_histogram_300)]:
        conditions[si] = {}
        
        for condition in ['base', 'binaural', 'monaural']:
            # All trials
            hist_all = pd.DataFrame(hist_func(process_data(data[condition], si)))[0]
            # Correct trials only
            hist_correct = pd.DataFrame(hist_func(process_data(correct_data[condition], si)))[0]
            
            # Calculate accuracy
            accuracy = pd.DataFrame((hist_correct / hist_all).to_list())
            accuracy = accuracy.apply(fill_psychometric_data, axis=1)
            
            # Calculate thresholds
            num_bins = 60 if si == 1000 else 50
            thresholds = calculate_thresholds(accuracy, num_bins, SI=si)
            conditions[si][condition] = thresholds
    
    # Combine with genetic data
    results = {}
    for si in [300, 1000]:
        combined = pd.concat([
            conditions[si]['base'], 
            conditions[si]['monaural'], 
            conditions[si]['binaural']
        ], axis=1)
        combined.columns = ['base', 'control', 'binaural']
        
        # Add genetic data
        combined = pd.concat([combined, alleles[['DAT', 'COMT', 'ACHEA (U/ML)', 'rs6313', 'SLC6A4']]], axis=1)
        
        # Map DAT categories
        dat_mapping = {9: '9R9R', 9.5: '9R10R', 10: '10R10R'}
        combined['DAT_Category'] = combined['DAT'].map(dat_mapping)
        
        # Log transform for analysis
        for col in ['base', 'control', 'binaural']:
            combined[col] = np.log10(combined[col])
        
        # Fill missing values with group means
        #combined = combined.apply(lambda col: col.fillna(col.mean()) if col.dtype in ['float64', 'int64'] else col)
        
        results[si] = combined
    
    return results

# Generate plots for genetic analysis
def plot_genetic_analysis(results):
    """Generate all genetic comparison plots"""
    
    for si in [300, 1000]:
        data = results[si]
        
        # COMT analysis
        comt_aa = data[data['COMT'] == 'AA'][['base', 'control', 'binaural']]
        comt_ag_gg = data[data['COMT'].isin(['AG', 'GG'])][['base', 'control', 'binaural']]
        plot_comparison_stats([comt_aa, comt_ag_gg], ['AA', 'AG-GG'], 
                            f'COMT (rs4680) - {si}ms')
        
        # DAT analysis
        dat_groups = []
        dat_names = []
        for category in ['9R9R', '10R10R', '9R10R']:
            group_data = data[data['DAT_Category'] == category][['base', 'control', 'binaural']]
            if not group_data.empty:
                dat_groups.append(group_data)
                dat_names.append(category)
        
        if len(dat_groups) >= 2:
            plot_comparison_stats(dat_groups, dat_names, f'SLC6A3 (DAT) - {si}ms')
        
        # HTR2A analysis
        htr2a_groups = []
        htr2a_names = []
        for genotype in ['CC', 'CT', 'TT']:
            group_data = data[data['rs6313'] == genotype][['base', 'control', 'binaural']]
            if not group_data.empty:
                htr2a_groups.append(group_data)
                htr2a_names.append(genotype)
        
        if len(htr2a_groups) >= 2:
            plot_comparison_stats(htr2a_groups, htr2a_names, f'HTR2A (rs6313) - {si}ms')
        
        # SLC6A4 analysis
        slc6a4_groups = []
        slc6a4_names = []
        for genotype in ['S', 'S/L', 'L']:
            group_data = data[data['SLC6A4'] == genotype][['base', 'control', 'binaural']]
            if not group_data.empty:
                slc6a4_groups.append(group_data)
                slc6a4_names.append(genotype)
        
        if len(slc6a4_groups) >= 2:
            plot_comparison_stats(slc6a4_groups, slc6a4_names, f'SLC6A4 - {si}ms')

# Run the complete analysis
if __name__ == "__main__":
    results = run_analysis()
    plot_genetic_analysis(results)