# Peripersonal Space Experimental Framework

## Project Overview

This computational framework is designed to investigate human spatial perception through a sophisticated, controlled experimental paradigm that explores how breathing patterns influence sensory processing and spatial awareness.

## Core Objectives

The script generates a comprehensive experimental setup that:
1. Creates counterbalanced experimental designs for 100 participants
2. Produces precisely synchronized audio stimuli
3. Ensures rigorous scientific methodology through advanced validation techniques

## Key Parameters

### Experimental Design
- **Participants**: 100 unique designs
- **Total Trials**: 202 trials per participant
  * Inhalation Trials: 60
  * Exhalation Trials: 60
  * Baseline Trials: 60
  * Catch Trials: 22

### Stimulus Characteristics
- **Spatial Audio Stimuli**: 
  * Right azimuth (90 degrees)
  * Front azimuth (0 degrees)
  * Left azimuth (-90 degrees)

### Timing Parameters
- **Stimulus Onset Asynchrony (SOA)**: 5 conditions
  * 190 ms
  * 400 ms
  * 700 ms
  * 1000 ms
  * 1500 ms

- **Temporal Jitter**: Random delays
  * 100 ms
  * 200 ms
  * 300 ms
  * 400 ms
  * 500 ms

## Computational Workflow

### Design Generation
- Generates counterbalanced experimental designs
- Ensures balanced distribution of:
  * Stimulus types
  * Trial types
  * SOA conditions
  * Temporal jitter

### Validation Techniques
- Doegen-inspired statistical checks
- Comprehensive validation including:
  * Transition matrix analysis
  * Factor orthogonality
  * Sequence randomness testing
  * Autocorrelation analysis

### Audio Generation
- Creates two synchronized audio files per participant:
  * Looming stimuli embedded in box breathing audio
  * Tactile stimuli track with precise SOA adjustments

## Outputs
- Participant-specific design CSV files
- Detailed stimulus injection logs
- Synchronized audio files
- Comprehensive validation reports and visualizations

## Scientific Significance
Investigates the neural mechanisms of:
- Peripersonal space perception
- Multisensory integration
- Breathing's impact on sensory processing

## Technological Innovations
- Millisecond-level stimulus synchronization
- Advanced counterbalancing
- Reproducible experimental design methodology

## Usage
Researchers can use this framework to conduct controlled experiments exploring the intricate relationship between breathing, spatial perception, and sensory integration.

In [1]:
# ======================================================================
# COMBINED COUNTERBALANCED DESIGN AND AUDIO GENERATOR FOR PPS EXPERIMENT
# ======================================================================
# 
# This script integrates two key components of a Peripersonal Space (PPS) experiment:
# 1. A counterbalancer that generates experimental designs with balanced conditions
# 2. An audio generator that creates TWO participant-specific audio files with embedded stimuli:
#    - First audio: Box breathing with embedded directional looming stimuli
#    - Second audio: Synchronized tactile stimuli with applied SOA offsets
#
# PROJECT OVERVIEW:
# This experiment investigates how breathing phases (inhalation/exhalation) influence 
# responses to looming stimuli in peripersonal space. The design presents auditory stimuli 
# from different directions (left/front/right) during controlled breathing exercises.
# Each looming stimulus is paired with a tactile stimulus with a specific SOA (Stimulus Onset Asynchrony).
#
# VALIDATION:
# This script includes comprehensive counterbalancing validation using Doegen-inspired
# validation checks to ensure proper experimental design balance across all factors.
"""
Peripersonal Space (PPS) Audio Stimuli Generator

PROJECT OVERVIEW:
This script is part of a cognitive neuroscience experiment investigating peripersonal space (PPS) 
perception during controlled breathing exercises. PPS refers to the space immediately surrounding 
the body that is integrated with tactile perception. The experiment aims to understand how PPS 
representation may change during different phases of breathing.

PURPOSE:
This script generates participant-specific audio files by injecting spatial audio stimuli 
(sounds that appear to come from specific directions) into a guided box breathing meditation 
audio track. Each participant receives a unique audio file where stimuli are presented at 
specific breathing hold phases with controlled randomization (jitter).

EXPERIMENTAL DESIGN:
1. Participants follow a guided box breathing meditation (inhale, hold, exhale, hold)
2. During specific hold phases, 3D spatial audio stimuli are presented from different directions
3. The timing of stimuli is precisely controlled but includes small random variations (jitter)
   to prevent anticipation effects
4. Each participant's design specifies which type of spatial audio stimuli (left/front/right)
   should be presented during which breathing cycles

STIMULI DESCRIPTION:
- Base audio: A guided box breathing meditation instructing participants when to inhale, 
  hold, exhale, and hold
- Spatial stimuli: 3D audio recordings using FABIAN HRIR (Head-Related Impulse Response) that
  create perception of sounds approaching from three directions:
  * Right (90° azimuth)
  * Front (0° azimuth) 
  * Left (-90° azimuth)

DATA FLOW:
1. Load participant-specific design CSV (specifying which trial gets which stimulus)
2. Load base box breathing audio (meditation instructions)
3. For each trial in the design:
   a. Determine exact sample position based on reference timestamps
   b. Add controlled randomization (jitter)
   c. Insert appropriate spatial audio stimulus
4. Save personalized audio file and detailed logs of stimulus timing

EXPECTED USE:
The generated audio files will be used in experiments where participants:
1. Listen to the personalized box breathing audio through headphones
2. Follow the breathing instructions
3. Experience the spatial audio stimuli during specific breathing phases
4. Potentially respond to these stimuli (reaction time measurements)
5. Data will help understand how PPS perception might change during different 
   breathing phases

AUTHOR: George Fejer
"""
import os
import tkinter as tk
from tkinter import simpledialog
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import soundfile as sf
import scipy.signal as signal
from datetime import datetime
from itertools import product, combinations
import warnings
import json

# ======================================================================
# USER-CONFIGURABLE PARAMETERS
# ======================================================================

# SOA conditions in milliseconds (Stimulus Onset Asynchrony)
# Positive values mean tactile comes after looming, negative means tactile comes before
SOA_CONDITIONS_MS = [190, 400, 700, 1000, 1500]  # Replace with your desired values

# File paths - UPDATE THESE TO MATCH YOUR SYSTEM
BASE_DIR = r"C:\Users\cogpsy-vrlab\Documents\PPS_module\BreathingPilot"

# Box breathing audio
BASE_AUDIO_PATH = os.path.join(BASE_DIR, "AudioInstructionsBoxBreathing", "TargetStimuli", 
                              "6. Box Breathing Combined (Segment Injected).wav")

# Timestamp reference for breathing phases
TIMESTAMPS_CSV_PATH = os.path.join(BASE_DIR, "AudioInstructionsBoxBreathing", "TargetStimuli", 
                                  "BoxBreathing_extended_timestamps.csv")

# Output directories
DESIGN_OUTPUT_DIR = os.path.join(BASE_DIR, "PPS_Experiment_Module", "ExperimentLog")
AUDIO_OUTPUT_DIR = os.path.join(BASE_DIR, "PPS_Experiment_Module", "ExperimentAudio")

# Looming stimuli (directional audio)
LOOMING_STIMULI = {
    'right': os.path.join(BASE_DIR, "PPS_Experiment_Module", "right_az90_FABIAN_HRIR_natural.wav"),
    'front': os.path.join(BASE_DIR, "PPS_Experiment_Module", "front_az0_FABIAN_HRIR_natural.wav"),
    'left': os.path.join(BASE_DIR, "PPS_Experiment_Module", "left_az-90_FABIAN_HRIR_natural.wav")
}

# Tactile stimulus (to be synchronized with looming stimuli + SOA)
TACTILE_STIMULUS_PATH = os.path.join(BASE_DIR, "PPS_Experiment_Module", "tactile_stimulus.wav")

# Number of participants to generate designs for
NUM_PARTICIPANTS = 100

# Trial counts for each condition
TRIAL_COUNTS = {
    'inhalation': 60,   # Trials during inhalation phase
    'exhalation': 60,   # Trials during exhalation phase
    'baseline': 60,     # Baseline trials (neutral breathing)
    'catch': 24         # Catch trials to maintain attention
}

