In [None]:
from VR_Trajectory_analysis import *

In [None]:
directory = '/Users/apaula/ownCloud/MatrexVR1/20250217_GreenBlue_Geometry_Data/RunData'
df1 = get_combined_df(directory, trim_seconds=1.0)
directory = '/Users/apaula/ownCloud/MatrexVR1/20250121_ants_StartingPosition_Data/RunData'
df2 = get_combined_df(directory, trim_seconds=1.0)
directory = '/Users/apaula/ownCloud/MatrexVR1/20250121_ants_GreenBlue_Data/RunData'
df3 = get_combined_df(directory, trim_seconds=1.0)

In [None]:
df_combined = pd.concat([df1, df2, df3], ignore_index=True)

In [None]:
df_combined["FlyID"].nunique()


In [None]:
df = df_combined.copy()

In [None]:
df = add_trial_id_and_displacement(df)
df = add_trial_time(df)

In [None]:
df_stationary, df_normal, df_excessive, stationary_ids, normal_ids, excessive_ids = classify_trials_by_displacement(df[df['Scene']=='Choice_noBG'], min_disp=0, max_disp=500)

In [None]:
#plot_trajectories(df_normal, 'normal')

In [None]:
df_binary = df_normal[df_normal['ConfigFile'] == 'BinaryChoice11_constantSize_BlackCylinder_BlackCylinder.json'].copy()


In [None]:
def get_first_goal_reached(df_normal,
                           goals,
                           threshold=3.5):
    """
    Given a dataframe of trial data, determine the first goal reached 
    and the time at which it was reached for each UniqueTrialID.
    """
    
    def distance(p1, p2):
        return np.sqrt((p1[0] - p2[0])**2 + (p1[1] - p2[1])**2)
    
    results = []
    
    # Group by UniqueTrialID
    for trial_id, trial_data in df_normal.groupby('UniqueTrialID'):
        config = trial_data['ConfigFile'].iloc[0]
        
        # Sort by time
        trial_data = trial_data.sort_values(by='trial_time')
        
        first_reached = None
        reached_time = None
        
        for idx, row in trial_data.iterrows():
            participant_pos = (row['GameObjectPosX'], row['GameObjectPosZ'])
            
            # Check each goal
            for goal_name, goal_pos in goals:
                if distance(participant_pos, goal_pos) <= threshold:
                    first_reached = goal_name
                    reached_time = row['trial_time']
                    break
            
            if first_reached is not None:
                break
        
        results.append((trial_id, config, first_reached, reached_time))
    
    # Convert to DataFrame
    results_df = pd.DataFrame(results, columns=[
        'UniqueTrialID',
        'ConfigFile',
        'FirstReachedGoal',
        'GoalReachedTime'
    ])
    
    return results_df

In [None]:
# 3. Run first goal reached for triple-goal
left_goal = (-10.416, 59.088)
right_goal = (10.416, 59.088)


results_binary = get_first_goal_reached(
    df_binary,
    goals=[
        ('left', left_goal),
        ('right', right_goal)
    ],
    threshold=3.5
)


In [None]:
# 1. Keep only trials that actually reached a goal
valid_results = results_binary.dropna(subset=['FirstReachedGoal'])

# 2. Merge the cutoff times back into df
#    We merge on 'UniqueTrialID' to get each trial's GoalReachedTime.
df_merged = pd.merge(df, valid_results[['UniqueTrialID', 'GoalReachedTime', 'FirstReachedGoal']], on='UniqueTrialID', how='inner')

# 3. Filter df so that only rows with trial_time less than or equal to the goal time are kept
df_cut = df_merged[df_merged['trial_time'] <= df_merged['GoalReachedTime']]

In [None]:
# 4. Plot the trajectories and show how many trials reached each goal
fig, ax = plt.subplots(figsize=(8, 6))

