In [None]:
import pandas as pd

import seaborn as sns

import cv2
import matplotlib.pyplot as plt
import numpy as np
import sys
import math
from pathlib import Path

from utils_behavior import Processing

In [None]:
Umap = pd.read_feather(
    "/mnt/upramdya_data/MD/BallPushing_Learning/Datasets/250326_StdContacts_Ctrl_300frames_Data/UMAP/250313_pooled_standardized_contacts_Allfeatures.feather"
)

In [None]:
# Define and create if not exist output directory

output_dir = Path("/mnt/upramdya_data/MD/BallPushing_Learning/Plots/UMAP/250327_StdContacts_Ctrl_300frames_Data")

output_dir.mkdir(parents=True, exist_ok=True)

In [None]:
# Scatter plot of UMAP1 vs UMAP2

sns.scatterplot(data=Umap, x="UMAP1", y="UMAP2", size=0.01, alpha = 0.1)

# save the plot

plt.savefig(f"{output_dir}/UMAP_Main.png")

In [None]:
# Same with kde

sns.kdeplot(data=Umap, x="UMAP1", y="UMAP2", fill=True)

# save the plot

plt.savefig(f"{output_dir}/UMAP_Main_kde.png")

# By trial

In [None]:
# Subset to keep only events that have a trial number

Umap.columns

In [None]:
Umap["trial"].unique()

In [None]:
# Only keep events that are in the first 4 trials in a subset

Umap_subset = Umap[Umap["trial"] < 5]

Umap_subset = Umap_subset[Umap_subset["trial"] > 0]

In [None]:
# Keep all

Umap_subset = Umap[Umap["trial"] > 0]

Umap_subset = Umap_subset[Umap_subset["trial"] < 5]

In [None]:
sns.scatterplot(data=Umap_subset, x="UMAP1", y="UMAP2", size=0.01, alpha = 0.1, hue="trial", palette="viridis")

In [None]:
# Create a FacetGrid with two columns, one for each event_type
g = sns.FacetGrid(Umap_subset, col="trial", hue="trial", sharex=True, sharey=True, height=6, aspect=1.2, col_wrap=3, palette="viridis")

# Map the kdeplot to the FacetGrid
g.map(sns.kdeplot, "UMAP1", "UMAP2", fill=True, alpha=0.5)

# Add titles and adjust layout
g.set_titles(col_template="{col_name}")
g.add_legend()
plt.subplots_adjust(top=0.85)
g.fig.suptitle("Density Plots of UMAP by trial number")

# Save the plot
plt.savefig(f"{output_dir}/UMAP_trials_density_full.png")
plt.show()
plt.close()

In [None]:
# Compute contact number per trial

Umap["trial"].value_counts()

# Clustering

In [None]:
# Identify clusters based on KMeans

from sklearn.cluster import KMeans

kmeans = KMeans(n_clusters=8, random_state=0).fit(Umap_subset[["UMAP1", "UMAP2"]])

Umap_subset["cluster"] = kmeans.labels_

In [None]:
plt.figure(figsize=(8, 6))

sns.scatterplot(data=Umap_subset, x="UMAP1", y="UMAP2", alpha = 0.1, hue="cluster", palette="tab20")

# Save the plot

#plt.savefig(f"{output_dir}/UMAP_Clusters_full.png")

In [None]:
# Calculate cluster proportions over time by event type
cluster_proportions = (
    Umap_subset
    .groupby(['trial', 'cluster'])
    .size()
    .reset_index(name='count')
)

# Calculate totals for proportions
totals = (
    cluster_proportions
    .groupby(['trial'])['count']
    .sum()
    .reset_index(name='total')
)

# Merge and calculate proportions
cluster_proportions = cluster_proportions.merge(totals, on=['trial'])
cluster_proportions['proportion'] = cluster_proportions['count'] / cluster_proportions['total']

# Plot
plt.figure(figsize=(20, 10))

# Plot a line for each cluster
for cluster in sorted(Umap_subset["cluster"].unique()):
    cluster_data = cluster_proportions[cluster_proportions["cluster"] == cluster]
    if not cluster_data.empty:
        plt.plot(
            cluster_data["trial"],
            cluster_data["proportion"],
            marker="o",
            linewidth=2,
            label=f"Cluster {cluster}",
        )

plt.title("Cluster Evolution Over trials", fontsize=14)
plt.xlabel("trial", fontsize=12)
plt.ylabel("Proportion of Events", fontsize=12)
plt.grid(True, alpha=0.3)
plt.legend(title="Cluster", bbox_to_anchor=(1.05, 1), loc="upper left")

plt.tight_layout()

# Save the plot
plt.savefig(f"{output_dir}/cluster_evolution_line_plots.png")
plt.show()

In [None]:
# Keep only flies that have atleast 2 trials

Data_2_trials = Umap_subset.groupby("fly").filter(
    lambda x: x["trial"].nunique() >= 4
)

