# Test Phase Analysis Notebook

This notebook analyzes behavioral data from the cue-reward association test phase for a specific animal. It processes monitoring data, incorporates test-specific data (like reward sizes) from `.mat` files, performs various analyses, and generates visualizations.

In [None]:
# Cell 2: Imports and Setup (Revert to Inline Backend)

# --- Autoreload Extension ---
%load_ext autoreload
%autoreload 2
# --------------------------

import os
import sys
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.io import loadmat
from scipy.stats import sem
import json
import traceback

# Add module path
module_path = os.path.abspath(os.path.join('.'))
if module_path not in sys.path:
    sys.path.append(module_path)
    print(f"Appended module path: {module_path}")

# Import custom modules
try:
    from modules.data_loader import (extract_animal_id_from_folder, find_animal_sessions,
                                     extract_session_day, extract_date, load_test_data)
    from modules.signal_processing import (process_monitoring_data, detect_rising_edges,
                                           detect_falling_edges, calculate_cut_duration_in_window)
    from modules.session_analysis import analyze_test_session
    from modules.visualization import (visualize_animal_sessions, visualize_session_details,
                                       visualize_post_reward_licking, visualize_test_session_analysis,
                                       visualize_cross_session_test_analysis, visualize_time_resolved_licking,
                                       visualize_licking_by_prev_extremes,
                                       visualize_stress_comparison_by_prev_reward,
                                       visualize_stress_comparison_pooled,
                                       visualize_cross_session_time_resolved_licking,
                                       visualize_cross_session_prev_extremes_stress_summary)
    print("Modules imported successfully.")
except ImportError as e:
    print(f"Error importing modules: {e}")
    traceback.print_exc()

# Define base directories
BASE_DIR = r"E:\OneDrive - New York State Office of Information Technology Services\Ehsan\Data"
FIGURES_DIR = r"E:\OneDrive - New York State Office of Information Technology Services\Ehsan\Figures"

# --- Notebook Specific Setup ---
# Use standard inline backend
%matplotlib inline
# -----------------------------

# Optional: Configure plot appearance (less likely to conflict with inline)
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.figsize'] = (10, 6) # Restore original size
plt.rcParams['font.size'] = 12

print(f"Base Data Directory: {BASE_DIR}")
print(f"Base Figures Directory: {FIGURES_DIR}")
print("Setup complete with autoreload and inline backend enabled.")

In [None]:
# Cell 3: Specify Animal ID, Find Sessions, Categorize Phases, and Store Mapping

# --- Specify the Animal ID ---
animal_id = 'b6_e1_g9' # <--- Set to the correct ID
# ---------------------------

print(f"Analyzing data for animal: {animal_id}")

all_session_folders = [] # Initialize in case of errors below
session_phase_mapping = {} # <--- Dictionary to store folder_path: phase_name

try:
    # Find all session folders for this animal
    all_session_folders = find_animal_sessions(animal_id, BASE_DIR)

    if not all_session_folders:
        print(f"Warning: No session folders found for animal '{animal_id}' in '{BASE_DIR}'.")
    else:
        # Sort the found folders by date
        all_session_folders.sort(key=lambda x: extract_date(x) or '9999-99-99')
        print(f"Found {len(all_session_folders)} session folders (sorted chronologically).")

except Exception as e:
    print(f"Error finding or sorting animal sessions: {e}")
    traceback.print_exc()

# --- Phase Categorization and Mapping ---
if all_session_folders: # Only categorize if sessions were found
    # Define phase names (Habituation added)
    phase_names = [
        "Habituation", # <-- New Phase
        "Pre-Test Training",
        "Initial Test",
        "Stress Test (Pre-PL37)",
        "PL37 Test",
        "Post-PL37 Test",
        "Other/Uncategorized" # Catch-all
    ]
    # Initialize dictionary to hold lists of (folder_name, date_str) for summary printing
    phase_summary_data = {name: [] for name in phase_names}

    # Find the index of the first TEST, STRESS, and PL37 sessions
    first_test_idx = -1
    first_stress_idx = -1
    pl37_idx = -1

    for i, folder_path in enumerate(all_session_folders):
        folder_name_upper = os.path.basename(folder_path).upper()
        is_test = 'TEST' in folder_name_upper
        is_stress = 'STRESS' in folder_name_upper
        is_pl37 = 'PL37' in folder_name_upper
        is_hab = 'HAB' in folder_name_upper # <-- Check for Hab

        if is_test and first_test_idx == -1:
            first_test_idx = i
        if is_stress and first_stress_idx == -1:
             first_stress_idx = i
        if is_pl37:
             pl37_idx = i

    # Categorize based on indices and keywords, storing folder details AND mapping
    total_categorized = 0
    for i, folder_path in enumerate(all_session_folders):
        folder_name = os.path.basename(folder_path)
        folder_name_upper = folder_name.upper()
        date_str = extract_date(folder_path) or 'N/A'

        # Check keywords
        is_test = 'TEST' in folder_name_upper
        is_stress = 'STRESS' in folder_name_upper
        is_pl37 = 'PL37' in folder_name_upper
        is_hab = 'HAB' in folder_name_upper # <-- Check for Hab

        # Determine category (Habituation check added first)
        category = "Other/Uncategorized" # Default
        if is_hab: # <-- Prioritize Habituation
            category = "Habituation"
        elif first_test_idx == -1: # No test sessions found
             if not is_test: # Must be Pre-Test Training (and not Hab)
                 category = "Pre-Test Training"
        elif i < first_test_idx: # Before the first TEST session (and not Hab)
             category = "Pre-Test Training"
        elif is_test: # It's a TEST session
             if i == pl37_idx: # Check for PL37 first
                 category = "PL37 Test"
             elif pl37_idx != -1 and i > pl37_idx: # After PL37
                  category = "Post-PL37 Test"
             elif first_stress_idx == -1 or i < first_stress_idx: # Before stress starts
                  category = "Initial Test"
             elif i >= first_stress_idx: # Stress test before PL37
                  category = "Stress Test (Pre-PL37)"

        # Store the folder details for summary printing
        phase_summary_data[category].append((folder_name, date_str))
        # Store the mapping: folder_path -> phase_name <--- STORED HERE
        session_phase_mapping[folder_path] = category
        total_categorized += 1

    # Print Summary with details
    print("\n--- Experiment Phase Summary ---")
    for phase, session_list in phase_summary_data.items():
        count = len(session_list)
        if count > 0:
            session_label = "session" if count == 1 else "sessions"
            print(f"\n* {phase}: {count} {session_label}")
            for f_name, d_str in session_list:
                print(f"  - {f_name} (Date: {d_str})")
    print(f"\n--------------------------------")

    # Verification check
    if total_categorized != len(all_session_folders):
         print(f"Warning: Categorization count ({total_categorized}) does not match total sessions ({len(all_session_folders)}). Review logic.")
    else:
         print(f"Total sessions categorized and mapped: {total_categorized}")
    
    # Optional: Display the first few items of the mapping dictionary for verification
    # print("\nSession Phase Mapping (sample):")
    # for k, v in list(session_phase_mapping.items())[:5]:
    #     print(f"  '{os.path.basename(k)}': '{v}'")


# --- Halt Execution Check ---
if 'all_session_folders' not in locals() or not all_session_folders:
     print("\nHalting execution as no session folders were found.")
     assert False, "No session folders found. Cannot proceed."
elif 'session_phase_mapping' not in locals() or not session_phase_mapping:
     print("\nHalting execution as session phase mapping failed.")
     assert False, "Session phase mapping failed. Cannot proceed."
else:
     print("\nPhase categorization complete. Mapping stored in 'session_phase_mapping'.")



In [None]:

# Cell 4: Process Training Sessions

# --- Dependency Checks ---
if 'all_session_folders' not in locals() or not all_session_folders:
     print("Error: 'all_session_folders' not found. Please run Cell 3 first.")
     assert False, "Dependency error: Cell 3 must be run successfully."
if 'session_phase_mapping' not in locals() or not session_phase_mapping:
     print("Error: 'session_phase_mapping' not found. Please run Cell 3 first.")
     assert False, "Dependency error: Cell 3 must be run successfully."
# -------------------------

print("\nProcessing Training Sessions (excluding Habituation)...")
training_sessions_data = [] # List to store results for each processed training session
training_session_paths_processed = [] # List to store paths of training sessions attempted

