In [1]:
import os
import numpy as np
import pickle
from dotenv import dotenv_values
config = dotenv_values(".env")
data_path = os.path.normpath(config['DATA_PATH'])
os.environ['DATA_PATH'] = data_path
print(f"DATA_PATH is set to: {os.environ['DATA_PATH']}")

from helpers import *


DATA_PATH is set to: /mnt/big_gulp/nc4_rat_data/Maze_Rats


In [6]:
# Animal ID and session date
rat_id = 19
date = '250310'

rat_path = os.path.join(data_path, f"NC400{rat_id}", date)
processed_path = os.path.join(rat_path, "Processed")
session_data_path = os.path.join(processed_path, "session_data")
psytrack_data_path = os.path.join(processed_path, "psytrack_data")

os.makedirs(session_data_path, exist_ok=True)
os.makedirs(psytrack_data_path, exist_ok=True)

print(f"Created or confirmed folders:\n- {session_data_path}\n- {psytrack_data_path}")


Created or confirmed folders:
- /mnt/big_gulp/nc4_rat_data/Maze_Rats/NC40019/250310/Processed/session_data
- /mnt/big_gulp/nc4_rat_data/Maze_Rats/NC40019/250310/Processed/psytrack_data


In [None]:
#Extract raw data from ROS
msg_list = get_log(rat_id, date)
msg_list.sort(key=lambda x: x[0])
msg_list

In [4]:
# Filter messages
trial_start_msg_list = [msg for msg in msg_list if msg[1].startswith('Current trial number:')]
trial_start_msg_list

trial_end_msg_list = [msg for msg in msg_list if msg[1].startswith('INTER_TRIAL_INTERVAL')]
trial_end_msg_list