In [None]:
# Get all unique clusters from your dataset
clusters = sorted(Data_2_trials["cluster"].unique())

# Count occurrences for each cluster and trial combination
counts = Data_2_trials.groupby(["cluster", "trial"]).size().reset_index(name="count")

# Calculate total events for each trial
totals = Data_2_trials.groupby(["trial"]).size().reset_index(name="total")

# Merge totals with counts to calculate proportions
counts = counts.merge(totals, on=['trial'], how='left')
counts['proportion'] = counts['count'] / counts['total'] * 100  # Convert to percentage

# Create the figure and axes grid - one subplot per cluster
fig, axes = plt.subplots(len(clusters), 1, 
                         figsize=(6, 3 * len(clusters)),
                         sharex=True, sharey=True)  # Share both axes

# Ensure axes is always an array
if len(clusters) == 1:
    axes = np.array([axes])

# Get all unique trials
all_trials = sorted(Data_2_trials["trial"].unique())

# Plot each cluster
for i, cluster in enumerate(clusters):
    ax = axes[i]
    
    # Get data for this cluster
    cluster_data = counts[counts['cluster'] == cluster]
    
    # Sort by trial for proper plotting
    cluster_data = cluster_data.sort_values('trial')
    
    # Plot if there's data
    if len(cluster_data) > 0:
        # Use line plot with markers to show evolution, using the proportion
        ax.plot(cluster_data['trial'], cluster_data['proportion'], 
                marker='o', markersize=8, 
                linestyle='-', linewidth=2.5, 
                color='blue', label=f"Cluster {cluster}")
        
        # Option: Add text labels showing the raw counts for context
        for idx, row in cluster_data.iterrows():
            ax.annotate(f"{int(row['count'])}", 
                        (row['trial'], row['proportion']),
                        textcoords="offset points", 
                        xytext=(0, 5), 
                        ha='center',
                        fontsize=8)
    
    # Set labels and title for subplot
    ax.set_title(f"Cluster {cluster}", fontsize=14, fontweight='bold')
    ax.set_ylabel("Percentage (%)", fontsize=12)
    if i == len(clusters) - 1:
        ax.set_xlabel("Trial", fontsize=12)
        
    # Add grid for readability
    ax.grid(True, linestyle='--', alpha=0.7)
    
    # Set y-axis limits to 0-30%
    ax.set_ylim(0, 35)  
    
    # Set x-axis to include all trials
    ax.set_xticks(all_trials)

# Adjust layout
plt.tight_layout()
fig.subplots_adjust(top=0.95)
fig.suptitle('Proportion of Clusters Across Trials (% of Total Events)', fontsize=18)

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Get all unique clusters from your dataset
clusters = sorted(Data_2_trials["cluster"].unique())

# Count occurrences for each cluster and trial combination
counts = Data_2_trials.groupby(["cluster", "trial"]).size().reset_index(name="count")

# Calculate total events for each trial
totals = Data_2_trials.groupby(["trial"]).size().reset_index(name="total")

# Merge totals with counts to calculate proportions
counts = counts.merge(totals, on=["trial"], how="left")
counts["proportion"] = counts["count"] / counts["total"] * 100  # Convert to percentage

# Define the number of columns for the grid
n_cols = 3
n_rows = int(np.ceil(len(clusters) / n_cols))

# Create the figure and axes grid
fig, axes = plt.subplots(
    n_rows,
    n_cols,
    figsize=(6, 3 * n_rows),
    sharex=True,
    sharey=True,
)

# Flatten the axes array for easier indexing
axes = axes.flatten()

# Get all unique trials
all_trials = sorted(Data_2_trials["trial"].unique())

# Plot each cluster
for i, cluster in enumerate(clusters):
    ax = axes[i]

    # Get data for this cluster
    cluster_data = counts[counts["cluster"] == cluster]

    # Sort by trial for proper plotting
    cluster_data = cluster_data.sort_values("trial")

    # Plot if there's data
    if len(cluster_data) > 0:
        # Use line plot with markers to show evolution, using the proportion
        ax.plot(
            cluster_data["trial"],
            cluster_data["proportion"],
            marker="o",
            markersize=8,
            linestyle="-",
            linewidth=2.5,
            color="blue",
            label=f"Cluster {cluster}",
        )

        # Option: Add text labels showing the raw counts for context
        for idx, row in cluster_data.iterrows():
            ax.annotate(
                f"{int(row['count'])}",
                (row["trial"], row["proportion"]),
                textcoords="offset points",
                xytext=(0, 5),
                ha="center",
                fontsize=8,
            )

    # Set labels and title for subplot
    ax.set_title(f"Cluster {cluster}", fontsize=14, fontweight="bold")
    ax.set_ylabel("Percentage (%)", fontsize=12)
    ax.set_xlabel("Trial", fontsize=12)

    # Add grid for readability
    ax.grid(True, linestyle="--", alpha=0.7)

    # Set y-axis limits to 0-35%
    ax.set_ylim(0, 35)

    # Set x-axis to include all trials
    ax.set_xticks(all_trials)

