In [1]:
import os
import pickle
import numpy as np
import pandas as pd
from pathlib import Path
from tqdm import tqdm
from allensdk.brain_observatory.ecephys.ecephys_project_cache import EcephysProjectCache # Still needed for session data
import sys
import io
from contextlib import contextmanager
 
# --- deal with verbos output ---
@contextmanager
def suppress_output_for_tqdm():
    original_stdout = sys.stdout
    original_stderr = sys.stderr
    try:
        # Redirect both stdout and stderr to dummy streams
        sys.stdout = io.StringIO()
        sys.stderr = io.StringIO()
        yield
    finally:
        sys.stdout = original_stdout
        sys.stderr = original_stderr


        
# --- Configuration ---
BASE_CACHE_DIR = Path("/home/pinky/PSTH_VisualData/")
MANIFEST_PATH = str(BASE_CACHE_DIR / "manifest.json") # For EcephysProjectCache init
SAVED_SESSIONS_TABLE_PATH = BASE_CACHE_DIR / "ecephys_sessions_table.csv" # Match saved format
# SAVED_SESSIONS_TABLE_PATH = BASE_CACHE_DIR / "ecephys_sessions_table.pkl" # If saved as pickle

OUTPUT_DIR = Path("./visual_coding_psth_output")
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

VISUAL_AREAS = ['VISal', 'VISam', 'VISl', 'VISp', 'VISpm', 'VISrl']
TARGET_STIMULUS_NAME = 'drifting_gratings'
PRE_TIME = 0.1
POST_TIME = 0.5
BIN_SIZE = 0.01
time_bins = np.arange(-PRE_TIME, POST_TIME + BIN_SIZE, BIN_SIZE)
time_bin_centers = time_bins[:-1] + BIN_SIZE / 2

# --- Load Session Table ---
if SAVED_SESSIONS_TABLE_PATH.exists():
    print(f"Loading pre-saved session table from: {SAVED_SESSIONS_TABLE_PATH}")
    if SAVED_SESSIONS_TABLE_PATH.suffix == '.csv':
        sessions_df = pd.read_csv(SAVED_SESSIONS_TABLE_PATH, index_col=0) # Assuming first col is index
    elif SAVED_SESSIONS_TABLE_PATH.suffix == '.pkl':
        sessions_df = pd.read_pickle(SAVED_SESSIONS_TABLE_PATH)
    else:
        raise ValueError("Unsupported saved sessions table format.")
    all_session_ids = sessions_df.index.tolist()
    print(f"Loaded {len(all_session_ids)} session IDs from saved file.")
else:
    print(f"Saved session table not found at {SAVED_SESSIONS_TABLE_PATH}.")
    print("Please run the script/cell to generate and save it first, or fall back to live fetching.")
    # Fallback to live fetching (will be slow)
    print("Initializing EcephysProjectCache to fetch session table live...")
    try:
        cache_for_table = EcephysProjectCache.from_warehouse(manifest=MANIFEST_PATH) # Temporary cache instance
        sessions_df = cache_for_table.get_session_table()
        all_session_ids = sessions_df.index.tolist()
        print(f"Fetched {len(all_session_ids)} session IDs live from AllenSDK.")
        # Optionally save it now if fetched live
        # sessions_df.to_csv(SAVED_SESSIONS_TABLE_PATH)
    except Exception as e:
        print(f"Error fetching session table live: {e}")
        exit()


# --- Initialize AllenSDK Cache (still needed for get_session_data) ---
# This initialization should be faster now if manifest.json is already downloaded
# by the script in Step 1 or a previous run.
try:
    cache = EcephysProjectCache.from_warehouse(manifest=MANIFEST_PATH)
    print("Main EcephysProjectCache initialized successfully for fetching session data.")
except Exception as e:
    print(f"Error initializing main cache: {e}")
    exit()

# --- Main Processing Loop (using all_session_ids from loaded/fetched table) ---
if not all_session_ids:
    print("No session IDs to process. Exiting.")
    exit()

print(f"Found {len(all_session_ids)} total Ecephys sessions to process.")


  from .autonotebook import tqdm as notebook_tqdm


Loading pre-saved session table from: /home/pinky/PSTH_VisualData/ecephys_sessions_table.csv
Loaded 58 session IDs from saved file.
Main EcephysProjectCache initialized successfully for fetching session data.
Found 58 total Ecephys sessions to process.


In [2]:
# --- Get all experiment session IDs ---
sessions_df = cache.get_session_table()
all_session_ids = sessions_df.index.tolist()
print(f"Found {len(all_session_ids)} total Ecephys sessions.")

Found 58 total Ecephys sessions.


In [4]:

# --- Main Processing Loop ---
for session_id in tqdm(all_session_ids, desc="Processing Sessions"): # Your main tqdm progress bar
    # This print will go to the original stdout before suppression
    print(f"\nProcessing session: {session_id}")
    output_filepath = OUTPUT_DIR / f"session_{session_id}_psth_data.pkl"

    if output_filepath.exists():
        # This print will go to the original stdout
        print(f"Data for session {session_id} already processed. Skipping.")
        continue

    try:
        # This print will go to the original stdout
        print(f"Attempting to load/download data for session {session_id} (internal progress suppressed)...")
        
        # --- Suppress AllenSDK's internal tqdm here (now suppressing both stdout and stderr) ---
        with suppress_output_for_tqdm():
            session = cache.get_session_data(session_id,
                                             isi_violations_maximum = np.inf,
                                             amplitude_cutoff_maximum = np.inf,
                                             presence_ratio_minimum = -np.inf
                                            )
        # --- End suppression ---
        
        # This print will go to the original stdout
        print(f"Session {session_id} loaded successfully.")

    except Exception as e:
        # This print will go to the original stdout (or stderr if the exception handling prints there)
        print(f"Error loading session {session_id}: {e}")
        # If an error occurs, stdout/stderr are restored by the context manager's finally block
        continue


    # 1. Filter units by visual areas
    # Filter for "good" units within the specified visual areas
    visual_units_df = units_df[
        (units_df['ecephys_structure_acronym'].isin(VISUAL_AREAS)) & # Condition 1: In visual areas
        (units_df['quality'] == 'good')                             # Condition 2: Unit quality is 'good'
    ]
    
    if visual_units_df.empty:
        print(f"No 'good' units found in specified visual areas for session {session_id}. Skipping.")
        # continue # Or handle as appropriate
    else:
        visual_unit_ids = visual_units_df.index.tolist()
        print(f"Found {len(visual_unit_ids)} 'good' units in visual areas.")

    # 2. Filter stimulus presentations for stationary gratings
    stimulus_presentations_df = session.stimulus_presentations
    
    # Filter for the 'drifting_gratings' stimulus type
    gratings_stim_df = stimulus_presentations_df[
        stimulus_presentations_df['stimulus_name'] == TARGET_STIMULUS_NAME
    ]

    if gratings_stim_df.empty:
        print(f"No '{TARGET_STIMULUS_NAME}' stimuli found for session {session_id}. Skipping.")
        continue

    # Filter for stationary gratings (temporal frequency is 0)
    # Also, ensure relevant parameters are not NaN, as this can happen for blank sweeps
    # or other variations sometimes included in 'drifting_gratings'
    stationary_gratings_df = gratings_stim_df[
        (gratings_stim_df['temporal_frequency'] == 0.0) &
        gratings_stim_df['orientation'].notna() &
        gratings_stim_df['spatial_frequency'].notna() &
        gratings_stim_df['phase'].notna()
    ].copy() # Use .copy() to avoid SettingWithCopyWarning

    if stationary_gratings_df.empty:
        print(f"No stationary gratings (TF=0 with valid parameters) found for session {session_id}. Skipping.")
        continue

    # 3. Create a unique stimulus information matrix
    # These are the parameters that define a unique stationary grating stimulus
    stimulus_params = ['orientation', 'spatial_frequency', 'phase']
    unique_stim_conditions = stationary_gratings_df[stimulus_params].drop_duplicates().sort_values(by=stimulus_params).reset_index(drop=True)
    unique_stim_conditions['stimulus_condition_id'] = unique_stim_conditions.index # Add an ID

    if unique_stim_conditions.empty:
        print(f"No unique stationary grating conditions found for session {session_id}. Skipping.")
        continue
        
    print(f"Found {len(unique_stim_conditions)} unique stationary grating conditions.")

    # 4. Calculate PSTH for each visual unit and each unique stimulus condition
    # Dimensions: (num_visual_units, num_unique_stim_conditions, num_time_bins)
    all_psth_data = np.zeros((len(visual_unit_ids), len(unique_stim_conditions), len(time_bin_centers)))
    
    spike_times_dict = session.spike_times # More efficient to get all once

    for i, unit_id in enumerate(tqdm(visual_unit_ids, desc="Calculating PSTHs", leave=False)):
        unit_spike_times = spike_times_dict.get(unit_id, np.array([]))
        if unit_spike_times.size == 0:
            continue # Skip if unit has no spikes

        for j, stim_condition_row in unique_stim_conditions.iterrows():
            condition_id = stim_condition_row['stimulus_condition_id']
            
            # Find all presentations of this specific stimulus condition
            # Need to handle potential floating point precision issues for 'phase' and 'spatial_frequency'
            # by comparing with a tolerance or by ensuring they were exactly the same in the table.
            # Here, direct equality should work as they come from the same table.
            presentations_for_condition = stationary_gratings_df[
                (stationary_gratings_df['orientation'] == stim_condition_row['orientation']) &
                (np.isclose(stationary_gratings_df['spatial_frequency'], stim_condition_row['spatial_frequency'])) &
                (np.isclose(stationary_gratings_df['phase'], stim_condition_row['phase']))
            ]
            
            stim_start_times = presentations_for_condition['start_time'].values
            num_trials = len(stim_start_times)

            if num_trials == 0:
                continue # Should not happen if logic is correct, but good check

            # Aggregate histograms for all trials of this condition
            summed_histogram = np.zeros(len(time_bin_centers))
            
            for start_time in stim_start_times:
                # Align spike times to this stimulus presentation's start time
                aligned_spike_times = unit_spike_times - start_time
                
                # Select spikes within the PSTH window
                spikes_in_window = aligned_spike_times[
                    (aligned_spike_times >= -PRE_TIME) &
                    (aligned_spike_times < POST_TIME) # Use < for the right edge of the last bin
                ]
                
                # Create histogram for this trial
                hist, _ = np.histogram(spikes_in_window, bins=time_bins)
                summed_histogram += hist
            
            # Calculate mean firing rate (spikes/sec) for each bin
            if num_trials > 0:
                mean_firing_rate_psth = summed_histogram / (num_trials * BIN_SIZE)
                all_psth_data[i, condition_id, :] = mean_firing_rate_psth

    # 5. Save the data
    data_to_save = {
        'session_id': session_id,
        'visual_unit_ids': visual_unit_ids, # List of unit IDs corresponding to 1st dim of psth_data
        'stimulus_info': unique_stim_conditions, # DataFrame, index maps to 2nd dim of psth_data
        'psth_data': all_psth_data, # (n_units, n_stim_conditions, n_time_bins)
        'psth_time_bin_centers': time_bin_centers, # Time points for 3rd dim of psth_data
        'psth_configs': {'pre_time': PRE_TIME, 'post_time': POST_TIME, 'bin_size': BIN_SIZE}
    }

    with open(output_filepath, 'wb') as f:
        pickle.dump(data_to_save, f)
    
    print(f"Saved PSTH data for session {session_id} to {output_filepath}")

 # 5. *** MODIFIED: Prepare data and save as .mat file ***
    
    # Convert pandas DataFrame 'unique_stim_conditions_df' to a dict of arrays for MATLAB struct
    stimulus_info_for_matlab = {}
    for col in unique_stim_conditions_df.columns:
        # Ensure data is in a basic type, like numpy array
        # For object columns (like strings if 'orientation' was 'horizontal'), convert to cell array like structure
        if unique_stim_conditions_df[col].dtype == 'object':
             # For string arrays, convert to an object array of strings, which savemat handles as cell arrays of strings
            stimulus_info_for_matlab[col] = np.array(unique_stim_conditions_df[col].tolist(), dtype=object)
        else:
            stimulus_info_for_matlab[col] = unique_stim_conditions_df[col].to_numpy()

    data_to_save_mat = {
        'session_id': str(session_id), # Ensure session_id is a string
        'visual_unit_ids': np.array(visual_unit_ids, dtype=np.int64), # Explicitly make it a NumPy array
        'visual_unit_areas': np.array(visual_unit_areas_list, dtype=object), # Object array for strings -> MATLAB cell array
        'stimulus_info': stimulus_info_for_matlab, # This will become a MATLAB struct
        'psth_data': all_psth_data, # NumPy array is fine
        'psth_time_bin_centers': time_bin_centers, # NumPy array is fine
        'psth_configs': { # This will become a MATLAB struct
            'pre_time': PRE_TIME,
            'post_time': POST_TIME,
            'bin_size': BIN_SIZE
        }
    }

    try:
        savemat(str(output_filepath_mat), data_to_save_mat, do_compression=True)
        print(f"Saved MAT data for session {session_id} to {output_filepath_mat}")
    except Exception as e:
        print(f"Error saving .mat file for session {session_id}: {e}")