# Iterate through the chronologically sorted folders
for session_folder_path in all_session_folders:
    session_name = os.path.basename(session_folder_path)
    # Get the phase for this session from the mapping created in Cell 3
    phase = session_phase_mapping.get(session_folder_path, "Other/Uncategorized")

    # Process only if the phase is 'Pre-Test Training'
    if phase == "Pre-Test Training":
        print(f"\nProcessing Training Session: {session_name} (Phase: {phase})")
        training_session_paths_processed.append(session_folder_path)
        try:
            # Analyze using the function from signal_processing module
            # This function should return a dictionary with session details including 'session_day'
            session_data = process_monitoring_data(session_folder_path)

            if session_data:
                # Add folder path and phase to the data for reference
                session_data['session_folder'] = session_folder_path
                session_data['phase'] = phase # Store the phase
                training_sessions_data.append(session_data)
                print(f"  Successfully processed Training Day {session_data.get('session_day', 'N/A')}")
            else:
                # process_monitoring_data might return None or an empty dict on failure/no data
                print(f"  Processing returned no data or failed for {session_name}.")
        except Exception as e:
            print(f"  ERROR processing training session {session_name}: {str(e)}")
            traceback.print_exc() # Print full traceback for debugging

# --- Summary of Training Session Processing ---
print(f"\nFinished processing potential training sessions.")
num_processed_successfully = len(training_sessions_data)
num_attempted = len(training_session_paths_processed)
print(f"Attempted processing for {num_attempted} sessions identified as 'Pre-Test Training'.")
print(f"Successfully processed and gathered data for {num_processed_successfully} training sessions.")

# Calculate the highest training day number from the processed data
# This is crucial for correctly renumbering the test days later.
valid_training_days = [s.get('session_day') for s in training_sessions_data if s.get('session_day') is not None]
num_training_days = max(valid_training_days) if valid_training_days else 0
print(f"\nHighest 'session_day' number found in processed training data: {num_training_days}")

if num_processed_successfully < num_attempted:
     print(f"Warning: {num_attempted - num_processed_successfully} training session(s) could not be processed successfully. Check logs above.")
if not training_sessions_data:
    print("Warning: No 'Pre-Test Training' session data was successfully processed.")


In [None]:
# Cell 5: Process Test Sessions and Display Saved Plots

# --- Dependency Checks ---
# (Keep checks as before)
if 'all_session_folders' not in locals() or not all_session_folders: assert False, "..."
if 'session_phase_mapping' not in locals() or not session_phase_mapping: assert False, "..."
if 'animal_id' not in locals(): assert False, "..."
if 'FIGURES_DIR' not in locals(): assert False, "..."
# -------------------------

import io
from contextlib import redirect_stdout
# --- Import IPython display tools ---
from IPython.display import Image, display
# ------------------------------------


print("\nProcessing Test Sessions...")
test_sessions_data = []
test_session_paths_attempted = []

VALID_TEST_PHASES = ["Initial Test", "Stress Test (Pre-PL37)", "PL37 Test", "Post-PL37 Test"]

# --- Plotting Setup ---
windows_to_analyze_post_reward = [
    {'label': '0-1s',  'start': 0.0, 'end': 1.0},
    {'label': '1-2s',  'start': 1.0, 'end': 2.0},
    {'label': '5-15s', 'start': 5.0, 'end': 15.0}
]
figures_dir_absolute = os.path.abspath(os.path.join(FIGURES_DIR, animal_id))
session_figures_dir = os.path.join(figures_dir_absolute, 'session_plots')
os.makedirs(session_figures_dir, exist_ok=True)
print(f"Individual session plots will be saved to: {session_figures_dir}")
# --- End plotting setup ---


# --- Main Loop ---
for session_folder_path in all_session_folders:
    session_name = os.path.basename(session_folder_path)
    phase = session_phase_mapping.get(session_folder_path, "Other/Uncategorized")

    if phase in VALID_TEST_PHASES:
        print(f"\nProcessing Test Session: {session_name} (Phase: {phase})")
        test_session_paths_attempted.append(session_folder_path)

        f = io.StringIO()
        session_data = None
        analysis_error = None
        try:
            with redirect_stdout(f): # Hide internal prints
                 session_data = analyze_test_session(session_folder_path)
        except Exception as e:
            print(f"    ERROR occurred during analyze_test_session call for {session_name}: {str(e)}")
            traceback.print_exc()
            analysis_error = e
        
        if session_data and not analysis_error:
            session_data['session_folder'] = session_folder_path
            session_data['phase'] = phase
            session_data['is_test_session'] = True
            test_sessions_data.append(session_data)
            original_day = session_data['session_info'].get('test_day', 'N/A')
            print(f"  Successfully analyzed Test Day {original_day}")

            # --- Generate Plots and Display Saved Files ---
            print(f"  Generating/Saving plots for {session_name}...")
            plt.close('all') # Close any previous figures before calling vis functions
            
            original_test_day_num = session_data['session_info'].get('test_day', 'unknown')
            session_label_for_plots = f"TEST_Day{original_test_day_num}"
            
            plot_filenames = {} # Dictionary to store expected filenames

            try:
                # 1. Post-Reward Licking
                if 'post_reward_lick_analysis' in session_data:
                    visualize_post_reward_licking(
                        session_data['post_reward_lick_analysis'], windows_to_analyze_post_reward,
                        animal_id, session_label_for_plots, session_figures_dir
                    )
                    plot_filenames['post_reward'] = os.path.join(session_figures_dir, f"{animal_id}_{session_label_for_plots}_post_reward_lick_duration.png")
                else: print("    Skipping post-reward lick plot (data not found).")

                # 2. Test Session Analysis Summary
                visualize_test_session_analysis(session_data, animal_id, session_figures_dir)
                plot_filenames['test_analysis'] = os.path.join(session_figures_dir, f"{animal_id}_{session_label_for_plots}_analysis.png")

                # 3. Session Details (Rasters etc.)
                visualize_session_details(session_data, animal_id, session_figures_dir)
                # Note: This might save TWO files (_session_analysis.png, _lick_sensor_cut_analysis.png)
                # We'll try displaying the main one. Check visualize_session_details if needed.
                plot_filenames['details'] = os.path.join(session_figures_dir, f"{animal_id}_{session_label_for_plots}_session_analysis.png")
                
                # 4. Time-Resolved Licking
                visualize_time_resolved_licking(session_data, animal_id, session_label_for_plots, session_figures_dir)
                plot_filenames['time_resolved'] = os.path.join(session_figures_dir, f"{animal_id}_{session_label_for_plots}_time_resolved_licking.png")

                # 5. Licking by Previous Extremes
                visualize_licking_by_prev_extremes(session_data, animal_id, session_label_for_plots, session_figures_dir)
                plot_filenames['prev_extremes'] = os.path.join(session_figures_dir, f"{animal_id}_{session_label_for_plots}_licking_by_prev_extremes.png")

                print(f"  Displaying saved plots for {session_name}:")
                # Loop through generated filenames and display if they exist
                for plot_type, filepath in plot_filenames.items():
                     if os.path.exists(filepath):
                          print(f"    - Displaying {plot_type} plot...")
                          display(Image(filename=filepath))
                     else:
                          print(f"    - Warning: Expected plot file not found: {filepath}")

            except Exception as plot_e:
                print(f"    ERROR during plot generation/display for {session_name}: {plot_e}")
                traceback.print_exc()
            # --- End Plot Generation/Display ---

        elif not session_data:
             print(f"  Analysis function returned no data or failed for {session_name}.")
             if analysis_error: print(f"    Analysis Error: {analysis_error}")


# --- Summary ---
# (Summary code remains the same)
# ...

In [None]:
# Cell 6: Renumber Test Days and Combine Data (Corrected Sorting)

# --- Dependency Checks ---
if 'training_sessions_data' not in locals():
     print("Warning: 'training_sessions_data' not found. Assuming 0 training sessions processed.")
     training_sessions_data = []
     num_training_days = 0
elif 'num_training_days' not in locals():
     print("Error: 'num_training_days' calculated in Cell 4. Cannot renumber test days.")
     assert False, "Dependency error: Cell 4 must calculate num_training_days."

if 'test_sessions_data' not in locals() or not test_sessions_data:
     print("Error: 'test_sessions_data' not found or empty. Please run Cell 5 first.")
     assert False, "Dependency error: Cell 5 must populate test_sessions_data."

# Check if extract_date is available (imported in Cell 2)
if 'extract_date' not in globals():
     print("Error: 'extract_date' function not found. Ensure Cell 2 ran correctly.")
     assert False, "Dependency error: extract_date needed for sorting."
# -------------------------

