# Importing libraries

In [1]:
from wfield import *
import tifffile
import matplotlib.pyplot as plt
import cv2
import os
import numpy as np
import re
import shutil
import pandas as pd
import h5py
import matplotlib.gridspec as gridspec


# Defining functions

In [2]:
def list_scans(data_folder, keyword):
    # Find folders containing the keyword
    scan_folders = [folder for folder in os.listdir(data_folder) if os.path.isdir(os.path.join(data_folder, folder)) and keyword in folder]

    # Print list of found folders
    print(f"Folders containing '{keyword}' keyword:")
    for i, folder in enumerate(scan_folders):
        print(f"{i + 1}. {folder}")

    # Prompt user to choose a folder
    while True:
        choice = input("Enter the number of the scan you want to choose: ")
        if choice.isdigit() and 1 <= int(choice) <= len(scan_folders):
            chosen_folder = scan_folders[int(choice) - 1]
            break
        else:
            print("Invalid input. Please enter a valid number.")
    print("Selected",chosen_folder)
    # Return the path to the chosen folder
    return os.path.join(data_folder, chosen_folder), chosen_folder


In [3]:
def find_bins(localdisk):
    # Find all .tif files in the directory
    tif_files = [f for f in os.listdir(localdisk) if f.endswith('.tif')]
    
    # Find all .dat files in the directory
    dat_files = glob(pjoin(localdisk,'*.dat'))
    
    if len(dat_files) > 1:
        print("Multiple .dat files found:")
        for i, dat_file in enumerate(dat_files):
            print(f"{i + 1}. {os.path.basename(dat_file)}")
        
        while True:
            choice = input("Enter the number of the .dat file you want to choose: ")
            if choice.isdigit() and 1 <= int(choice) <= len(dat_files):
                dat_path = dat_files[int(choice) - 1]
                break
            else:
                print("Invalid input. Please enter a valid number.")
    
    elif len(dat_files) == 1:
        dat_path = dat_files[0]
    else:
        print("Binaries file not found")
        print("-----------------------")
        print("Concatenating TIF files...")
        print("-----------------------")
        tif_data_list = []
        for tif_file in sorted(tif_files):
            tif_file_path = os.path.join(localdisk, tif_file)
            tif_data = tifffile.imread(tif_file_path);
            tif_data_list.append(tif_data)
        concatenated_data = np.concatenate(tif_data_list, axis=0)
        print(concatenated_data.shape)
        # Splitting into violet and green channels
        violet_channel = concatenated_data[::2]  # Every second frame starting from the first
        green_channel = concatenated_data[1::2]  # Every second frame starting from the second

        # Ensure both channels have the same length (in case the number of frames is odd)
        if violet_channel.shape[0] > green_channel.shape[0]:
            violet_channel = violet_channel[:-1]
        h, w = green_channel.shape[1:]
        merged_data = np.stack((violet_channel, green_channel), axis=1)
        n_frames = merged_data.shape[0]
        filename_parts = re.split(r'_|\.', tif_files[0])
        scan_info = '_'.join(filename_parts[0:4])  # Joining parts 1 to 4 with underscores
        frame_shape = f"{h}_{w}"
        filename = f"{scan_info}_{frame_shape}_2_uint16.dat"
        print(f"Saving file in {filename}")
        save_path = os.path.join(localdisk, filename)
        merged_data.astype(np.uint16).tofile(save_path)
        del merged_data
        dat_path = dat_files[0]
    
    print("Binaries file found")
    print("Loading...")
    dat = mmap_dat(dat_path)
    n_frames, h, w, n_channels = dat.shape
    
    dat = mmap_dat(dat_path, mode='r+', nframes=n_frames, shape=(2, h, w))
    print(f'Loaded {n_frames} frames with {h} x {w} (height x width)')
    print("Selected", os.path.basename(dat_path))
    return dat, n_frames, h, w


In [4]:
def save_video_from_array(array, filename_prefix, folder, num_channels):
    n_frames, *shape = array.shape
    height, width = shape[-2:]

    # Create folder if it doesn't exist
    os.makedirs(folder, exist_ok=True)

    if len(shape) == 2:  # Single channel
        channels_to_save = 1
    elif len(shape) == 3:  # Multiple channels
        channels_to_save = min(num_channels, shape[0])
    else:
        raise ValueError("Invalid array shape.")

    # Iterate over each channel
    for channel_idx in range(channels_to_save):
        # Create VideoWriter object
        filename = os.path.join(folder, f"{filename_prefix}_channel{channel_idx}.mp4")
        fourcc = cv2.VideoWriter_fourcc(*'mp4v')  # Change codec to MP4V for MP4 format
        out = cv2.VideoWriter(filename, fourcc, 25.0, (width, height))

        # Iterate over each frame
        for frame_idx in range(n_frames):
            # Extract channel if multiple channels, else use array directly
            channel_array = array[frame_idx, channel_idx] if len(shape) == 3 else array[frame_idx]

            # Normalize pixel values to 0-255 range
            channel_array = (channel_array - np.min(channel_array)) / (np.max(channel_array) - np.min(channel_array)) * 255

            # Convert to uint8
            frame = channel_array.astype(np.uint8)

            # Apply colormap if needed (e.g., for grayscale images)
            frame = cv2.applyColorMap(frame, cv2.COLORMAP_VIRIDIS)

            # Write frame to video file
            out.write(frame)

        # Release VideoWriter object
        out.release()

    print("Videos saved successfully.")

In [5]:



def dFF_z(fluorescence_data, base_start, base_end):
    """
    Calculate dF/F₀ and z-scores for fluorescence data.

    Parameters:
    - fluorescence_data (array-like): Fluorescence data.
    - base_start (int): Start index of the baseline period.
    - base_end (int): End index of the baseline period.

    Returns:
    - dF_over_F0 (DataFrame): DataFrame containing the dF/F₀ values.
    - z_scores (DataFrame): DataFrame containing the z-scores.
    """
    # Calculate F₀ (baseline fluorescence)
    if base_end:
        F0 = np.mean(fluorescence_data[base_start:base_end])
        # print(f"Using as baseline frames: [{base_start}, {base_end}]")
    else:
        F0 = np.mean(fluorescence_data)
        # print("Using as baseline scan average")
    
    # Calculate dF (change in fluorescence)
    dF = fluorescence_data - F0
    
    # Calculate dF/F₀
    dF_over_F0 = dF / F0
    
    # Normalize dF/F₀ to 0-1 range
    dF_over_F0_min = np.min(dF_over_F0)
    dF_over_F0_max = np.max(dF_over_F0)
    dF_over_F0_normalized = (dF_over_F0 - dF_over_F0_min) / (dF_over_F0_max - dF_over_F0_min)
    
    # Convert normalized dF/F₀ to percentage
    dF_over_F0_percent = dF_over_F0_normalized * 100
    
    # Calculate z-scores for the normalized dF/F₀ values
    z_scores = (dF_over_F0 - dF_over_F0.mean()) / np.std(dF_over_F0)
    # z_scores = (dF_over_F0_normalized - np.mean(dF_over_F0_normalized)) / np.std(dF_over_F0_normalized)
    
    
    return dF_over_F0, z_scores

In [6]:
def list_bpod(localdisk):
    import scipy.io
    mat_files = glob(pjoin(localdisk,'*.mat'))
    if len(mat_files) > 1:
        print("Multiple Bpod files found:")
        for i, mat_file in enumerate(mat_files):
            print(f"{i + 1}. {os.path.basename(mat_file)}")
        
        while True:
            choice = input("Enter the number of the .dat file you want to choose: ")
            if choice.isdigit() and 1 <= int(choice) <= len(mat_files):
                mat_path = mat_files[int(choice) - 1]
                break
            else:
                print("Invalid input. Please enter a valid number.")
    elif len(mat_files) == 1:
        mat_path = mat_files[0]
        print("Bpod file found")
        print("Loading...")
    else:
        print("Bpod file not found")
        print("-----------------------")
    
    try:
        bpod_data = scipy.io.loadmat(mat_path)
        print("Selected", os.path.basename(mat_path))
    except:
        print("Failed")
    return bpod_data