print("\n--- All processing finished ---")

# --- Example of how to load and use the saved data ---
# (You would typically do this in a separate script)
#
# output_files = list(OUTPUT_DIR.glob("*.pkl"))
# if output_files:
#     first_file = output_files[0]
#     print(f"\n--- Example: Loading data from {first_file} ---")
#     with open(first_file, 'rb') as f:
#         loaded_data = pickle.load(f)
    
#     print(f"Session ID: {loaded_data['session_id']}")
#     print(f"Number of visual units: {len(loaded_data['visual_unit_ids'])}")
#     print(f"Shape of PSTH data: {loaded_data['psth_data'].shape}") # (units, stim_conditions, time_bins)
#     print("Stimulus Information DataFrame (first 5 rows):")
#     print(loaded_data['stimulus_info'].head())
#     print("PSTH time bin centers (first 5):")
#     print(loaded_data['psth_time_bin_centers'][:5])

#     # Example: Plot PSTH for the first unit and first stimulus condition
#     # import matplotlib.pyplot as plt
#     # if loaded_data['psth_data'].size > 0 : # Check if there's actual data
#     #     plt.figure()
#     #     plt.plot(loaded_data['psth_time_bin_centers'], loaded_data['psth_data'][0, 0, :])
#     #     plt.xlabel("Time from stimulus onset (s)")
#     #     plt.ylabel("Firing rate (spikes/s)")
#     #     plt.title(f"PSTH: Unit {loaded_data['visual_unit_ids'][0]}, Stimulus Cond. 0")
#     #     # Get stimulus parameters for title
#     #     stim_params_for_plot = loaded_data['stimulus_info'].iloc[0]
#     #     plt.suptitle(f"Ori: {stim_params_for_plot['orientation']}, SF: {stim_params_for_plot['spatial_frequency']:.2f}, Phase: {stim_params_for_plot['phase']:.2f}")
#     #     plt.show()