# Hide unused subplots
for j in range(len(clusters), len(axes)):
    axes[j].axis("off")

# Adjust layout
plt.tight_layout()
fig.subplots_adjust(top=0.90)
fig.suptitle("Proportion of Clusters Across Trials (% of Total Events)", fontsize=18)

plt.show()

In [None]:
# Load the source data

interaction_data = pd.read_feather(
    "/mnt/upramdya_data/MD/BallPushing_Learning/Datasets/250326_StdContacts_Ctrl_300frames_Data/standardized_contacts/250326_pooled_standardized_contacts.feather"
)
interaction_data.head()
interaction_data.columns


In [None]:

# Make a unique identifier as combination of fly, event_type and event_id

Umap_subset["unique_id"] = Umap_subset["fly"].astype(str) + "_" + Umap["event_id"].astype(str)

interaction_data["unique_id"] = interaction_data["fly"].astype(str) + "_" + interaction_data["event_id"].astype(str)


In [None]:

# Looping over all clusters
# Configuration parameters
MAX_CELL_WIDTH = 96   # Maximum width for grid cells
MAX_CELL_HEIGHT = 516  # Maximum height for grid cells
MAX_OUTPUT_WIDTH = 3840
MAX_OUTPUT_HEIGHT = 2160
FPS = 5
CODEC = "mp4v"
OUTPUT_DIR = output_dir

def resize_with_padding(frame, target_w, target_h):
    """Resize frame while maintaining aspect ratio with padding"""
    h, w = frame.shape[:2]
    scale = min(target_w/w, target_h/h)
    new_w = int(w * scale)
    new_h = int(h * scale)
    
    resized = cv2.resize(frame, (new_w, new_h))
    pad_w = target_w - new_w
    pad_h = target_h - new_h
    
    # Add equal padding on both sides
    top = pad_h // 2
    bottom = pad_h - top
    left = pad_w // 2
    right = pad_w - left
    
    return cv2.copyMakeBorder(resized, top, bottom, left, right, 
                            cv2.BORDER_CONSTANT, value=(0,0,0))

def process_cluster(cluster_id, Umap, interaction_data):
    cluster_data = Umap[Umap["cluster"] == cluster_id]
    cluster_interactions = interaction_data[interaction_data["unique_id"].isin(cluster_data["unique_id"])]
    
    # Calculate frame ranges for each unique_id
    frame_ranges = (cluster_interactions
                    .groupby('unique_id')['frame']
                    .agg(frame_start=('min'), frame_end=('max'))
                    .reset_index())

    # Merge with path information
    event_metadata = (cluster_interactions[['unique_id', 'flypath']]
                      .drop_duplicates()
                      .merge(frame_ranges, on='unique_id'))

    # Calculate grid layout based on max output dimensions
    cols = MAX_OUTPUT_WIDTH // MAX_CELL_HEIGHT  # Note the swapped dimensions
    rows = MAX_OUTPUT_HEIGHT // MAX_CELL_WIDTH  # Note the swapped dimensions
    max_events = cols * rows

    # Sample events if needed
    if len(event_metadata) > max_events:
        event_metadata = event_metadata.sample(max_events, random_state=42)

    # Initialize frame storage and video metadata
    frames_dict = {}
    max_duration = 0
    valid_events = 0

    # Process videos in optimized groups
    for flypath, group in event_metadata.groupby('flypath'):
        video_files = list(Path(flypath).glob("*.mp4"))
        video_file = next((vf for vf in video_files if "_preprocessed" not in vf.stem), None)
        
        if not video_file:
            print(f"Skipping {flypath} - no suitable MP4 found")
            continue

        cap = cv2.VideoCapture(str(video_file))
        if not cap.isOpened():
            print(f"Couldn't open {video_file}")
            continue

        # Process all events from this video
        for _, event in group.iterrows():
            try:
                start = int(event['frame_start'])
                end = int(event['frame_end'])
                if start > end:
                    print(f"Invalid frames for {event['unique_id']}")
                    continue
                    
                # Read event frames with boundary checks
                cap.set(cv2.CAP_PROP_POS_FRAMES, start)
                frames = []
                for _ in range(end - start + 1):
                    ret, frame = cap.read()
                    if not ret:
                        break
                    # Rotate frame 90° clockwise
                    frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)
                    # Convert color space and resize with padding
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    frame = resize_with_padding(frame, MAX_CELL_HEIGHT, MAX_CELL_WIDTH)  # Note the swapped dimensions
                    frames.append(frame)
                
                if frames:
                    frames_dict[event['unique_id']] = frames
                    max_duration = max(max_duration, len(frames))
                    valid_events += 1

            except Exception as e:
                print(f"Error processing {event['unique_id']}: {str(e)}")
        
        cap.release()

    # Early exit if no valid events
    if valid_events == 0:
        raise ValueError(f"No processable events found for cluster {cluster_id}")

    # Pad all clips to max duration with black frames
    for uid in frames_dict:
        frames = frames_dict[uid]
        if len(frames) < max_duration:
            padding = [np.zeros((MAX_CELL_WIDTH, MAX_CELL_HEIGHT, 3), dtype=np.uint8)] * (max_duration - len(frames))  # Note the swapped dimensions
            frames_dict[uid] = frames + padding

    # Final output dimensions
    output_size = (cols * MAX_CELL_HEIGHT, rows * MAX_CELL_WIDTH)  # Note the swapped dimensions

    # Initialize video writer
    fourcc = cv2.VideoWriter_fourcc(*CODEC)
    output_path = Path(OUTPUT_DIR) / f"cluster_{cluster_id}_video.mp4"
    out = cv2.VideoWriter(str(output_path), fourcc, FPS, output_size)

    # Generate grid frames
    for frame_idx in range(max_duration):
        grid = np.zeros((output_size[1], output_size[0], 3), dtype=np.uint8)
        
        for idx, (uid, frames) in enumerate(frames_dict.items()):
            if frame_idx >= len(frames):
                continue
                
            row = idx // cols
            col = idx % cols
            
            # Calculate position
            x = col * MAX_CELL_HEIGHT  # Note the swapped dimensions
            y = row * MAX_CELL_WIDTH  # Note the swapped dimensions
            
            # Place frame in grid cell
            grid[y:y+MAX_CELL_WIDTH, x:x+MAX_CELL_HEIGHT] = frames[frame_idx]  # Note the swapped dimensions

        out.write(cv2.cvtColor(grid, cv2.COLOR_RGB2BGR))

    out.release()
    print(f"Successfully created grid video for cluster {cluster_id} at {output_path}")