# --- Ensure test sessions are sorted chronologically BY DATE ---
print(f"\nSorting {len(test_sessions_data)} test sessions by date before renumbering...")
try:
    # Sort the list IN PLACE using the reliable extract_date function on the stored folder path
    test_sessions_data.sort(key=lambda x: extract_date(x.get('session_folder', '')) or '9999-99-99')
    print("Test sessions sorted chronologically. Order for renumbering:")
    # Optionally print the order to verify
    # for idx, s_data in enumerate(test_sessions_data):
    #     print(f"  {idx+1}. {os.path.basename(s_data.get('session_folder', 'N/A'))}")
except Exception as sort_e:
    print(f"ERROR during test session sorting: {sort_e}")
    traceback.print_exc()
    assert False, "Stopping due to sorting error."
# ----------------------------------------------------------------

print(f"\nRenumbering Test Session Days (Offset by {num_training_days} training days)...")
renumbered_test_sessions = []
# Renumbering loop will now process in correct chronological order
for i, session_data in enumerate(test_sessions_data):
    correct_sequential_test_day = i + 1
    overall_day = num_training_days + correct_sequential_test_day

    original_test_day_from_file = session_data['session_info'].get('test_day', 'N/A')
    session_data['original_test_day'] = original_test_day_from_file

    session_data['session_day'] = overall_day # Update session day
    renumbered_test_sessions.append(session_data)

    session_base_name = os.path.basename(session_data.get('session_folder', f'Unknown Session {i}'))
    print(f"  Mapping original Test Day {original_test_day_from_file} ({session_base_name}) "
          f"to sequential Test Day {correct_sequential_test_day} -> Overall Day {overall_day}")


# --- Combine Training and Renumbered Test Data ---
processed_sessions = (training_sessions_data if 'training_sessions_data' in locals() else []) + \
                     (renumbered_test_sessions if 'renumbered_test_sessions' in locals() else [])
processed_sessions.sort(key=lambda x: x.get('session_day', float('inf')))

print(f"\nCombined and sorted data for {len(processed_sessions)} sessions.")
print("Variable 'processed_sessions' now contains all analyzed session data.")

# --- Final Check ---
if not processed_sessions:
     print("Warning: No sessions (training or test) were successfully processed and combined.")
else:
    print(f"First session day in combined list: {processed_sessions[0].get('session_day')}")
    print(f"Last session day in combined list: {processed_sessions[-1].get('session_day')}")

In [None]:
# Cell 7: Generate Focused Cross-Session Summary Visualizations (Phase-Aware)

# --- Imports and Setup ---
import os
import matplotlib.pyplot as plt
from IPython.display import display, Image
import traceback

# --- MODIFICATION: Import REQUIRED Phase-Aware Functions ---
try:
    from modules.phase_aware_visualization import (
        PHASE_ORDER, PHASE_COLORS, 
        visualize_animal_sessions_phase_aware,
        visualize_pooled_comparison_four_phases, 
        visualize_pooled_stress_epoch_comparison,
        visualize_cross_session_test_analysis_with_phase_lines
    )
    print("Successfully imported required phase-aware visualization functions.")
except ImportError as e:
    print(f"ERROR: Could not import required functions from modules.phase_aware_visualization: {e}")
    print("Please ensure 'modules/phase_aware_visualization.py' contains the necessary functions.")
    assert False, "ImportError for phase-aware visualizations."
# ------------------------------------------------------------

# --- Dependency Checks ---
if 'processed_sessions' not in locals() or not processed_sessions: 
    assert False, "CRITICAL ERROR: 'processed_sessions' not found. Ensure Cell 6 ran successfully."
if 'renumbered_test_sessions' not in locals() or not renumbered_test_sessions: 
    assert False, "CRITICAL ERROR: 'renumbered_test_sessions' not found. Ensure Cell 6 ran successfully."
if 'animal_id' not in locals(): 
    assert False, "CRITICAL ERROR: 'animal_id' not defined. Ensure it's set in Cell 3."
if 'figures_dir_absolute' not in locals(): 
    assert False, "CRITICAL ERROR: 'figures_dir_absolute' not defined. Ensure it's set in Cell 5 (derived from Cell 2 & 3)."
if 'ANIMAL_PHASE_DAYS' not in locals(): 
    print("Warning: ANIMAL_PHASE_DAYS not found. Phase line annotations on some plots might be derived dynamically or missing if derivation fails.")
    ANIMAL_PHASE_DAYS = {} 
# ------------------------

# Use the renumbered test sessions for plots that focus only on test data with phase info.
test_data_for_summary = renumbered_test_sessions 

os.makedirs(figures_dir_absolute, exist_ok=True)

print("\\nGenerating Focused Phase-Aware Cross-Session Summary Plots...")
print(f"Plots will be saved to: {figures_dir_absolute}") 

summary_figure_paths = []

# --- Determine PL37 and Post-PL37 start days for the CURRENT animal ---
# Attempt to get from ANIMAL_PHASE_DAYS first, then derive if not found.
pl37_day_for_animal = ANIMAL_PHASE_DAYS.get(animal_id, {}).get("PL37_Start_Day")
first_post_pl37_day_for_animal = ANIMAL_PHASE_DAYS.get(animal_id, {}).get("Post_PL37_Start_Day")

# Derive PL37 start day if not predefined
if pl37_day_for_animal is None:
    print(f"PL37_Start_Day for {animal_id} not in ANIMAL_PHASE_DAYS. Attempting to derive from session data...")
    for session in test_data_for_summary: 
        if session.get('phase') == "PL37 Test" and session.get('session_day') is not None:
            pl37_day_for_animal = session['session_day']
            break 
    if pl37_day_for_animal is None:
        print(f"  Warning: Could not derive PL37_Start_Day for {animal_id}. Phase line might be absent or incorrect.")

# Derive Post-PL37 start day if not predefined
if first_post_pl37_day_for_animal is None:
    print(f"Post_PL37_Start_Day for {animal_id} not in ANIMAL_PHASE_DAYS. Attempting to derive from session data...")
    for session in test_data_for_summary: 
        if session.get('phase') == "Post-PL37 Test" and session.get('session_day') is not None:
            first_post_pl37_day_for_animal = session['session_day']
            break
    if first_post_pl37_day_for_animal is None:
        print(f"  Warning: Could not derive Post_PL37_Start_Day for {animal_id}. Phase line might be absent or incorrect.")

print(f"For animal {animal_id}:")
print(f"  PL37 Test Overall Start Day (used for plotting): {pl37_day_for_animal}")
print(f"  First Post-PL37 Test Overall Day (used for plotting): {first_post_pl37_day_for_animal}")
# ---------------------------------------------

try:
    plt.close('all') # Close any pre-existing plots

    # Plot 1: Overall Performance Across All Sessions (with Phase Lines)
    current_animal_all_sessions = [s for s in processed_sessions if animal_id in s.get('session_folder', '')]

    if not current_animal_all_sessions:
        print(f"CRITICAL WARNING: No sessions found for animal '{animal_id}'. Skipping plots.")
    else:
        print(f"Total sessions for animal '{animal_id}' to be used in plots: {len(current_animal_all_sessions)}")

        print("\\n1. Generating overall performance plot with phase lines (`visualize_animal_sessions_phase_aware`)...")
        
        # --- FIX: Handle list of paths returned by the function ---
        fig_paths_perf = visualize_animal_sessions_phase_aware(
            animal_sessions=current_animal_all_sessions, 
            animal_id=animal_id,
            output_dir=figures_dir_absolute,
            pl37_vline_day=pl37_day_for_animal,
            post_pl37_vline_day=first_post_pl37_day_for_animal
        )

        generated_paths = []
        if isinstance(fig_paths_perf, list):
            generated_paths.extend(fig_paths_perf)
        elif isinstance(fig_paths_perf, str):
            generated_paths.append(fig_paths_perf)

        if generated_paths:
            for path in generated_paths:
                if path and os.path.exists(path):
                    summary_figure_paths.append(path)
                    display(Image(filename=path))
                else:
                    print(f"   - Warning: Plot file not found: {path}")
        else: 
            print(f"   - Plot 1 (visualize_animal_sessions_phase_aware) did not return a valid path or file was not found.")
        # --- END OF FIX ---

    # --- Plots based only on TEST sessions for the CURRENT animal (using test_data_for_summary) ---
    if not test_data_for_summary:
        print(f"\\nNo TEST session data available for animal {animal_id}. Skipping subsequent plots.")
    else:
        plt.close('all')
        print("\\n2. Generating cross-session summary (5x2 grid) with phase lines...")
        fig_path_cross_lines = visualize_cross_session_test_analysis_with_phase_lines(
            test_sessions=test_data_for_summary,
            animal_id=animal_id,
            output_dir=figures_dir_absolute,
            pl37_vline_day=pl37_day_for_animal,
            post_pl37_vline_day=first_post_pl37_day_for_animal
        )
        if fig_path_cross_lines and os.path.exists(fig_path_cross_lines):
             summary_figure_paths.append(fig_path_cross_lines)
             display(Image(filename=fig_path_cross_lines))
        else: 
            print(f"   - Plot 2 (cross-session summary) did not return a valid path or file not found.")

        plt.close('all')
        print("\\n3. Generating pooled comparison across 4 phases...")
        fig_path_pooled_4phase = visualize_pooled_comparison_four_phases(
            test_sessions=test_data_for_summary, 
            animal_id=animal_id,
            output_dir=figures_dir_absolute
        )
        if fig_path_pooled_4phase and os.path.exists(fig_path_pooled_4phase):
            summary_figure_paths.append(fig_path_pooled_4phase)
            display(Image(filename=fig_path_pooled_4phase))
        else: 
            print(f"   - Plot 3 (pooled comparison) did not return a valid path or file not found.")

        plt.close('all')
        print(f"\\n4. Generating pooled stress epoch comparison...")
        pooled_stress_epoch_fig_path = visualize_pooled_stress_epoch_comparison(
            test_sessions=test_data_for_summary, 
            animal_id=animal_id,
            output_dir=figures_dir_absolute
        )
        if pooled_stress_epoch_fig_path and os.path.exists(pooled_stress_epoch_fig_path):
            summary_figure_paths.append(pooled_stress_epoch_fig_path)
            display(Image(filename=pooled_stress_epoch_fig_path))
        else:
            print(f'   - Skipping display of pooled stress epoch comparison as it was not generated.')