# Loading data

In [7]:
data_folder = r'/datajoint-data/data/aeltona/'
# tif_file_path = pjoin(data_folder, 'scan9FN2ANVG_Oddball_AA_ROS-1706_2025_MMStack_Default.ome.tif')
# localdisk = r'C:\datatemp'
localdisk, scan_idx = list_scans(data_folder,"AA")



Folders containing 'AA' keyword:
1. Habituation_AA_WEZ-8950_2024-04-16_scan9FNN1Y64_sess9FNN0LFP
2. Habituation_AA_WEZ-8948_2024-04-17_scan9FNNP1N7_sess9FNNO3Z1
3. Oddball_AA_ROS-1706_2024-03-12_scan9FN2BCOS_sess9FN2ANVG
4. AA_ROS-1688_2024_01_27_scan000WQU9_sess000EAEIO
5. Habituation_AA_WEZ-8950_2024-04-18_scan9FNO8CLT_sess9FNO8CLT
6. Habituation_AA_WEZ-8950_2024-04-17_scan9FNNO3Z1_sess9FNNO3Z1
7. AA_ROS-1706_2024-03-12_scan9FN2ANVG_sess9FN2ANVG
8. Habituation_AA_WEZ-8950_2024-04-16_scan9FNN1YXK_sess9FNN1YXK
9. Habituation_AA_WEZ-8948_2024-04-18_scan9FNO99ZE_sess9FNO8CLT
10. AA_ROS-1688_2024_01_27_scan000EAEIO_sess000EAEIO
11. Habituation_AA_WEZ-8950_2024-04-16_scan9FNN2FX7_sess9FNN1YXK
12. Habituation_AA_WEZ-8948_2024-04-16_scan9FNN1M1R_sess9FNN0LFP


Enter the number of the scan you want to choose:  10


Selected AA_ROS-1688_2024_01_27_scan000EAEIO_sess000EAEIO


In [21]:
dat,n_frames, h, w = find_bins(localdisk)

Multiple .dat files found:
1. scan000EAEIO_AA_ROS-1688_2024_647_730_2_uint16.dat
2. motioncorrected_scan000EAEIO_AA_ROS-1688_2024_647_730_2_uint16.dat


Enter the number of the .dat file you want to choose:  2


Binaries file found
Loading...
Loaded 3047 frames with 647 x 730 (height x width)
Selected motioncorrected_scan000EAEIO_AA_ROS-1688_2024_647_730_2_uint16.dat


In [None]:
    # save_video_from_array(violet_channel, f"{scan_idx}_violet_uncorrected",localdisk,1)
# save_video_from_array(green_channel, f"{scan_idx}_green_uncorrected",localdisk,1)

# Motion Correction

In [None]:
# out = pjoin(localdisk,'motion_corrected.dat')
out = np.empty_like(dat)

(yshifts,xshifts),rot = motion_correct(dat, out = out, chunksize=512,
                                     apply_shifts=True)
del dat # close and finish writing


In [None]:
tif_files = [f for f in os.listdir(localdisk) if f.endswith('.tif')]

filename_parts = re.split(r'_|\.', tif_files[0])
scan_info = '_'.join(filename_parts[0:4])  # Joining parts 1 to 4 with underscores
frame_shape = f"{h}_{w}"
filename = f"motioncorrected_{scan_info}_{frame_shape}_2_uint16.dat"
print(f"Saving file in {filename}")
save_path = os.path.join(localdisk, filename)
out.astype(np.uint16).tofile(save_path)

In [None]:

# save the shifts
shifts = np.rec.array([yshifts,xshifts],dtype=[('y','float32'),('x','float32')])
np.save(pjoin(localdisk,'motion_correction_shifts.npy'),shifts)
# np.save(pjoin(data_folder,'motion_correction_shifts.npy'),shifts)

In [None]:
import pylab as plt
plt.matplotlib.style.use('ggplot')

%matplotlib inline

# localdisk = '/mnt/dual/temp_folder/CSP23_20200226' # this should be an SSD or a fast drive

shifts = np.load(pjoin(localdisk,'motion_correction_shifts.npy'))

plot_summary_motion_correction(shifts,localdisk);

# Trial Onsets

### Reading Aux file

In [None]:
data_folder = r'/datajoint-data/data/aeltona/'
# tif_file_path = pjoin(data_folder, 'scan9FN2ANVG_Oddball_AA_ROS-1706_2025_MMStack_Default.ome.tif')
# localdisk = r'C:\datatemp'

localdisk, scan_idx = list_scans(data_folder,"AA")



In [14]:
h5_path = glob(pjoin(localdisk,'*.h5'))[0]
h5_path

'/datajoint-data/data/aeltona/AA_ROS-1688_2024_01_27_scan000EAEIO_sess000EAEIO/scan000EAEIO_AA_ROS-1688_2024_01_27_Preconditioning.h5'

In [15]:

# Open the HDF5 file
with h5py.File(h5_path, 'r') as f:
    # Access specific group
    # print(f.keys())
    sweep_data_key = list(f.keys())[1]
    sweep_data = f[sweep_data_key]
    header = f['header']
    # List all keys within the group
    # print(f"Keys in {sweep_data_key}: %s" % sweep_data.keys())
    # print("Keys in header: %s" % header.keys())
    
    
    AIChannelNames = header['AIChannelNames'][:]
    YLimitsPerAIChannel = header['YLimitsPerAIChannel'][:]
    AIChannelNames = [name.decode('utf-8') for name in AIChannelNames]
    SampleRate = header['AcquisitionSampleRate'][:]
    analogData = sweep_data['analogScans'][:]
    # # Check if 'data' exists within the group
    # if 'analogScans' in sweep_data:
    #     data = sweep_data['analogScans'][:]  # accessing array dataset
    #     print("Data from sweep_0001:", data)
    # else:
    #     print("Dataset 'data' does not exist in sweep_0001.")

In [16]:
# Convert analogData to DataFrame
df = pd.DataFrame(analogData.T, columns=AIChannelNames)
sampling_rate_hz = SampleRate[0,0]
df['time_seconds'] = df.index / sampling_rate_hz
# Print DataFrame
# print(df)

In [17]:
df.columns

Index(['camera_trigger', 'blue_470nm', 'violet_405nm', 'HiFi_module', 'Reward',
       'SoftCode', 'Lick Spout', 'Rotary Encoder', 'time_seconds'],
      dtype='object')

In [None]:
# plt.plot(df['Rotary Encoder']);

In [22]:
Oddball = False
SP = True
if Oddball:
    norm_stim_id = df['Stim_ID']/3000
    norm_stim_id = norm_stim_id.round(0)
    norm_stim_on = (df['Stim_ON'] - df['Stim_ON'].min()) / (df['Stim_ON'].max() - df['Stim_ON'].min())
    norm_stim_on = norm_stim_on.round(0)
    df['norm_stim_on'] = norm_stim_on
    df['norm_stim_id'] = norm_stim_id
elif SP:
    print('tbd')

norm_blue_on = (df['blue_470nm'] - df['blue_470nm'].min()) / (df['blue_470nm'].max() - df['blue_470nm'].min())
norm_blue_on = norm_blue_on.round(0)
df['norm_blue_on'] = norm_blue_on

blue_df = df[df['norm_blue_on'] == 1]

points_per_frame = len(blue_df) // n_frames
blue_df_reshaped = pd.DataFrame()
for col in blue_df.columns:
    reshaped_col = blue_df[col].values[:points_per_frame * n_frames].reshape(n_frames, -1)
    blue_df_reshaped[col] = reshaped_col[:, 0]
data_df = blue_df_reshaped
data_df.shape
# norm_blue_on.unique()

tbd


(3047, 10)

### Reading Bpod Data


In [23]:
bpod_data = list_bpod(localdisk)