# Plot each trial’s trajectory
for trial_id, trial_data in df_cut.groupby('UniqueTrialID'):
    ax.plot(trial_data['GameObjectPosX'], trial_data['GameObjectPosZ'], alpha=0.3)

# Basic plot settings
ax.set_ylim(20, 60)
ax.set_xlim(-20, 20)
ax.set_aspect('equal', adjustable='box')
ax.set_title("Trajectories: BinaryChoice11_constantSize_BlackCylinder_BlackCylinder.json")
ax.set_xlabel("X Position (cm)")
ax.set_ylabel("Z Position (cm)")
ax.grid(True)

# Count how many reached each of the three possible goals
left_count = (valid_results['FirstReachedGoal'] == 'left').sum()
right_count = (valid_results['FirstReachedGoal'] == 'right').sum()

total = left_count + right_count
if total > 0:
    left_ratio = left_count / total
    right_ratio = right_count / total
    ratio_text = (f"Left: {left_count} ({left_ratio:.2f}), "
                  f"Right: {right_count} ({right_ratio:.2f})")
else:
    ratio_text = "No goals reached"

# Add text box with ratio information
ax.text(0.05, 0.95, ratio_text, transform=ax.transAxes, va='top',
        bbox=dict(boxstyle="round", fc="w", ec="0.5"))

plt.tight_layout()
plt.show()

In [None]:
df_merged = df_cut.copy()

In [None]:
print(df_merged.columns)

In [None]:
# Go through each FlyID, plotting all trials for that FlyID
for fly_id, df_fly in df_merged.groupby('FlyID'):
    
    fig, ax = plt.subplots(figsize=(8, 6))

    # Identify unique days for color mapping
    # We'll convert 'Current Time' to just the date (YYYY-MM-DD)
    unique_days = df_fly['Current Time'].dt.date.unique()
    # Use the new recommended colormaps API in Matplotlib 3.7+
    # 1) Get the "tab10" colormap
    cmap = plt.colormaps.get_cmap('tab10')

    # 2) Resample to have exactly len(unique_days) discrete colors
    cmap = cmap.resampled(5)

    day2color = {day: cmap(i) for i, day in enumerate(unique_days)}
    
    # Plot each trial’s trajectory in the color for that trial's date
    for trial_id, trial_data in df_fly.groupby('UniqueTrialID'):
        # We'll assume all rows in this trial share the same day
        day_for_trial = trial_data['Current Time'].dt.date.iloc[0]
        color = day2color[day_for_trial]
        
        ax.plot(
            trial_data['GameObjectPosX'], 
            trial_data['GameObjectPosZ'], 
            alpha=0.8,
            color=color
        )

    # Count how many trials reached each goal (left/right)
    # so we only count each trial once rather than each row
    trial_goals = df_fly.drop_duplicates(subset='UniqueTrialID')[['UniqueTrialID', 'FirstReachedGoal']]
    left_count  = (trial_goals['FirstReachedGoal'] == 'left').sum()
    right_count = (trial_goals['FirstReachedGoal'] == 'right').sum()

    total = left_count + right_count
    if total > 0:
        left_ratio = left_count / total
        right_ratio = right_count / total
        ratio_text = (
            f"Left: {left_count} ({left_ratio:.2f}), "
            f"Right: {right_count} ({right_ratio:.2f})"
        )
    else:
        ratio_text = "No goals reached"

    # Basic axes settings
    ax.set_xlim(-20, 20)
    ax.set_ylim(20, 60)
    ax.set_aspect('equal', adjustable='box')
    ax.set_title(f"Trajectories for FlyID = {fly_id}")
    ax.set_xlabel("X Position (cm)")
    ax.set_ylabel("Z Position (cm)")
    ax.grid(True)

    # Add text box for ratio info
    ax.text(
        0.05, 0.95, ratio_text,
        transform=ax.transAxes,
        va='top',
        bbox=dict(boxstyle="round", fc="w", ec="0.5")
    )

    plt.tight_layout()
    plt.show()