[[1741655944806720345, 'INTER_TRIAL_INTERVAL'],
 [1741656027126281166, 'INTER_TRIAL_INTERVAL'],
 [1741656095566304801, 'INTER_TRIAL_INTERVAL'],
 [1741656139955961842, 'INTER_TRIAL_INTERVAL'],
 [1741656200385988140, 'INTER_TRIAL_INTERVAL'],
 [1741656236955894697, 'INTER_TRIAL_INTERVAL'],
 [1741656269026574683, 'INTER_TRIAL_INTERVAL'],
 [1741656305956162137, 'INTER_TRIAL_INTERVAL'],
 [1741656354066268524, 'INTER_TRIAL_INTERVAL'],
 [1741656391866220345, 'INTER_TRIAL_INTERVAL'],
 [1741656430426535506, 'INTER_TRIAL_INTERVAL'],
 [1741656466416393934, 'INTER_TRIAL_INTERVAL'],
 [1741656523336471547, 'INTER_TRIAL_INTERVAL'],
 [1741656556456829387, 'INTER_TRIAL_INTERVAL'],
 [1741656615466425468, 'INTER_TRIAL_INTERVAL'],
 [1741656657116292945, 'INTER_TRIAL_INTERVAL'],
 [1741656709686194659, 'INTER_TRIAL_INTERVAL'],
 [1741656740756568388, 'INTER_TRIAL_INTERVAL'],
 [1741656768066076947, 'INTER_TRIAL_INTERVAL'],
 [1741656799066648056, 'INTER_TRIAL_INTERVAL'],
 [1741656836156557546, 'INTER_TRIAL_INTE

In [30]:
from collections import deque
import pprint

# Define exclusion conditions
exclusion_conditions = [
    {
        "choice_correctness": 0,
        "correct_choice": 1,
        "gantry_sequence": ["4", "3", "0", "3"]
    },
    {
        "choice_correctness": 0,
        "correct_choice": 2,
        "gantry_sequence": ["4", "3", "6", "3"]
    },
    {
        "choice_correctness": 1,
        "correct_choice": 2,
        "gantry_sequence": ["4", "3", "0", "3"]
    },
    {
        "choice_correctness": 1,
        "correct_choice": 1,
        "gantry_sequence": ["4", "3", "6", "3"]
    }
]

# Track used trial numbers and start times
used_trial_numbers = set()
used_trial_start_times = set()

trials = []
for msg in trial_end_msg_list:
    # Find the closest previous trial start message
    trial_start_msg_candidates = [msg_start for msg_start in trial_start_msg_list if msg_start[0] < msg[0] and msg_start[0] not in used_trial_start_times]
    if not trial_start_msg_candidates:
        continue  # Skip if no valid start message found

    trial_start_msg = trial_start_msg_candidates[-1]
    trial_start_time = trial_start_msg[0]
    trial_end_time = msg[0]

    used_trial_start_times.add(trial_start_time)

    # Extract trial number
    trial_number_msg = next((m for m in msg_list if m[1].startswith("Current trial number:") and trial_start_time <= m[0] < trial_end_time), None)
    if trial_number_msg:
        trial_number = int(trial_number_msg[1].split(": ")[1]) + 1  # shift trial numbers up by 1
        if trial_number in used_trial_numbers:
            continue  # Skip duplicates
        used_trial_numbers.add(trial_number)
    else:
        trial_number = None

    # Dynamically collect trial messages (+ small margin to catch outcome messages)
    trial_msgs = [m for m in msg_list if trial_start_time <= m[0] < trial_end_time + 3]

    ## Extract key trial information
    choice_correctness_msg = next((m for m in trial_msgs if "Choice is correct" in m[1] or "Choice is incorrect" in m[1]), None)
    correct_choice_msg = next((m for m in trial_msgs if "Correct stimulus-response choice is:" in m[1]), None)
    choice_y_msg = next((m for m in trial_msgs if "choice (y) is" in m[1].lower()), None)
    stimulus_msg = next((m for m in trial_msgs if "Selected stimulus is:" in m[1]), None)
    choice_msg = next((m for m in trial_msgs if m[1] == "CHOICE"), None)

    if not choice_msg:
        continue  # Skip if no CHOICE detected

    # Get index of CHOICE inside msg_list (not trial_msgs!)
    choice_msg_full = next((m for m in msg_list if m[0] == choice_msg[0]), None)
    if not choice_msg_full:
        continue

    choice_index = msg_list.index(choice_msg_full)

    # Extract gantry messages after CHOICE
    gantry_messages_after_choice = []
    for m in msg_list[choice_index:]:
        if "Move to chamber command received: chamber[" in m[1]:
            chamber = m[1].split("chamber[")[1].split("]")[0]
            gantry_messages_after_choice.append(chamber)
        if len(gantry_messages_after_choice) == 4:
            break

    # Parse values carefully
    choice_correctness = int(choice_correctness_msg[1].split(": ")[1]) if choice_correctness_msg else None
    correct_choice = int(correct_choice_msg[1].split(": ")[1]) if correct_choice_msg else None
    choice_y = int(choice_y_msg[1].split()[-1]) if choice_y_msg else None
    stimulus = int(stimulus_msg[1].split(": ")[1]) if stimulus_msg else None

    # Calculate response time (start_to_choice to success/error)
    start_to_choice_msg = next((m for m in trial_msgs if "START_TO_CHOICE" in m[1]), None)
    success_msg = next((m for m in trial_msgs if "SUCCESS" in m[1]), None)
    error_msg = next((m for m in trial_msgs if "ERROR" in m[1]), None)

    if start_to_choice_msg and (success_msg or error_msg):
        response_end_time = success_msg[0] if success_msg else error_msg[0]
        response_time = (response_end_time - start_to_choice_msg[0]) * 1e-9
    else:
        response_time = np.nan

    # Check exclusion conditions
    exclude = False
    for condition in exclusion_conditions:
        if (
            choice_correctness == condition["choice_correctness"] and
            correct_choice == condition["correct_choice"] and
            gantry_messages_after_choice == condition["gantry_sequence"]
        ):
            print(f"Skipping trial {trial_number} due to exclusion rule")
            exclude = True
            break

    if not exclude:
        trials.append({
            "trial_number": trial_number,
            "y": choice_y,
            "answer": correct_choice,
            "correct": choice_correctness,
            "stimulus": stimulus,
            "response_time": response_time,
        })

print(f"Processed {len(trials)} valid trials after filtering.")
import pandas as pd

# Convert to DataFrame
trials_df = pd.DataFrame(trials)
display(trials_df)



Skipping trial 1 due to exclusion rule


Processed 59 valid trials after filtering.


Unnamed: 0,trial_number,y,answer,correct,stimulus,response_time
0,2,2,2,1,-1,3.849216
1,3,1,1,1,1,4.10979
2,4,1,2,0,-1,11.370144
3,5,2,1,0,1,5.249832
4,6,1,1,1,1,3.710375
5,7,1,1,1,1,2.979828
6,8,1,2,0,-1,3.27998
7,9,2,2,1,-1,3.000678
8,10,1,2,0,-1,3.280313
9,11,1,1,1,1,4.730345


In [31]:
# Calculate percent correct
percent_correct = np.mean(trials_df["correct"]) * 100
print(f"Percent correct: {percent_correct:.2f}%")

# Calculate average response time
avg_response_time = np.nanmean(trials_df["response_time"])
print(f"Average response time: {avg_response_time:.3f} seconds")

Percent correct: 45.76%
Average response time: 5.216 seconds


In [None]:
# Convert extracted trial data into NumPy arrays
y = np.array([t["y"] for t in trials])
answer = np.array([t["answer"] for t in trials])
correct = np.array([t["correct"] for t in trials])

# Fix stimulus_values: force it to be flat, simple numbers
stimulus_values = np.array([t["stimulus"] for t in trials], dtype=float)  # force to float or int!

# Optional: If you want ints instead of floats:
stimulus_values = stimulus_values.astype(int)

# Define history length for inputs
history_length = 3  

# Construct stimulus history matrix (N, M)
stimulus_history = np.zeros((len(stimulus_values), history_length))

for i in range(len(stimulus_values)):
    for j in range(history_length):
        if i - j >= 0:
            stimulus_history[i, j] = stimulus_values[i - j]  #  Now guaranteed to be a number

inputs = {"s1": stimulus_history}

dayLength = np.array([len(trials)])

# Print array shapes to confirm correctness
print(f"y shape: {y.shape}, answer shape: {answer.shape}, correct shape: {correct.shape}, inputs['s1'] shape: {inputs['s1'].shape}")
print("y:", y[:5])          # Should see values 1 or 2
print("answer:", answer[:5]) # Should see values 1 or 2
print("correct:", correct[:5]) # Should see values 0 or 1
print("inputs['s1']:", inputs["s1"][:5]) # Should see matrix of -1s or 1s

D = {
    "y": y,
    "inputs": {"s1": stimulus_history},
    "name": rat_id,
    "answer": answer,
    "correct": correct,
    "dayLength": dayLength
}

y shape: (59,), answer shape: (59,), correct shape: (59,), inputs['s1'] shape: (59, 3)
y: [2 1 1 2 1]
answer: [2 1 2 1 1]
correct: [1 1 0 0 1]
inputs['s1']: [[-1.  0.  0.]
 [ 1. -1.  0.]
 [-1.  1. -1.]
 [ 1. -1.  1.]
 [ 1.  1. -1.]]


In [None]:
# Save the session data as a CVS file to folder
trials_df.to_csv(os.path.join(session_data_path, f"{rat_id}_{date}_trials_df.csv"), index=False)

# Save data for PsyTrack as an array
np.savez(os.path.join(psytrack_data_path, f"{rat_id}_{date}_psytrack_data.npz"), D=D)

summary_path = os.path.join(processed_path, "summary_metrics.csv")
summary_entry = {
    "rat_id": rat_id,
    "date": date,
    "n_trials": len(trials_df),
    "percent_correct": np.mean(trials_df["correct"]) * 100,
    "avg_response_time": np.nanmean(trials_df["response_time"])
}

# Load existing summary file or create new one
if os.path.exists(summary_path):
    summary_df = pd.read_csv(summary_path)
    # Avoid duplicate entry
    if not ((summary_df["rat_id"] == rat_id) & (summary_df["date"] == date)).any():
        summary_df = pd.concat([summary_df, pd.DataFrame([summary_entry])], ignore_index=True)
        summary_df.to_csv(summary_path, index=False)
else:
    pd.DataFrame([summary_entry]).to_csv(summary_path, index=False)


In [None]:
## Construct trial list
#trial_list = []
#k=1
#for msg in trial_end_msg_list:
#    # Find the corresponding trial start message whose timestamp is smaller than the trial end message
#    trial_start_msg = [msg_start for msg_start in trial_start_msg_list if msg_start[0] < msg[0]][-1]
#    trial_list.append([k, trial_start_msg[0], msg[0]])
#    k += 1

#trial_list

# Get duration of each trial
#for trial in trial_list:
#    trial.append((trial[2] - trial[1])*1.0/1e9)

#trial_list

[[1, 1741658921977109052, 1741659010747037609, 88.769928557],
 [2, 1741659011767274436, 1741659069036904557, 57.269630121],
 [3, 1741659070057380277, 1741659099717657155, 29.660276878],
 [4, 1741659100737323817, 1741659297291432715, 196.554108898],
 [5, 1741659298301973246, 1741659373703859042, 75.401885796],
 [6, 1741659374714045791, 1741659432073978910, 57.359933119],
 [7, 1741659433084077337, 1741659513784160187, 80.70008285],
 [8, 1741659514794734644, 1741659547224713164, 32.42997852],
 [9, 1741659548244279456, 1741659624674059358, 76.429779902],
 [10, 1741659625684626321, 1741659670564337701, 44.87971138],
 [11, 1741659671575174981, 1741659717273904850, 45.698729869],
 [12, 1741659718285033096, 1741659781944051968, 63.659018872],
 [13, 1741659782964346810, 1741659821254493702, 38.290146892],
 [14, 1741659822264860810, 1741659873114470059, 50.849609249],
 [15, 1741659874124979901, 1741659927994375861, 53.86939596],
 [16, 1741659929004112113, 1741659964874826663, 35.87071455],
 [17,

In [None]:
# Find messages within a trial that contains 'Selected stimulus is:'

#selected_stimulus_msg_list = [msg for msg in msg_list if msg[1].startswith('Selected stimulus is:')]

#for msg in selected_stimulus_msg_list:
    # Find the corresponding trial
    #trial = [trial for trial in trial_list if trial[1] < msg[0] and trial[2] > msg[0]][0]
    #print(f"Trial {trial[0]}: {msg[1]}")