In [2]:
import os
import numpy as np
import pickle
from dotenv import dotenv_values
config = dotenv_values(".env")
data_path = os.path.normpath(config['DATA_PATH'])

# List of possible valid paths depending on machine
possible_mounts = [
    '/mnt/big_gulp/nc4_rat_data/Maze_Rats',
    '/media/nc4-gingerbeer/big_gulp/nc4_rat_data/Maze_Rats',
    os.path.expanduser('~/data/Maze_Rats'),  # local fallback
]

for path in possible_mounts:
    if os.path.exists(path):
        os.environ['DATA_PATH'] = path
        break
else:
    raise RuntimeError("Could not find a valid DATA_PATH location. Please update the paths.")

print(f"DATA_PATH is set to: {os.environ['DATA_PATH']}")

from helpers import *


DATA_PATH is set to: /mnt/big_gulp/nc4_rat_data/Maze_Rats


In [350]:
# Animal ID and session date
rat_id = 18
date = '250321'

# Always use the latest DATA_PATH from the environment
data_path = os.environ["DATA_PATH"]

# Build paths using consistent structure
rat_path = os.path.join(data_path, f"NC400{rat_id}", date)
processed_path = os.path.join(rat_path, "Processed")
session_data_path = os.path.join(processed_path, "session_data")
psytrack_data_path = os.path.join(processed_path, "psytrack_data")

# Print to confirm correctness
#print("processed_path =", processed_path)
#print("session_data_path =", session_data_path)
#print("psytrack_data_path =", psytrack_data_path)

# Create folders if they don’t exist
os.makedirs(session_data_path, exist_ok=True)
os.makedirs(psytrack_data_path, exist_ok=True)

print(f"Created or confirmed folders:\n- {session_data_path}\n- {psytrack_data_path}")

Created or confirmed folders:
- /mnt/big_gulp/nc4_rat_data/Maze_Rats/NC40018/250321/Processed/session_data
- /mnt/big_gulp/nc4_rat_data/Maze_Rats/NC40018/250321/Processed/psytrack_data


In [351]:
#Extract raw data from ROS
msg_list = get_log(rat_id, date)
msg_list.sort(key=lambda x: x[0])
msg_list

Using path: /mnt/big_gulp/nc4_rat_data/Maze_Rats
Using path: /mnt/big_gulp/nc4_rat_data/Maze_Rats
Looking in folder: /mnt/big_gulp/nc4_rat_data/Maze_Rats/NC40018
Looking in folder: /mnt/big_gulp/nc4_rat_data/Maze_Rats/NC40018/250321/Raw/ROS