'''    # Optional: Create a legend for day colors
    handles = [
        Line2D([], [], color=day2color[day], label=str(day))
        for day in unique_days
    ]
    ax.legend(handles=handles, title="Day", loc="best")'''



In [None]:
import pandas as pd
import matplotlib.pyplot as plt

# Assume df_cut is your DataFrame of interest.

# 1. Create a mirrored copy of df_cut
df_mirrored = df_cut.copy()

# 2. Modify the x-coordinate by multiplying by -1
df_mirrored['GameObjectPosX'] = -df_mirrored['GameObjectPosX']

# 3. (Optional) Modify the UniqueTrialID to mark these as mirrored duplicates
#    so you can tell them apart in groupby plots, etc.
df_mirrored['UniqueTrialID'] = df_mirrored['UniqueTrialID'].astype(str) + "_mirror"

# 4. Concatenate original and mirrored data
df_duplicated = pd.concat([df_cut, df_mirrored], ignore_index=True)

# 5. Now df_duplicated has both the original and mirrored trajectories.
#    Plot them together or separately as needed.

fig, ax = plt.subplots(figsize=(8,6))

for trial_id, trial_data in df_duplicated.groupby('UniqueTrialID'):
    ax.plot(trial_data['GameObjectPosX'], trial_data['GameObjectPosZ'], alpha=0.2)
    
ax.set_ylim(20, 60)
ax.set_xlim(-20, 20)
ax.set_aspect('equal', adjustable='box')
ax.set_title("Original + Mirrored Trajectories")
ax.set_xlabel("X Position (cm)")
ax.set_ylabel("Z Position (cm)")
ax.grid(True)

plt.tight_layout()
plt.show()


In [None]:
import matplotlib.pyplot as plt

# Assume df_duplicated is the DataFrame containing (X, Z) for both original and mirrored trajectories

fig, ax = plt.subplots(figsize=(6, 5))

# Plot a 2D histogram
# 'bins' can be a single integer or a tuple/list specifying bins for X and Z
# 'range' sets the min/max for X and Z. Adjust to suit your data.
h = ax.hist2d(
    df_duplicated['GameObjectPosX'], 
    df_duplicated['GameObjectPosZ'], 
    bins=(32, 32),                 # 40 bins along each dimension (adjust as needed)
    range=[[-20, 20], [20, 60]],   # X range: -20 to 20, Z range: 20 to 60
    cmap='viridis'
)

# Add a colorbar to show counts/density
fig.colorbar(h[3], ax=ax, label='Count')

ax.set_xlabel("X Position (cm)")
ax.set_ylabel("Z Position (cm)")
ax.set_title("2D Histogram of (X, Z) Positions (Original + Mirrored)")
plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt

# --------------------------------------------------------------------
# 1) Create or load your data
# Suppose 'df_duplicated' holds both original and mirrored data,
# with columns 'GameObjectPosX' (X) and 'GameObjectPosZ' (Z).

x_data = df_duplicated['GameObjectPosX'].values
z_data = df_duplicated['GameObjectPosZ'].values

# --------------------------------------------------------------------
# 2) Build a 2D histogram of size (n_zbins x n_xbins)
#    We'll use np.histogram2d(x, z) such that:
#       H[i, j] = count of points in x_bin_i, z_bin_j
#    By default: i -> X-bins, j -> Z-bins.
#    We'll flip that below for plotting so that Z is "vertical".

n_xbins = 64
n_zbins = 64

x_range = (-20, 20)
z_range = (20, 60)

H, xedges, zedges = np.histogram2d(
    x_data, 
    z_data, 
    bins=(n_xbins, n_zbins),
    range=[x_range, z_range]
)

# H.shape == (n_xbins, n_zbins)
#  - The first axis is X-bin index (rows in H)
#  - The second axis is Z-bin index (columns in H)