except Exception as e:
    print(f"\\nAn ERROR occurred during plot generation in Cell 7 for animal {animal_id}: {e}")
    traceback.print_exc()

print(f"\\nFinished generating focused phase-aware summary plots for {animal_id}.")
print(f"Generated {len(summary_figure_paths)} summary plot(s) in Cell 7 for {animal_id}.")

In [None]:

# Cell 8: Generate Phase Comparison Plot Grouped by Previous Reward Size

# --- Imports and Setup ---
import os
import matplotlib.pyplot as plt
from IPython.display import display, Image
import traceback

# --- Ensure phase-aware module is loaded and function exists ---
try:
    from modules.phase_aware_visualization import visualize_phase_comparison_by_prev_reward
    print("Successfully imported visualize_phase_comparison_by_prev_reward.")
except ImportError as e:
    print(f"ERROR: Could not import function: {e}")
    print("Please ensure 'modules/phase_aware_visualization.py' contains the visualize_phase_comparison_by_prev_reward function.")
    assert False, "ImportError for visualize_phase_comparison_by_prev_reward."
# ------------------------------------------------------------

# --- Dependency Checks ---
# Assuming these are set correctly from previous cells for the current animal_id
if 'renumbered_test_sessions' not in locals() or not renumbered_test_sessions: assert False, "'renumbered_test_sessions' list not found. Ensure Cell 6 ran."
if 'animal_id' not in locals(): assert False, "animal_id not defined." 
if 'figures_dir_absolute' not in locals(): assert False, "figures_dir_absolute not defined."
# ------------------------

# Use the renumbered test sessions which include 'phase' info
# Ensure this data pertains to the animal_id being processed
test_data_for_plot = renumbered_test_sessions 

print(f"\nGenerating Phase Comparison by Previous Reward Size plot for {animal_id}...")
print(f"Plots will be saved to: {figures_dir_absolute}") 

figure_generated = False
try:
    plt.close('all') # Close previous plots
    
    # --- Call the new plotting function ---
    new_fig_path = visualize_phase_comparison_by_prev_reward(
        test_sessions=test_data_for_plot, 
        animal_id=animal_id,
        output_dir=figures_dir_absolute
    )
    # --------------------------------------
    
    if new_fig_path and os.path.exists(new_fig_path):
        print(f"--- Phase Comparison by Previous Reward Figure for {animal_id} ---")
        display(Image(filename=new_fig_path))
        figure_generated = True
    else:
        print(f"Skipping display as figure was not generated or file not found for {animal_id}.")

except NameError as ne:
    print(f"\nERROR during plot generation (NameError): {ne}")
    traceback.print_exc()
except Exception as e:
    print(f"\nERROR during plot generation for animal {animal_id}: {e}")
    traceback.print_exc()

if figure_generated:
    print(f"\nSuccessfully generated Phase Comparison by Previous Reward plot for {animal_id}.")
else:
    print(f"\nFailed to generate Phase Comparison by Previous Reward plot for {animal_id}.")


In [None]:

# Cell 9: Generate Time-Resolved Licking Plots (By Phase and By Reward Size)

# --- Imports and Setup ---
import os
import matplotlib.pyplot as plt
from IPython.display import display, Image
import traceback

# --- Ensure phase-aware module is loaded and functions exist ---
try:
    # Import the required functions
    from modules.phase_aware_visualization import (
        visualize_time_resolved_licking_by_phase,
        visualize_time_resolved_licking_by_reward_size 
    )
    print("Successfully imported time-resolved licking functions.")
except ImportError as e:
    print(f"ERROR: Could not import function(s): {e}")
    print("Please ensure 'modules/phase_aware_visualization.py' contains the necessary functions.")
    assert False, "ImportError for time-resolved licking visualizations."
# ------------------------------------------------------------

# --- Dependency Checks ---
# Assuming these are set correctly from previous cells for the current animal_id
if 'renumbered_test_sessions' not in locals() or not renumbered_test_sessions: assert False, "'renumbered_test_sessions' list not found. Ensure Cell 6 ran."
if 'animal_id' not in locals(): assert False, "animal_id not defined." 
if 'figures_dir_absolute' not in locals(): assert False, "figures_dir_absolute not defined."
# ------------------------

# Use the renumbered test sessions which include 'phase' info
# Ensure this data pertains to the animal_id being processed
test_data_for_plots = renumbered_test_sessions 

print(f"\nGenerating Time-Resolved Licking Plots for {animal_id}...")
print(f"Plots will be saved to: {figures_dir_absolute}") 

# --- Plot 1: Grouped by Current Reward Size (5x1) ---
print("\n1. Generating Time-Resolved Licking by Phase (Grouped by Current Reward)...")
figure_1_generated = False
try:
    plt.close('all') # Close previous plots
    
    time_resolved_fig_path_1 = visualize_time_resolved_licking_by_phase(
        test_sessions=test_data_for_plots, 
        animal_id=animal_id,
        output_dir=figures_dir_absolute
    )
    
    if time_resolved_fig_path_1 and os.path.exists(time_resolved_fig_path_1):
        print(f"--- Time-Resolved Licking by Phase (Grouped by Current Reward) Figure for {animal_id} ---")
        display(Image(filename=time_resolved_fig_path_1))
        figure_1_generated = True
    else:
        print(f"Skipping display as figure 1 was not generated or file not found for {animal_id}.")

except NameError as ne:
    print(f"\nERROR during Plot 1 generation (NameError): {ne}")
    traceback.print_exc()
except Exception as e:
    print(f"\nERROR during Plot 1 generation for animal {animal_id}: {e}")
    traceback.print_exc()

# --- Plot 2: Grouped by Phase (1x4) ---
print("\n2. Generating Time-Resolved Licking by Reward Size (Grouped by Phase)...")
figure_2_generated = False
try:
    plt.close('all') # Close previous plots

    time_resolved_fig_path_2 = visualize_time_resolved_licking_by_reward_size(
        test_sessions=test_data_for_plots, 
        animal_id=animal_id,
        output_dir=figures_dir_absolute
    )
    
    if time_resolved_fig_path_2 and os.path.exists(time_resolved_fig_path_2):
        print(f"--- Time-Resolved Licking by Reward Size (Grouped by Phase) Figure for {animal_id} ---")
        display(Image(filename=time_resolved_fig_path_2))
        figure_2_generated = True
    else:
        print(f"Skipping display as figure 2 was not generated or file not found for {animal_id}.")

except NameError as ne:
    print(f"\nERROR during Plot 2 generation (NameError): {ne}")
    traceback.print_exc()
except Exception as e:
    print(f"\nERROR during Plot 2 generation for animal {animal_id}: {e}")
    traceback.print_exc()


# --- Summary ---
print("\nFinished generating Time-Resolved Licking plots.")
if figure_1_generated: print("  - Plot 1 (Grouped by Current Reward) generated.")
else: print("  - Plot 1 (Grouped by Current Reward) FAILED.")
if figure_2_generated: print("  - Plot 2 (Grouped by Phase) generated.")
else: print("  - Plot 2 (Grouped by Phase) FAILED.")


