In [None]:
import numpy as np
from ethograph import TrialTree
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.pyplot as plt


trial_data = []
paths = [
        r"D:\Alice\AK_data\derivatives\sub-01_id-Ivy\ses-000_date-20250306_01\behav\Trial_data_all_s3d_temp.nc",
        r"D:\Alice\AK_data\derivatives\sub-01_id-Ivy\ses-000_date-20250309_01\behav\Trial_data_all_s3d_temp.nc",
        r"D:\Alice\AK_data\derivatives\sub-01_id-Ivy\ses-000_date-20250503_02\behav\Trial_data_all_s3d_temp.nc",
        r"D:\Alice\AK_data\derivatives\sub-01_id-Ivy\ses-000_date-20250514_01\behav\Trial_data_all_s3d_temp.nc",]

for path in paths:
    dt = TrialTree.open(path)


    for trial in dt.trials:
        ds = dt.trial(trial)
        s3d = ds.s3d.values
        
        
        labels = ds.labels.values.squeeze()  # replace with actual label field
        if np.all(labels == 0):
            continue
        
        trial_data.append((s3d, labels))

def cohens_d_feature_selection(s3d, labels):
    """
    Calculate Cohen's d for each feature and each label.
    Returns the maximum Cohen's d across all labels for each feature.
    """
    n_features = s3d.shape[1]
    unique_labels = np.unique(labels)
    n_labels = len(unique_labels)
    
    # Matrix to store Cohen's d for each feature-label pair
    cohens_d_matrix = np.zeros((n_features, n_labels))
    
    if len(unique_labels) < 2:
        return np.zeros(n_features), cohens_d_matrix, unique_labels
    
    for feat_idx in range(n_features):
        # Skip if all values are zero
        if np.all(s3d[:, feat_idx] == 0):
            continue
        
        feature_values = s3d[:, feat_idx]
        
        for label_idx, target_label in enumerate(unique_labels):
            
            if target_label not in [18, 19]:
                continue
            
            
            # Split into during label vs not during label
            during_label = feature_values[labels == target_label]
            not_during_label = feature_values[labels != target_label]
            
            if len(during_label) > 0 and len(not_during_label) > 0:
                # Calculate means
                mean_during = np.mean(during_label)
                mean_not = np.mean(not_during_label)
                
                # Calculate variances (sample variance with ddof=1)
                var_during = np.var(during_label, ddof=1) if len(during_label) > 1 else 0
                var_not = np.var(not_during_label, ddof=1) if len(not_during_label) > 1 else 0
                
                # Pooled standard deviation
                pooled_variance = (var_during + var_not) / 2
                pooled_std = np.sqrt(pooled_variance)
                
                # Cohen's d
                if pooled_std > 0:
                    cohens_d = abs(mean_during - mean_not) / pooled_std
                    cohens_d_matrix[feat_idx, label_idx] = cohens_d
    
    # For each feature, take the maximum Cohen's d across all labels
    max_cohens_d = np.max(cohens_d_matrix, axis=1)
    
    return max_cohens_d, cohens_d_matrix, unique_labels


def analyze_with_cohens_d(trial_list):
    """
    Run Cohen's d analysis across all trials.
    """
    n_trials = len(trial_list)
    n_features = 1024
    
    # Get all unique labels across all trials
    all_labels = set()
    for _, labels in trial_list:
        all_labels.update(np.unique(labels))
    label_names = np.array(sorted(all_labels))
    n_labels = len(label_names)
    
    # Store results for each trial
    per_trial_max_d = np.zeros((n_trials, n_features))
    per_trial_d_matrix = np.zeros((n_trials, n_features, n_labels))
    
    for trial_idx, (s3d, labels) in enumerate(trial_list):
        print(f"Processing trial {trial_idx+1}/{n_trials}...", end='\r')
        
        max_d, d_matrix, trial_labels = cohens_d_feature_selection(s3d, labels)
        per_trial_max_d[trial_idx] = max_d
        
        # Map trial labels to global label indices
        for i, label in enumerate(trial_labels):
            label_global_idx = np.where(label_names == label)[0][0]
            per_trial_d_matrix[trial_idx, :, label_global_idx] = d_matrix[:, i]
    
    
    # Aggregate across trials
    mean_max_d = np.mean(per_trial_max_d, axis=0)
    mean_d_matrix = np.mean(per_trial_d_matrix, axis=0)
    
    return mean_max_d, mean_d_matrix, per_trial_max_d, label_names




# # Run analysis
mean_max_d, mean_d_matrix, per_trial_max_d, label_names = analyze_with_cohens_d(trial_data)



top_k = 20
top_indices = np.argsort(mean_max_d)[-top_k:][::-1]
# Ivy: [192, 232, 330,  21, 115, 210, 102, 100, 114, 352, 107, 199, 813, 265,   4, 342, 460, 454, 21, 57] # last two are loher Cohen's D but specifically good for motif 18, 19
# Freddy: [ 326, 327, 292, 363, 219, 192, 260, 66, 332, 199, 288, 763, 837, 182, 24, 218, 213, 21, 733, 242]

fig, ax = plt.subplots(figsize=(14, 8))

heatmap_data = mean_d_matrix[top_indices[:top_k], :].T
im = ax.imshow(heatmap_data, aspect='auto', cmap='YlOrRd', interpolation='nearest')

ax.set_xlabel('S3D feature', fontsize=24)
ax.set_ylabel('Motifs', fontsize=24)

ax.set_yticks(range(len(label_names)))
ax.set_yticklabels(label_names, fontsize=14)

x_ticks = range(top_k)
ax.set_xticks(x_ticks)
ax.set_xticklabels([top_indices[i] for i in x_ticks], rotation=45, ha='right', fontsize=12)

cbar = plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
cbar.set_label("Cohen's d", fontsize=24)
cbar.ax.tick_params(labelsize=12)


plt.show()