# --------------------------------------------------------------------
# 3) Transpose H so that:
#      H_plot[z_bin, x_bin]
#    Because we want H_plot to have shape (n_zbins, n_xbins),
#    so that z_bin is the row index and x_bin is the column index.
#    Then we can display with imshow or pcolormesh 
#    so that the vertical axis corresponds to Z.
H_plot = H.T  # shape: (n_zbins, n_xbins)

# --------------------------------------------------------------------
# 4) Per-row (per-Z-bin) min–max scaling
#    For each row z_idx, find the min and max in that row.
#    Then normalize so that row values go from 0 to 1.
for z_idx in range(H_plot.shape[0]):
    row_min = H_plot[z_idx, :].min()
    row_max = H_plot[z_idx, :].max()
    if row_max > row_min:
        H_plot[z_idx, :] = (H_plot[z_idx, :] - row_min) / (row_max - row_min)
    else:
        # If the entire row is zero (no data), do nothing or keep it zero
        pass

# After this loop, each row has values in [0,1].

# --------------------------------------------------------------------
# 5) Plot using imshow (or pcolormesh)
#    We'll use imshow for simplicity. We set origin='lower'
#    so that the first row of the array is at the bottom of the plot
#    (smallest Z). We'll also provide 'extent' so the axes match
#    the actual X and Z ranges.

fig, ax = plt.subplots(figsize=(8,6))

# imshow expects array shape (nrows, ncols), which we have as (n_zbins, n_xbins)
img = ax.imshow(
    H_plot, 
    origin='lower', 
    aspect='auto',
    cmap='viridis',
    extent=[xedges[0], xedges[-1], zedges[0], zedges[-1]]
)

ax.set_xlabel("X Position (cm)")
ax.set_ylabel("Z Position (cm)")
ax.set_title("2D Per-Z-bin Normalized Heat Map")

# The colorbar now indicates the normalized range [0,1] 
# with 1.0 meaning "highest count in that specific row (Z-bin)."
cbar = fig.colorbar(img, ax=ax, label="Normalized count per Z-bin")

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter

# --------------------------------------------------------------------
# 1) Create or load your data
# Suppose 'df_duplicated' holds both original and mirrored data,
# with columns 'GameObjectPosX' (X) and 'GameObjectPosZ' (Z).

x_data = df_duplicated['GameObjectPosX'].values
z_data = df_duplicated['GameObjectPosZ'].values

# --------------------------------------------------------------------
# 2) Build a 2D histogram of size (n_zbins x n_xbins)
n_xbins = 64
n_zbins = 64

x_range = (-20, 20)
z_range = (20, 60)

H, xedges, zedges = np.histogram2d(
    x_data, 
    z_data, 
    bins=(n_xbins, n_zbins),
    range=[x_range, z_range]
)

# --------------------------------------------------------------------
# 3) Transpose H for plotting so that Z is the vertical axis
H_plot = H.T  # shape: (n_zbins, n_xbins)

# --------------------------------------------------------------------
# 4) Per-row (per-Z-bin) min–max scaling
for z_idx in range(H_plot.shape[0]):
    row_min = H_plot[z_idx, :].min()
    row_max = H_plot[z_idx, :].max()
    if row_max > row_min:
        H_plot[z_idx, :] = (H_plot[z_idx, :] - row_min) / (row_max - row_min)

# --------------------------------------------------------------------
# 5) Apply a Gaussian blur
#    The 'sigma' parameter controls how blurry the result is.
#    You can experiment with different sigma values (e.g. 0.5, 1, 2, ...)
#    The mode='nearest' helps with edges by replicating the edge values.
H_blurred = gaussian_filter(H_plot, sigma=1.0, mode='nearest')

# --------------------------------------------------------------------
# 6) Plot using imshow
fig, ax = plt.subplots(figsize=(8,6))

img = ax.imshow(
    H_blurred,      # <-- Display the blurred data
    origin='lower', 
    aspect='auto',
    cmap='viridis',
    extent=[xedges[0], xedges[-1], zedges[0], zedges[-1]]
)