In [None]:

# Cell 10: Focused Pre-Stress vs. Stress Comparison for Previous Extremes Summary

# --- Imports and Setup (ensure these are loaded from previous cells) ---
# import os # Should be imported (e.g. Cell 2)
# import matplotlib.pyplot as plt # Should be imported (e.g. Cell 2)
# from IPython.display import display, Image # Should be imported (e.g. Cell 5, 7)
# import traceback # Already imported in Cell 2
# Ensure `visualize_cross_session_prev_extremes_stress_summary` is available (imported in Cell 2)
# Ensure `renumbered_test_sessions`, `animal_id`, `figures_dir_absolute` are available

print("\\n--- Generating Focused Previous Extremes Summary: Initial Test vs. Stress Test (Pre-PL37) ---")

# --- Define phases of interest for this specific comparison ---
PRE_STRESS_PHASE_NAME = "Initial Test"
STRESS_PHASE_NAME = "Stress Test (Pre-PL37)"
target_phases_for_comparison = [PRE_STRESS_PHASE_NAME, STRESS_PHASE_NAME]

# --- Filter session data for the specified phases ---
# Assumes 'renumbered_test_sessions' is a list of dictionaries, each with a 'phase' key.
focused_sessions_data = [
    session for session in renumbered_test_sessions
    if session.get('phase') in target_phases_for_comparison
]

# --- Report the number of sessions included in each condition ---
num_pre_stress_sessions = sum(1 for session in focused_sessions_data if session.get('phase') == PRE_STRESS_PHASE_NAME)
num_stress_sessions = sum(1 for session in focused_sessions_data if session.get('phase') == STRESS_PHASE_NAME)

print(f"Number of '{PRE_STRESS_PHASE_NAME}' (pre-stress) sessions included: {num_pre_stress_sessions}")
print(f"Number of '{STRESS_PHASE_NAME}' (stress) sessions included: {num_stress_sessions}")

# --- Generate and Display the Plot ---
if not focused_sessions_data:
    print(f"Warning: No sessions found for the specified phases: {target_phases_for_comparison}. Skipping plot generation.")
elif num_pre_stress_sessions == 0 or num_stress_sessions == 0:
    print(f"Warning: At least one of the required phases ({PRE_STRESS_PHASE_NAME} or {STRESS_PHASE_NAME}) has no sessions. Plot might be misleading or fail. Skipping plot generation.")
else:
    try:
        plt.close('all') # Close any pre-existing Matplotlib figures

        print(f"\\nGenerating focused 'previous extremes stress summary' plot for animal: {animal_id}...")
        
        # Call the visualization function with the filtered data
        # This function is expected to save the plot and return its path.
        fig_path_focused = visualize_cross_session_prev_extremes_stress_summary(
            focused_sessions_data, # Use the filtered data
            animal_id,
            figures_dir_absolute 
            # If your function supports a suffix to make the filename unique for this specific plot:
            # plot_suffix="_initial_vs_stress_prePL37" 
        )

        if fig_path_focused and os.path.exists(fig_path_focused):
            print(f"  Focused plot generated: {fig_path_focused}")
            print(f"  Displaying focused plot for animal {animal_id}...")
            display(Image(filename=fig_path_focused)) # This line displays the image
        else:
            print("  - Focused 'previous extremes stress summary' plot was not generated or the file was not found.")
            if fig_path_focused: 
                 print(f"    Expected path: {fig_path_focused}")

    except NameError as ne:
        print(f"  ERROR during focused plot generation (NameError): {ne}")
        print("  Please ensure `renumbered_test_sessions`, `animal_id`, `figures_dir_absolute`, "
              "`plt`, `os`, `display`, `Image`, `traceback`, "
              "and the `visualize_cross_session_prev_extremes_stress_summary` function are defined and available.")
        traceback.print_exc()
    except Exception as e:
        print(f"  ERROR during focused plot generation for animal {animal_id}: {e}")
        traceback.print_exc()

print("\\n--- Finished generating focused previous extremes summary plot ---")

In [None]:
# Cell 10: Calculate Averaged Time-Resolved Lick Traces (Simplified - Direct Mapping)

import numpy as np
import pandas as pd 
from scipy.stats import sem
import os
import time 
# Imports needed for plotting if you uncomment the sanity check plots:
import matplotlib.pyplot as plt
from IPython.display import display, Image
import matplotlib.colors as mcolors # Needed for reward colors in plot 2

print("\\n--- Initializing Simplified Calculation of Averaged Lick Traces ---")

# --- 1. Define Constants and Configuration ---

# Assuming animal_id and figures_dir_absolute are available from previous cells.
# Fallback definitions if not found (for robustness, but ideally they are set)
if 'animal_id' not in locals():
    print("Warning: 'animal_id' not found in global scope. Using a placeholder.")
    animal_id = "unknown_animal"
if 'figures_dir_absolute' not in locals():
    print("Warning: 'figures_dir_absolute' not found. Using current directory for saving plots.")
    figures_dir_absolute = "." 
    # Ensure the base figures directory exists if using fallback
    os.makedirs(figures_dir_absolute, exist_ok=True)


# Phase Definitions
PHASE_ORDER = ["Initial Test", "Stress Test (Pre-PL37)", "PL37 Test", "Post-PL37 Test"]
# Define PHASE_COLORS here for self-containment, or ensure it's loaded
if 'PHASE_COLORS' not in locals():
    print("Defining default PHASE_COLORS")
    PHASE_COLORS = {
        "Initial Test": "cornflowerblue", 
        "Stress Test (Pre-PL37)": "salmon",
        "PL37 Test": "mediumseagreen",
        "Post-PL37 Test": "orchid",
        "Other/Uncategorized": "grey" 
    }

# Reward Definitions
# Your data uses numerical sizes 1, 2, 3, 4, 5
REWARD_CATEGORY_STRINGS = ["Very Small", "Small", "Medium", "Large", "Very Large"] 
# Direct map from your numerical size to the desired string category
NUMERICAL_SIZE_TO_CATEGORY_STRING_MAP = {
    1: "Very Small", 
    2: "Small", 
    3: "Medium", 
    4: "Large", 
    5: "Very Large"
}
# Map string category back to numerical size (needed for consistent colormapping)
CATEGORY_STRING_TO_NUMERICAL_SIZE_MAP = {v: k for k, v in NUMERICAL_SIZE_TO_CATEGORY_STRING_MAP.items()}

# Define REWARD_COLORS here for self-containment, or ensure it's loaded
if 'REWARD_COLORS' not in locals():
    print("Defining default REWARD_COLORS")
    REWARD_COLORS = {
        "Very Small": "#FDBA74", 
        "Small":      "#A5B4FC", 
        "Medium":     "#6EE7B7", 
        "Large":      "#FDE047", 
        "Very Large": "#F0ABFC"  
    }
# List of numerical sizes we expect based on the map keys
EXPECTED_NUMERICAL_SIZES = list(NUMERICAL_SIZE_TO_CATEGORY_STRING_MAP.keys()) 

# Analysis Parameters (matching module for consistency)
ANALYSIS_START_TIME_REL_REWARD = -2.0
ANALYSIS_END_TIME_REL_REWARD = 8.0 
WINDOW_SIZE = 0.1  # Rolling window for smoothing hires trace
STEP_SIZE = 0.02   # Step size for final time axis
TIME_RESOLUTION = 0.001 # Internal hires time resolution

# Shared Calculation Setup (matching module)
time_points_relative_to_reward = np.arange(ANALYSIS_START_TIME_REL_REWARD + STEP_SIZE,
                                           ANALYSIS_END_TIME_REL_REWARD + STEP_SIZE,
                                           STEP_SIZE)
time_points_relative_to_reward = time_points_relative_to_reward[time_points_relative_to_reward <= ANALYSIS_END_TIME_REL_REWARD + 1e-9] 
num_target_time_points = len(time_points_relative_to_reward)
hires_trace_duration_rel_reward = ANALYSIS_END_TIME_REL_REWARD - ANALYSIS_START_TIME_REL_REWARD
num_hires_points_total_analysis = int(np.round(hires_trace_duration_rel_reward / TIME_RESOLUTION))
target_indices_in_hires = np.round((time_points_relative_to_reward - ANALYSIS_START_TIME_REL_REWARD) / TIME_RESOLUTION).astype(int) - 1
target_indices_in_hires = np.clip(target_indices_in_hires, 0, num_hires_points_total_analysis - 1)
win_samples = int(np.round(WINDOW_SIZE / TIME_RESOLUTION)); win_samples = max(1, win_samples)