Processing Sessions:   0%|                                                     | 0/58 [00:00<?, ?it/s]


Processing session: 715093703
Attempting to load/download data for session 715093703 (internal progress suppressed)...


Processing Sessions:   2%|▋                                       | 1/58 [40:01<38:01:06, 2401.17s/it]

Error loading session 715093703: Download took 1200.056250861846 seconds, but timeout was set to 1200

Processing session: 719161530
Attempting to load/download data for session 719161530 (internal progress suppressed)...


Processing Sessions:   3%|█▎                                    | 2/58 [1:20:02<37:21:10, 2401.25s/it]

Error loading session 719161530: Download took 1200.1675209160894 seconds, but timeout was set to 1200

Processing session: 721123822
Attempting to load/download data for session 721123822 (internal progress suppressed)...


Processing Sessions:   5%|█▉                                    | 3/58 [2:00:03<36:41:13, 2401.34s/it]

Error loading session 721123822: Download took 1200.3066675253212 seconds, but timeout was set to 1200

Processing session: 732592105
Attempting to load/download data for session 732592105 (internal progress suppressed)...


Processing Sessions:   7%|██▌                                   | 4/58 [2:40:05<36:01:27, 2401.63s/it]

Error loading session 732592105: Download took 1200.0301601099782 seconds, but timeout was set to 1200

Processing session: 737581020
Attempting to load/download data for session 737581020 (internal progress suppressed)...


Processing Sessions:   9%|███▎                                  | 5/58 [3:20:07<35:21:17, 2401.47s/it]

Error loading session 737581020: Download took 1200.0183887421153 seconds, but timeout was set to 1200

Processing session: 739448407
Attempting to load/download data for session 739448407 (internal progress suppressed)...


Processing Sessions:  10%|███▉                                  | 6/58 [4:00:08<34:41:10, 2401.36s/it]

Error loading session 739448407: Download took 1200.0037562600337 seconds, but timeout was set to 1200

Processing session: 742951821
Attempting to load/download data for session 742951821 (internal progress suppressed)...


Processing Sessions:  12%|████▌                                 | 7/58 [4:40:09<34:01:06, 2401.31s/it]

Error loading session 742951821: Download took 1200.0047698239796 seconds, but timeout was set to 1200

Processing session: 743475441
Attempting to load/download data for session 743475441 (internal progress suppressed)...


Processing Sessions:  14%|█████▏                                | 8/58 [5:20:10<33:21:04, 2401.30s/it]

Error loading session 743475441: Download took 1200.0169138042256 seconds, but timeout was set to 1200

Processing session: 744228101
Attempting to load/download data for session 744228101 (internal progress suppressed)...


Processing Sessions:  16%|█████▉                                | 9/58 [6:00:12<32:41:02, 2401.28s/it]

Error loading session 744228101: Download took 1200.0034770970233 seconds, but timeout was set to 1200

Processing session: 746083955
Attempting to load/download data for session 746083955 (internal progress suppressed)...


Processing Sessions:  17%|██████▍                              | 10/58 [6:40:13<32:01:00, 2401.25s/it]

Error loading session 746083955: Download took 1200.0053078946657 seconds, but timeout was set to 1200

Processing session: 750332458
Attempting to load/download data for session 750332458 (internal progress suppressed)...


Processing Sessions:  17%|██████▍                              | 10/58 [6:57:29<33:23:55, 2504.91s/it]

Session 750332458 loaded successfully.





NameError: name 'units_df' is not defined