[[1742598315623918163, 'Subscribing to /rosout_agg'],
 [1742598315623929303, 'Subscribing to /rosout'],
 [1742598315623930953, 'Subscribing to /natnet_ros/Gantry/pose'],
 [1742598315623932433, 'Subscribing to /natnet_ros/Gantry/marker0/pose'],
 [1742598315623935523, 'Subscribing to /natnet_ros/Gantry/marker1/pose'],
 [1742598315623936803, 'Subscribing to /natnet_ros/Gantry/marker2/pose'],
 [1742598315623938263, 'Subscribing to /natnet_ros/Harness/pose'],
 [1742598315623942053,
  "Recording to 'pseudorandomExperimentData_2025-03-21-16-05-15.bag'."],
 [1742598315623943493, 'Subscribing to /natnet_ros/Harness/marker0/pose'],
 [1742598315623944713, 'Subscribing to /natnet_ros/Harness/marker1/pose'],
 [1742598315623947503, 'Subscribing to /natnet_ros/Harness/marker2/pose'],
 [1742598315623948753, 'Subscribing to /natnet_ros/MazeBoundary/pose'],
 [1742598315623949983, 'Subscribing to /natnet_ros/MazeBoundary/marker0/pose'],
 [1742598315623952963, 'Subscribing to /natnet_ros/MazeBoundary/mark

In [352]:
# Filter messages
trial_start_msg_list = [msg for msg in msg_list if msg[1].startswith('Current trial number:')]
trial_start_msg_list

trial_end_msg_list = [msg for msg in msg_list if msg[1].startswith('INTER_TRIAL_INTERVAL')]
trial_end_msg_list

[[1742598409448085219, 'INTER_TRIAL_INTERVAL'],
 [1742598459456149634, 'INTER_TRIAL_INTERVAL'],
 [1742598501374694263, 'INTER_TRIAL_INTERVAL'],
 [1742598570886389725, 'INTER_TRIAL_INTERVAL'],
 [1742598618397696417, 'INTER_TRIAL_INTERVAL'],
 [1742598672290196998, 'INTER_TRIAL_INTERVAL'],
 [1742598730718060299, 'INTER_TRIAL_INTERVAL'],
 [1742598779056857735, 'INTER_TRIAL_INTERVAL'],
 [1742598835597843464, 'INTER_TRIAL_INTERVAL'],
 [1742598920541631717, 'INTER_TRIAL_INTERVAL'],
 [1742598956581139381, 'INTER_TRIAL_INTERVAL'],
 [1742599004190800077, 'INTER_TRIAL_INTERVAL'],
 [1742599045031269104, 'INTER_TRIAL_INTERVAL'],
 [1742599096768397142, 'INTER_TRIAL_INTERVAL'],
 [1742599137481576107, 'INTER_TRIAL_INTERVAL'],
 [1742599170930969946, 'INTER_TRIAL_INTERVAL'],
 [1742599199121122995, 'INTER_TRIAL_INTERVAL'],
 [1742599224541753495, 'INTER_TRIAL_INTERVAL'],
 [1742599254140928160, 'INTER_TRIAL_INTERVAL'],
 [1742599280725271148, 'INTER_TRIAL_INTERVAL'],
 [1742599314081687931, 'INTER_TRIAL_INTE

In [353]:
from collections import deque
import pprint

# Define exclusion conditions
exclusion_conditions = [
    {
        "choice_correctness": 0,
        "correct_choice": 1,
        "gantry_sequence": ["4", "3", "0", "3"]
    },
    {
        "choice_correctness": 0,
        "correct_choice": 2,
        "gantry_sequence": ["4", "3", "6", "3"]
    },
    {
        "choice_correctness": 1,
        "correct_choice": 2,
        "gantry_sequence": ["4", "3", "0", "3"]
    },
    {
        "choice_correctness": 1,
        "correct_choice": 1,
        "gantry_sequence": ["4", "3", "6", "3"]
    }
]

# Track used trial numbers and start times
used_trial_numbers = set()
used_trial_start_times = set()

trials = []
for msg in trial_end_msg_list:
    # Find the closest previous trial start message
    trial_start_msg_candidates = [msg_start for msg_start in trial_start_msg_list if msg_start[0] < msg[0] and msg_start[0] not in used_trial_start_times]
    if not trial_start_msg_candidates:
        continue  # Skip if no valid start message found

    trial_start_msg = trial_start_msg_candidates[-1]
    trial_start_time = trial_start_msg[0]
    trial_end_time = msg[0]

    used_trial_start_times.add(trial_start_time)

    # Extract trial number
    trial_number_msg = next((m for m in msg_list if m[1].startswith("Current trial number:") and trial_start_time <= m[0] < trial_end_time), None)
    if trial_number_msg:
        trial_number = int(trial_number_msg[1].split(": ")[1]) + 1  # shift trial numbers up by 1
        if trial_number in used_trial_numbers:
            continue  # Skip duplicates
        used_trial_numbers.add(trial_number)
    else:
        trial_number = None

    # Dynamically collect trial messages (+ small margin to catch outcome messages)
    trial_msgs = [m for m in msg_list if trial_start_time <= m[0] < trial_end_time + 3]

    ## Extract key trial information
    choice_correctness_msg = next((m for m in trial_msgs if "Choice is correct" in m[1] or "Choice is incorrect" in m[1]), None)
    correct_choice_msg = next((m for m in trial_msgs if "Correct stimulus-response choice is:" in m[1]), None)
    choice_y_msg = next((m for m in trial_msgs if "choice (y) is" in m[1].lower()), None)
    stimulus_msg = next((m for m in trial_msgs if "Selected stimulus is:" in m[1]), None)
    choice_msg = next((m for m in trial_msgs if m[1] == "CHOICE"), None)

    if not choice_msg:
        continue  # Skip if no CHOICE detected

    # Get index of CHOICE inside msg_list (not trial_msgs!)
    choice_msg_full = next((m for m in msg_list if m[0] == choice_msg[0]), None)
    if not choice_msg_full:
        continue

    choice_index = msg_list.index(choice_msg_full)

    # Extract gantry messages after CHOICE
    gantry_messages_after_choice = []
    for m in msg_list[choice_index:]:
        if "Move to chamber command received: chamber[" in m[1]:
            chamber = m[1].split("chamber[")[1].split("]")[0]
            gantry_messages_after_choice.append(chamber)
        if len(gantry_messages_after_choice) == 4:
            break

    # Parse values carefully
    choice_correctness = int(choice_correctness_msg[1].split(": ")[1]) if choice_correctness_msg else None
    correct_choice = int(correct_choice_msg[1].split(": ")[1]) if correct_choice_msg else None
    choice_y = int(choice_y_msg[1].split()[-1]) if choice_y_msg else None
    stimulus = int(stimulus_msg[1].split(": ")[1]) if stimulus_msg else None

    # Calculate response time (start_to_choice to success/error)
    # Calculate response time (handle START_TO_CHOICE or START_TO_CENTRAL to SUCCESS/ERROR)
    start_msg = next(
        (m for m in trial_msgs if "START_TO_CHOICE" in m[1] or "START_TO_CENTRAL" in m[1]),
        None
    )
    success_msg = next((m for m in trial_msgs if "SUCCESS" in m[1]), None)
    error_msg = next((m for m in trial_msgs if "ERROR" in m[1]), None)

    if start_msg and (success_msg or error_msg):
        response_end_time = success_msg[0] if success_msg else error_msg[0]
        response_time = (response_end_time - start_msg[0]) * 1e-9
    else:
        response_time = np.nan


    # Check exclusion conditions
    exclude = False
    for condition in exclusion_conditions:
        if (
            choice_correctness == condition["choice_correctness"] and
            correct_choice == condition["correct_choice"] and
            gantry_messages_after_choice == condition["gantry_sequence"]
        ):
            print(f"Skipping trial {trial_number} due to exclusion rule")
            exclude = True
            break

    if not exclude:
        trials.append({
            "trial_number": trial_number,
            "y": choice_y,
            "answer": correct_choice,
            "correct": choice_correctness,
            "stimulus": stimulus,
            "response_time": response_time,
        })

print(f"Processed {len(trials)} valid trials after filtering.")
import pandas as pd

# Convert to DataFrame
trials_df = pd.DataFrame(trials)
display(trials_df)



Skipping trial 37 due to exclusion rule
Skipping trial 39 due to exclusion rule
Skipping trial 49 due to exclusion rule
Skipping trial 50 due to exclusion rule
Processed 55 valid trials after filtering.


Unnamed: 0,trial_number,y,answer,correct,stimulus,response_time
0,1,1,1,1,1,4.813987
1,2,2,1,0,1,3.021001
2,3,1,1,1,1,4.358452
3,4,2,2,1,-1,3.240798
4,5,1,2,0,-1,11.214115
5,6,1,1,1,1,4.988071
6,7,2,2,1,-1,2.820511
7,8,1,1,1,1,2.837691
8,9,2,1,0,1,2.84066
9,10,1,1,1,1,13.123491


In [355]:
# Calculate percent correct
percent_correct_small = np.mean(trials_df["correct"][:40]) * 100
print(f"Percent correct for 40 trials: {percent_correct_small:.2f}%")

percent_correct_big = np.mean(trials_df["correct"]) * 100
print(f"Percent correct for whole session: {percent_correct_big:.2f}%")

# Calculate average response time
avg_response_time = np.nanmean(trials_df["response_time"])
print(f"Average response time: {avg_response_time:.3f} seconds")

Percent correct for 40 trials: 62.50%
Percent correct for whole session: 47.27%
Average response time: 4.104 seconds


In [356]:
# Convert extracted trial data into NumPy arrays
y = np.array([t["y"] for t in trials])
answer = np.array([t["answer"] for t in trials])
correct = np.array([t["correct"] for t in trials])

# Fix stimulus_values: force it to be flat, simple numbers
stimulus_values = np.array([t["stimulus"] for t in trials], dtype=float)  # force to float or int!

# Define history length (M) for inputs
M = 3  
N = len(trials)

inputs = {}

# Create stimulus matrix shape (N, 2) with current and 1-back
# stim_current = stimulus_values.reshape(N, 1)
# stim_prev = np.roll(stimulus_values, 1).reshape(N, 1)
# stim_prev[0] = 0  # pad first row with 0 since there's no previous trial

# stimulus = np.hstack([stim_current, stim_prev])  # shape (N, 2)


# Reshape into (N, 1) matrix
stimulus = stimulus_values.reshape(N, 1)

# Check rank
rank = np.linalg.matrix_rank(stimulus)
print(f"Stimulus shape: {stimulus.shape}, rank: {rank}")

# Fix if rank-deficient
if rank < 2:
    print("⚠️ Stimulus regressor is rank-deficient — dropping to single column.")
    stimulus = stimulus[:, [0]]  # keep only the current trial column

inputs['stimulus'] = stimulus

# stimH (N, N)
stimH = np.zeros((N, M))
for i in range(N):
    for j in range(M):
        if i - j >= 0:
            stimH[i, j] = stimulus_values[i - j]
inputs['stimH'] = stimH

# actionH (N, N)
actionH = np.zeros((N, M))
for i in range(N):
    for j in range(M):
        if i - j >= 0:
            actionH[i, j] = y[i - j]
inputs['actionH'] = actionH

# win-stay
win_stay = np.zeros((N, M))
for i in range(N):
    for j in range(M):
        if i - j >= 1 and correct[i - j] == 1:
            win_stay[i, j] = 1 if y[i] == y[i - j] else 0
inputs['actionXposRewardH'] = win_stay

# lose-switch
lose_switch = np.zeros((N, M))
for i in range(N):
    for j in range(M):
        if i - j >= 1 and correct[i - j] == 0:
            lose_switch[i, j] = 1 if y[i] != y[i - j] else 0
inputs['actionXnegRewardH'] = lose_switch

dayLength = np.array([len(trials)])

# Print array shapes to confirm correctness
print(f"y shape: {y.shape}, answer shape: {answer.shape}, correct shape: {correct.shape}, inputs['stimulus'] shape: {inputs['stimulus'].shape}")
print("y:", y)          # Should see values 1 or 2
print("answer:", answer[:5]) # Should see values 1 or 2
print("correct:", correct[:5]) # Should see values 0 or 1
print("inputs['stimulus']:", inputs["stimulus"][:5]) # Should see matrix of -1s or 1s
print("inputs['stimulus'][7]:", inputs["stimulus"][7]) # Should see stimulus hear on the 7th trial
#print("inputs['stimulus'][7, 1]:", inputs["stimulus"][7, 1]) # Should see stimulus hear on the 6th trial
print("inputs['stimulus'][6, 0]:", inputs["stimulus"][6, 0]) # Should see stimulus hear on the 6th trial
print("inputs['stimH']:", inputs["stimH"][:5]) # Should see matrix of -1s or 1s
print("inputs['actionH']:", inputs["actionH"][:5]) # Should see matrix of 1s or 2s
print("inputs['actionXposRewardH']:", inputs["actionXposRewardH"][:5]) # Should see matrix of 0s or 1s
print("inputs['actionXnegRewardH']:", inputs["actionXnegRewardH"][:5]) # Should see matrix of 0s or 1s

D = {
    'y': y,
    'inputs': {'stimulus': stimulus, 'stimH': stimH, 'actionH': actionH, 'actionXposRewardH': win_stay, 'actionXnegRewardH': lose_switch},
    'name': rat_id,
    'answer': answer,
    'correct': correct,
    'dayLength': dayLength
}


Stimulus shape: (55, 1), rank: 1
⚠️ Stimulus regressor is rank-deficient — dropping to single column.
y shape: (55,), answer shape: (55,), correct shape: (55,), inputs['stimulus'] shape: (55, 1)
y: [1 2 1 2 1 1 2 1 2 1 1 1 1 2 1 2 1 1 1 2 1 1 1 1 1 1 1 2 1 1 1 1 1 2 1 1 1
 1 2 1 1 1 1 2 1 1 1 2 1 1 1 2 1 2 1]
answer: [1 1 1 2 2]
correct: [1 0 1 1 0]
inputs['stimulus']: [[ 1.]
 [ 1.]
 [ 1.]
 [-1.]
 [-1.]]
inputs['stimulus'][7]: [1.]
inputs['stimulus'][6, 0]: -1.0
inputs['stimH']: [[ 1.  0.  0.]
 [ 1.  1.  0.]
 [ 1.  1.  1.]
 [-1.  1.  1.]
 [-1. -1.  1.]]
inputs['actionH']: [[1. 0. 0.]
 [2. 1. 0.]
 [1. 2. 1.]
 [2. 1. 2.]
 [1. 2. 1.]]
inputs['actionXposRewardH']: [[0. 0. 0.]
 [0. 0. 0.]
 [1. 0. 0.]
 [1. 0. 0.]
 [0. 0. 1.]]
inputs['actionXnegRewardH']: [[0. 0. 0.]
 [0. 0. 0.]
 [0. 1. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


In [357]:
# Save the session data as a CVS file to folder
trials_df.to_csv(os.path.join(session_data_path, f"{rat_id}_{date}_trials_df.csv"), index=False)

# Save data for PsyTrack as an array
np.savez(os.path.join(psytrack_data_path, f"{rat_id}_{date}_psytrack_data.npz"), D=D)

summary_path = os.path.join(processed_path, "summary_metrics.csv")
summary_entry = {
    "rat_id": rat_id,
    "date": date,
    "n_trials": len(trials_df),
    "percent_correct": np.mean(trials_df["correct"]) * 100,
    "avg_response_time": np.nanmean(trials_df["response_time"])
}

# Load existing summary file or create new one
if os.path.exists(summary_path):
    summary_df = pd.read_csv(summary_path)
    # Avoid duplicate entry
    if not ((summary_df["rat_id"] == rat_id) & (summary_df["date"] == date)).any():
        summary_df = pd.concat([summary_df, pd.DataFrame([summary_entry])], ignore_index=True)
        summary_df.to_csv(summary_path, index=False)
else:
    pd.DataFrame([summary_entry]).to_csv(summary_path, index=False)


In [100]:
## Construct trial list
#trial_list = []
#k=1
#for msg in trial_end_msg_list:
#    # Find the corresponding trial start message whose timestamp is smaller than the trial end message
#    trial_start_msg = [msg_start for msg_start in trial_start_msg_list if msg_start[0] < msg[0]][-1]
#    trial_list.append([k, trial_start_msg[0], msg[0]])
#    k += 1

#trial_list

# Get duration of each trial
#for trial in trial_list:
#    trial.append((trial[2] - trial[1])*1.0/1e9)

#trial_list

In [None]:
# Find messages within a trial that contains 'Selected stimulus is:'

#selected_stimulus_msg_list = [msg for msg in msg_list if msg[1].startswith('Selected stimulus is:')]

#for msg in selected_stimulus_msg_list:
    # Find the corresponding trial
    #trial = [trial for trial in trial_list if trial[1] < msg[0] and trial[2] > msg[0]][0]
    #print(f"Trial {trial[0]}: {msg[1]}")