# --- 2. Copied Helper Function `calculate_single_trial_ts` from Module ---
# (This function needs to be defined here as Cell 11 depends on it)
def calculate_single_trial_ts(mtrial, reward_latency):
    lick_downs = mtrial.get('lick_downs_relative', []); lick_ups = mtrial.get('lick_ups_relative', [])
    lick_downs = lick_downs if lick_downs is not None else [] ; lick_ups = lick_ups if lick_ups is not None else []
    analysis_start_rel_trial = reward_latency + ANALYSIS_START_TIME_REL_REWARD
    analysis_end_rel_trial = reward_latency + ANALYSIS_END_TIME_REL_REWARD
    max_time_rel_trial_needed = analysis_end_rel_trial + TIME_RESOLUTION 
    num_hires_points_trial = int(np.ceil(max_time_rel_trial_needed / TIME_RESOLUTION))
    lick_trace_trial = np.zeros(num_hires_points_trial, dtype=np.int8)
    lick_downs_sorted = sorted([ld for ld in lick_downs if isinstance(ld, (int, float, np.number))])
    lick_ups_sorted = sorted([lu for lu in lick_ups if isinstance(lu, (int, float, np.number))])
    num_bouts = min(len(lick_downs_sorted), len(lick_ups_sorted))
    for i in range(num_bouts):
        start_time = lick_downs_sorted[i]; end_time = lick_ups_sorted[i]
        if end_time > start_time:
            start_idx = int(np.floor((start_time + 1e-9) / TIME_RESOLUTION))
            end_idx = int(np.ceil(end_time / TIME_RESOLUTION))
            start_idx = max(0, start_idx); end_idx = min(num_hires_points_trial, end_idx)
            if start_idx < end_idx: lick_trace_trial[start_idx:end_idx] = 1
    idx_analysis_start_in_trial_trace = int(np.round(analysis_start_rel_trial / TIME_RESOLUTION))
    idx_analysis_end_in_trial_trace = idx_analysis_start_in_trial_trace + num_hires_points_total_analysis
    actual_start_idx = max(0, idx_analysis_start_in_trial_trace)
    actual_end_idx = min(len(lick_trace_trial), idx_analysis_end_in_trial_trace)
    extracted_trace = lick_trace_trial[actual_start_idx:actual_end_idx]
    pad_before = max(0, -idx_analysis_start_in_trial_trace)
    pad_after = max(0, num_hires_points_total_analysis - (len(extracted_trace) + pad_before))
    analysis_trace_hires = np.pad(extracted_trace, (pad_before, pad_after), 'constant', constant_values=0)
    if len(analysis_trace_hires) != num_hires_points_total_analysis: return np.full(num_target_time_points, np.nan)
    # Use pandas Series for rolling mean calculation
    rolling_proportions_series = pd.Series(analysis_trace_hires)
    rolling_proportions = rolling_proportions_series.rolling(window=win_samples, min_periods=1, center=True).mean().to_numpy()
    valid_indices_mask = (target_indices_in_hires >= 0) & (target_indices_in_hires < len(rolling_proportions))
    valid_target_indices_for_sampling = target_indices_in_hires[valid_indices_mask]
    sampled_proportions = np.full(num_target_time_points, np.nan)
    if np.any(valid_indices_mask): sampled_proportions[valid_indices_mask] = rolling_proportions[valid_target_indices_for_sampling]
    return sampled_proportions

# --- 3. Data Pooling and Aggregation Function (Direct Mapping) ---
def calculate_averaged_lick_traces_direct(sessions_data, 
                                         size_to_cat_map, 
                                         target_reward_cats, 
                                         target_phases, 
                                         time_axis, 
                                         single_trial_func):
    """
    Pools trial data, calculates traces using single_trial_func, 
    and aggregates to mean/SEM directly keyed by (phase_str, reward_category_str).
    """
    print("Pooling single trial time series data (Direct Mapping)...")
    aggregated_trial_traces = {}
    for phase in target_phases:
        for reward_cat in target_reward_cats:
            aggregated_trial_traces[(phase, reward_cat)] = []

    sessions_processed_count = 0; trials_pooled_count = 0
    skipped_reward_map = 0; skipped_latency = 0; skipped_calc_error = 0

    for session_idx, session in enumerate(sessions_data):
        session_phase = session.get('phase')
        trial_results = session.get('trial_results')
        if not session_phase or session_phase not in target_phases: continue
        if not trial_results or not isinstance(trial_results, list): continue
        sessions_processed_count +=1

        for trial_idx, trial_dict in enumerate(trial_results):
            if not isinstance(trial_dict, dict) or not trial_dict.get('rewarded', False): continue
            numerical_reward = trial_dict.get('reward_size')
            reward_latency = trial_dict.get('reward_latency')
            reward_category_str = size_to_cat_map.get(numerical_reward) # Map numerical size to string category

            if not reward_category_str or reward_category_str not in target_reward_cats:
                skipped_reward_map += 1; continue
            if reward_latency is None:
                skipped_latency += 1; continue
            
            try:
                trial_trace = single_trial_func(trial_dict, reward_latency)
                if trial_trace is not None and not np.all(np.isnan(trial_trace)):
                    aggregated_trial_traces[(session_phase, reward_category_str)].append(trial_trace)
                    trials_pooled_count +=1
                else: skipped_calc_error += 1
            except Exception as e: skipped_calc_error += 1; pass 

    print(f"Pooling Summary: Processed {sessions_processed_count} sessions, pooled {trials_pooled_count} traces. (Skipped: RewardMap={skipped_reward_map}, Latency={skipped_latency}, CalcErr/NaN={skipped_calc_error})")

    averaged_lick_traces_final = {}
    num_target_time_points_agg = len(time_axis)
    for (phase, reward_cat), traces_list in aggregated_trial_traces.items():
        valid_traces = [tr for tr in traces_list if isinstance(tr, np.ndarray) and tr.shape==(num_target_time_points_agg,) and not np.all(np.isnan(tr))]
        n_valid_trials = len(valid_traces)
        mean_ts = np.full(num_target_time_points_agg, np.nan); sem_ts = np.full(num_target_time_points_agg, np.nan)
        if n_valid_trials > 0:
            traces_array = np.array(valid_traces); mean_ts = np.nanmean(traces_array, axis=0)
            if n_valid_trials > 1:
                with np.errstate(invalid='ignore'): sem_ts = sem(traces_array, axis=0, nan_policy='omit')
                sem_ts = np.nan_to_num(sem_ts, nan=0.0)
            else: sem_ts = np.zeros_like(mean_ts)
        averaged_lick_traces_final[(phase, reward_cat)] = {
            'time_axis': time_axis, 'mean_lick_probability': mean_ts,
            'sem_lick_probability': sem_ts, 'n_trials': n_valid_trials }
    print(f"Aggregation complete. Produced {len(averaged_lick_traces_final)} averaged lick traces.")
    return averaged_lick_traces_final

# --- 4. Execute Data Calculation ---
averaged_lick_traces = {} # Initialize final output dictionary
# Ensure renumbered_test_sessions exists from Cell 6 before proceeding
if 'renumbered_test_sessions' not in locals() or not renumbered_test_sessions:
    print("ERROR: `renumbered_test_sessions` is not defined. Please run previous cells (especially Cell 6).")
else:
    start_calc_time = time.time()
    averaged_lick_traces = calculate_averaged_lick_traces_direct(
        renumbered_test_sessions,
        NUMERICAL_SIZE_TO_CATEGORY_STRING_MAP, # Use the direct map
        REWARD_CATEGORY_STRINGS, 
        PHASE_ORDER, 
        time_points_relative_to_reward, # Pass the final time axis
        calculate_single_trial_ts # Pass the copied helper
    )
    end_calc_time = time.time()
    print(f"\nCalculation finished in {end_calc_time - start_calc_time:.2f} seconds.")
    print(f"Final `averaged_lick_traces` dictionary created with {len(averaged_lick_traces)} entries.")

# --- 5. Sanity Check Plotting (Optional - Uncomment to Run) ---
# (Plotting code remains the same as the last corrected version, using Viridis for Plot 2)