# Jitter options in milliseconds (temporal variability)
JITTER_OPTIONS_MS = [100, 200, 300, 400, 500]

# ======================================================================
# COUNTERBALANCED DESIGN GENERATOR
# ======================================================================

class ComprehensiveCounterbalancer:
    """
    Creates counterbalanced experimental designs for multiple participants.
    
    This class generates trial sequences ensuring balanced presentation of 
    experimental factors (stimulus direction, trial type, SOA conditions) 
    across participants. It can use reference timestamps from a breathing exercise
    to align stimuli with specific breathing phases.
    """
    
    def __init__(self, output_dir, soa_conditions, trial_counts, jitter_options):
        """
        Initialize the counterbalancer with experiment parameters.
        
        Args:
            output_dir: Directory to save design files
            soa_conditions: List of SOA values in milliseconds
            trial_counts: Dictionary with number of trials per condition
            jitter_options: List of jitter values in milliseconds
        """
        # Ensure output directory exists
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        
        # Define experimental factors
        self.factors = {
            'stimulus_type': ['right', 'front', 'left'],  # Stimulus directions
            'trial_type': ['inhalation', 'exhalation', 'baseline', 'catch'],  # Breathing phases
            'soa_condition': list(range(len(soa_conditions))),  # SOA condition indices (0-4)
            'jitter': jitter_options  # Jitter in milliseconds for temporal variability
        }
        
        # Store actual SOA values for reference
        self.soa_values_ms = soa_conditions
        
        # Experiment design parameters - number of trials per condition
        self.trial_counts = trial_counts
        
        # Default timing parameters (used if no reference CSV is provided)
        self.base_time = pd.Timestamp('2025-03-01 09:00:00')  # Starting time for the experiment
        self.trial_duration = 8  # Seconds per trial
        
        # This will be set if a reference CSV is provided
        self.reference_csv_path = None

    def generate_comprehensive_design(self, num_participants=100, reference_csv_path=None):
        """
        Generate a counterbalanced design for multiple participants.
        
        Args:
            num_participants: Number of participant designs to generate
            reference_csv_path: Path to CSV with breathing exercise timestamps
                                (used to align stimuli with breathing phases)
        
        Returns:
            List of DataFrames, one for each participant
        """
        self.reference_csv_path = reference_csv_path
        designs = []
        
        # Set a fixed seed for reproducibility
        np.random.seed(42)
        
        # Generate design for each participant
        for participant_id in range(num_participants):
            design = self._generate_single_participant_design(participant_id)
            designs.append(design)
        
        return designs
    
    def _generate_single_participant_design(self, participant_id):
        """
        Generate a balanced design for a single participant with precise timestamps.
        
        For each trial, computes three timestamps:
          - timestamp_original: baseline timing
          - timestamp_after_jitter: timing with added jitter
          - timestamp_with_soa: timing with jitter and SOA adjustment
        
        Args:
            participant_id: Unique identifier for the participant
        
        Returns:
            DataFrame with complete experimental design including timestamps
        """
        design = []
        
        # If a reference CSV is provided, use it for timing
        if self.reference_csv_path is not None:
            timestamp_df = pd.read_csv(self.reference_csv_path)
            csv_milliseconds = timestamp_df['milliseconds'].astype(int).values
            total_trials = sum(self.trial_counts.values())
            if total_trials > len(csv_milliseconds):
                raise ValueError(f"Not enough timestamps in CSV! Needed {total_trials}, but found {len(csv_milliseconds)} rows.")
            csv_index = 0  # Track which CSV row to use for each trial
        
        # Generate designs for each trial type (inhalation, exhalation, etc.)
        for trial_type, count in self.trial_counts.items():
            # Create a balanced stimulus sequence for each trial type
            stimulus_sequence = []
            for stim_type in self.factors['stimulus_type']:
                # Distribute stimulus types evenly within each trial type
                stimulus_sequence.extend([stim_type] * (count // len(self.factors['stimulus_type'])))
            
            # Shuffle the stimulus sequence for pseudorandomization
            np.random.shuffle(stimulus_sequence)
            
            # Add design details for each trial
            for idx, stimulus_type in enumerate(stimulus_sequence):
                # Compute trial number (starting after an initial offset)
                trial_idx = len(design)
                trial_number = 20 + trial_idx  # Offset of 20 trials
                
                # Choose a random jitter value (in ms)
                current_jitter = np.random.choice(self.factors['jitter'])
                
                # Get SOA condition index and actual value in milliseconds
                soa_idx = idx % len(self.factors['soa_condition'])
                soa_value_ms = self.soa_values_ms[soa_idx]
                
                # Generate timestamps (either from CSV or calculated)
                if self.reference_csv_path is not None:
                    # Use CSV-based timestamps (linked to breathing audio)
                    base_ms = csv_milliseconds[csv_index]
                    csv_index += 1
                    jittered_ms = base_ms + current_jitter
                    soa_ms = jittered_ms + soa_value_ms
                    
                    # Format timestamps as "mm:ss.f"
                    timestamp_original = f"{int(base_ms // 60000):02}:{(base_ms % 60000) / 1000:.1f}"
                    timestamp_after_jitter = f"{int(jittered_ms // 60000):02}:{(jittered_ms % 60000) / 1000:.1f}"
                    timestamp_with_soa = f"{int(soa_ms // 60000):02}:{(soa_ms % 60000) / 1000:.1f}"
                    
                    # Store raw millisecond values for precise audio syncing
                    base_ms_raw = base_ms
                    jittered_ms_raw = jittered_ms
                    soa_ms_raw = soa_ms
                else:
                    # Use base_time and trial_duration method (regular intervals)
                    t_orig = self.base_time + pd.Timedelta(seconds=trial_number * self.trial_duration)
                    t_jitter = self.base_time + pd.Timedelta(seconds=(trial_number * self.trial_duration) + (current_jitter / 1000))
                    t_soa = self.base_time + pd.Timedelta(seconds=(trial_number * self.trial_duration) + (current_jitter / 1000) + (soa_value_ms / 1000))
                    
                    timestamp_original = t_orig.strftime('%M:%S.%f')[:-4]
                    timestamp_after_jitter = t_jitter.strftime('%M:%S.%f')[:-4]
                    timestamp_with_soa = t_soa.strftime('%M:%S.%f')[:-4]
                    
                    # Calculate raw millisecond values for precise audio syncing
                    base_ms_raw = trial_number * self.trial_duration * 1000
                    jittered_ms_raw = base_ms_raw + current_jitter
                    soa_ms_raw = jittered_ms_raw + soa_value_ms
                
                # Add this trial to the design
                design.append({
                    'participant_id': participant_id,
                    'trial_number': trial_number,
                    'stimulus_type': stimulus_type,
                    'trial_type': trial_type,
                    'soa_condition_idx': soa_idx,
                    'soa_value_ms': soa_value_ms,  # Store exact SOA value in ms
                    'jitter_ms': current_jitter,
                    'timestamp_original': timestamp_original,
                    'timestamp_after_jitter': timestamp_after_jitter,
                    'timestamp_with_soa': timestamp_with_soa,  # New column with jitter + SOA
                    'base_ms': base_ms_raw,
                    'jittered_ms': jittered_ms_raw,
                    'soa_ms': soa_ms_raw
                })
        
        # Convert to DataFrame
        design_df = pd.DataFrame(design)
        
        # Shuffle within each trial type to maintain balance
        shuffled_design = []
        for trial_type in self.trial_counts.keys():
            type_subset = design_df[design_df['trial_type'] == trial_type].copy()
            type_subset = type_subset.sample(frac=1, random_state=participant_id).reset_index(drop=True)
            shuffled_design.append(type_subset)
        
        # Recombine, sort by timestamp, and reset index
        design_df = pd.concat(shuffled_design).reset_index(drop=True)
        design_df = design_df.sort_values(by='timestamp_after_jitter').reset_index(drop=True)
        
        return design_df
    
    def save_designs(self, designs):
        """
        Save generated designs to CSV files.
        
        Args:
            designs: List of participant design DataFrames
        """
        for design in designs:
            participant_id = design['participant_id'].iloc[0]
            filename = f'participant_{participant_id}_design.csv'
            filepath = os.path.join(self.output_dir, filename)
            design.to_csv(filepath, index=False)
            print(f"Saved design for participant {participant_id}")
    
    def validate_designs(self, designs):
        """
        Comprehensively validate the generated designs to ensure balance using Doegen-inspired
        validation checks and statistical methods.
        
        Args:
            designs: List of participant design DataFrames
        
        Returns:
            Dictionary with validation statistics
        """
        print("Starting design validation...")
        
        validation_results = {
            'stimulus_type': {},
            'trial_type': {},
            'soa_condition_idx': {},
            'jitter': {},
            'doegen_validation': {}
        }
        
        try:
            # Combine all designs for overall validation
            all_designs = pd.concat(designs)
            
            # ------- Basic Distribution Validation -------
            # Validate each factor by grouping by trial type
            for factor in ['stimulus_type', 'trial_type', 'soa_condition_idx']:
                grouped_validation = all_designs.groupby(['trial_type', factor]).size() / len(designs)
                validation_results[factor] = grouped_validation.to_dict()
            
            print("Basic distribution validation complete.")
            
            # ------- Doegen-inspired Validation Checks -------
            # 1. Check for First-Order Counterbalancing (transitions between conditions)
            print("Checking first-order counterbalancing...")
            transition_counts = {}
            for factor in ['stimulus_type', 'soa_condition_idx']:
                try:
                    # Count transitions across all designs
                    transitions = {}
                    for design in designs:
                        # Sort by timestamp to ensure we check transitions in presentation order
                        sorted_design = design.sort_values('timestamp_after_jitter')
                        for i in range(len(sorted_design) - 1):
                            curr = sorted_design.iloc[i][factor]
                            next_val = sorted_design.iloc[i+1][factor]
                            transition_key = f"{curr}->{next_val}"
                            if transition_key not in transitions:
                                transitions[transition_key] = 0
                            transitions[transition_key] += 1
                    
                    # Store transition counts
                    transition_counts[factor] = transitions
                    
                    # Calculate transition balance score (coefficient of variation)
                    values = np.array(list(transitions.values()))
                    cv = np.std(values) / np.mean(values) if np.mean(values) > 0 else 0
                    validation_results['doegen_validation'][f'{factor}_transition_cv'] = cv
                except Exception as e:
                    print(f"Warning: Error in counterbalancing check for {factor}: {str(e)}")
                    validation_results['doegen_validation'][f'{factor}_transition_cv'] = "Error"
            
            # 2. Check for Factor Orthogonality (independence between factors)
            print("Checking factor orthogonality...")
            orthogonality_results = {}
            for f1, f2 in combinations(['stimulus_type', 'trial_type', 'soa_condition_idx'], 2):
                try:
                    # Create contingency table
                    cont_table = pd.crosstab(all_designs[f1], all_designs[f2])
                    
                    # Calculate chi-square statistic for independence
                    chi2 = 0
                    total = cont_table.sum().sum()
                    for i in cont_table.index:
                        for j in cont_table.columns:
                            observed = cont_table.loc[i, j]
                            expected = cont_table.loc[i].sum() * cont_table[j].sum() / total
                            chi2 += ((observed - expected) ** 2) / expected if expected > 0 else 0
                    
                    orthogonality_results[f"{f1} vs {f2}_chi2"] = chi2
                except Exception as e:
                    print(f"Warning: Error in orthogonality check for {f1} vs {f2}: {str(e)}")
                    orthogonality_results[f"{f1} vs {f2}_chi2"] = "Error"
            
            validation_results['doegen_validation']['orthogonality'] = orthogonality_results
            
            # 3. Sequence Randomness Test (runs test for randomness)
            print("Checking sequence randomness...")
            randomness_results = {}
            for factor in ['stimulus_type', 'soa_condition_idx']:
                for design_idx, design in enumerate(designs[:5]):  # Test a sample of designs
                    try:
                        # Sort by timestamp
                        sorted_design = design.sort_values('timestamp_after_jitter')
                        sequence = sorted_design[factor].astype(str).tolist()
                        
                        # Count runs (consecutive identical values)
                        runs = 1
                        for i in range(1, len(sequence)):
                            if sequence[i] != sequence[i-1]:
                                runs += 1
                        
                        # Expected runs and std dev for a random sequence
                        n1 = len([x for x in sequence if x == sequence[0]])
                        n2 = len(sequence) - n1
                        if n1 > 0 and n2 > 0:  # Avoid division by zero
                            expected_runs = 1 + (2 * n1 * n2) / (n1 + n2)
                            std_runs = np.sqrt((2 * n1 * n2 * (2 * n1 * n2 - n1 - n2)) / 
                                            ((n1 + n2)**2 * (n1 + n2 - 1)))
                            
                            # Z-score for runs test
                            z_score = (runs - expected_runs) / std_runs if std_runs > 0 else 0
                            randomness_results[f"{factor}_design{design_idx}_z"] = z_score
                    except Exception as e:
                        print(f"Warning: Error in randomness check for {factor}, design {design_idx}: {str(e)}")
                        randomness_results[f"{factor}_design{design_idx}_z"] = "Error"
            
            validation_results['doegen_validation']['sequence_randomness'] = randomness_results
            
            # 4. Factor Sequence Autocorrelation Analysis
            print("Checking sequence autocorrelation...")
            autocorr_results = {}
            for factor in ['soa_condition_idx']:
                for design_idx, design in enumerate(designs[:5]):  # Test a sample of designs
                    try:
                        # Sort by timestamp
                        sorted_design = design.sort_values('timestamp_after_jitter')
                        
                        # Convert categorical variables to numeric for autocorrelation
                        if factor == 'stimulus_type':
                            # Map stimulus types to numbers
                            mapping = {stim: i for i, stim in enumerate(sorted_design[factor].unique())}
                            sequence = sorted_design[factor].map(mapping).tolist()
                        else:
                            sequence = sorted_design[factor].tolist()
                        
                        # Calculate lag-1 autocorrelation
                        mean_val = np.mean(sequence)
                        numerator = sum([(sequence[i] - mean_val) * (sequence[i+1] - mean_val) 
                                        for i in range(len(sequence)-1)])
                        denominator = sum([(val - mean_val)**2 for val in sequence])
                        
                        autocorr = numerator / denominator if denominator > 0 else 0
                        autocorr_results[f"{factor}_design{design_idx}_autocorr"] = autocorr
                    except Exception as e:
                        print(f"Warning: Error in autocorrelation check for {factor}, design {design_idx}: {str(e)}")
                        autocorr_results[f"{factor}_design{design_idx}_autocorr"] = "Error"
            
            validation_results['doegen_validation']['autocorrelation'] = autocorr_results
            
            # Plot factor distributions for visual validation
            print("Creating validation plots...")
            try:
                self._plot_factor_distributions(all_designs)
            except Exception as e:
                print(f"Warning: Error creating validation plots: {str(e)}")
            
            # Save detailed validation report to JSON
            validation_report_path = os.path.join(self.output_dir, 'validation_report.json')
            try:
                with open(validation_report_path, 'w') as f:
                    json.dump(validation_results, f, indent=4, default=str)
                print(f"Validation report saved to: {validation_report_path}")
            except Exception as e:
                print(f"Warning: Error saving validation report: {str(e)}")
            
            # Print summary of validation results
            print("\nValidation Results:")
            for factor, results in validation_results.items():
                if factor != 'doegen_validation':
                    print(f"\n{factor.replace('_', ' ').title()} Distribution:")
                    for key, value in sorted(results.items()):
                        print(f"  {key}: {value}")
            
            # Print Doegen validation summary
            print("\nDoegen Validation Checks:")
            for check_type, value in validation_results['doegen_validation'].items():
                if isinstance(value, dict):
                    print(f"  {check_type.replace('_', ' ').title()}:")
                    for subkey, subval in list(value.items())[:5]:  # Show first 5 items
                        print(f"    {subkey}: {subval}")
                    if len(value) > 5:
                        print(f"    ... and {len(value)-5} more (see validation_report.json)")
                else:
                    print(f"  {check_type.replace('_', ' ').title()}: {value}")
            
        except Exception as e:
            print(f"Error during validation: {str(e)}")
            print("Validation could not be completed. Returning partial results.")
        
        print("Validation complete.")
        return validation_results
    
    def _plot_factor_distributions(self, all_designs):
        """
        Create comprehensive visualizations of factor distributions for quality control,
        including Doegen-inspired validation plots.
        
        Args:
            all_designs: Combined DataFrame of all participant designs
        """
        plot_dir = os.path.join(self.output_dir, 'validation_plots')
        os.makedirs(plot_dir, exist_ok=True)
        
        # ------- Basic Distribution Plots -------
        # Create bar charts for factor distributions
        factors_to_plot = ['stimulus_type', 'soa_condition_idx']
        for factor in factors_to_plot:
            plt.figure(figsize=(10, 6))
            grouped = all_designs.groupby(['trial_type', factor]).size().unstack(fill_value=0)
            grouped.plot(kind='bar', stacked=True)
            plt.title(f'Distribution of {factor} Across Trial Types')
            plt.xlabel('Trial Type')
            plt.ylabel('Number of Occurrences')
            plt.legend(title=factor, bbox_to_anchor=(1.05, 1), loc='upper left')
            plt.tight_layout()
            plot_path = os.path.join(plot_dir, f'{factor}_distribution.png')
            plt.savefig(plot_path)
            plt.close()
        
        # Create heatmap for stimulus type distribution
        plt.figure(figsize=(12, 8))
        factor_pivot = all_designs.pivot_table(
            index='trial_type', 
            columns='stimulus_type', 
            aggfunc='size', 
            fill_value=0
        )
        sns.heatmap(factor_pivot, annot=True, cmap='YlGnBu', fmt='g')
        plt.title('Stimulus Type Distribution Across Trial Types')
        plt.tight_layout()
        heatmap_path = os.path.join(plot_dir, 'stimulus_type_heatmap.png')
        plt.savefig(heatmap_path)
        plt.close()
        
        # ------- Doegen-inspired Validation Plots -------
        
        # 1. Factor pairwise density plots (check for factor independence)
        factor_pairs = list(combinations(['stimulus_type', 'trial_type', 'soa_condition_idx'], 2))
        for f1, f2 in factor_pairs:
            plt.figure(figsize=(10, 8))
            
            # For categorical data, create a contingency table instead of 2D histogram
            contingency = pd.crosstab(all_designs[f1], all_designs[f2])
            
            # Plot as a heatmap
            sns.heatmap(contingency, annot=True, fmt='d', cmap='Blues')
            
            plt.xlabel(f1.replace('_', ' ').title())
            plt.ylabel(f2.replace('_', ' ').title())
            plt.title(f'Factor Independence Check: {f1} vs {f2}')
            plt.tight_layout()
            
            plot_path = os.path.join(plot_dir, f'factor_independence_{f1}_vs_{f2}.png')
            plt.savefig(plot_path)
            plt.close()
        
        # 2. Transition matrix heatmaps (first-order counterbalancing)
        for factor in ['stimulus_type', 'soa_condition_idx']:
            try:
                # Get unique values in sorted order for consistent matrix
                unique_vals = sorted(all_designs[factor].unique())
                n_vals = len(unique_vals)
                
                # Create transition matrix
                transitions = np.zeros((n_vals, n_vals))
                
                # Fill with transition counts from a sample of designs
                sample_designs = all_designs.groupby('participant_id').head(100)  # Limit sample size
                
                for pid, group in sample_designs.groupby('participant_id'):
                    sorted_group = group.sort_values('timestamp_after_jitter')
                    
                    # Map values to indices
                    if factor == 'stimulus_type':
                        mapping = {val: i for i, val in enumerate(unique_vals)}
                        sequence = sorted_group[factor].map(mapping).tolist()
                    else:
                        sequence = [list(unique_vals).index(val) for val in sorted_group[factor]]
                    
                    # Count transitions
                    for i in range(len(sequence) - 1):
                        from_idx = sequence[i]
                        to_idx = sequence[i + 1]
                        transitions[from_idx, to_idx] += 1
                
                # Plot transition matrix
                plt.figure(figsize=(10, 8))
                sns.heatmap(transitions, annot=True, fmt='.0f', cmap='YlGnBu',
                           xticklabels=unique_vals, yticklabels=unique_vals)
                plt.title(f'Transition Matrix for {factor}')
                plt.xlabel('To')
                plt.ylabel('From')
                plt.tight_layout()
                
                plot_path = os.path.join(plot_dir, f'transition_matrix_{factor}.png')
                plt.savefig(plot_path)
                plt.close()
            except Exception as e:
                print(f"Warning: Could not create transition matrix for {factor}: {str(e)}")
                continue
        
        # 3. Factor autocorrelation plots
        for factor in ['soa_condition_idx']:  # Only numeric factors
            try:
                # Sample a few designs
                autocorrs = []
                lags = list(range(1, 11))  # Lags 1 to 10
                
                for design_idx, (_, design) in enumerate(all_designs.groupby('participant_id')):
                    if design_idx >= 5:  # Limit to 5 participants
                        break
                        
                    sorted_design = design.sort_values('timestamp_after_jitter')
                    sequence = sorted_design[factor].astype(float).tolist()
                    
                    # Skip if too few data points
                    if len(sequence) < 20:
                        continue
                    
                    # Calculate autocorrelations for different lags
                    design_autocorrs = []
                    for lag in lags:
                        if lag >= len(sequence):
                            design_autocorrs.append(np.nan)
                            continue
                            
                        mean_val = np.mean(sequence)
                        numerator = sum([(sequence[i] - mean_val) * (sequence[i+lag] - mean_val) 
                                        for i in range(len(sequence)-lag)])
                        denominator = sum([(val - mean_val)**2 for val in sequence])
                        
                        autocorr = numerator / denominator if denominator > 0 else 0
                        design_autocorrs.append(autocorr)
                    
                    autocorrs.append(design_autocorrs)
                
                # Skip if no autocorrelations were calculated
                if not autocorrs:
                    continue
                    
                # Plot autocorrelations
                plt.figure(figsize=(10, 6))
                for i, autocorr in enumerate(autocorrs):
                    plt.plot(lags, autocorr, marker='o', label=f'Participant {i}')
                
                plt.axhline(y=0, color='r', linestyle='--', alpha=0.3)
                plt.xlabel('Lag')
                plt.ylabel('Autocorrelation')
                plt.title(f'Sequence Autocorrelation for {factor}')
                plt.legend()
                plt.grid(True, alpha=0.3)
                plt.tight_layout

                plt.tight_layout()
                
                plot_path = os.path.join(plot_dir, f'autocorrelation_{factor}.png')
                plt.savefig(plot_path)
                plt.close()
            except Exception as e:
                print(f"Warning: Could not create autocorrelation plot for {factor}: {str(e)}")
                continue
        
        # 4. SOA balance check across trial types
        try:
            plt.figure(figsize=(12, 6))
            
            # Convert soa_value_ms to string to avoid issues with pandas pivot
            all_designs['soa_value_ms_str'] = all_designs['soa_value_ms'].astype(str)
            
            soa_balance = pd.crosstab(
                index=all_designs['trial_type'],
                columns=all_designs['soa_value_ms_str'],
                margins=False
            )
            
            # Plot SOA balance
            sns.heatmap(soa_balance, annot=True, fmt='g', cmap='YlGnBu')
            plt.title('SOA Values Across Trial Types')
            plt.xlabel('SOA Value (ms)')
            plt.ylabel('Trial Type')
            plt.tight_layout()
            
            plot_path = os.path.join(plot_dir, 'soa_balance.png')
            plt.savefig(plot_path)
            plt.close()
        except Exception as e:
            print(f"Warning: Could not create SOA balance plot: {str(e)}")
            
        print(f"\nValidation plots saved to: {plot_dir}")


# ======================================================================
# AUDIO GENERATOR
# ======================================================================

class ParticipantAudioGenerator:
    """
    Generates TWO participant-specific audio files with embedded stimuli:
    1. Looming audio: Box breathing with directional looming sounds
    2. Tactile audio: Synchronized tactile stimuli with SOA offsets
    
    This class takes the counterbalanced designs and injects stimuli
    into audio files, creating unique pairs of audio for each participant
    with precisely timed stimuli.
    """
    
    def __init__(self, base_audio_path, output_dir, looming_stimuli_paths, tactile_stimulus_path, timestamps_csv_path=None):
        """
        Initialize the audio generator.
        
        Args:
            base_audio_path: Path to the base box breathing audio
            output_dir: Directory to save participant-specific audio files
            looming_stimuli_paths: Dictionary of paths to looming stimuli sounds
            tactile_stimulus_path: Path to the tactile stimulus sound
            timestamps_csv_path: Optional path to CSV with breathing timestamps (not needed when using design CSV)
        """
        # Store stimulus paths
        self.looming_stimuli = looming_stimuli_paths
        self.tactile_stimulus_path = tactile_stimulus_path
        
        # Load base audio (breathing exercise)
        self.base_audio_data, self.sample_rate = sf.read(base_audio_path)
        
        # Load tactile stimulus
        self.tactile_data, _ = sf.read(tactile_stimulus_path)
        
        # Load timestamps if provided (not needed when using design CSV directly)
        self.timestamps_df = None
        if timestamps_csv_path:
            self.timestamps_df = pd.read_csv(timestamps_csv_path)
        
        # Output directory for audio files
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
    
    def generate_participant_audio(self, participant_id, design_df):
        """
        Generate two synchronized audio files for a participant:
        1. Looming stimuli embedded in box breathing audio (stereo)
        2. Tactile stimuli with SOA adjustments (mono content in stereo format)
        
        Uses precise millisecond values from design_df to ensure exact synchronization.
        
        Args:
            participant_id: Participant identifier
            design_df: DataFrame with stimuli design (MUST contain base_ms, jittered_ms, and soa_ms columns)
                
        Returns:
            Tuple (looming_audio, tactile_audio, injection_log)
        """
        # Create two separate audio streams
        looming_audio = self.base_audio_data.copy()  # Box breathing with looming sounds (stereo)
        tactile_audio = np.zeros_like(self.base_audio_data)  # Empty stereo track matching base audio shape
        
        # Ensure tactile stimulus is in stereo format (same content in both channels)
        # This converts mono to stereo if needed
        if self.tactile_data.ndim == 1:  # If tactile data is mono
            tactile_stereo = np.column_stack((self.tactile_data, self.tactile_data))
        else:
            # If already stereo, use as is
            tactile_stereo = self.tactile_data
            
        # Initialize log to track all stimulus injections
        injection_log = []
        
        # Verify that design_df contains required columns
        required_columns = ['stimulus_type', 'jittered_ms', 'soa_ms', 'soa_value_ms']
        missing_columns = [col for col in required_columns if col not in design_df.columns]
        if missing_columns:
            raise ValueError(f"Design DataFrame missing required columns: {missing_columns}. Cannot ensure precise synchronization.")
        
        # Process each trial in the design
        for _, row in design_df.iterrows():
            # Get exact millisecond timestamps from the design CSV
            jittered_ms = row['jittered_ms']  # For looming stimuli
            soa_ms = row['soa_ms']  # For tactile stimuli
            
            # Convert to sample indices
            looming_injection_point = int(jittered_ms / 1000 * self.sample_rate)
            tactile_injection_point = int(soa_ms / 1000 * self.sample_rate)
            
            # Get SOA value for logging
            soa_value_ms = row['soa_value_ms']
            
            # Load the appropriate looming stimulus based on design
            stim_type = row['stimulus_type']
            looming_path = self.looming_stimuli.get(stim_type)
            if looming_path is None:
                continue  # Skip if stimulus type not found
            
            # Read the looming stimulus audio file
            looming_data, _ = sf.read(looming_path)
            
            # Check if both stimuli fit within the audio bounds
            looming_fits = (looming_injection_point + len(looming_data) <= len(looming_audio))
            tactile_fits = (tactile_injection_point >= 0 and 
                           tactile_injection_point + len(tactile_stereo) <= len(tactile_audio))
            
            if looming_fits and tactile_fits:
                # Inject looming stimulus into the first audio track
                looming_audio[looming_injection_point:looming_injection_point+len(looming_data)] += looming_data
                
                # Inject tactile stimulus into the second audio track (now using stereo format)
                tactile_audio[tactile_injection_point:tactile_injection_point+len(tactile_stereo)] += tactile_stereo
                
                # Log details of this injection for record-keeping
                injection_log.append({
                    'participant_id': participant_id,
                    'trial_number': row.get('trial_number', 0),
                    'stimulus_type': stim_type,
                    'soa_condition_idx': row.get('soa_condition_idx', 0),
                    'soa_value_ms': soa_value_ms,
                    'looming_injection_sample': looming_injection_point,
                    'tactile_injection_sample': tactile_injection_point,
                    'looming_time_ms': looming_injection_point / self.sample_rate * 1000,
                    'tactile_time_ms': tactile_injection_point / self.sample_rate * 1000,
                    'jittered_ms': jittered_ms,
                    'soa_ms': soa_ms
                })
        
        return looming_audio, tactile_audio, pd.DataFrame(injection_log)
    
    def save_participant_files(self, participant_id, looming_audio, tactile_audio, injection_log, design_filename):
        """
        Save participant-specific audio and log files.
        
        Args:
            participant_id: Participant identifier
            looming_audio: Modified audio with looming stimuli
            tactile_audio: Audio track with tactile stimuli
            injection_log: DataFrame with injection details
            design_filename: Base design filename for naming consistency
        
        Returns:
            Tuple of saved file paths
        """
        # Use the design CSV base name (remove .csv extension)
        base_name = os.path.splitext(design_filename)[0]
        
        # Define filenames for output files
        looming_filename = f"{base_name}_looming.wav"
        tactile_filename = f"{base_name}_tactile.wav"
        log_filename = f"{base_name}_stimuli_log.csv"
        
        # Create full paths
        looming_path = os.path.join(self.output_dir, looming_filename)
        tactile_path = os.path.join(self.output_dir, tactile_filename)
        log_path = os.path.join(self.output_dir, log_filename)
        
        # Save the audio files
        sf.write(looming_path, looming_audio, self.sample_rate)
        sf.write(tactile_path, tactile_audio, self.sample_rate)
        
        # Save the detailed injection log (for analysis)
        injection_log.to_csv(log_path, index=False)
        
        print(f"Looming audio saved: {looming_filename}")
        print(f"Tactile audio saved: {tactile_filename}")
        print(f"Injection log saved: {log_filename}")
        
        return looming_path, tactile_path, log_path


# ======================================================================
# MAIN FUNCTION
# ======================================================================

def main():
    """
    Main function to run the complete PPS experiment setup process:
    1. Generate counterbalanced designs for multiple participants
    2. Create two synchronized audio files per participant:
       - Looming stimuli embedded in box breathing audio
       - Tactile stimuli with SOA adjustments
    
    The script creates and saves:
    1. Design CSV files in the DESIGN_OUTPUT_DIR directory, including:
       - exact SOA values in milliseconds
       - precise timestamps with jitter and SOA adjustments
    2. Two WAV files per participant in the AUDIO_OUTPUT_DIR directory:
       - participant_X_design_looming.wav (box breathing + looming stimuli)
       - participant_X_design_tactile.wav (synchronized tactile stimuli)
    3. Log files with detailed timing information
    
    CRITICAL: This implementation ensures precise synchronization between the audio files
    by directly referencing timestamps from the design CSV file rather than recalculating them.
    """
    # ===== Ensure output directories exist =====
    os.makedirs(DESIGN_OUTPUT_DIR, exist_ok=True)
    os.makedirs(AUDIO_OUTPUT_DIR, exist_ok=True)
    
    # ===== Step 1: Generate counterbalanced designs with exact SOA values =====
    print("===== STEP 1: GENERATING COUNTERBALANCED DESIGNS =====")
    print(f"Using SOA conditions (ms): {SOA_CONDITIONS_MS}")
    print("Each design CSV will contain exact SOA values and precise timestamps")
    
    counterbalancer = ComprehensiveCounterbalancer(
        output_dir=DESIGN_OUTPUT_DIR,
        soa_conditions=SOA_CONDITIONS_MS,
        trial_counts=TRIAL_COUNTS,
        jitter_options=JITTER_OPTIONS_MS
    )
    
    designs = counterbalancer.generate_comprehensive_design(
        num_participants=NUM_PARTICIPANTS, 
        reference_csv_path=TIMESTAMPS_CSV_PATH
    )
    
    # Save designs to CSV files - these now include exact SOA values and timestamp_with_soa
    counterbalancer.save_designs(designs)
    
    # Run comprehensive Doegen validation checks
    print("\n===== RUNNING DOEGEN COUNTERBALANCING VALIDATION CHECKS =====")
    print("Performing comprehensive validation including:")
    print("  - First-order counterbalancing (transition matrices)")
    print("  - Factor orthogonality/independence")
    print("  - Sequence randomness (runs test)")
    print("  - Autocorrelation analysis")
    print("  - SOA balance verification")
    
    try:
        validation_results = counterbalancer.validate_designs(designs)
        
        print("\n===== DOEGEN VALIDATION SUMMARY =====")
        if 'doegen_validation' in validation_results:
            # Check transition balance
            if 'stimulus_type_transition_cv' in validation_results['doegen_validation']:
                cv = validation_results['doegen_validation']['stimulus_type_transition_cv']
                if not isinstance(cv, str) and cv < 0.2:
                    print("✓ Stimulus transitions are well-balanced (CV < 0.2)")
                else:
                    print(f"⚠ Stimulus transitions show some imbalance (CV = {cv})")
                    
            # Check orthogonality
            if 'orthogonality' in validation_results['doegen_validation']:
                ortho = validation_results['doegen_validation']['orthogonality']
                non_error_values = [v for v in ortho.values() if not isinstance(v, str)]
                if non_error_values:
                    mean_chi2 = np.mean(non_error_values)
                    if mean_chi2 < 20:
                        print("✓ Factors are sufficiently orthogonal/independent")
                    else:
                        print(f"⚠ Some factors may have dependencies (mean χ² = {mean_chi2:.2f})")
                    
            # Check sequence randomness
            if 'sequence_randomness' in validation_results['doegen_validation']:
                random_vals = [v for v in validation_results['doegen_validation']['sequence_randomness'].values() 
                              if not isinstance(v, str)]
                if random_vals:
                    mean_abs_z = np.mean([abs(z) for z in random_vals])
                    if mean_abs_z < 1.96:
                        print("✓ Sequences appear sufficiently random (|z| < 1.96)")
                    else:
                        print(f"⚠ Sequences may have randomness issues (mean |z| = {mean_abs_z:.2f})")
                    
            # Check autocorrelation
            if 'autocorrelation' in validation_results['doegen_validation']:
                autocorr_vals = [v for v in validation_results['doegen_validation']['autocorrelation'].values()
                                if not isinstance(v, str)]
                if autocorr_vals:
                    mean_autocorr = np.mean([abs(a) for a in autocorr_vals])
                    if mean_autocorr < 0.2:
                        print("✓ Low sequential autocorrelation (< 0.2)")
                    else:
                        print(f"⚠ Some sequential dependencies detected (mean |autocorr| = {mean_autocorr:.2f})")
    
    except Exception as e:
        print(f"Warning: Validation checks encountered an error: {str(e)}")
        print("Continuing with design generation...")
    
    print("\nSee validation_plots directory for detailed visual analyses")
    
    # ===== Step 2: Generate participant-specific audio files =====
    print("\n===== STEP 2: GENERATING PARTICIPANT AUDIO FILES =====")
    print(f"Creating TWO audio files per participant:")
    print(f"  - Looming audio: Box breathing with directional looming sounds")
    print(f"  - Tactile audio: Synchronized tactile stimuli with SOA offsets")
    print(f"CRITICAL: Using exact millisecond timestamps from design CSV for precise synchronization")
    
    # Initialize the audio generator, now with timestamps_csv_path as optional
    audio_generator = ParticipantAudioGenerator(
        base_audio_path=BASE_AUDIO_PATH,
        output_dir=AUDIO_OUTPUT_DIR,
        looming_stimuli_paths=LOOMING_STIMULI,
        tactile_stimulus_path=TACTILE_STIMULUS_PATH,
        timestamps_csv_path=TIMESTAMPS_CSV_PATH  # Now optional
    )
    
    # Process each design and generate audio pairs
    for i, design_df in enumerate(designs):
        participant_id = design_df['participant_id'].iloc[0]
        design_filename = f"participant_{participant_id}_design.csv"
        
        print(f"\nProcessing participant {participant_id} ({i+1}/{NUM_PARTICIPANTS})...")
        
        # Verify that design_df contains necessary timing columns
        if 'jittered_ms' not in design_df.columns or 'soa_ms' not in design_df.columns:
            print(f"Warning: Design for participant {participant_id} missing critical timing columns.")
            print("Reading design from saved CSV file to ensure correct timing...")
            
            # Re-read the design CSV to ensure we have the most accurate timing info
            design_file_path = os.path.join(DESIGN_OUTPUT_DIR, design_filename)
            if os.path.exists(design_file_path):
                design_df = pd.read_csv(design_file_path)
                print("Successfully loaded design from CSV with all timing information.")
            else:
                print(f"Error: Cannot find design file {design_file_path}")
                continue
        
        # Generate two synchronized audio tracks for this participant
        try:
            looming_audio, tactile_audio, injection_log = audio_generator.generate_participant_audio(
                participant_id, 
                design_df
            )
            
            # Save both audio files and the log
            audio_generator.save_participant_files(
                participant_id, 
                looming_audio, 
                tactile_audio, 
                injection_log, 
                design_filename
            )
            
            # Verify synchronization
            print("Verifying synchronization...")
            if len(injection_log) > 0:
                for i, row in injection_log.head(3).iterrows():
                    looming_time = row['looming_time_ms']
                    tactile_time = row['tactile_time_ms']
                    expected_diff = row['soa_value_ms']
                    actual_diff = tactile_time - looming_time
                    
                    print(f"  Trial {i+1}: SOA expected={expected_diff}ms, actual={actual_diff:.2f}ms, " +
                          f"difference={abs(actual_diff-expected_diff):.4f}ms")
                
                if len(injection_log) > 3:
                    print(f"  ... and {len(injection_log) - 3} more trials")
            else:
                print("  Warning: No stimuli were injected - check design parameters")
                
        except Exception as e:
            print(f"Error generating audio for participant {participant_id}: {str(e)}")
            continue
    
    print(f"\nSUMMARY:")
    print(f"All designs generated and saved to: {DESIGN_OUTPUT_DIR}")
    print(f"All audio files generated and saved to: {AUDIO_OUTPUT_DIR}")
    print(f"Created {NUM_PARTICIPANTS} participant designs with {len(SOA_CONDITIONS_MS)} SOA conditions")
    print(f"Each participant has TWO audio files: one with looming stimuli and one with tactile stimuli")
    print(f"Exact SOA values used (ms): {SOA_CONDITIONS_MS}")
    print(f"Every design CSV contains precise millisecond timing for exact synchronization")
    print(f"Process complete.")

# ===== For Jupyter notebook use =====
# This cell can be run directly in a Jupyter notebook
if __name__ == "__main__":
    main()

===== STEP 1: GENERATING COUNTERBALANCED DESIGNS =====
Using SOA conditions (ms): [190, 400, 700, 1000, 1500]
Each design CSV will contain exact SOA values and precise timestamps
Saved design for participant 0
Saved design for participant 1
Saved design for participant 2
Saved design for participant 3
Saved design for participant 4
Saved design for participant 5
Saved design for participant 6
Saved design for participant 7
Saved design for participant 8
Saved design for participant 9
Saved design for participant 10
Saved design for participant 11
Saved design for participant 12
Saved design for participant 13
Saved design for participant 14
Saved design for participant 15
Saved design for participant 16
Saved design for participant 17
Saved design for participant 18
Saved design for participant 19
Saved design for participant 20
Saved design for participant 21
Saved design for participant 22
Saved design for participant 23
Saved design for participant 24
Saved design for participant 25

<Figure size 1000x600 with 0 Axes>

<Figure size 1000x600 with 0 Axes>

# PPS Practice Stimulus Generator - Concise Outline

This code generates short practice stimuli for a Peripersonal Space (PPS) experiment with these key features:

### What it Creates:
- 5 different randomized versions of practice stimuli
- Each version includes two synchronized audio files:
  1. A file containing looming sounds (approaching from left/front/right)
  2. A file containing tactile stimuli (vibrations)
- CSV log files documenting precise stimulus timing

### Content Structure:
- Each version contains 12 trials with 4-second intervals between them:
  - 4 baseline trials (tactile vibration only)
  - 4 catch trials (looming sound only) 
  - 4 PPS trials (both looming and tactile with random SOA)
- Different randomized order in each version

### Technical Approach:
- Creates short audio files (approximately 1-2 minutes)
- Maintains precise synchronization between paired audio files
- Uses random SOA (Stimulus Onset Asynchrony) for PPS trials
- Includes trial timing information in CSV logs for verification

The program outputs all files to the specified directory, ready for use in participant practice/training sessions.

In [3]:
# ======================================================================
# PRACTICE STIMULI GENERATOR FOR PPS EXPERIMENT
# ======================================================================
# 
# This script generates 5 different versions of practice stimuli for the 
# Peripersonal Space (PPS) experiment. Each version contains:
#
# 1. 12 trials divided into:
#    - 4 baseline trials (tactile stimuli only)
#    - 4 catch trials (looming stimuli only)
#    - 4 PPS trials (both looming and tactile with random SOA)
#
# 2. Two synchronized audio files:
#    - One for looming stimuli
#    - One for tactile stimuli
#
# Each of the 5 versions has a different ordering of these trials.
#
# PRACTICE STIMULI PURPOSE:
# These practice files will help familiarize participants with the different 
# trial types before the actual experiment starts.

import os
import numpy as np
import pandas as pd
import soundfile as sf
import matplotlib.pyplot as plt
from datetime import datetime
import random
import warnings

# Suppress warnings for cleaner output
warnings.filterwarnings("ignore")

# ======================================================================
# USER-CONFIGURABLE PARAMETERS
# ======================================================================

# File paths - UPDATE THESE TO MATCH YOUR SYSTEM
BASE_DIR = r"C:\Users\cogpsy-vrlab\Documents\PPS_module\BreathingPilot"

# Box breathing audio
BASE_AUDIO_PATH = os.path.join(BASE_DIR, "AudioInstructionsBoxBreathing", "TargetStimuli", 
                              "6. Box Breathing Combined (Segment Injected).wav")

# Timestamp reference for breathing phases
TIMESTAMPS_CSV_PATH = os.path.join(BASE_DIR, "AudioInstructionsBoxBreathing", "TargetStimuli", 
                                  "BoxBreathing_extended_timestamps.csv")

# Output directory for practice stimuli
PRACTICE_STIMULI_DIR = r"C:\Users\cogpsy-vrlab\Documents\PPS_module\BreathingPilot\PracticeStimuli"

# Looming stimuli (directional audio)
LOOMING_STIMULI = {
    'right': os.path.join(BASE_DIR, "PPS_Experiment_Module", "right_az90_FABIAN_HRIR_natural.wav"),
    'front': os.path.join(BASE_DIR, "PPS_Experiment_Module", "front_az0_FABIAN_HRIR_natural.wav"),
    'left': os.path.join(BASE_DIR, "PPS_Experiment_Module", "left_az-90_FABIAN_HRIR_natural.wav")
}

# Tactile stimulus (to be synchronized with looming stimuli + SOA)
TACTILE_STIMULUS_PATH = os.path.join(BASE_DIR, "PPS_Experiment_Module", "tactile_stimulus.wav")

# Range of SOA values (in milliseconds) for PPS trials
SOA_RANGE = [100, 1500]  # Random SOA between 100ms and 1500ms

# Jitter options in milliseconds (temporal variability)
JITTER_OPTIONS_MS = [100, 200, 300, 400, 500]

# Number of iterations (different versions)
NUM_ITERATIONS = 5

# ======================================================================
# PRACTICE STIMULI GENERATOR
# ======================================================================

class PracticeStimuli:
    """
    Generates practice stimuli for the PPS experiment with balanced trial types:
    - Baseline trials (tactile only)
    - Catch trials (looming only)
    - PPS trials (both looming and tactile with random SOA)
    """
    
    def __init__(self, output_dir, looming_stimuli_paths, tactile_stimulus_path, base_audio_path=None, timestamps_csv_path=None):
        """
        Initialize the practice stimuli generator
        
        Args:
            output_dir: Directory to save practice stimuli
            looming_stimuli_paths: Dictionary of paths to looming stimuli sounds
            tactile_stimulus_path: Path to the tactile stimulus sound
            base_audio_path: Path to an audio file to get sample rate (not used for content)
            timestamps_csv_path: Not used in this simplified version
        """
        # Create output directory
        self.output_dir = output_dir
        os.makedirs(output_dir, exist_ok=True)
        
        # Store stimulus paths
        self.looming_stimuli = looming_stimuli_paths
        self.tactile_stimulus_path = tactile_stimulus_path
        
        # Set sample rate (typically 44100 Hz)
        self.sample_rate = 44100
        
        # Check if any audio file exists to get sample rate
        if base_audio_path:
            try:
                _, loaded_sample_rate = sf.read(base_audio_path)
                self.sample_rate = loaded_sample_rate
                print(f"Using sample rate from audio file: {self.sample_rate} Hz")
            except Exception as e:
                print(f"Could not load audio for sample rate: {e}")
                print(f"Using default sample rate: {self.sample_rate} Hz")
        
        # Load tactile stimulus
        try:
            self.tactile_data, _ = sf.read(tactile_stimulus_path)
            print(f"Tactile stimulus loaded: {tactile_stimulus_path}")
        except Exception as e:
            print(f"Error loading tactile stimulus: {e}")
            raise
        
        # For storing looming stimulus data
        self.loaded_looming_stimuli = {}
        for stim_type, path in self.looming_stimuli.items():
            try:
                self.loaded_looming_stimuli[stim_type], _ = sf.read(path)
                print(f"Looming stimulus loaded: {stim_type}")
            except Exception as e:
                print(f"Error loading looming stimulus {stim_type}: {e}")
    
    def generate_practice_iteration(self, iteration_num):
        """
        Generate a single practice iteration with 12 trials:
        - 4 baseline trials (tactile only)
        - 4 catch trials (looming only)
        - 4 PPS trials (looming + tactile with random SOA)
        
        Each trial is separated by only 4 seconds.
        
        Args:
            iteration_num: Iteration number (1-5)
        
        Returns:
            Tuple (looming_audio, tactile_audio, trial_log)
        """
        print(f"\nGenerating practice iteration {iteration_num}...")
        
        # Create trial types (4 of each type)
        trial_types = ['baseline'] * 4 + ['catch'] * 4 + ['pps'] * 4
        
        # Use a different random seed for each iteration to ensure different orderings
        np.random.seed(42 + iteration_num)
        np.random.shuffle(trial_types)
        
        # Create stimulus directions (balanced across types)
        stimulus_directions = ['left', 'front', 'right'] * 4
        np.random.shuffle(stimulus_directions)
        
        # Get the duration of the looming stimuli (use the first one as reference)
        first_looming_key = list(self.loaded_looming_stimuli.keys())[0]
        looming_duration_samples = len(self.loaded_looming_stimuli[first_looming_key])
        
        # Get the duration of the tactile stimulus
        if self.tactile_data.ndim == 1:  # If tactile data is mono
            tactile_stereo = np.column_stack((self.tactile_data, self.tactile_data))
        else:
            tactile_stereo = self.tactile_data
        tactile_duration_samples = len(tactile_stereo)
        
        # Calculate the total duration needed (in samples)
        # 4 seconds between trials + buffer at beginning and end
        interval_samples = 4 * self.sample_rate  # 4 seconds between trials
        max_stimulus_duration = max(looming_duration_samples, tactile_duration_samples)
        
        # Total duration: start buffer + (12 trials * (interval + max stimulus duration)) + end buffer
        total_samples = 2 * self.sample_rate + (12 * (interval_samples + max_stimulus_duration)) + 2 * self.sample_rate
        
        # Create two separate audio streams (initialize as silence)
        looming_audio = np.zeros((total_samples, 2))  # Stereo silence
        tactile_audio = np.zeros((total_samples, 2))  # Stereo silence
        
        # Track injection details
        injection_log = []
        
        # Process each trial
        for trial_idx, (trial_type, stim_direction) in enumerate(zip(trial_types, stimulus_directions)):
            # Calculate position for this trial (in samples)
            # Start with 2 second buffer, then add trial intervals
            position_samples = 2 * self.sample_rate + trial_idx * (interval_samples + max_stimulus_duration)
            position_ms = (position_samples / self.sample_rate) * 1000
            
            # Apply random jitter (100-500ms)
            jitter_ms = np.random.choice(JITTER_OPTIONS_MS)
            jittered_ms = position_ms + jitter_ms
            jittered_samples = int(jittered_ms / 1000 * self.sample_rate)
            
            # For PPS trials, generate a random SOA
            soa_ms = 0
            if trial_type == 'pps':
                soa_ms = np.random.randint(SOA_RANGE[0], SOA_RANGE[1])
            
            # Calculate injection points
            looming_injection_point = jittered_samples
            tactile_injection_point = jittered_samples + int((soa_ms / 1000) * self.sample_rate)
            
            # Perform the appropriate injections based on trial type
            if trial_type in ['catch', 'pps']:
                # Inject looming stimulus
                looming_data = self.loaded_looming_stimuli.get(stim_direction)
                if looming_data is not None and looming_injection_point + len(looming_data) <= len(looming_audio):
                    looming_audio[looming_injection_point:looming_injection_point+len(looming_data)] += looming_data
            
            if trial_type in ['baseline', 'pps']:
                # Inject tactile stimulus
                if tactile_injection_point + len(tactile_stereo) <= len(tactile_audio):
                    tactile_audio[tactile_injection_point:tactile_injection_point+len(tactile_stereo)] += tactile_stereo
            
            # Log trial details
            injection_log.append({
                'iteration_num': iteration_num,
                'trial_number': trial_idx + 1,
                'trial_type': trial_type,
                'stimulus_direction': stim_direction,
                'position_ms': position_ms,
                'jitter_ms': jitter_ms,
                'jittered_ms': jittered_ms,
                'soa_ms': soa_ms,
                'looming_injection_sample': looming_injection_point,
                'tactile_injection_sample': tactile_injection_point,
                'looming_time_seconds': looming_injection_point / self.sample_rate,
                'tactile_time_seconds': tactile_injection_point / self.sample_rate
            })
        
        return looming_audio, tactile_audio, pd.DataFrame(injection_log)
    
    def save_practice_files(self, iteration_num, looming_audio, tactile_audio, trial_log):
        """
        Save practice stimuli files for an iteration
        
        Args:
            iteration_num: Iteration number
            looming_audio: Modified audio with looming stimuli
            tactile_audio: Audio track with tactile stimuli
            trial_log: DataFrame with trial details
        
        Returns:
            Tuple of saved file paths
        """
        # Define filenames
        looming_filename = f"practice_iteration_{iteration_num}_looming.wav"
        tactile_filename = f"practice_iteration_{iteration_num}_tactile.wav"
        log_filename = f"practice_iteration_{iteration_num}_log.csv"
        
        # Create full paths
        looming_path = os.path.join(self.output_dir, looming_filename)
        tactile_path = os.path.join(self.output_dir, tactile_filename)
        log_path = os.path.join(self.output_dir, log_filename)
        
        # Save the audio files
        sf.write(looming_path, looming_audio, self.sample_rate)
        sf.write(tactile_path, tactile_audio, self.sample_rate)
        
        # Save the trial log
        trial_log.to_csv(log_path, index=False)
        
        print(f"Saved practice files for iteration {iteration_num}:")
        print(f"  Looming audio: {looming_filename}")
        print(f"  Tactile audio: {tactile_filename}")
        print(f"  Trial log: {log_filename}")
        
        return looming_path, tactile_path, log_path
    
    # Removed plotting function as requested


# ======================================================================
# MAIN FUNCTION
# ======================================================================

def main():
    """
    Main function to generate 5 different iterations of practice stimuli for the PPS experiment.
    
    Each iteration contains 12 trials (4 baseline, 4 catch, 4 PPS) with randomized order.
    For each iteration, two synchronized audio files are generated:
    - One for looming stimuli
    - One for tactile stimuli
    
    Trials are separated by 4 seconds for a much shorter practice session.
    """
    print("===== GENERATING PPS EXPERIMENT PRACTICE STIMULI =====")
    print(f"Creating {NUM_ITERATIONS} different iterations of practice stimuli")
    print(f"Each iteration will contain 12 trials with only 4 seconds between trials:")
    print(f"  - 4 baseline trials (tactile only)")
    print(f"  - 4 catch trials (looming only)")
    print(f"  - 4 PPS trials (both with random SOA)")
    print(f"Output directory: {PRACTICE_STIMULI_DIR}")
    
    # Initialize the practice stimuli generator
    practice_generator = PracticeStimuli(
        output_dir=PRACTICE_STIMULI_DIR,
        looming_stimuli_paths=LOOMING_STIMULI,
        tactile_stimulus_path=TACTILE_STIMULUS_PATH,
        base_audio_path=BASE_AUDIO_PATH,
        timestamps_csv_path=TIMESTAMPS_CSV_PATH
    )
    
    # Generate practice stimuli for each iteration
    for iteration in range(1, NUM_ITERATIONS + 1):
        try:
            # Generate the practice stimuli
            looming_audio, tactile_audio, trial_log = practice_generator.generate_practice_iteration(iteration)
            
            # Save the files
            practice_generator.save_practice_files(iteration, looming_audio, tactile_audio, trial_log)
            
            # Summary of this iteration
            print(f"\nIteration {iteration} summary:")
            trial_types = trial_log['trial_type'].value_counts()
            for trial_type, count in trial_types.items():
                print(f"  {trial_type}: {count} trials")
            
            # Show SOA distribution for PPS trials
            pps_trials = trial_log[trial_log['trial_type'] == 'pps']
            if not pps_trials.empty:
                min_soa = pps_trials['soa_ms'].min()
                max_soa = pps_trials['soa_ms'].max()
                mean_soa = pps_trials['soa_ms'].mean()
                print(f"  PPS trials SOA range: {min_soa}-{max_soa}ms (mean: {mean_soa:.1f}ms)")
            
        except Exception as e:
            print(f"Error generating practice iteration {iteration}: {e}")
    
    print("\n===== PRACTICE STIMULI GENERATION COMPLETE =====")
    print(f"Created {NUM_ITERATIONS} practice iterations with 12 trials each")
    print(f"All files saved to: {PRACTICE_STIMULI_DIR}")
    print("Each iteration contains two audio files (looming and tactile) with:")
    print("  - 4 baseline trials (tactile stimulus only)")
    print("  - 4 catch trials (looming stimulus only)")
    print("  - 4 PPS trials (both looming and tactile with random SOA)")
    print("Each iteration has a different randomized order of trials")
    print("There are only 4 seconds between trials for shorter practice sessions")
    print("\nEach CSV file records the exact timing of when looming and tactile stimuli are played")

if __name__ == "__main__":
    main()

===== GENERATING PPS EXPERIMENT PRACTICE STIMULI =====
Creating 5 different iterations of practice stimuli
Each iteration will contain 12 trials with only 4 seconds between trials:
  - 4 baseline trials (tactile only)
  - 4 catch trials (looming only)
  - 4 PPS trials (both with random SOA)
Output directory: C:\Users\cogpsy-vrlab\Documents\PPS_module\BreathingPilot\PracticeStimuli
Using sample rate from audio file: 48000 Hz
Tactile stimulus loaded: C:\Users\cogpsy-vrlab\Documents\PPS_module\BreathingPilot\PPS_Experiment_Module\tactile_stimulus.wav
Looming stimulus loaded: right
Looming stimulus loaded: front
Looming stimulus loaded: left

Generating practice iteration 1...
Saved practice files for iteration 1:
  Looming audio: practice_iteration_1_looming.wav
  Tactile audio: practice_iteration_1_tactile.wav
  Trial log: practice_iteration_1_log.csv

Iteration 1 summary:
  pps: 4 trials
  catch: 4 trials
  baseline: 4 trials
  PPS trials SOA range: 327-985ms (mean: 738.5ms)

Generating