In [None]:
def process_cluster(cluster_id, Umap, interaction_data, best_disp=False):
    cluster_data = Umap[Umap["cluster"] == cluster_id]
    cluster_interactions = interaction_data[interaction_data["unique_id"].isin(cluster_data["unique_id"])]
    
    # Calculate frame ranges for each unique_id
    frame_ranges = (cluster_interactions
                    .groupby('unique_id')['frame']
                    .agg(frame_start=('min'), frame_end=('max'))
                    .reset_index())

    # Merge with path information
    event_metadata = (cluster_interactions[['unique_id', 'flypath']]
                      .drop_duplicates()
                      .merge(frame_ranges, on='unique_id'))

    # Add raw_displacement from Umap to event_metadata
    event_metadata = event_metadata.merge(
        Umap[['unique_id', 'raw_displacement']],
        on='unique_id',
        how='left'
    )

    # Calculate grid layout based on max output dimensions
    cols = MAX_OUTPUT_WIDTH // MAX_CELL_HEIGHT  # Note the swapped dimensions
    rows = MAX_OUTPUT_HEIGHT // MAX_CELL_WIDTH  # Note the swapped dimensions
    max_events = cols * rows

    # Select events based on the best_disp argument
    if len(event_metadata) > max_events:
        if best_disp:
            # Sort by raw_displacement in descending order and pick the top events
            event_metadata = event_metadata.sort_values(by='raw_displacement', ascending=False).head(max_events)
        else:
            # Randomly sample events
            event_metadata = event_metadata.sample(max_events, random_state=42)

    # Initialize frame storage and video metadata
    frames_dict = {}
    max_duration = 0
    valid_events = 0

    # Process videos in optimized groups
    for flypath, group in event_metadata.groupby('flypath'):
        video_files = list(Path(flypath).glob("*.mp4"))
        video_file = next((vf for vf in video_files if "_preprocessed" not in vf.stem), None)
        
        if not video_file:
            print(f"Skipping {flypath} - no suitable MP4 found")
            continue

        cap = cv2.VideoCapture(str(video_file))
        if not cap.isOpened():
            print(f"Couldn't open {video_file}")
            continue

        # Process all events from this video
        for _, event in group.iterrows():
            try:
                start = int(event['frame_start'])
                end = int(event['frame_end'])
                if start > end:
                    print(f"Invalid frames for {event['unique_id']}")
                    continue
                    
                # Read event frames with boundary checks
                cap.set(cv2.CAP_PROP_POS_FRAMES, start)
                frames = []
                for _ in range(end - start + 1):
                    ret, frame = cap.read()
                    if not ret:
                        break
                    # Rotate frame 90° clockwise
                    frame = cv2.rotate(frame, cv2.ROTATE_90_CLOCKWISE)
                    # Convert color space and resize with padding
                    frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
                    frame = resize_with_padding(frame, MAX_CELL_HEIGHT, MAX_CELL_WIDTH)  # Note the swapped dimensions
                    frames.append(frame)
                
                if frames:
                    frames_dict[event['unique_id']] = frames
                    max_duration = max(max_duration, len(frames))
                    valid_events += 1

            except Exception as e:
                print(f"Error processing {event['unique_id']}: {str(e)}")
        
        cap.release()

    # Early exit if no valid events
    if valid_events == 0:
        raise ValueError(f"No processable events found for cluster {cluster_id}")

    # Pad all clips to max duration with black frames
    for uid in frames_dict:
        frames = frames_dict[uid]
        if len(frames) < max_duration:
            padding = [np.zeros((MAX_CELL_WIDTH, MAX_CELL_HEIGHT, 3), dtype=np.uint8)] * (max_duration - len(frames))  # Note the swapped dimensions
            frames_dict[uid] = frames + padding

    # Final output dimensions
    output_size = (cols * MAX_CELL_HEIGHT, rows * MAX_CELL_WIDTH)  # Note the swapped dimensions

    # Initialize video writer
    fourcc = cv2.VideoWriter_fourcc(*CODEC)
    if best_disp:
        output_path = Path(OUTPUT_DIR) / f"cluster_{cluster_id}_video_best_disp.mp4"
    else:
        output_path = Path(OUTPUT_DIR) / f"cluster_{cluster_id}_video.mp4"
    out = cv2.VideoWriter(str(output_path), fourcc, FPS, output_size)

    # Generate grid frames
    for frame_idx in range(max_duration):
        grid = np.zeros((output_size[1], output_size[0], 3), dtype=np.uint8)
        
        for idx, (uid, frames) in enumerate(frames_dict.items()):
            if frame_idx >= len(frames):
                continue
                
            row = idx // cols
            col = idx % cols
            
            # Calculate position
            x = col * MAX_CELL_HEIGHT  # Note the swapped dimensions
            y = row * MAX_CELL_WIDTH  # Note the swapped dimensions
            
            # Place frame in grid cell
            grid[y:y+MAX_CELL_WIDTH, x:x+MAX_CELL_HEIGHT] = frames[frame_idx]  # Note the swapped dimensions

        out.write(cv2.cvtColor(grid, cv2.COLOR_RGB2BGR))

    out.release()
    print(f"Successfully created grid video for cluster {cluster_id} at {output_path}")