# print("\\n--- Generating Sanity Check Plots (Optional) ---")
# if 'averaged_lick_traces' in locals() and averaged_lick_traces: # Check if dict exists and is not empty
#     os.makedirs(figures_dir_absolute, exist_ok=True) # Ensure dir exists
#     # Plot 1: Licking by Phase (grouped by Current Reward Category)
#     if REWARD_CATEGORY_STRINGS:
#         num_reward_cats = len(REWARD_CATEGORY_STRINGS)
#         fig1_height = max(5, 2.5 * num_reward_cats)
#         fig1, axes1 = plt.subplots(num_reward_cats, 1, figsize=(12, fig1_height), sharex=True, sharey=True)
#         if num_reward_cats == 1: axes1 = [axes1]
#
#         for i, reward_cat in enumerate(REWARD_CATEGORY_STRINGS):
#             ax = axes1[i] ; ax.set_title(f"Reward: {reward_cat}"); ax.set_ylabel("Mean Lick Prob.")
#             ax.grid(True, linestyle=':', alpha=0.7); ax.axvline(0, color='k', linestyle='--', lw=1, alpha=0.8, label="Reward (t=0)")
#             ax.set_ylim(bottom=-0.05) ; has_data_in_subplot = False
#             for phase in PHASE_ORDER:
#                 data = averaged_lick_traces.get((phase, reward_cat)) 
#                 if data and data['n_trials'] > 0 and not np.all(np.isnan(data['mean_lick_probability'])):
#                     ax.plot(data['time_axis'], data['mean_lick_probability'], label=f"{phase} (n={data['n_trials']})", color=PHASE_COLORS.get(phase, 'k'), linewidth=1.5) # Uses PHASE_COLORS
#                     if data['n_trials'] > 1 and not np.all(np.isnan(data['sem_lick_probability'])):
#                         ax.fill_between(data['time_axis'], data['mean_lick_probability'] - data['sem_lick_probability'], data['mean_lick_probability'] + data['sem_lick_probability'], color=PHASE_COLORS.get(phase, 'k'), alpha=0.2)
#                     has_data_in_subplot = True
#             if has_data_in_subplot: ax.legend(fontsize='small', loc='upper right')
#
#         if num_reward_cats > 0 : axes1[-1].set_xlabel("Time from Reward (s)"); axes1[-1].set_xlim(time_points_relative_to_reward[0], time_points_relative_to_reward[-1])
#         fig1.suptitle(f"{animal_id} - Sanity Check: Licking by Phase (Grouped by Current Reward)", fontsize=16, y=0.99)
#         plt.tight_layout(rect=[0, 0.03, 1, 0.96])
#         plot1_filename_sc = os.path.join(figures_dir_absolute, f"{animal_id}_sanity_check_lick_by_phase.png")
#         plt.savefig(plot1_filename_sc); print(f"Saved Sanity Check Plot 1: {plot1_filename_sc}"); 
#         display(Image(filename=plot1_filename_sc)); 
#         plt.close(fig1) # Close figure object
#
#     # Plot 2: Licking by Reward Category (grouped by Phase)
#     if PHASE_ORDER:
#         num_phases = len(PHASE_ORDER)
#         fig2_height = max(5, 2.5 * num_phases)
#         fig2, axes2 = plt.subplots(num_phases, 1, figsize=(12, fig2_height), sharex=True, sharey=True)
#         if num_phases == 1: axes2 = [axes2]
#
#         reward_size_cmap = plt.cm.viridis 
#         num_sizes_for_norm = [k for k,v in NUMERICAL_SIZE_TO_CATEGORY_STRING_MAP.items() if v in REWARD_CATEGORY_STRINGS]
#         if num_sizes_for_norm: reward_size_norm = mcolors.Normalize(vmin=min(num_sizes_for_norm), vmax=max(num_sizes_for_norm))
#         else: reward_size_norm = mcolors.Normalize(vmin=1, vmax=5) 
#
#         for i, phase in enumerate(PHASE_ORDER):
#             ax = axes2[i] ; ax.set_title(f"Phase: {phase}"); ax.set_ylabel("Mean Lick Prob.")
#             ax.grid(True, linestyle=':', alpha=0.7); ax.axvline(0, color='k', linestyle='--', lw=1, alpha=0.8, label="Reward (t=0)")
#             ax.set_ylim(bottom=-0.05) ; has_data_in_subplot = False
#             
#             for reward_cat in REWARD_CATEGORY_STRINGS: 
#                 data = averaged_lick_traces.get((phase, reward_cat))
#                 if data and data['n_trials'] > 0 and not np.all(np.isnan(data['mean_lick_probability'])):
#                     numerical_size = CATEGORY_STRING_TO_NUMERICAL_SIZE_MAP.get(reward_cat) 
#                     if numerical_size is not None: plot_color = reward_size_cmap(reward_size_norm(numerical_size))
#                     else: plot_color = 'grey' 
#
#                     ax.plot(data['time_axis'], data['mean_lick_probability'], label=f"{reward_cat} (n={data['n_trials']})", color=plot_color, linewidth=1.5)
#                     if data['n_trials'] > 1 and not np.all(np.isnan(data['sem_lick_probability'])):
#                         ax.fill_between(data['time_axis'], data['mean_lick_probability'] - data['sem_lick_probability'], data['mean_lick_probability'] + data['sem_lick_probability'], color=plot_color, alpha=0.2)
#                     has_data_in_subplot = True
#             if has_data_in_subplot: ax.legend(fontsize='small', loc='upper right')
#
#         if num_phases > 0: axes2[-1].set_xlabel("Time from Reward (s)"); axes2[-1].set_xlim(time_points_relative_to_reward[0], time_points_relative_to_reward[-1])
#         fig2.suptitle(f"{animal_id} - Sanity Check: Licking by Reward (Grouped by Phase)", fontsize=16, y=0.99)
#         plt.tight_layout(rect=[0, 0.03, 1, 0.96])
#         plot2_filename_sc = os.path.join(figures_dir_absolute, f"{animal_id}_sanity_check_lick_by_reward.png")
#         plt.savefig(plot2_filename_sc); print(f"Saved Sanity Check Plot 2: {plot2_filename_sc}"); 
#         display(Image(filename=plot2_filename_sc)); 
#         plt.close(fig2) # Close figure object
# else:
#    print("No averaged lick traces calculated. Skipping sanity check plots.")


# --- End of Cell ---
print(f"\nVariable `averaged_lick_traces` is ready for AUC analysis in the next cell.")


In [None]:
# Cell 11: Calculate and Visualize AUC with SEM (Adding Session/Trial Counts)

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from scipy.stats import sem
from collections import defaultdict # Use defaultdict for easier list creation
from IPython.display import display, Image

print("\\n--- Initializing AUC Calculation with SEM and Visualization ---")

# --- Prerequisites (Ensure these variables/functions are available from Cell 10) ---
# Needed: 
#   renumbered_test_sessions, NUMERICAL_SIZE_TO_CATEGORY_STRING_MAP,
#   REWARD_CATEGORY_STRINGS, PHASE_ORDER, time_points_relative_to_reward,
#   calculate_single_trial_ts (function), animal_id, figures_dir_absolute,
#   PHASE_COLORS, REWARD_COLORS (or redefine below)

# Check prerequisites exist
prerequisites_ok = True
required_vars = ['renumbered_test_sessions', 'NUMERICAL_SIZE_TO_CATEGORY_STRING_MAP', 
                 'REWARD_CATEGORY_STRINGS', 'PHASE_ORDER', 'time_points_relative_to_reward',
                 'calculate_single_trial_ts', 'animal_id', 'figures_dir_absolute', 
                 'PHASE_COLORS', 'REWARD_COLORS']
for var in required_vars:
    if var not in locals() or (isinstance(locals()[var], (list, dict)) and not locals()[var]):
        print(f"ERROR: Prerequisite '{var}' is missing or empty. Ensure Cell 10 ran successfully.")
        prerequisites_ok = False
if not callable(locals().get('calculate_single_trial_ts')):
     print(f"ERROR: Prerequisite function 'calculate_single_trial_ts' is missing.")
     prerequisites_ok = False

# --- 1. Define AUC Time Windows ---
AUC_WINDOWS = {
    "Pre-Reward": (-2.0, 0.0),
    "Consumption": (0.0, 0.5),
    "Post-Consumption": (0.5, 6.0) 
}
print(f"Using AUC Windows: {AUC_WINDOWS}")

# --- 2. AUC Calculation Function (for single trace) ---
# (Keep the calculate_auc_single_trace function as defined previously)
def calculate_auc_single_trace(time_axis, lick_trace, window_start, window_end):
    """Calculates Area Under Curve for a single lick trace within a time window."""
    try:
        window_indices = np.where((time_axis >= window_start) & (time_axis <= window_end))[0]
        if len(window_indices) < 2: return 0.0 
        time_subset = time_axis[window_indices]
        trace_subset = lick_trace[window_indices]
        if np.any(np.isnan(trace_subset)): return np.nan 
        auc_value = np.trapz(trace_subset, time_subset)
        return auc_value
    except Exception: return np.nan