Bpod file found
Loading...
Selected ROS-1688_SensoryPreconditioning_20240127_175031.mat


In [None]:
trial_onsets = bpod_data['SessionData'][0, 0]['TrialStartTimestamp'][0]
trial_types = bpod_data['SessionData'][0, 0]['TrialTypes'][0]
n_trials = bpod_data['SessionData'][0, 0]['nTrials'][0][0]

stim_1 = bpod_data['SessionData'][0, 0]['RawEvents']['Trial'][0,0][0,0]['States'][0,0]['Stimulus4On'][0,0][0,:]


reward =  bpod_data['SessionData'][0, 0]['RawEvents']['Trial'][0,0][0,0]['States'][0,0]['Reward'][0,0][0,:]
odd_trials = np.where(trial_types==2)


In [None]:
first_stim_bpod = stim_1[0]

# Find the first stimulus time in data_df
first_stim_df_index = data_df[data_df['norm_stim_on'] == 1].index[0]
first_stim_df_time = data_df.loc[first_stim_df_index, 'time_seconds']

# Calculate the time delay between data_df and Bpod
time_delay = first_stim_df_time - (first_stim_bpod + trial_onsets[0])

# Adjust the Bpod trial onsets by the time delay
adjusted_trial_onsets = trial_onsets + time_delay
# Create a new column in data_df for trial onsets, initialized to 0
data_df['trial_onsets'] = 0
# Mark the frames (rows) corresponding to the adjusted trial onsets
for onset in adjusted_trial_onsets:
    # Find the closest frame (row) in data_df corresponding to each adjusted trial onset
    closest_idx = (np.abs(data_df['time_seconds'] - onset)).idxmin()
    data_df.at[closest_idx, 'trial_onsets'] = 1

adjusted_odd_stim = []
for trial_num in odd_trials[0]:
    oddball = bpod_data['SessionData'][0, 0]['RawEvents']['Trial'][0,0][0,trial_num]['States'][0,0]['Stimulus1On'][0,0][0,:]
    
    adjusted = oddball + trial_onsets[trial_num] +  time_delay# + time_delay
    # print(oddball,trial_onsets[trial_num],adjusted)
    adjusted_odd_stim.append(adjusted)

adjusted_reward_onsets = []
for trial_num in range(0, n_trials):
    reward =  bpod_data['SessionData'][0, 0]['RawEvents']['Trial'][0,0][0,trial_num]['States'][0,0]['Reward'][0,0][0,:]
    adjusted = reward + trial_onsets[trial_num] +  time_delay
    adjusted_reward_onsets.append(adjusted)


# adjusted_odd_stim
data_df['oddball_onsets'] = 0

for onset_offset in adjusted_odd_stim:
    onset, offset = onset_offset
    # Find indices where the time is within the onset and offset range
    indices_in_range = data_df[(data_df['time_seconds'] >= onset) & (data_df['time_seconds'] <= offset)].index
    # Set the corresponding rows in 'oddball_onsets' to 1
    data_df.loc[indices_in_range, 'oddball_onsets'] = 1

data_df['reward_onsets'] = 0

for onset_offset in adjusted_reward_onsets:
    onset, offset = onset_offset
    indices_in_range = data_df[(data_df['time_seconds'] >= onset) & (data_df['time_seconds'] <= offset)].index
    data_df.loc[indices_in_range, 'reward_onsets'] = 1

start = 0
end = 1500


plt.plot(data_df['time_seconds'][start:end],data_df['norm_stim_on'][start:end],color = 'black')
plt.plot(data_df['time_seconds'][start:end],data_df['oddball_onsets'][start:end],color = 'r')
plt.plot(data_df['time_seconds'][start:end],data_df['trial_onsets'][start:end],color = 'b')
plt.plot(data_df['time_seconds'][start:end],data_df['reward_onsets'][start:end],color = 'purple')



In [None]:
trial_types

In [None]:
start = 0
end = 5000


plt.plot(data_df['time_seconds'][start:end],data_df['norm_stim_on'][start:end],color = 'black')
plt.plot(data_df['time_seconds'][start:end],data_df['oddball_onsets'][start:end],color = 'r')
plt.plot(data_df['time_seconds'][start:end],data_df['trial_onsets'][start:end],color = 'b')
plt.plot(data_df['time_seconds'][start:end],data_df['reward_onsets'][start:end],color = 'purple')


# Match Dorsal Cortex to the Allen Atlas

In [8]:
from bokeh.plotting import show



## Frame averaging

In [9]:
try:
    frames_average = np.load(pjoin(localdisk,'frames_average.npy'))
    print("Found frames average")
except:
    print("Frames average not found, averaging...")
    nbaseline_frames = 10
    trial_onsets = data_df['trial_onsets']
    frames_average_trials = frames_average_for_trials(dat,
                                                        trial_onsets,
                                                        nbaseline_frames)

  
        
    np.save(pjoin(localdisk,'frames_average.npy'), frames_average)


Found frames average


In [13]:
frame_averages = []
trial_onsets = data_df['trial_onsets']
for on in tqdm(trial_onsets):
    frame_averages.append(dat[on:on+nbaseline_frames].mean(axis=0))
frame_average = np.stack(frames_averages)
    
np.save(pjoin(localdisk,'frames_average.npy'), frames_average)

NameError: name 'data_df' is not defined

In [None]:
nbaseline_frames = 10
trial_onsets = data_df['trial_onsets']
frames_average_trials = frames_average_for_trials(dat,trial_onsets,nbaseline_frames)        
np.save(pjoin(localdisk,'frames_average.npy'), frames_average_trials.mean(axis=0))


In [None]:



resolution = 0.015 # mm per pixel
image = frames_average[0] # load a frame from the references 
bregma_offset = (np.array(image.shape[::-1])/2).astype('int')

ccf_regions,proj,brain_outline = allen_load_reference('dorsal_cortex')

# Get or load landmarks, these are for the allen dorsal cortex dataset
landmarks = {'x': [-1.95, 0, 1.95, 0],
             'y': [-3.45, -3.45, -3.45, 3.2],
             'name': ['OB_left', 'OB_center', 'OB_right', 'RSP_base'],
             'color': ['#fc9d03', '#0367fc', '#fc9d03', '#fc4103']}
landmarks = pd.DataFrame(landmarks)

# Load holoviews with the bokeh backend
import holoviews as hv
hv.extension('bokeh')

In [None]:
landmarks = {'x': [-1.95, 0, 1.95, 0],
             'y': [-3.45, -3.45, -3.45, 3.2],
             'name': ['OB_left', 'OB_center', 'OB_right', 'RSP_base'],
             'color': ['#fc9d03', '#0367fc', '#fc9d03', '#fc4103']}

wid,lmark_wid = hv_adjust_reference_landmarks(landmarks,ccf_regions)
wid

In [None]:
# Get the new landmarks (save these somewhere)
# WARNING: This only works if you change something (edit the table)
landmarks = pd.DataFrame(lmark_wid.data)[['x','y','name','color']] # landmarks in allen_coords

save_allen_landmarks(landmarks,
                     resolution = resolution,
                     bregma_offset = bregma_offset)
landmarks

In [None]:
lmarks['resolution']


In [None]:
from wfield import *

# load data and references
# localdisk = '/home/joao/dual/temp_folder/'
frames_average = np.load(pjoin(localdisk,'frames_average.npy'))
image = frames_average[0]
lmarks = load_allen_landmarks(None)
landmarks = lmarks['landmarks']
if 'landmarks_match' in lmarks.keys():
    landmarks_match = lmarks['landmarks_match']
else:
    landmarks_match = None
# bregma_offset = lmarks['bregma_offset']
bregma_offset = [300,320]
resolution = lmarks['resolution']

# The following line lets you plot previous landmarks
# how to load a transform
# landmarks_match = landmarks_match[['x','y','name','color']]