ax.set_xlabel("X Position (cm)")
ax.set_ylabel("Z Position (cm)")
ax.set_title("2D Per-Z-bin Normalized Heat Map (Blurred)")

cbar = fig.colorbar(img, ax=ax, label="Normalized count per Z-bin (blurred)")

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter

# --- Assume you already have this from earlier code ---
# xedges, zedges, H_blurred (64x64 blurred histogram)

# --- Step 1: Compute mode (X-peak) for each Z-bin ---
z_centers = 0.5 * (zedges[:-1] + zedges[1:])
x_centers = 0.5 * (xedges[:-1] + xedges[1:])

x_peakline = []

for z_idx in range(H_blurred.shape[0]):
    row = H_blurred[z_idx, :]
    
    # Only consider positive X values
    positive_mask = x_centers >= 0
    row_positive = row[positive_mask]
    x_positive = x_centers[positive_mask]
    
    if row_positive.sum() == 0:
        x_peakline.append(np.nan)
    else:
        max_idx = np.argmax(row_positive)
        x_peakline.append(x_positive[max_idx])

x_peakline = np.array(x_peakline)

# --- Step 2: Fit piecewise linear model with variable baseline a ---
# Filter to z <= 57
z_mask = z_centers <= 57
z_fit = z_centers[z_mask]
x_fit = x_peakline[z_mask]

# Candidate search space
z_range_fit = (45, 55)
a_range = (-1.0, 3.0)

z_candidates = np.linspace(*z_range_fit, num=100)
a_candidates = np.linspace(*a_range, num=100)

z_target = 59.088
x_target = 10.416

best_zc = None
best_a = None
best_model = None
best_error = np.inf

for z_c in z_candidates:
    for a in a_candidates:
        model = np.zeros_like(z_fit)
        for i, z in enumerate(z_fit):
            if z >= z_c:
                slope = (x_target - a) / (z_target - z_c)
                model[i] = a + slope * (z - z_c)
            else:
                model[i] = a
        
        valid = ~np.isnan(x_fit)
        error = np.mean((x_fit[valid] - model[valid])**2)
        
        if error < best_error:
            best_error = error
            best_zc = z_c
            best_a = a
            best_model = model

# --- Step 3: Rebuild model across full z_centers ---
full_model = np.zeros_like(z_centers)
for i, z in enumerate(z_centers):
    if z >= best_zc:
        slope = (x_target - best_a) / (z_target - best_zc)
        full_model[i] = best_a + slope * (z - best_zc)
    else:
        full_model[i] = best_a

# --- Step 4: Plot on original heatmap ---
fig, ax = plt.subplots(figsize=(8,6))

img = ax.imshow(
    H_blurred,
    origin='lower',
    aspect='auto',
    cmap='viridis',
    extent=[xedges[0], xedges[-1], zedges[0], zedges[-1]]
)

# Overlay the fitted bifurcation line
ax.plot(full_model, z_centers, color='red', linewidth=2,
        label=f"Fitted Bifurcation\nzc={best_zc:.2f}, a={best_a:.2f}")

ax.set_xlabel("X Position (cm)")
ax.set_ylabel("Z Position (cm)")
ax.set_title("Heatmap with Fitted Bifurcation Line (Mode-based)")
ax.legend()

cbar = fig.colorbar(img, ax=ax, label="Normalized count per Z-bin (blurred)")

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.ndimage import gaussian_filter

# --- Assume you already have: xedges, zedges, H_blurred ---

# Step 1: Compute mode-based X centerline
z_centers = 0.5 * (zedges[:-1] + zedges[1:])
x_centers = 0.5 * (xedges[:-1] + xedges[1:])

x_peakline = []