In [None]:

# Get unique clusters
unique_clusters = Umap_subset["cluster"].unique()

# Create output directory if it doesn't exist
#Path(OUTPUT_DIR).mkdir(parents=True, exist_ok=True)

# Process each cluster
for cluster_id in unique_clusters:
    try:
        process_cluster(cluster_id, Umap_subset, interaction_data, best_disp=True)
    except Exception as e:
        print(f"Error processing cluster {cluster_id}: {str(e)}")



## Checking the ball efficiency per cluster

In [None]:
# Let's rank the clusters by the median efficiency of interactions

Subset_positive = Umap_subset[Umap_subset["raw_displacement"] > 0]

# Rank the clusters by median efficiency

cluster_ranking = Subset_positive.groupby("cluster")["raw_displacement"].median().sort_values(ascending=False)

# Get the order of the clusters

cluster_order = cluster_ranking.index

In [None]:
# Log version

Umap_subset["log_displacement"] = np.log1p(Umap_subset["raw_displacement"])

# Make a boxplot of the log_displacement by cluster

plt.figure(figsize=(12, 6))

sns.boxplot(data=Umap_subset, x="cluster", y="log_displacement", color="skyblue")

plt.title("Efficiency of Interactions by Cluster", fontsize=16)

plt.xlabel("Cluster Number", fontsize=14)

plt.ylabel("Log(1 + Raw Displacement)", fontsize=14)

plt.xticks(range(len(cluster_order)), cluster_order, fontsize=12)

plt.grid(axis='y', linestyle='--', alpha=0.7)

plt.tight_layout()

In [None]:
# Let's compute bootstrapped confidence intervals for average efficiency of interactions per cluster

clusters_bs_ci = {}

for cluster in Umap_subset["cluster"].unique():
    cluster_data = Umap_subset[Umap_subset["cluster"] == cluster]
    
    bs_ci = Processing.draw_bs_ci(cluster_data["raw_displacement"], np.mean)
    
    clusters_bs_ci[cluster] = bs_ci
    
    print(f"Cluster {cluster}: {bs_ci}")
    
# Rank the clusters by the average efficiency of interactions

cluster_ranking = {cluster: np.mean(ci) for cluster, ci in clusters_bs_ci.items()}

cluster_order = [cluster for cluster, _ in sorted(cluster_ranking.items(), key=lambda x: x[1], reverse=True)]
    
    
# Let's plot the average efficiency of interactions per cluster with bootstrapped confidence intervals

plt.figure(figsize=(12, 6))

# Plot the bootstrapped confidence intervals
for cluster in cluster_order:
    low, high = clusters_bs_ci[cluster]
    mean_value = Umap_subset[Umap_subset["cluster"] == cluster]["raw_displacement"].mean()
    plt.errorbar(cluster_order.index(cluster), mean_value, yerr=[[mean_value - low], [high - mean_value]], fmt='o', color='black', capsize=5)