wid,lmark_wid,landmarks_im = hv_adjust_image_landmarks(image,landmarks,
                                                       landmarks_match = landmarks_match,
                                                       bregma_offset = bregma_offset,
                                                       resolution = resolution)
wid

In [None]:
# Get the similarity transform and plot the result
landmarks_match = pd.DataFrame(lmark_wid.data)
M = allen_transform_from_landmarks(landmarks_im,landmarks_match)
overlay = hv_show_transformed_overlay(frames_average[0], M, ccf_regions,bregma_offset=bregma_offset,resolution = resolution)
show(hv.render(overlay))

In [None]:
# Save the transform and the reference points
save_allen_landmarks(landmarks,
                    #  filename = pjoin(localdisk,'ccf_transform.json'),
                     filename = pjoin(localdisk,'ccf_transform_landmarks.json'),
                     landmarks_match=landmarks_match,
                     transform=M,
                     resolution = resolution,
                     bregma_offset = bregma_offset)
print('''Transform:
    - scale {0}
    - translation {1}
    - rotation {2}'''.format(M.scale,M.translation,np.rad2deg(M.rotation)))

In [None]:
# Plot an allen reference with holoviews
allen_regions = hv_plot_allen_regions(ccf_regions).options(width = w, height=h)
show(hv.render(allen_regions))

In [None]:
# Plot with matplotlib (example)
%matplotlib inline
plt.figure(figsize=[8,8])
h,w = image.shape
plt.imshow(im_apply_transform(image,M),
           cmap='gray')
for i,r in landmarks_im.iterrows():
    m = landmarks_match.loc[i]
    plt.plot(m.x,m.y,'g+',ms=10)
    plt.plot(r.x,r.y,'rx',ms=20)
    plt.text(r.x,r.y,r['name'],color='y',fontsize=12,va='bottom',ha='center')

# Masked decomposition

In [None]:
from wfield import atlas_from_landmarks_file
from wfield import *

In [None]:
lmarks_path = pjoin(localdisk,'ccf_transform_landmarks.json')
atlas, areanames, brain_mask = atlas_from_landmarks_file(lmarks_path,dims = [h,w],do_transform=True)



In [None]:
# frames_average = np.load(pjoin(localdisk,'frames_average.npy'))
frames_average = frames_average_trials

ch1data = dat[:,1,:,:]
mask  = get_std_mask(ch1data,threshold=45)
U,SVT = approximate_svd(dat, frames_average, mask = atlas, k = 200)



In [None]:
import pylab as plt
from wfield import nb_play_movie
from wfield import SVDStack
%matplotlib widget
stack = SVDStack(U,SVT)
plt.figure()
nb_play_movie(stack,cmap='gray')

In [None]:
np.save(pjoin(localdisk,'U.npy'),U)
np.save(pjoin(localdisk,'SVT.npy'),SVT)

In [None]:
# # filter and interpolate
# t = np.arange(SVT_470.shape[1]*2) # interpolate the violet
# from scipy.interpolate import interp1d
# SVT_405 = interp1d(t[1::2],SVT_405,axis=1,
#                    fill_value='extrapolate')(t[0::2])
# freq_highpass = 0.1
# fs = 10.
# SVT_470 = highpass(SVT_470,w = freq_highpass, fs = fs).astype(np.float32)
# SVT_405 = highpass(SVT_405,w = freq_highpass, fs = fs).astype(np.float32)

# # ref_folder = '/home/aeltona/wfield/references'
# # Apply the transform, prepare for plotting
# lmarks = load_allen_landmarks(pjoin(localdisk,'ccf_transform_landmarks.json'))
# ccf_regions,proj,brain_outline = allen_load_reference('dorsal_cortex')
# bout = brain_outline/lmarks['resolution'] + np.array(lmarks['bregma_offset'])

# mask = contour_to_mask(*bout.T,dims = U.shape[:-1])

# if 'dims' in dir():
#     U = U.reshape(dims)

# # warp and mask U, first make boarders zero
# from wfield.imutils import mask_to_3d
# U[:,0,:] = 0
# U[0,:,:] = 0
# U[-1,:,:] = 0
# U[:,-1,:] = 0
# U = np.stack(runpar(im_apply_transform,
#                     U.transpose([2,0,1]),
#                     M = lmarks['transform']))
# U[mask_to_3d(mask,U.shape)==0] = np.nan
# U = U.transpose([1,2,0])

# Hemodynamics correction

In [None]:
folder = localdisk

In [None]:
# U = np.load(pjoin(folder,'U.npy'))
# SVT = np.load(pjoin(folder,'SVT.npy'))
SVT_470 = SVT[:,::2]
SVT_405 = SVT[:,1::2]
# SVT_corr = np.load(pjoin(folder,'SVTcorr.npy'))

# T = np.load(pjoin(folder,'T.npy'))

In [None]:
fs = 5

output_folder = localdisk

tstart = time.time()

SVT_470 = SVT[:,0::2]
t = np.arange(SVT.shape[1]) # interpolate the violet
from scipy.interpolate import interp1d
SVT_405 = interp1d(t[1::2],SVT[:,1::2],axis=1,
                    fill_value='extrapolate')(t[0::2])
SVTcorr, rcoeffs, T = hemodynamic_correction(U, 
                                             SVT_470, 
                                             SVT_405, 
                                             fs=fs,
                                             freq_lowpass=2)  

print('Done hemodynamic correction in {0} s '.format(time.time()-tstart))

np.save(pjoin(localdisk,'rcoeffs.npy'),rcoeffs)
np.save(pjoin(localdisk,'T.npy'),T)
np.save(pjoin(localdisk,'SVTcorr.npy'),SVTcorr)

In [None]:
try: # don't crash while plotting
    import pylab as plt
    plt.matplotlib.style.use('ggplot')
    from wfield import  plot_summary_hemodynamics_dual_colors
    frame_rate = 5.

    plot_summary_hemodynamics_dual_colors(rcoeffs,
                                            SVT_470,
                                            SVT_405,
                                            U,
                                            T,
                                            frame_rate=frame_rate,
                                            duration = 6,
                                            outputdir = output_folder);
except Exception as err:
    print('There was an issue plotting.')
    print(err)

In [None]:
plt.imshow(usvt.mean());

In [None]:
plt.imshow(vio.var());

In [None]:
plt.imshow(gre.var());

In [None]:
save_video_from_array(usvt, f"U_SVDStack",localdisk,1)


# Warping Stack ROIs

In [None]:


U = np.load(pjoin(localdisk,'U_atlas.npy'))
SVT = np.load(pjoin(localdisk,'SVTcorr.npy'))
stack = SVDStack(U,SVT)
lmarks = load_allen_landmarks(pjoin(localdisk,'ccf_transform_landmarks.json'))

ccf_regions_reference,proj,brain_outline = allen_load_reference('dorsal_cortex')
# the reference is in allen CCF space and needs to be converted
# this converts to warped image space (accounting for the transformation)
ccf_regions = allen_transform_regions(None,ccf_regions_reference,
                                      resolution = lmarks['resolution'],
                                        bregma_offset = lmarks['bregma_offset'])
atlas, areanames, brain_mask = atlas_from_landmarks_file(pjoin(localdisk,'ccf_transform_landmarks.json')) # this loads the atlas in transformed coords

# this does the transform (warps the original images)
stack.set_warped(1, M = lmarks['transform']) # this warps the spatial components in the stack


In [None]:
# this converts the reference to image space (unwarped)
ref_folder = '/home/aeltona/wfield/references'
atlas_im, areanames, brain_mask = atlas_from_landmarks_file(pjoin(ref_folder,'dorsal_cortex_landmarks.json'),do_transform = True) # this loads the untransformed atlas
ccf_regions_im = allen_transform_regions(lmarks['transform'],ccf_regions_reference,
                                        resolution = lmarks['resolution'],
                                        bregma_offset = lmarks['bregma_offset'])