for z_idx in range(H_blurred.shape[0]):
    row = H_blurred[z_idx, :]
    
    positive_mask = x_centers >= 0
    row_positive = row[positive_mask]
    x_positive = x_centers[positive_mask]
    
    if row_positive.sum() == 0:
        x_peakline.append(np.nan)
    else:
        max_idx = np.argmax(row_positive)
        x_peakline.append(x_positive[max_idx])

x_peakline = np.array(x_peakline)

# Step 2: Fit z_c (with a = 0 fixed)
z_mask = z_centers <= 57
z_fit = z_centers[z_mask]
x_fit = x_peakline[z_mask]

z_range_fit = (45, 55)
z_candidates = np.linspace(*z_range_fit, num=100)

z_target = 59.088
x_target = 10.416

best_zc = None
best_model = None
best_error = np.inf

for z_c in z_candidates:
    model = np.zeros_like(z_fit)
    slope = x_target / (z_target - z_c)
    for i, z in enumerate(z_fit):
        if z >= z_c:
            model[i] = slope * (z - z_c)
        else:
            model[i] = 0
    
    valid = ~np.isnan(x_fit)
    error = np.mean((x_fit[valid] - model[valid])**2)
    
    if error < best_error:
        best_error = error
        best_zc = z_c
        best_model = model

# Step 3: Rebuild full model
full_model = np.zeros_like(z_centers)
slope = x_target / (z_target - best_zc)
for i, z in enumerate(z_centers):
    if z >= best_zc:
        full_model[i] = slope * (z - best_zc)
    else:
        full_model[i] = 0
# After computing slope
angle_rad = np.arctan(slope)
angle_deg = np.degrees(angle_rad)
bifurcation_angle = 2 * angle_deg

# Plot with bifurcation angle in legend
fig, ax = plt.subplots(figsize=(8,6))

img = ax.imshow(
    H_blurred,
    origin='lower',
    aspect='auto',
    cmap='viridis',
    extent=[xedges[0], xedges[-1], zedges[0], zedges[-1]]
)

ax.plot(full_model, z_centers, color='red', linewidth=2,
        label=f"Bifurcation Fit\nzc={best_zc:.2f}, angle={bifurcation_angle:.2f}°")

ax.set_xlabel("X Position (cm)")
ax.set_ylabel("Z Position (cm)")
ax.set_title("Heatmap with Fitted Bifurcation Line (a = 0)")
ax.legend()

fig.colorbar(img, ax=ax, label="Normalized count per Z-bin (blurred)")
plt.tight_layout()
plt.show()



In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.patches import Arc

# --- Bifurcation params ---
slope = x_target / (z_target - best_zc)
angle_rad = np.arctan(slope)
angle_deg = np.degrees(angle_rad)
bifurcation_angle = 2 * angle_deg

# --- Rebuild full right and left models ---
x_right = np.zeros_like(z_centers)
x_left = np.zeros_like(z_centers)

for i, z in enumerate(z_centers):
    if z >= best_zc:
        x_right[i] = slope * (z - best_zc)
        x_left[i] = -slope * (z - best_zc)
    else:
        x_right[i] = 0
        x_left[i] = 0

# --- Plot heatmap and bifurcation arms ---
fig, ax = plt.subplots(figsize=(8, 6))

img = ax.imshow(
    H_blurred,
    origin='lower',
    aspect='auto',
    cmap='viridis',
    extent=[xedges[0], xedges[-1], zedges[0], zedges[-1]]
)

# Overlay left and right branches
ax.plot(x_right, z_centers, color='red', linewidth=2, label=f"Bifurcation Arms\nzc={best_zc:.2f}, angle={bifurcation_angle:.2f}°")
ax.plot(x_left, z_centers, color='red', linewidth=2)

# --- Add angle arc ---
arc_radius = 5  # visual size
arc_center = (0, best_zc)
arc = Arc(
    arc_center,
    width=2 * arc_radius, height=2 * arc_radius,
    angle=0,
    theta1=90 - angle_deg,
    theta2=90 + angle_deg,
    color='white',
    lw=2
)
ax.add_patch(arc)