plt.title("Average Efficiency of Interactions by Cluster", fontsize=16)
plt.xlabel("Cluster Number", fontsize=14)
plt.ylabel("Raw Displacement", fontsize=14)
plt.xticks(range(len(cluster_order)), cluster_order, fontsize=12)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()

# Save the plot
plt.savefig(f"{output_dir}/average_efficiency_interactions_by_cluster_subset.png")
plt.show()

In [None]:
# Calculate average efficiency (raw_displacement) for each time bin
efficiency_over_time = Umap_subset.groupby('trial')['raw_displacement'].agg(['mean', 'count', 'std']).reset_index()

# Calculate 95% confidence interval
efficiency_over_time['ci'] = 1.96 * efficiency_over_time['std'] / np.sqrt(efficiency_over_time['count'])

# Plot average efficiency over time with confidence intervals
plt.figure(figsize=(12, 6))
plt.errorbar(efficiency_over_time['trial'], 
             efficiency_over_time['mean'], 
             yerr=efficiency_over_time['ci'], 
             fmt='o-', 
             capsize=5, 
             linewidth=2, 
             markersize=8)

# Change x axis to be from 1 to 4 without decimals

plt.xticks(range(1, 5))

plt.title('Average Efficiency Over Time', fontsize=16)
plt.xlabel('Trial', fontsize=14)
plt.ylabel('Raw Displacement (Efficiency)', fontsize=14)
plt.grid(True, linestyle='--', alpha=0.7)
plt.tight_layout()
plt.savefig(f"{output_dir}/average_efficiency_over_time.png")
plt.show()

In [None]:
# Identify efficient vs. inefficient clusters (using median as cutoff)
median_efficiency = np.median(list(cluster_ranking.values()))

threshold = median_efficiency

efficient_clusters = [c for c, v in cluster_ranking.items() if v >= threshold]
inefficient_clusters = [c for c, v in cluster_ranking.items() if v < threshold]

print(f"Efficient clusters: {efficient_clusters}")
print(f"Inefficient clusters: {inefficient_clusters}")

# Calculate the proportion of efficient and inefficient clusters over time
efficiency_dist = []

for time_bin in sorted(Umap_subset['trial'].unique()):
    bin_data = Umap_subset[Umap_subset['trial'] == time_bin]
    total_count = len(bin_data)
    
    if total_count > 0:
        efficient_count = len(bin_data[bin_data['cluster'].isin(efficient_clusters)])
        inefficient_count = len(bin_data[bin_data['cluster'].isin(inefficient_clusters)])
        
        efficiency_dist.append({
            'trial': time_bin,
            'efficient_prop': efficient_count / total_count * 100,
            'inefficient_prop': inefficient_count / total_count * 100,
            'efficient_count': efficient_count,
            'inefficient_count': inefficient_count,
            'total_count': total_count
        })

efficiency_dist_df = pd.DataFrame(efficiency_dist)

# Plot the proportion of efficient vs. inefficient clusters over time
plt.figure(figsize=(12, 6))

plt.plot(efficiency_dist_df['trial'], efficiency_dist_df['efficient_prop'], 
         'o-', color='green', linewidth=2, markersize=8, label='Efficient Clusters')
plt.plot(efficiency_dist_df['trial'], efficiency_dist_df['inefficient_prop'], 
         'o-', color='red', linewidth=2, markersize=8, label='Inefficient Clusters')

# Add counts as annotations
for i, row in efficiency_dist_df.iterrows():
    plt.annotate(f"{int(row['efficient_count'])}", 
               (row['trial'], row['efficient_prop']),
               textcoords="offset points", 
               xytext=(0,10), 
               ha='center',
               color='green',
               fontsize=9)
    plt.annotate(f"{int(row['inefficient_count'])}", 
               (row['trial'], row['inefficient_prop']),
               textcoords="offset points", 
               xytext=(0,10), 
               ha='center',
               color='red',
               fontsize=9)

plt.title('Proportion of Efficient vs. Inefficient Clusters Over Time', fontsize=16)
plt.xlabel('Time Bin', fontsize=14)
plt.ylabel('Percentage (%)', fontsize=14)
plt.legend(fontsize=12)
plt.grid(True, linestyle='--', alpha=0.7)
plt.ylim(0, 100)
plt.tight_layout()
plt.savefig(f"{output_dir}/efficiency_proportion_over_time.png")
plt.show()

# Density based clustering?

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import gaussian_kde
from sklearn.cluster import DBSCAN
from sklearn.neighbors import NearestNeighbors

# Extract t-SNE components
x = Umap_subset['UMAP1']
y = Umap_subset['UMAP2']

# Compute the KDE
kde = gaussian_kde([x, y], bw_method=0.05)

# Evaluate the KDE on a grid
xmin, xmax = x.min(), x.max()
ymin, ymax = y.min(), y.max()
xx, yy = np.mgrid[xmin:xmax:100j, ymin:ymax:100j]
positions = np.vstack([xx.ravel(), yy.ravel()])
kde_values = kde(positions).reshape(xx.shape)
# Plot the KDE and extract contour levels
plt.figure(figsize=(10, 8))
contour = plt.contourf(xx, yy, kde_values, levels=20, cmap="Blues")
plt.colorbar(label="Density")
plt.title("KDE Contour Plot")
plt.xlabel("UMAP1")
plt.ylabel("UMAP2")
plt.show()