In [None]:
# Lets compare the warped with the unwarped average activity in an area
# area 33 is VISp
area = 33
stack.set_warped(True) # once this is done once the transform is set and you can alternate between the 2 modes.
warped = stack.get_timecourse(np.where(atlas == area)).mean(axis = 0)
stack.set_warped(False)
unwarped = stack.get_timecourse(np.where(atlas_im == area)).mean(axis = 0)

fig = plt.figure(figsize = [7,10])
fig.add_subplot(2,1,1)
plt.plot(unwarped,'k',lw = .5,label = 'unwarped average')
plt.plot(warped,'r',lw = .5,label = 'warped average')
plt.legend();
plt.xlim([2000,2500])
# plt.ylim([-0.04,0.06]);

area = -33 # plot the other side
stack.set_warped(True) 
warped = stack.get_timecourse(np.where(atlas == area)).mean(axis = 0)
stack.set_warped(False)
unwarped = stack.get_timecourse(np.where(atlas_im == area)).mean(axis = 0)

fig.add_subplot(2,1,2)
plt.plot(unwarped,'k',lw = .5,label = 'unwarped average')
plt.plot(warped,'r',lw = .5,label = 'warped average')
plt.legend()
plt.xlim([2000,2500])

# plt.xlim([20000,20500])
# plt.ylim([-0.04,0.06]);

In [None]:
# Plot the first 20 spatial components of the transformed dataset and the raw dataset
fig = plt.figure(figsize = [15,10])
ncomponents = 20
for icomponent in range(ncomponents):
    fig.add_subplot(5,4,icomponent+1)
    plt.imshow(np.concatenate([stack.originalU[:,:,icomponent],
                               stack.U_warped[:,:,icomponent]],axis = 1),
               clim=[-0.01,0.01],cmap='Spectral_r')
    # plot the regions overlayed on the raw images
    for i,r in ccf_regions_im.iterrows():
        plt.plot(r['left_x'],r['left_y'],'k',lw=0.3)
    # plot the raw reference because the images were converted
    for i,r in ccf_regions.iterrows():
        plt.plot(np.array(r['left_x'])+stack.U.shape[1],r['left_y'],'k',lw=0.3)
    plt.axis('off')

## Plot 20 random frames of the transformed dataset and the raw dataset 
(sanity check: what does the transform do to the images)

In [None]:

# Plot the 20 frames of the transformed dataset and the raw dataset
fig = plt.figure(figsize = [15,10])
for i,iframe in enumerate(np.random.choice(np.arange(0,stack.shape[0]),20)):
    fig.add_subplot(5,4,i+1)
    plt.imshow(np.concatenate([reconstruct(stack.originalU,stack.SVT[:,iframe]),
                               reconstruct(stack.U_warped,stack.SVT[:,iframe])],axis = 1),
               clim=[-0.03,0.03],cmap='Spectral_r')
    # plot the regions overlayed on the raw images
    for i,r in ccf_regions_im.iterrows():
        plt.plot(r['left_x'],r['left_y'],'k',lw=0.3)
    # plot the raw reference because the images were converted
    for i,r in ccf_regions.iterrows():
        plt.plot(np.array(r['left_x'])+stack.U.shape[1],r['left_y'],'k',lw=0.3)
    plt.axis('off')

## Highlight the differences in the atlas and ROIs in both warped and raw spaces

In [None]:
# Plot the overlap for the same region in the warped versus unwarped atlases 
plt.figure(figsize = [5,5])
area1 = 4
area2 = 33
reg = np.zeros([*atlas.shape,3])
reg[:,:,0] = np.array((atlas == area1) | (atlas == area2))*255
reg[:,:,1] = np.array((atlas_im == area1) | (atlas_im == area2))*255
plt.imshow(reg.astype('uint8'))

for i,r in ccf_regions.iterrows():
    for side in ['right']:
        plt.plot(np.array(r[side+'_x']),r[side +'_y'],'r',lw=1)
for i,r in ccf_regions_im.iterrows():
    for side in ['right']:
        plt.plot(np.array(r[side+'_x']),r[side +'_y'],'g',lw=1)
plt.axis('off');

## Explore the converted stack on the notebook

In [None]:
# play a movie of the stack with the regions overlayed
fig = plt.figure()
# This is for stack.set_warped(True)
stack.set_warped(True)
for i,r in ccf_regions.iterrows():
    for side in ['left','right']:
        plt.plot(np.array(r[side+'_x']),r[side +'_y'],'k',lw=0.3);
plt.axis('off')
nb_play_movie(stack,cmap = 'Spectral_r',clim=[-0.04,0.04]);


# Calculate ΔF/F and Z-score

In [None]:
data_df.shape[0]

In [None]:
data_df.columns

In [None]:
# Find intervals where norm_stim_on is 0
zero_intervals = []
start = None
for idx, row in data_df.iterrows():
    if row['norm_stim_on'] == 0:
        if start is None:
            start = idx
    else:
        if start is not None:
            zero_intervals.append((start, idx - 1))
            start = None
if start is not None:
    zero_intervals.append((start, len(data_df) - 1))

# Sort intervals by length in descending order
sorted_intervals = sorted(zero_intervals, key=lambda x: x[1] - x[0], reverse=True)

print("Sorted Intervals...")
# for interval in sorted_intervals:
#     print(interval)

# Find intervals between trial_onsets and next stimulus
# trial_intervals =pend((trial_onset, data_df.loc[idx + 1, 'frame']))

# print("\nIntervals between trial_onsets and next stimulus:")
# for interval in trial_intervals:
#     print(interval)

In [None]:
for i in range (0,5):
    print(sorted_intervals[i])
# print(base_start, base_end )

In [None]:
base_start, base_end = sorted_intervals[3]

U = np.load(pjoin(localdisk,'U.npy'))
SVT = np.load(pjoin(localdisk,'SVTcorr.npy'))

lmarks = load_allen_landmarks(pjoin(localdisk,'ccf_transform_landmarks.json'))
atlas, areanames, brain_mask = atlas_from_landmarks_file(pjoin(localdisk,'ccf_transform_landmarks.json')) # this loads the atlas in transformed coords

stack = SVDStack(U,SVT)
stack.set_warped(True)
for area_number in range(0,len(areanames)):
    area_name = str(areanames[area_number][1])
    area = areanames[area_number][0]
    # print(area,area_name)
    try:
        data = stack.get_timecourse(np.where(atlas == area)).mean(axis = 0)
        dFF, zscore = dFF_z(data,base_start,base_end);
        # dFF, zscore = dFF_z(data)
        data_df[f'{area_name}_dFF'] = dFF
        data_df[f'{area_name}_z'] = zscore
        print(area_name, 'loaded')
    except Exception as e:
        
        print(f"An error occurred: {e} ",area_name, 'failed')


In [None]:
# Find the intervals of stimuli
one_intervals = []
start = None
for idx, row in data_df.iterrows():
    if row['norm_stim_on'] == 1:
        if start is None:
            start = idx
    else:
        if start is not None:
            one_intervals.append((start, idx - 1))
            start = None
if start is not None:
    one_intervals.append((start, len(data_df) - 1))

# Sort intervals by the order of appearance
stim_intervals = sorted(one_intervals, key=lambda x: x[0])

# Create a DataFrame for the stimulus intervals
stimulus_data = {
    'stim_num': [],
    'onset': [],
    'offset': [],
    'stim_id': []
}

for stim_num, (onset, offset) in enumerate(stim_intervals):
    stim_id = data_df.loc[onset, 'norm_stim_id']
    stimulus_data['stim_num'].append(stim_num+1)
    stimulus_data['onset'].append(onset)
    stimulus_data['offset'].append(offset)
    stimulus_data['stim_id'].append(stim_id)