# --- 3. Recalculate/Aggregate Trial Traces and Compute AUC Stats ---
auc_stats_results = []

if prerequisites_ok:
    print("Processing sessions to calculate AUC statistics per trial...")
    start_auc_calc_time = time.time()
    
    # Dictionary to hold lists of tuples: (session_folder, trial_auc_value)
    # Key: (phase, reward_cat, window_name)
    trial_auc_values_by_condition_window = defaultdict(list)

    sessions_processed_count = 0; trials_processed_count = 0
    time_axis_for_auc = time_points_relative_to_reward 

    for session in renumbered_test_sessions:
        session_phase = session.get('phase')
        session_folder = session.get('session_folder', f'UnknownSession_{sessions_processed_count}') # Get session ID
        trial_results = session.get('trial_results')
        if not session_phase or session_phase not in PHASE_ORDER: continue
        if not trial_results or not isinstance(trial_results, list): continue
        sessions_processed_count += 1

        for trial_dict in trial_results:
            if not isinstance(trial_dict, dict) or not trial_dict.get('rewarded', False): continue
            numerical_reward = trial_dict.get('reward_size')
            reward_latency = trial_dict.get('reward_latency')
            reward_category_str = NUMERICAL_SIZE_TO_CATEGORY_STRING_MAP.get(numerical_reward)
            if not reward_category_str or reward_category_str not in REWARD_CATEGORY_STRINGS: continue
            if reward_latency is None: continue

            try:
                trial_trace = calculate_single_trial_ts(trial_dict, reward_latency)
                if trial_trace is not None and not np.all(np.isnan(trial_trace)):
                    trials_processed_count += 1
                    for window_name, (start_t, end_t) in AUC_WINDOWS.items():
                        trial_auc = calculate_auc_single_trace(time_axis_for_auc, trial_trace, start_t, end_t)
                        if not np.isnan(trial_auc):
                             condition_window_key = (session_phase, reward_category_str, window_name)
                             # Store tuple of (session_id, auc_value)
                             trial_auc_values_by_condition_window[condition_window_key].append((session_folder, trial_auc))
            except Exception: pass 

    print(f"Processed {sessions_processed_count} sessions, calculated AUCs for {trials_processed_count} valid trial traces.")

    # Calculate Mean, SEM, N_trials, N_sessions
    print("Calculating statistics (Mean AUC, SEM AUC, N Trials, N Sessions)...")
    for (phase, reward_cat, window_name), auc_tuples_list in trial_auc_values_by_condition_window.items():
        n_total_entries = len(auc_tuples_list) # Should match trials if no NaNs stored
        auc_list = [item[1] for item in auc_tuples_list] # Extract just AUC values
        session_ids = [item[0] for item in auc_tuples_list] # Extract session IDs
        
        n_trials = len(auc_list) # Number of trials with valid AUC
        n_sessions = len(set(session_ids)) # Number of unique sessions contributing

        mean_auc_val = np.nan; sem_auc_val = np.nan
        if n_trials > 0:
            mean_auc_val = np.mean(auc_list)
            if n_trials > 1: sem_auc_val = sem(auc_list, nan_policy='omit') 
            else: sem_auc_val = 0
        
        auc_stats_results.append({
            "phase": phase, "reward_category": reward_cat, "window": window_name,
            "n_sessions": n_sessions, "n_trials": n_trials, 
            "mean_auc": mean_auc_val, "sem_auc": np.nan_to_num(sem_auc_val, nan=0.0)})
    
    end_auc_calc_time = time.time()
    print(f"AUC Statistics calculation finished in {end_auc_calc_time - start_auc_calc_time:.2f} seconds.")

# Convert results to DataFrame
auc_df = pd.DataFrame(auc_stats_results)

# --- 4. Print Summary of Session/Trial Counts ---
if not auc_df.empty:
    print("\n--- Summary of Data Points per Condition ---")
    # Group by window, then phase, then reward category for a structured printout
    summary_grouped = auc_df.groupby(['window', 'phase', 'reward_category'])[['n_sessions', 'n_trials']].first().unstack() 
    # Reorder index and columns for readability
    summary_grouped = summary_grouped.reindex(index=pd.MultiIndex.from_product([AUC_WINDOWS.keys(), PHASE_ORDER]), 
                                              columns=pd.MultiIndex.from_product([['n_sessions', 'n_trials'], REWARD_CATEGORY_STRINGS]))
    
    # Print the summary table (might be wide)
    with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', 1000):
         print(summary_grouped.fillna(0).astype(int)) # Fill missing combos with 0 counts
    print("--------------------------------------------")

# --- 5. Visualize AUC Results with SEM ---
if not auc_df.empty:
    print("\nGenerating AUC Bar Plots with SEM...")
    
    # Define colors if not globally available
    if 'PHASE_COLORS' not in locals(): PHASE_COLORS = {p: plt.cm.tab10(i) for i, p in enumerate(PHASE_ORDER)}
    if 'REWARD_COLORS' not in locals(): REWARD_COLORS = {rc: plt.cm.viridis(i/len(REWARD_CATEGORY_STRINGS)) for i, rc in enumerate(REWARD_CATEGORY_STRINGS)}

    # Define Mapping for X-axis Labels
    PHASE_DISPLAY_NAME_MAP = {
        "Initial Test": "Pre Stress", "Stress Test (Pre-PL37)": "Stress",
        "PL37 Test": "PL37", "Post-PL37 Test": "Post Stress" }
    phase_display_labels = [PHASE_DISPLAY_NAME_MAP.get(p, p) for p in PHASE_ORDER] 

    num_phases = len(PHASE_ORDER); num_rewards = len(REWARD_CATEGORY_STRINGS)
    bar_width = 0.8 / num_rewards 
    
    for window_name, window_times in AUC_WINDOWS.items():
        print(f"  Plotting for window: {window_name}")
        window_df = auc_df[auc_df['window'] == window_name].copy()
        
        try:
            pivot_mean = window_df.pivot_table(index='phase', columns='reward_category', values='mean_auc')
            pivot_sem = window_df.pivot_table(index='phase', columns='reward_category', values='sem_auc')
            pivot_mean = pivot_mean.reindex(index=PHASE_ORDER, columns=REWARD_CATEGORY_STRINGS)
            pivot_sem = pivot_sem.reindex(index=PHASE_ORDER, columns=REWARD_CATEGORY_STRINGS)
        except Exception as e:
            print(f"    Error pivoting data for window {window_name}: {e}. Skipping plot."); continue

        if pivot_mean.isnull().all().all():
             print(f"    No valid AUC data to plot for window {window_name}. Skipping plot."); continue

        fig, ax = plt.subplots(figsize=(max(8, 2 * num_phases), 6)) 
        x_indices = np.arange(num_phases) 

        for i, reward_cat in enumerate(REWARD_CATEGORY_STRINGS):
            offset = (i - (num_rewards - 1) / 2) * bar_width
            mean_values = pivot_mean[reward_cat].values 
            sem_values = pivot_sem[reward_cat].fillna(0).values 
            ax.bar(x_indices + offset, mean_values, bar_width, label=reward_cat, 
                   color=REWARD_COLORS.get(reward_cat, 'grey'), yerr=sem_values, capsize=3, 
                   error_kw={'elinewidth':1, 'capthick':1})

        ax.set_ylabel(f"Total Lick Activity (AUC)")
        ax.set_title(f"{animal_id} - Lick Activity during {window_name} ({window_times[0]}s to {window_times[1]}s)", fontsize=12) 
        ax.set_xticks(x_indices)
        ax.set_xticklabels(phase_display_labels, rotation=30, ha='right', fontsize=9)  
        ax.legend(title="Reward Category", bbox_to_anchor=(1.02, 1), loc='upper left', fontsize='small') 
        ax.grid(True, axis='y', linestyle='--', alpha=0.6)
        ax.axhline(0, color='black', linewidth=0.7)
        ax.tick_params(axis='y', labelsize=9) 
        
        plt.tight_layout(rect=[0, 0.05, 0.88, 0.95]) 

        auc_plot_filename = os.path.join(figures_dir_absolute, f"{animal_id}_AUC_{window_name.replace(' ', '_')}_withSEM.png")
        try:
             plt.savefig(auc_plot_filename, dpi=150)
             print(f"  Saved AUC plot: {auc_plot_filename}")
             display(Image(filename=auc_plot_filename))
        except Exception as e_save: print(f"  Error saving AUC plot for {window_name}: {e_save}")
        plt.close(fig) 
else:
    print("Skipping summary printout and plotting as prerequisites were not met or AUC DataFrame is empty.")

print("\n--- AUC Analysis and Visualization Finished ---")