In [None]:

# Extract the top 5 density levels (darkest shades)
top_levels = contour.levels[-15:]  # Adjust number of levels if needed
threshold = top_levels[0]  # Use the lowest value of the top 5 levels as threshold
# Plot the area of contour levels above the threshold

plt.figure(figsize=(10, 8))

plt.contourf(xx, yy, kde_values, levels=top_levels, cmap="Blues")

plt.colorbar(label="Density")

plt.title("KDE Contour Plot (Top 10 Levels)")

plt.xlabel("UMAP1")
plt.ylabel("UMAP2")

plt.show()


In [None]:


# Plot the TSNE data with contour around the top 5 density levels

plt.figure(figsize=(10, 6))

sns.scatterplot(data=Umap_subset, x='UMAP1', y='UMAP2', alpha=0.5)

plt.contour(xx, yy, kde_values, levels=top_levels, colors='r')

# plt.xlabel('t-SNE1')

# plt.ylabel('t-SNE2')

plt.title('Umap Plot of Behavior Map with Density Contour')

plt.show()



In [None]:

# Redo but get only the lowest contour level among the selected levels

plt.figure(figsize=(10, 6))

sns.scatterplot(data=Umap_subset, x='UMAP1', y='UMAP2', alpha=0.5)

plt.contour(xx, yy, kde_values, levels=[threshold], colors='r')

# plt.xlabel('t-SNE1')

# plt.ylabel('t-SNE2')

plt.title('Umap of Behavior Map with Density Contour')

plt.show()



In [None]:

# Make a list of the distinct areas of the contour plot

contour_areas = []

for i in range(len(top_levels) - 1):
    
    # Get the area of the contour between the current and next levels
    area = np.sum(kde_values * (kde_values >= top_levels[i]) * (kde_values < top_levels[i + 1]))
    
    contour_areas.append(area)
    
# Display the areas of the contour plot

contour_areas

# Find the indices of the points that are within the contour area

points_in_contour = np.where(kde_values >= threshold)

# Extract the x and y coordinates of the points within the contour area

x_in_contour = xx[points_in_contour]

y_in_contour = yy[points_in_contour]

# Plot the points within the contour area

plt.figure(figsize=(10, 6))

plt.scatter(x_in_contour, y_in_contour, alpha=0.5)

plt.contour(xx, yy, kde_values, levels=[threshold], colors='r')

# plt.xlabel('t-SNE1')

# plt.ylabel('t-SNE2')

plt.title('Points within the Density Contour')

plt.show()

In [None]:
from matplotlib.path import Path

# Create a "cluster_db" column initialized to -1 (indicating no cluster)
Umap_subset['cluster_db'] = -1

# Generate the contour object
contour_obj = plt.contour(xx, yy, kde_values, levels=top_levels, colors='r')

# Loop through each contour level and its associated paths
cluster_id = 0
for collection in contour_obj.collections:
    for path in collection.get_paths():
        # Create a Path object for the current contour
        contour_path = Path(path.vertices)
        
        # Check which points in the dataset fall within this contour
        points = np.vstack((Umap_subset['UMAP1'], Umap_subset['UMAP2'])).T
        inside = contour_path.contains_points(points)
        
        # Assign the current cluster ID to points within the contour
        Umap_subset.loc[inside, 'cluster_db'] = cluster_id
    
    # Increment the cluster ID for the next contour
    cluster_id += 1

# Display the updated dataset
print(Umap_subset.head())

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from scipy.stats import gaussian_kde
from scipy.ndimage import label

# Create a binary mask for high-density areas
high_density_mask = kde_values >= threshold

# Label connected components in the binary mask
labeled_array, num_features = label(high_density_mask)

# Add a "cluster_db" column to the dataset, initialized to -1 (indicating no cluster)
Umap_subset['cluster_db'] = -1

# Assign cluster labels to points based on the labeled high-density areas
for i in range(1, num_features + 1):
    area_mask = labeled_array == i
    x_in_area = xx[area_mask]
    y_in_area = yy[area_mask]
    
    # Find points from the original dataset that fall within this area
    points_in_area = Umap_subset[
        (Umap_subset['UMAP1'] >= x_in_area.min()) &
        (Umap_subset['UMAP1'] <= x_in_area.max()) &
        (Umap_subset['UMAP2'] >= y_in_area.min()) &
        (Umap_subset['UMAP2'] <= y_in_area.max())
    ]
    
    # Assign the cluster ID to these points
    Umap_subset.loc[points_in_area.index, 'cluster_db'] = i

# Plot each distinct area with a different color
plt.figure(figsize=(12, 10))
plt.contourf(xx, yy, kde_values, levels=20, cmap="Blues", alpha=0.3)