stim_df = pd.DataFrame(stimulus_data)
last_stim_num = stim_df['stim_num'].iloc[-1] // 4
trial_num = [i // 4 + 1 for i in range(4 * last_stim_num)]
stim_df['trial_num'] = trial_num
# Print the resulting DataFrame

odd_trials_array = odd_trials[0]  # Extracting the array from the tuple
true_odd_trials = [trial + 1 for trial in odd_trials_array]
stim_df.loc[(stim_df['stim_id'] == 9) & (stim_df['trial_num'].isin(true_odd_trials)), 'stim_id'] = 10
stim_df['duration'] = stim_df['offset'] - stim_df['onset']
mean_dur = stim_df.groupby('stim_id')['duration'].mean()
sem_dur = stim_df.groupby('stim_id')['duration'].sem()
stim_df['mean_dur'] = stim_df['stim_id'].map(mean_dur)
stim_df['sem_dur'] = stim_df['stim_id'].map(sem_dur)
display(stim_df)

In [None]:
# Select region and data type
region = 'ACAd'
data_type = 'z'

time_window_before = 10  # 10 frames before stimulus onset
time_window_after = 20  # 20 frames after stimulus onset

fig, axs = plt.subplots(nrows=1, ncols=len(sides), figsize=(21, 7))
sides = ['left', 'right']
for i, side in enumerate(sides):
    column_name = f'{region}_{side}_{data_type}'
    
    # Filter dataframe for the selected region and data type
    region_data = data_df[[column_name, 'time_seconds']]
    
    # Initialize lists for peri-stimulus data
    peri_stim_data_expected = []
    peri_stim_data_oddball = []
    offsets_expected = []
    offsets_oddball = []
    
    # Extract indices of stimulus onsets and offsets from stim_df
    exp_onsets = stim_df.loc[stim_df['stim_id'] == 9, 'onset'].tolist()
    odd_onsets = stim_df.loc[stim_df['stim_id'] == 10, 'onset'].tolist()
    exp_offsets = stim_df.loc[stim_df['stim_id'] == 9, 'offset'].tolist()
    odd_offsets = stim_df.loc[stim_df['stim_id'] == 10, 'offset'].tolist()
    exp_mean_dur = stim_df.loc[stim_df['stim_id'] == 9, 'mean_dur'].tolist()
    odd_mean_dur = stim_df.loc[stim_df['stim_id'] == 10, 'mean_dur'].tolist()
    exp_sem_dur = stim_df.loc[stim_df['stim_id'] == 9, 'sem_dur'].tolist()
    odd_sem_dur = stim_df.loc[stim_df['stim_id'] == 10, 'sem_dur'].tolist()
    
    # Extract peri-stimulus data for expected stimuli
    for onset, offset, mean_dur, sem_dur in zip(exp_onsets, exp_offsets, exp_mean_dur, exp_sem_dur):
        start_idx = onset - time_window_before
        end_idx = onset + time_window_after
        if start_idx >= 0 and end_idx < len(region_data):
            peri_stim_data_expected.append(region_data.iloc[start_idx:end_idx][column_name].values)
            offsets_expected.append(offset - onset)
    
    # Extract peri-stimulus data for oddball stimuli
    for onset, offset, mean_dur, sem_dur in zip(odd_onsets, odd_offsets, odd_mean_dur, odd_sem_dur):
        start_idx = onset - time_window_before
        end_idx = onset + time_window_after
        if start_idx >= 0 and end_idx < len(region_data):
            peri_stim_data_oddball.append(region_data.iloc[start_idx:end_idx][column_name].values)
            offsets_oddball.append(offset - onset)

    # Determine the maximum length of the peri-stimulus data
    max_length = max(max(len(x) for x in peri_stim_data_expected), max(len(x) for x in peri_stim_data_oddball))
    
    # Pad or truncate the data to ensure consistent length
    peri_stim_data_expected = [np.pad(x, (0, max_length - len(x)), 'constant', constant_values=np.nan) if len(x) < max_length else x[:max_length] for x in peri_stim_data_expected]
    peri_stim_data_oddball = [np.pad(x, (0, max_length - len(x)), 'constant', constant_values=np.nan) if len(x) < max_length else x[:max_length] for x in peri_stim_data_oddball]

    # Calculate mean activity for expected and oddball stimuli
    mean_activity_expected = np.nanmean(peri_stim_data_expected, axis=0)
    sem_activity_expected = np.nanstd(peri_stim_data_expected, axis=0) / np.sqrt(np.sum(~np.isnan(peri_stim_data_expected), axis=0))
    mean_activity_oddball = np.nanmean(peri_stim_data_oddball, axis=0)
    sem_activity_oddball = np.nanstd(peri_stim_data_oddball, axis=0) / np.sqrt(np.sum(~np.isnan(peri_stim_data_oddball), axis=0))

    ax = axs[i]
    
    # Plot mean activity for expected stimuli
    ax.plot(np.arange(-time_window_before, time_window_after), mean_activity_expected, label=f'{region} {side.capitalize()} - Expected', color='blue')
    ax.fill_between(np.arange(-time_window_before, time_window_after), mean_activity_expected - sem_activity_expected, mean_activity_expected + sem_activity_expected, color='blue', alpha=0.3)
    
    # Plot mean activity for oddball stimuli
    ax.plot(np.arange(-time_window_before, time_window_after), mean_activity_oddball, label=f'{region} {side.capitalize()} - Oddball', color='red')
    ax.fill_between(np.arange(-time_window_before, time_window_after), mean_activity_oddball - sem_activity_oddball, mean_activity_oddball + sem_activity_oddball, color='red', alpha=0.3)
    
    # Plot stimulus onset and offset
    ax.axvline(x=0, linestyle='--', color='gray', label='Stimulus Onset')  # Stimulus onset
    ax.axvline(x=np.mean(exp_mean_dur), linestyle='--', color='blue', label='Expected Stimulus Offset')  # Average offset for expected stimuli
    ax.axvline(x=np.mean(odd_mean_dur), linestyle='--', color='red', label='Oddball Stimulus Offset')  # Average offset for oddball stimuli
    
    # Shade stimulus period
    ax.axvspan(0, np.mean(exp_mean_dur), alpha=0.3, color='orange')  # Stimulus period for expected stimuli
    ax.axvspan(0, np.mean(odd_mean_dur), alpha=0.3, color='pink')  # Stimulus period for oddball stimuli

    # Shade SEM for offsets
    ax.axvspan(exp_mean_dur[0] - exp_sem_dur[0], exp_mean_dur[0] + exp_sem_dur[0], alpha=0.1, color='blue')
    ax.axvspan(odd_mean_dur[0] - odd_sem_dur[0], odd_mean_dur[0] + odd_sem_dur[0], alpha=0.1, color='red')
    
    ax.set_title(f'{region} {side}')
    ax.set_xlabel('Time (indices) relative to stimulus onset')
    ax.set_ylabel('Activity (z-score)')
    ax.legend();
    ax.grid(False)

plt.tight_layout()
plt.savefig(pjoin(localdisk, f'PSTH/{region}_PSTH.png'))

# plt.show()


In [None]:
# Iterate over areanames to find existing columns
for area_number in range(len(areanames)):
    area_name = str(areanames[area_number][1])
    column_name = f'{area_name}_{data}'
    if column_name in data_df.columns:
        column_data = data_df[column_name].values
        if not np.isnan(column_data).all():  # Check if the entire column is NaN
            columns_to_plot.append(column_name)
        else:
            excluded_columns.append(column_name) # This will print the first part of the name
# Use a set to store unique base region names
unique_regions = set()

for region in columns_to_plot:
    base_region_name = region.split('_')[0]
    unique_regions.add(base_region_name)

# Convert the set to a list if needed
unique_regions = list(unique_regions)

before = 10
after = 20
data_type = 'z'
# Print unique base region names
for region in unique_regions:
    plot_PSTH(data_df, stim_df, region, before, after,data_type);
    

In [None]:
def plot_PSTH(data_df, stim_df, region, time_window_before, time_window_after,data_type):
    sides = ['left', 'right']
    # data_type = 'zscore'  # Assuming 'zscore' is the data type, adjust if necessary
    print(f"PSTH for {region} in progress...")
    fig, axs = plt.subplots(nrows=1, ncols=len(sides), figsize=(21, 7))
    
    for i, side in enumerate(sides):
        column_name = f'{region}_{side}_{data_type}'
        
        # Filter dataframe for the selected region and data type
        region_data = data_df[[column_name, 'time_seconds']]
        
        # Initialize lists for peri-stimulus data
        peri_stim_data_expected = []
        peri_stim_data_oddball = []
        offsets_expected = []
        offsets_oddball = []
        
        # Extract indices of stimulus onsets and offsets from stim_df
        exp_onsets = stim_df.loc[stim_df['stim_id'] == 9, 'onset'].tolist()
        odd_onsets = stim_df.loc[stim_df['stim_id'] == 10, 'onset'].tolist()
        exp_offsets = stim_df.loc[stim_df['stim_id'] == 9, 'offset'].tolist()
        odd_offsets = stim_df.loc[stim_df['stim_id'] == 10, 'offset'].tolist()
        exp_mean_dur = stim_df.loc[stim_df['stim_id'] == 9, 'mean_dur'].tolist()
        odd_mean_dur = stim_df.loc[stim_df['stim_id'] == 10, 'mean_dur'].tolist()
        exp_sem_dur = stim_df.loc[stim_df['stim_id'] == 9, 'sem_dur'].tolist()
        odd_sem_dur = stim_df.loc[stim_df['stim_id'] == 10, 'sem_dur'].tolist()
        
        # Extract peri-stimulus data for expected stimuli
        for onset, offset, mean_dur, sem_dur in zip(exp_onsets, exp_offsets, exp_mean_dur, exp_sem_dur):
            start_idx = onset - time_window_before
            end_idx = onset + time_window_after
            if start_idx >= 0 and end_idx < len(region_data):
                peri_stim_data_expected.append(region_data.iloc[start_idx:end_idx][column_name].values)
                offsets_expected.append(offset - onset)
        
        # Extract peri-stimulus data for oddball stimuli
        for onset, offset, mean_dur, sem_dur in zip(odd_onsets, odd_offsets, odd_mean_dur, odd_sem_dur):
            start_idx = onset - time_window_before
            end_idx = onset + time_window_after
            if start_idx >= 0 and end_idx < len(region_data):
                peri_stim_data_oddball.append(region_data.iloc[start_idx:end_idx][column_name].values)
                offsets_oddball.append(offset - onset)

        # Determine the maximum length of the peri-stimulus data
        max_length = max(max(len(x) for x in peri_stim_data_expected), max(len(x) for x in peri_stim_data_oddball))
        
        # Pad or truncate the data to ensure consistent length
        peri_stim_data_expected = [np.pad(x, (0, max_length - len(x)), 'constant', constant_values=np.nan) if len(x) < max_length else x[:max_length] for x in peri_stim_data_expected]
        peri_stim_data_oddball = [np.pad(x, (0, max_length - len(x)), 'constant', constant_values=np.nan) if len(x) < max_length else x[:max_length] for x in peri_stim_data_oddball]

        # Calculate mean activity for expected and oddball stimuli
        mean_activity_expected = np.nanmean(peri_stim_data_expected, axis=0)
        sem_activity_expected = np.nanstd(peri_stim_data_expected, axis=0) / np.sqrt(np.sum(~np.isnan(peri_stim_data_expected), axis=0))
        mean_activity_oddball = np.nanmean(peri_stim_data_oddball, axis=0)
        sem_activity_oddball = np.nanstd(peri_stim_data_oddball, axis=0) / np.sqrt(np.sum(~np.isnan(peri_stim_data_oddball), axis=0))

        ax = axs[i]
        
        # Plot mean activity for expected stimuli
        ax.plot(np.arange(-time_window_before, time_window_after), mean_activity_expected, label=f'{region} {side.capitalize()} - Expected', color='blue')
        ax.fill_between(np.arange(-time_window_before, time_window_after), mean_activity_expected - sem_activity_expected, mean_activity_expected + sem_activity_expected, color='blue', alpha=0.3)
        
        # Plot mean activity for oddball stimuli
        ax.plot(np.arange(-time_window_before, time_window_after), mean_activity_oddball, label=f'{region} {side.capitalize()} - Oddball', color='red')
        ax.fill_between(np.arange(-time_window_before, time_window_after), mean_activity_oddball - sem_activity_oddball, mean_activity_oddball + sem_activity_oddball, color='red', alpha=0.3)
        
        # Plot stimulus onset and offset
        ax.axvline(x=0, linestyle='--', color='gray', label='Stimulus Onset')  # Stimulus onset
        ax.axvline(x=np.mean(exp_mean_dur), linestyle='--', color='blue', label='Expected Stimulus Offset')  # Average offset for expected stimuli
        ax.axvline(x=np.mean(odd_mean_dur), linestyle='--', color='red', label='Oddball Stimulus Offset')  # Average offset for oddball stimuli
        
        # Shade stimulus period
        ax.axvspan(0, np.mean(exp_mean_dur), alpha=0.3, color='orange')  # Stimulus period for expected stimuli
        ax.axvspan(0, np.mean(odd_mean_dur), alpha=0.3, color='pink')  # Stimulus period for oddball stimuli

        # Shade SEM for offsets
        ax.axvspan(exp_mean_dur[0] - exp_sem_dur[0], exp_mean_dur[0] + exp_sem_dur[0], alpha=0.1, color='blue')
        ax.axvspan(odd_mean_dur[0] - odd_sem_dur[0], odd_mean_dur[0] + odd_sem_dur[0], alpha=0.1, color='red')
        
        ax.set_title(f'{region} {side}')
        ax.set_xlabel('Time (indices) relative to stimulus onset')
        ax.set_ylabel('Activity (z-score)')
        ax.legend()
        ax.grid(False)

    plt.tight_layout()
    plt.savefig(pjoin(localdisk, f'PSTH/{region}_PSTH.png'))
    print(f"PSTH for {region} saved...")
    plt.clf()

In [None]:
# Define start and end points for the data slice
# start, end = onset_intervals[-1]  

# start -= 25
# end += 50

start, end = 0,data_df.shape[0]
# Initialize lists to store columns and corresponding data
columns_to_plot = []
data_to_plot = []
excluded_columns = []

# Specify the data type ('dFF' or 'z')
data = 'z'

# Iterate over areanames to find existing columns
for area_number in range(len(areanames)):
    area_name = str(areanames[area_number][1])
    column_name = f'{area_name}_{data}'
    if column_name in data_df.columns:
        column_data = data_df[column_name].values
        if not np.isnan(column_data).all():  # Check if the entire column is NaN
            columns_to_plot.append(column_name)
            data_to_plot.append(column_data[start:end])  # Slice data to the specified range
        else:
            excluded_columns.append(column_name)

# Check if any columns exist
if len(columns_to_plot) == 0:
    print("No columns found in the DataFrame.")
else:
    # Create a gridspec with a color bar
    fig = plt.figure(figsize=(12, 8))
    gs = gridspec.GridSpec(3, 2, width_ratios=[15, 1], height_ratios=[1, 1, len(columns_to_plot)], hspace=0.05, wspace=0.05)
    
    # Plot norm_stim_on as bars on the top left subplot
    ax1 = fig.add_subplot(gs[1, 0])
    norm_stim_on_values = data_df['norm_stim_on'].values[start:end]
    ax1.bar(data_df['time_seconds'].values[start:end], norm_stim_on_values, color='black', width=1)
    ax1.set_xlim([data_df['time_seconds'].values[start], data_df['time_seconds'].values[end-1]])
    ax1.axis('off')  # Turn off the axis

    # Plot rotary encoder data
    ax3 = fig.add_subplot(gs[0, 0])
    encoder_values = data_df['Rotary Encoder'].values[start:end]
    ax3.plot(data_df['time_seconds'].values[start:end], encoder_values, color='black')
    ax3.set_xlim([data_df['time_seconds'].values[start], data_df['time_seconds'].values[end-1]])
    ax3.axis('off')
    
    # Plot oddball_onsets and reward_onsets as bars
    oddball_values = data_df['oddball_onsets'].values[start:end]
    ax1.bar(data_df['time_seconds'].values[start:end], oddball_values, color='red', width=1)

    reward_values = data_df['reward_onsets'].values[start:end]
    ax1.bar(data_df['time_seconds'].values[start:end], reward_values, color='purple', width=1)

    # Plot the heatmap on the bottom left subplot
    ax2 = fig.add_subplot(gs[2, 0], sharex=ax1)
    cax = ax2.imshow(data_to_plot, aspect='auto', cmap='bwr', interpolation='nearest', extent=[data_df['time_seconds'].values[start], data_df['time_seconds'].values[end-1], 0, len(columns_to_plot)])
    ax2.set_xlabel('Time(s)')
    ax2.set_ylabel('Areas')

    # Set x-axis ticks for the heatmap
    time_seconds = data_df['time_seconds'].values[start:end]
    ax2.set_xticks(time_seconds[::int(len(time_seconds)/10)])  # Adjust the step for better labeling
    ax2.set_xticklabels(time_seconds[::int(len(time_seconds)/10)], rotation=90)
    
    # Set y-axis ticks for the heatmap
    ax2.set_yticks(range(len(columns_to_plot)))
    ax2.set_yticklabels(columns_to_plot)
    
    # Add a color bar in the right subplot
    cbar_ax = fig.add_subplot(gs[:, 1])
    cbar = fig.colorbar(cax, cax=cbar_ax)
    cbar.set_label(f'{data}')
    ax2.grid(False)
    # plt.show()
# plt.savefig(pjoin(localdisk, '')
# plt.savefig(pjoin(localdisk, 'whole_session.png'))


In [None]:

tif_files = [f for f in os.listdir(localdisk) if f.endswith('.tif')]

filename_parts = re.split(r'_|\.', tif_files[0])
scan_info = '_'.join(filename_parts[0:4])

csv_path = pjoin(localdisk, scan_info+'.csv')
print(csv_path)

# garb

In [None]:
# Select region and data type
region = 'ACAd'
data_type = 'z'
sides = ['left', 'right']
time_window_before = 5  # 5 seconds before stimulus onset
time_window_after = 10  # 10 seconds after stimulus offset

fig, axs = plt.subplots(nrows=1, ncols=len(sides), figsize=(15, 5))

for i, side in enumerate(sides):
    column_name = f'{region}_{side}_{data_type}'
    
    # Filter dataframe for the selected region and data type
    region_data = data_df[[column_name, 'norm_stim_on', 'oddball_onsets', 'norm_stim_id', 'time_seconds']]
    
    # Initialize lists for peri-stimulus data
    peri_stim_data_expected = []
    peri_stim_data_oddball = []
    offsets_expected = []
    offsets_oddball = []
    
    stim_onsets = region_data.index[region_data['norm_stim_on'] == 1].tolist()
    oddball_onsets = region_data.index[region_data['oddball_onsets'] == 1].tolist()
    expected_stimuli_indices = [idx for idx in stim_onsets if region_data.at[idx, 'norm_stim_id'] == 9 and idx not in oddball_onsets]

    # Extract peri-stimulus data for expected stimuli
    for onset in expected_stimuli_indices:
        offset = region_data.index[(region_data.index > onset) & (region_data['norm_stim_on'] == 0)].tolist()
        if offset:
            offset = offset[0]
            offsets_expected.append(offset - onset)
            start_idx = onset - time_window_before
            end_idx = offset + time_window_after
            if start_idx >= 0 and end_idx < len(region_data):
                peri_stim_data_expected.append(region_data.iloc[start_idx:end_idx][column_name].values)

    # Extract peri-stimulus data for oddball stimuli
    for onset in oddball_onsets:
        offset = region_data.index[(region_data.index > onset) & (region_data['norm_stim_on'] == 0)].tolist()
        if offset:
            offset = offset[0]
            offsets_oddball.append(offset - onset)
            start_idx = onset - time_window_before
            end_idx = offset + time_window_after
            if start_idx >= 0 and end_idx < len(region_data):
                peri_stim_data_oddball.append(region_data.iloc[start_idx:end_idx][column_name].values)

    # Determine the maximum length of the peri-stimulus data
    max_length = max(max(len(x) for x in peri_stim_data_expected), max(len(x) for x in peri_stim_data_oddball))

    # Pad or truncate the data to ensure consistent length
    peri_stim_data_expected = [np.pad(x, (0, max_length - len(x)), 'constant', constant_values=np.nan) if len(x) < max_length else x[:max_length] for x in peri_stim_data_expected]
    peri_stim_data_oddball = [np.pad(x, (0, max_length - len(x)), 'constant', constant_values=np.nan) if len(x) < max_length else x[:max_length] for x in peri_stim_data_oddball]

    # Calculate mean and SEM for expected stimuli
    peri_stim_data_expected = np.array(peri_stim_data_expected)
    mean_activity_expected = np.nanmean(peri_stim_data_expected, axis=0)
    sem_activity_expected = np.nanstd(peri_stim_data_expected, axis=0) / np.sqrt(np.sum(~np.isnan(peri_stim_data_expected), axis=0))
    
    # Calculate mean and SEM for oddball stimuli
    peri_stim_data_oddball = np.array(peri_stim_data_oddball)
    mean_activity_oddball = np.nanmean(peri_stim_data_oddball, axis=0)
    sem_activity_oddball = np.nanstd(peri_stim_data_oddball, axis=0) / np.sqrt(np.sum(~np.isnan(peri_stim_data_oddball), axis=0))
    
    time_vector = np.linspace(-time_window_before, time_window_after, max_length)  # Adjust time vector for the peri-stimulus window

    # Calculate average and SEM for offsets
    mean_offset_expected = np.mean(offsets_expected)
    sem_offset_expected = np.std(offsets_expected) / np.sqrt(len(offsets_expected))
    mean_offset_oddball = np.mean(offsets_oddball)
    sem_offset_oddball = np.std(offsets_oddball) / np.sqrt(len(offsets_oddball))

    ax = axs[i]
    
    # Plot mean activity for expected stimuli
    ax.plot(time_vector, mean_activity_expected, label=f'{region} {side.capitalize()} - Expected', color='blue')
    ax.fill_between(time_vector, mean_activity_expected - sem_activity_expected, mean_activity_expected + sem_activity_expected, color='blue', alpha=0.3)
    
    # Plot mean activity for oddball stimuli
    ax.plot(time_vector, mean_activity_oddball, label=f'{region} {side.capitalize()} - Oddball', color='red')
    ax.fill_between(time_vector, mean_activity_oddball - sem_activity_oddball, mean_activity_oddball + sem_activity_oddball, color='red', alpha=0.3)
    
    # Plot stimulus onset and offset
    ax.axvline(x=0, linestyle='--', color='gray', label='Stimulus Onset')  # Stimulus onset
    ax.axvline(x=mean_offset_expected, linestyle='--', color='blue', label='Expected Stimulus Offset')  # Average offset for expected stimuli
    ax.axvline(x=mean_offset_oddball, linestyle='--', color='red', label='Oddball Stimulus Offset')  # Average offset for oddball stimuli
    
    # Shade stimulus period
    ax.axvspan(0, mean_offset_expected, alpha=0.3, color='orange')
    
    # Shade SEM for offsets
    ax.axvspan(mean_offset_expected - sem_offset_expected, mean_offset_expected + sem_offset_expected, alpha=0.1, color='blue')
    ax.axvspan(mean_offset_oddball - sem_offset_oddball, mean_offset_oddball + sem_offset_oddball, alpha=0.1, color='red')
    
    ax.set_title(f'Peri-stimulus Activity - {side.capitalize()} Side')
    ax.set_xlabel('Time (s) relative to stimulus onset')
    ax.set_ylabel('Activity (z-score)')
    ax.legend('_')
    ax.grid(False)

plt.tight_layout()
plt.show()