In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib as mpl
import seaborn as sns
from PyQt5.QtCore.QByteArray import length
from scipy import stats
from figure_making import data_loader

In [None]:
animal_ids = ["SZ036", "SZ037", "SZ038", "SZ039", "SZ042", "SZ043"]
# animal_ids=["SZ036"]
master_df1 = data_loader.load_dataframes_for_animal_summary(animal_ids, 'DA_vs_lastR_df',
                                                            day_0='2023-11-30', hemisphere_qc=1,
                                                            file_format='parquet')

animal_ids = ["RK007", "RK008"]
master_df2 = data_loader.load_dataframes_for_animal_summary(animal_ids, 'DA_vs_lastR_df',
                                                            day_0='2025-06-17', hemisphere_qc=1,
                                                            file_format='parquet')
master_lastR_df = pd.concat([master_df1, master_df2], ignore_index=True)

In [None]:
master_lastR_df.head()

In [None]:
def adjust_da_signals(df, adjustment_map):
    df_adjusted = df.copy()
    for (animal, hemi), offset in adjustment_map.items():

        # Define the filter:
        # 1. Match Animal
        # 2. Match Hemisphere
        # 3. Match Block '0.8' ONLY
        mask = (
            (df_adjusted['animal'] == animal) &
            (df_adjusted['hemisphere'] == hemi) &
            (df_adjusted['block'] == '0.8')
        )
        df_adjusted.loc[mask, 'DA'] += offset

    return df_adjusted

In [None]:
adjustment_map = {
    ('RK007', 'left'): 0.15,
    ('RK008', 'left'): 0.1,
    ('SZ036', 'left'): 0.25,
    ('SZ036', 'right'): 0.4,
    ('SZ037', 'left'): 0.3,
    ('SZ037', 'right'): 0.1,
    ('SZ038', 'left'): 0.3,
    ('SZ038', 'right'): 0.25,
    ('SZ039', 'left'): 0.3,
    ('SZ042', 'left'): 0.05,
    ('SZ043', 'right'): 0.0
}
df_adjusted = adjust_da_signals(master_lastR_df, adjustment_map)

In [None]:
def plot_da_vs_rxi(df, bin_size=0.2):
    mpl.rcParams['figure.dpi'] = 300
    color_map = {
        '0.4': sns.color_palette('Set2')[0], # Greenish
        '0.8': sns.color_palette('Set2')[1]  # Orangeish
    }
    label_map = {
        '0.4': 'Low',
        '0.8': 'High'
    }

    # 1. Identify unique Animal + Hemisphere combinations
    groups = df[['animal', 'hemisphere']].drop_duplicates()

    if len(groups) == 0:
        print("No data found.")
        return

    # 2. Iterate through each group
    for _, row in groups.iterrows():
        animal = row['animal']
        hemi = row['hemisphere']

        # Filter for this animal/hemisphere
        subset_ah = df[(df['animal'] == animal) & (df['hemisphere'] == hemi)].copy()

        fig, ax = plt.subplots(figsize=(6, 5))

        for block_val, color in color_map.items():
            subset = subset_ah[subset_ah['block'] == block_val]
            if subset.empty:
                continue
            label = label_map[block_val]
            # --- A) Scatter Plot (Raw Data) ---
            ax.scatter(
                subset['RXI'],
                subset['DA'],
                color=color,
                alpha=0.2,
                s=10,
                label=label if False else None # Hide raw from legend to keep it clean
            )

            # --- B) Calculate Binned Statistics ---
            max_rxi = subset['RXI'].max()
            if pd.isna(max_rxi): continue

            bins = np.arange(0, max_rxi + bin_size, bin_size)
            subset['rxi_bin'] = pd.cut(subset['RXI'], bins=bins, include_lowest=True)

            stats = subset.groupby('rxi_bin', observed=True)['DA'].agg(['mean', 'sem']).reset_index()
            stats['bin_center'] = stats['rxi_bin'].apply(lambda x: x.mid)
            stats = stats.dropna(subset=['mean'])

            # --- C) Plot Mean Curve and Error Band ---
            # Line
            ax.plot(
                stats['bin_center'],
                stats['mean'],
                color=color,
                linewidth=2.5,
                label=label
            )

            # Error Band
            ax.fill_between(
                stats['bin_center'],
                stats['mean'] - stats['sem'],
                stats['mean'] + stats['sem'],
                color=color,
                alpha=0.3
            )

        # Formatting
        ax.set_title(f"DA vs RXI: {animal} - {hemi}")
        ax.set_xlabel("Reward-to-Exit Interval (s)")
        ax.set_ylabel("DA Response")
        ax.legend()
        ax.set_xlim(0, 8)

        plt.tight_layout()
        plt.show()

In [None]:
plot_da_vs_rxi(df_adjusted)

In [None]:
plot_da_vs_rxi(master_lastR_df)