# Optional: annotate the angle
ax.text(
    0, best_zc + arc_radius + 0.5,
    f"{bifurcation_angle:.1f}°",
    color='white',
    ha='center', va='bottom', fontsize=10, fontweight='bold'
)

# --- Finalize plot ---
ax.set_xlabel("X Position (cm)")
ax.set_ylabel("Z Position (cm)")
ax.set_title("Bifurcation Heatmap with Fitted Arms and Angle")
ax.legend()

fig.colorbar(img, ax=ax, label="Normalized count per Z-bin (blurred)")
plt.tight_layout()
plt.show()


In [None]:
import numpy as np

# Given values
x_target = 10.416
z_target = 59.088
z_c = 49.44

dx = x_target
dz = z_target - z_c

# Compute angle between vertical and branch
theta_rad = np.arccos(dz / np.sqrt(dx**2 + dz**2))
theta_deg = np.degrees(theta_rad)

print(f"Angle between pre- and post-bifurcation: {theta_deg:.2f} degrees")


In [None]:
# --- Filter to z <= 57 ---
z_mask = z_centers <= 56
z_fit = z_centers[z_mask]
x_fit = x_centerline[z_mask]

# --- Grid of candidate values ---
z_range_fit = (45, 55)
a_range = (-0.0, 5.0)  # Try shifting baseline up/down

z_candidates = np.linspace(*z_range_fit, num=100)
a_candidates = np.linspace(*a_range, num=100)

z_target = 59.088
x_target = 10.416

best_zc = None
best_a = None
best_model = None
best_error = np.inf

for z_c in z_candidates:
    for a in a_candidates:
        model = np.zeros_like(z_fit)
        for i, z in enumerate(z_fit):
            if z >= z_c:
                slope = (x_target - a) / (z_target - z_c)
                model[i] = a + slope * (z - z_c)
            else:
                model[i] = a
        
        valid = ~np.isnan(x_fit)
        error = np.mean((x_fit[valid] - model[valid])**2)
        
        if error < best_error:
            best_error = error
            best_zc = z_c
            best_a = a
            best_model = model

# --- Plot ---
fig, ax = plt.subplots(figsize=(8,6))
ax.plot(z_fit, x_fit, label="Observed X-center", color='blue')
ax.plot(z_fit, best_model, label=f"Best Fit\nzc={best_zc:.2f}, a={best_a:.2f}", color='red', linestyle='--')

ax.set_xlabel("Z Position (cm)")
ax.set_ylabel("X Center (cm)")
ax.set_title("Piecewise Fit with Flexible Baseline")
ax.legend()
plt.grid(True)
plt.tight_layout()
plt.show()


In [None]:
# --- Reconstruct best-fit model across full z_centers ---
full_model = np.zeros_like(z_centers)
for i, z in enumerate(z_centers):
    if z >= best_zc:
        slope = (x_target - best_a) / (z_target - best_zc)
        full_model[i] = best_a + slope * (z - best_zc)
    else:
        full_model[i] = best_a

# --- Plot on heatmap ---
fig, ax = plt.subplots(figsize=(8,6))

img = ax.imshow(
    H_blurred,
    origin='lower',
    aspect='auto',
    cmap='viridis',
    extent=[xedges[0], xedges[-1], zedges[0], zedges[-1]]
)

# Overlay the fitted line
ax.plot(full_model, z_centers, color='red', linewidth=2, label=f"Fitted Bifurcation\nzc={best_zc:.2f}, a={best_a:.2f}")

ax.set_xlabel("X Position (cm)")
ax.set_ylabel("Z Position (cm)")
ax.set_title("Heatmap with Fitted Bifurcation Overlay")
ax.legend()

cbar = fig.colorbar(img, ax=ax, label="Normalized count per Z-bin (blurred)")

plt.tight_layout()
plt.show()