colors = plt.cm.rainbow(np.linspace(0, 1, num_features))
for i in range(1, num_features + 1):
    cluster_points = Umap_subset[Umap_subset['cluster_db'] == i]
    plt.scatter(cluster_points['UMAP1'], cluster_points['UMAP2'], 
                color=colors[i - 1], label=f'Cluster {i}', alpha=0.6)

plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.title('Distinct High-Density Areas in UMAP Plot')
plt.xlabel('UMAP1')
plt.ylabel('UMAP2')
plt.tight_layout()
plt.show()

# Print information about each cluster
for i in range(1, num_features + 1):
    cluster_points = Umap_subset[Umap_subset['cluster_db'] == i]
    print(f"Cluster {i}: {len(cluster_points)} points")


In [None]:
# Calculate cluster proportions over time by event type
cluster_proportions = (
    Umap_subset
    .groupby(['trial', 'cluster_db'])
    .size()
    .reset_index(name='count')
)

# Calculate totals for proportions
totals = (
    cluster_proportions
    .groupby(['trial'])['count']
    .sum()
    .reset_index(name='total')
)

# Merge and calculate proportions
cluster_proportions = cluster_proportions.merge(totals, on=['trial'])
cluster_proportions['proportion'] = cluster_proportions['count'] / cluster_proportions['total']

# Plot
plt.figure(figsize=(20, 10))

# Plot a line for each cluster
for cluster in sorted(Umap_subset["cluster_db"].unique()):
    cluster_data = cluster_proportions[cluster_proportions["cluster_db"] == cluster]
    if not cluster_data.empty:
        plt.plot(
            cluster_data["trial"],
            cluster_data["proportion"],
            marker="o",
            linewidth=2,
            label=f"Cluster {cluster}",
        )

plt.title("Cluster Evolution Over Time", fontsize=14)
plt.xlabel("Normalized Time Bin (Event Progress)", fontsize=12)
plt.ylabel("Proportion of Events", fontsize=12)
plt.grid(True, alpha=0.3)
plt.legend(title="Cluster", bbox_to_anchor=(1.05, 1), loc="upper left")

plt.tight_layout()

# Save the plot
plt.savefig(f"{output_dir}/cluster_evolution_line_plots_db.png")
plt.show()

In [None]:
# Do the same with the DBSCAN

# Get all unique clusters from your dataset
clusters = sorted(Umap_subset['cluster_db'].unique())

# Count occurrences for each cluster and trial combination
counts = Umap_subset.groupby(['cluster_db', 'trial']).size().reset_index(name='count')

# Calculate total events for each trial
totals = Umap_subset.groupby(['trial']).size().reset_index(name='total')

# Merge totals with counts to calculate proportions
counts = counts.merge(totals, on=['trial'], how='left')
counts['proportion'] = counts['count'] / counts['total'] * 100  # Convert to percentage

# Create the figure and axes grid - one subplot per cluster
fig, axes = plt.subplots(len(clusters), 1, 
                         figsize=(12, 3 * len(clusters)),
                         sharex=True, sharey=True)  # Share both axes

# Ensure axes is always an array
if len(clusters) == 1:
    axes = np.array([axes])

# Get all unique trials
all_trials = sorted(Umap_subset['trial'].unique())

# Plot each cluster
for i, cluster in enumerate(clusters):
    ax = axes[i]
    
    # Get data for this cluster
    cluster_data = counts[counts['cluster_db'] == cluster]
    
    # Sort by trial for proper plotting
    cluster_data = cluster_data.sort_values('trial')
    
    # Plot if there's data
    if len(cluster_data) > 0:
        # Use line plot with markers to show evolution, using the proportion
        ax.plot(cluster_data['trial'], cluster_data['proportion'], 
                marker='o', markersize=8, 
                linestyle='-', linewidth=2.5, 
                color='blue', label=f"Cluster {cluster}")
        
        # Option: Add text labels showing the raw counts for context
        for idx, row in cluster_data.iterrows():
            ax.annotate(f"{int(row['count'])}", 
                        (row['trial'], row['proportion']),
                        textcoords="offset points", 
                        xytext=(0, 5), 
                        ha='center',
                        fontsize=8)
    
    # Set labels and title for subplot
    ax.set_title(f"Cluster {cluster}", fontsize=14, fontweight='bold')
    ax.set_ylabel("Percentage (%)", fontsize=12)
    if i == len(clusters) - 1:
        ax.set_xlabel("Trial", fontsize=12)
        
    # Add grid for readability
    ax.grid(True, linestyle='--', alpha=0.7)
    
    # Set y-axis limits to 0-30%
    ax.set_ylim(0, 30)  
    
    # Set x-axis to include all trials
    ax.set_xticks(all_trials)

# Adjust layout
plt.tight_layout()
fig.subplots_adjust(top=0.95)
fig.suptitle('Proportion of Clusters Across Trials (% of Total Events)', fontsize=18)