In [1]:
import pandas as pd
import numpy as np
from joblib import Parallel, delayed
from contextlib import contextmanager
import joblib
from tqdm import tqdm
import os
import pathlib
tqdm.pandas()

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import matplotlib.animation as animation
from matplotlib.animation import PillowWriter
from matplotlib.patches import Ellipse, Circle
import seaborn as sns

pd.set_option('display.max_rows', 500)

from IPython.display import HTML

In [2]:
# Predefined lookup and type dictionary
grid = (11,11)
reward_place_to_coord = {
    "u": ((grid[0] // 2, grid[1] - 1),),
    "r": ((grid[0] - 1, grid[1] // 2),),
    "d": ((grid[0] // 2, 0),),
    "l": ((0, grid[1] // 2),),
    "ur": ((grid[0] // 2, grid[1] - 1), (grid[0] - 1, grid[1] // 2)),
    "rd": ((grid[0] - 1, grid[1] // 2), (grid[0] // 2, 0)),
    "dl": ((grid[0] // 2, 0), (0, grid[1] // 2)),
    "ul": ((0, grid[1] // 2), (grid[0] // 2, grid[1] - 1)),
    "ud": ((grid[0] // 2, grid[1] - 1), (grid[0] // 2, 0)),
    "rl": ((grid[0] - 1, grid[1] // 2), (0, grid[1] // 2)),
    # Note: Replace this with the updated one
}        # Mapping from reward identifiers to board coordinates.

reverse_reward_place_to_coord = {
    (0, grid[1] // 2): "l",
    (grid[0] - 1, grid[1] // 2): "r",
    (grid[0] // 2, 0): "d",
    (grid[0] // 2, grid[1] - 1): "u",
}

stats_dtype = {
    "trial_id": "int32",
    "activated": "bool",
    "collected": "bool",
    "reward_loc": "object",
    "activated_by": "object",
    "activated_frame": "object",
    "first_close_to_zone": "object",
    "first_close_to_zone_frame": "object",
    "first_to_zone": "object",
    "first_to_zone_frame": "object",
    "reward_counter": "float64"
}

# Our custom context manager to hook joblib with tqdm
@contextmanager
def tqdm_joblib(tqdm_object):
    """Context manager to patch joblib to report into tqdm progress bar."""
    class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)
        def __call__(self, *args, **kwargs):
            tqdm_object.update(n=self.batch_size)
            return super().__call__(*args, **kwargs)
    old_callback = joblib.parallel.BatchCompletionCallBack
    joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
    try:
        yield tqdm_object
    finally:
        joblib.parallel.BatchCompletionCallBack = old_callback
        tqdm_object.close()

In [3]:
# Function to compute trial stats in parallel
def compute_trial_stats(trial_df):
    """
    Given a DataFrame corresponding to a single trial,
    compute aggregated trial-level statistics.
    """
    trial_length = len(trial_df)
    
    reward_loc = trial_df.iloc[-1]['reward_loc']
    reward_counter = trial_df['r1'].sum()
    collected = trial_df.iloc[-1]['collected']
    trial_start = int(trial_df.iloc[0]['frame_idx'])
    activated = False
    
    wz_cancel = False
    
    if (collected == False) & (trial_df.iloc[-1]["steps_without_reward"] < 51):
        wz_cancel = True
    
    collected_zone = None
    # Find the zone where the reward was collected
    if collected:
        row = trial_df.iloc[-1]
        collected_zone = reverse_reward_place_to_coord.get((row['a1x'], row['a1y']), None)
    
    # Leader #1: Activated info (first frame where activated == True)
    activated_by = None
    activated_frame = None
    act_rows = trial_df[trial_df['activated']]
    if not act_rows.empty:
        activated = True
        row_act = act_rows.iloc[0]
        activated_frame = int(row_act["frame_idx"]) - trial_start  # Adjust for trial start
        # Check conditions for the activated leader
        if (row_act['a1x'] == row_act['a1y']) and (row_act['a2x'] == row_act['a2y']) and (row_act['a1x'] == 5):
            activated_by = "tie"
        elif (row_act['a1x'] == row_act['a1y']) and (row_act['a1x'] == 5):
            activated_by = "a1"
        elif (row_act['a2x'] == row_act['a2y']) and (row_act['a2x'] == 5):
            activated_by = "a2"
    
    # Leader #2: Who gets within 2 Euclidean distance of any reward zone coordinate first?
    first_close_to_zone = None
    first_close_to_zone_frame = None
    # Only evaluate if reward_loc is set (note: it might be a string 'None' or an actual None)
    if reward_loc not in ['None', None]:
        # Convert reward_loc string into tuple using ast.literal_eval
        reward_coords = reward_place_to_coord.get(reward_loc, ())
        # Iterate over frames in order
        for i, row in trial_df.iterrows():
            for coords in reward_coords:
                a1_dist = np.linalg.norm(np.array([row['a1x'], row['a1y']]) - np.array(coords))
                a2_dist = np.linalg.norm(np.array([row['a2x'], row['a2y']]) - np.array(coords))
                closeness_thresh = 2
                if a1_dist <= closeness_thresh and a2_dist <= closeness_thresh:
                    first_close_to_zone = "tie"
                    first_close_to_zone_frame = i
                    break
                elif a1_dist <= closeness_thresh:
                    first_close_to_zone = "a1"
                    first_close_to_zone_frame = i
                    break
                elif a2_dist <= closeness_thresh:
                    first_close_to_zone = "a2"
                    first_close_to_zone_frame = i
                    break
            if first_close_to_zone is not None:
                break

    # Leader #3: Who reaches the reward zone exactly first?
    first_to_zone = None
    first_to_zone_frame = None
    if reward_loc not in ['None', None]:
        reward_coords = reward_place_to_coord.get(reward_loc, ())
        for i, row in trial_df.iterrows():
            for coords in reward_coords:
                if (row['a1x'], row['a1y']) == (row['a2x'], row['a2y']) == coords:
                    first_to_zone = "tie"
                    first_to_zone_frame = i
                    break
                elif (row['a1x'], row['a1y']) == coords:
                    first_to_zone = "a1"
                    first_to_zone_frame = i
                    break
                elif (row['a2x'], row['a2y']) == coords:
                    first_to_zone = "a2"
                    first_to_zone_frame = i
                    break
            if first_to_zone is not None:
                break

    return pd.Series({
        "trial_length": trial_length,
        "activated": activated,
        "collected": collected,
        "wz_cancel": wz_cancel,
        "collected_zone": collected_zone,
        "reward_loc": reward_loc,
        "activated_by": activated_by,
        "activated_frame": activated_frame,
        "first_close_to_zone": first_close_to_zone,
        "first_close_to_zone_frame": first_close_to_zone_frame,
        "first_to_zone": first_to_zone,
        "first_to_zone_frame": first_to_zone_frame,
        "reward_counter": reward_counter
    })

In [4]:
# Folded statistics computation
def compute_folded_stats(df, trail_function, n_folds=20):
    gids = df.groupby("trial_id", group_keys=False).indices
    chunk_idxs = np.array_split(list(gids.keys()), n_folds)
    start_ends = [(gids[chunk[0]][0], gids[chunk[-1]][-1]) for chunk in chunk_idxs]

    def process_chunk(start, end):
        sdf = df.loc[start:end]
        return (
            sdf.groupby("trial_id", group_keys=False)
               .apply(trail_function, include_groups=False)
               .reset_index()
        )

    with tqdm_joblib(tqdm(total=len(start_ends), desc="Processing chunks")) as progress_bar:
        stats = pd.concat(
            Parallel(n_jobs=n_folds, backend="loky", prefer="processes")(
                delayed(process_chunk)(start, end) for start, end in start_ends
            )
        )
    return stats.reset_index(drop=True)

In [5]:
# get every folder in store_w
base= "../store_bk"
trains = sorted(os.listdir(base))

for train in trains[:]:
    store = f"{base}/{train}/logs/"
    export = f"{base}/{train}/trial_stats/"
    
    # Create the export folder if it doesn't exist
    # Delete the existing export folder if it exists
    if os.path.exists(export):
        import shutil
        shutil.rmtree(export)
    # Create the export directory
    pathlib.Path(export).mkdir(parents=True, exist_ok=True)
    
    files = os.listdir(store)
    for file in files:
        tdf = pd.read_parquet(f'{store}{file}')
        tdf["trial_id"] = tdf.terminated.cumsum().shift(fill_value=0)
        trial_stats = compute_folded_stats(tdf, compute_trial_stats, 50)
        trial_stats.to_parquet(f'{export}{file}')

Processing chunks: 100%|██████████| 50/50 [00:02<00:00, 23.19it/s]
Processing chunks: 100%|██████████| 50/50 [00:02<00:00, 18.81it/s]
Processing chunks: 100%|██████████| 50/50 [00:02<00:00, 18.31it/s]
Processing chunks: 100%|██████████| 50/50 [00:12<00:00,  3.97it/s]
Processing chunks: 100%|██████████| 50/50 [00:01<00:00, 43.79it/s]
Processing chunks: 100%|██████████| 50/50 [00:02<00:00, 19.54it/s]
Processing chunks: 100%|██████████| 50/50 [00:02<00:00, 17.86it/s]
Processing chunks: 100%|██████████| 50/50 [00:13<00:00,  3.80it/s]
Processing chunks: 100%|██████████| 50/50 [00:01<00:00, 42.57it/s]
Processing chunks: 100%|██████████| 50/50 [00:02<00:00, 19.63it/s]
Processing chunks: 100%|██████████| 50/50 [00:02<00:00, 18.17it/s]
Processing chunks: 100%|██████████| 50/50 [00:12<00:00,  3.94it/s]
Processing chunks: 100%|██████████| 50/50 [00:01<00:00, 41.27it/s]
Processing chunks: 100%|██████████| 50/50 [00:02<00:00, 19.59it/s]
Processing chunks: 100%|██████████| 50/50 [00:03<00:00, 16.38i

In [6]:
# # get every folder in store_w
# base= "../store_bk"
# trains = sorted(os.listdir(base))

# for train in trains[:]:
#     store = f"{base}/{train}/logs/"
#     export = f"{base}/{train}/trial_stats/"
    
#     # Create the export folder if it doesn't exist
#     pathlib.Path(export).mkdir(parents=True, exist_ok=True)
    
#     file = 'testing_1.parquet'
#     tdf = pd.read_parquet(f'{store}{file}')
#     tdf["trial_id"] = tdf.terminated.cumsum().shift(fill_value=0)
#     trial_stats = compute_folded_stats(tdf, compute_trial_stats, 50)
#     trial_stats.to_parquet(f'{export}{file}')