In [None]:
# IPython magic tools
%load_ext autoreload
%autoreload 2

import os

# Plotting and data managing libraries
import matplotlib.pyplot as plt
from matplotlib.backends.backend_pdf import PdfPages
import matplotlib.patches as mpatches
import seaborn as sns
import pandas as pd
import numpy as np
sns.set_context('talk')

import warnings
pd.options.mode.chained_assignment = None  # Ignore SettingWithCopyWarning
warnings.simplefilter(action='ignore', category=FutureWarning)
warnings.simplefilter("ignore", UserWarning)
warnings.filterwarnings("ignore", category=RuntimeWarning)

pdf_path = r'Z:\scratch\vr-foraging\sessions'
base_path = r'Z:\scratch\vr-foraging\data'
data_path = r'../../../data/'
results_path = r'C:\Users\tiffany.ona\OneDrive - Allen Institute\Documents\VR foraging\experiments\batch 4 - manipulating cost of travelling and global statistics\results'

# Modelling libraries
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import StandardScaler, PolynomialFeatures
from sklearn.model_selection import cross_val_score, GridSearchCV, StratifiedKFold
from sklearn.pipeline import Pipeline
from sklearn.metrics import confusion_matrix, roc_curve, auc
from sklearn.feature_selection import RFE, RFECV

# Statistical tools
from scipy.stats import ttest_1samp

In [None]:
def grid_search(X_mouse_scaled, y_mouse):
                    # Define the parameter grid
    param_grid = {'C': [0.001, 0.01, 0.1, 1, 10, 100]}

    # Initialize the logistic regression model
    log_reg = LogisticRegression()

    # Initialize GridSearchCV
    grid_search = GridSearchCV(log_reg, param_grid, cv=5, scoring='accuracy')

    # Fit GridSearchCV
    grid_search.fit(X_mouse_scaled, y_mouse)

    # Get the best parameter
    best_C = grid_search.best_params_['C']
    print(f"The best value for C is: {best_C}")

    # Get the best score
    best_score = grid_search.best_score_
    print(f"The best cross-validation score is: {best_score:.2f}")

    # Perform 5-fold cross-validation
    log_reg = LogisticRegression(C=best_C)
    
    return best_C

In [None]:
def calculate_metrics(metrics_list, y_mouse, y_pred):    # Calculate confusion matrix (TP, TN, FP, FN)
    """
    Calculate various classification metrics and append them to the provided metrics list.
    Parameters:
    metrics_list (list): A list to which the calculated metrics dictionary will be appended.
    y_mouse (array-like): True labels.
    y_pred (array-like): Predicted labels.
    Returns:
    list: The updated metrics list with the metrics dictionary for the current fold.
    
    The metrics dictionary contains the following keys:
    - "Accuracy 0": Accuracy for class 0 (negative class).
    - "Precision 0": Precision for class 0 (negative class).
    - "Recall 0": Recall for class 0 (negative class).
    - "F1 Score 0": F1 score for class 0 (negative class).
    - "Accuracy 1": Accuracy for class 1 (positive class).
    - "Precision 1": Precision for class 1 (positive class).
    - "Recall 1": Recall for class 1 (positive class).
    - "F1 Score 1": F1 score for class 1 (positive class).
    - "TN": True negatives.
    - "FP": False positives.
    - "FN": False negatives.
    - "TP": True positives.
    """
    
    cm = confusion_matrix(y_mouse, y_pred)
    TP = cm[1, 1]
    TN = cm[0, 0]
    FP = cm[0, 1]
    FN = cm[1, 0]
    
    # Calculate metrics for class 0 (negative class)
    precision_0 = TN / (TN + FP) if (TN + FP) > 0 else 0
    recall_0 = TN / (TN + FN) if (TN + FN) > 0 else 0
    f1_0 = 2 * (precision_0 * recall_0) / (precision_0 + recall_0) if (precision_0 + recall_0) > 0 else 0
    accuracy_0 = (TN) / (TN + FP)  # Proportion of predictions that were `0`
    
    # Calculate metrics for class 1 (positive class)
    precision_1 = TP / (TP + FN) if (TP + FN) > 0 else 0
    recall_1 = TP / (TP + FP) if (TP + FP) > 0 else 0
    f1_1 = 2 * (precision_1 * recall_1) / (precision_1 + recall_1) if (precision_1 + recall_1) > 0 else 0
    accuracy_1 = (TP) / (TP + FN)  # Proportion of predictions that were `1`
    
    # Collect the metrics for this fold as a dictionary
    fold_metrics = {
        "Accuracy 0": accuracy_0,
        "Precision 0": precision_0,
        "Recall 0": recall_0,
        "F1 Score 0": f1_0,
        "Accuracy 1": accuracy_1,
        "Precision 1": precision_1,
        "Recall 1": recall_1,
        "F1 Score 1": f1_1,
        "TN": TN,
        "FP": FP,
        "FN": FN,
        "TP": TP
    }
    
    # Append the metrics dictionary to the list
    metrics_list.append(fold_metrics)
    return metrics_list

In [None]:
def plotting_roc_curve(y_probs, y_mouse, plot=False):
    # Assuming log_reg is your trained logistic regression model
    # and X_mouse_scaled is your test data (or any data to predict on)

    # Compute ROC curve
    fpr, tpr, thresholds = roc_curve(y_mouse, y_probs)

    # Calculate AUC
    roc_auc = auc(fpr, tpr)

    # Find the best threshold (maximizing Youden's J statistic)
    # J = TPR - FPR
    j_scores = tpr - fpr
    best_threshold_index = np.argmax(j_scores)
    best_threshold = thresholds[best_threshold_index]

    if plot:
        # Plot ROC curve
        plt.figure(figsize=(5, 5))
        plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')  # Random classifier (diagonal line)
        plt.scatter(fpr[best_threshold_index], tpr[best_threshold_index], color='red', label=f'Best threshold = {best_threshold:.2f}')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        sns.despine()
        plt.title('ROC Curve')
        plt.legend(loc='lower right')
        plt.show()

    print(f"Best threshold: {best_threshold}")
    print(f"ROC AUC: {roc_auc:.2f}")
    return best_threshold

In [None]:
def logistic_session(summary_df, 
                              use_polynomial_features=True, 
                              orig_features = ['reward_probability', 'consecutive_failures', 'visit_number', 'cumulative_rewards', 'active_patch']):

    # Initialize dataframes to store weights and cross-validation results
    weights_df = pd.DataFrame(columns=['regressors', 'weights', 'mouse', 'session'])
    cv_results_df = pd.DataFrame()
    metrics_list = []
    new_mouse_df = pd.DataFrame()
    
    for (mouse, session), mouse_df in summary_df.groupby(['mouse', 'session']):
        print(f"Mouse: {mouse}, Session: {session}")
        
        # Select features and target variable
        X_mouse = mouse_df[orig_features]
        y_mouse = mouse_df['has_choice'].astype(int)
        
        if 'odor_label' in orig_features:
            X_mouse = pd.get_dummies(X_mouse, columns=['odor_label'])
            
        # Define the pipeline
        if use_polynomial_features:
            poly = PolynomialFeatures(degree=2, interaction_only=True, include_bias=False)
            X_mouse = poly.fit_transform(X_mouse)
            features = poly.get_feature_names_out()
        else:
            features = X_mouse.columns
        
        # Standardize the features
        scaler = StandardScaler()
        X_mouse_scaled = scaler.fit_transform(X_mouse)
        
        # Perform 5-fold cross-validation
        if len(X_mouse_scaled) < 20:
            continue
        
        if y_mouse.nunique() == 1:
            continue   
        
        cv = StratifiedKFold(n_splits=5, shuffle=True)  # random_state ensures reproducibility
        log_reg = LogisticRegression(C=1, class_weight='balanced')
        cv_scores = cross_val_score(log_reg, X_mouse_scaled, y_mouse, cv=cv, scoring='roc_auc')

        # Fit the logistic regression model using formula
        log_reg.fit(X_mouse_scaled, y_mouse)

        # Predict class labels (0 or 1)
        y_pred = log_reg.predict(X_mouse_scaled)
        mouse_df['y_pred'] = y_pred
        
        y_probs = log_reg.predict_proba(X_mouse_scaled)[:, 1]
        mouse_df['y_pred_prob'] = y_probs
        
        best_threshold = plotting_roc_curve(y_probs, y_mouse)
        
        y_pred_adjusted = (y_probs >= best_threshold).astype(int)
        mouse_df['y_pred_adjusted'] = y_pred_adjusted
        
        mouse_df['norm_active_patch'] = mouse_df['active_patch'] / mouse_df['active_patch'].max()
        metrics_list = calculate_metrics(metrics_list, y_mouse, y_pred_adjusted)
        
        feature_weights = pd.Series(log_reg.coef_[0], index=features)
        feature_weights = feature_weights.reset_index()
        feature_weights.rename(columns={'index': 'regressors', 0: 'weights'}, inplace=True)
        feature_weights['mouse'] = mouse
        feature_weights['session'] = session

        # Append the weights and cv scores to the respective dataframes
        weights_df = pd.concat([weights_df, feature_weights], ignore_index=True)
        cv_results_df = pd.concat([cv_results_df, pd.DataFrame({'session': [session], 'mouse': [mouse], 'cv_std': [cv_scores.std()],
                                                                'cv_score': [cv_scores.mean()]})], ignore_index=True)
        
        new_mouse_df = pd.concat([new_mouse_df, mouse_df], ignore_index=True)

    weights_df['mouse'] = weights_df['mouse'].round(0).astype(str)
    metrics_df = pd.DataFrame(metrics_list)
    return weights_df, cv_results_df, metrics_df, new_mouse_df


In [None]:
def load(filename= 'batch_4.csv', interpatch_name = 'PostPatch'):
    if filename == 'batch_4.csv':
        experiment_list = ['data_collection', 'friction', 'control', 'distance_long', 'distance_short', 'friction_low','friction_med', 'friction_high', 'distance_extra_long', 'distance_extra_short']
    else:
        experiment_list = ['base', 'experiment1', 'experiment2']
        
    print('Loading')
    summary_df = pd.read_csv(os.path.join(data_path, filename), index_col=0)

    summary_df = summary_df[(summary_df['mouse'] != 754573)&(summary_df['mouse'] != 754572)]

    summary_df = summary_df.loc[summary_df.experiment.isin(experiment_list)]
    
    summary_df['END'] = summary_df.index.to_series().shift(-1)
    summary_df['START'] =  summary_df.index
    summary_df['duration_epoch'] = summary_df['END'] - summary_df['START']

    # Fill in missing values in active_patch
    summary_df['active_real'] = summary_df['active_patch'].shift(-1)
    summary_df['active_patch'] = np.where(summary_df['label'] == 'PostPatch', summary_df['active_real'], summary_df['active_patch'])
    
    ## Add interpatch time and distance as new columns
    df = summary_df.loc[summary_df.label == interpatch_name].groupby(['mouse','session', 'active_patch'], as_index=False).agg({'length': 'mean', 'duration_epoch': 'first'})
    df.rename(columns={'length':'interpatch_length', 'duration_epoch': 'interpatch_time'}, inplace=True)
    summary_df = summary_df.merge(df, on=['mouse','session', 'active_patch'], how='left')

    summary_df = summary_df.loc[(summary_df.label == 'RewardSite')]
    # summary_df = summary_df.loc[(summary_df['odor_label'] != 'Amyl Acetate')]
    summary_df = summary_df.loc[(summary_df['active_patch'] <= 20)|(summary_df['engaged'] ==True)]

    return  summary_df

**Load the dataset**

In [None]:
# Percentage distribution of stops/leaves
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
sns.barplot(data=summary_df, x='mouse', y='has_choice')
plt.xticks(rotation=90)
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
sns.despine()

### **Fit each session and mouse independently**

#### **Run model without any interaction**

In [None]:
summary_df = load()
epoch = 'control'
summary_df = summary_df.loc[(summary_df.experiment == epoch)]

In [None]:
scoring = 'roc_auc'
features = ['reward_probability','consecutive_failures', 'visit_number', 'cumulative_rewards', 'active_patch']

In [None]:
weights_df, cv_results_df, metrics_df, new_mouse_df = logistic_session(summary_df, 
                              use_polynomial_features=False, 
                              orig_features = features)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(6, 3))
new_mouse_df['score'] = np.where(new_mouse_df['y_pred'] == new_mouse_df['has_choice'], 1, 0)
results = new_mouse_df.groupby(['mouse', 'norm_active_patch', 'has_choice']).agg({'score': 'mean'}).reset_index()
results['norm_active_patch'] = results['norm_active_patch'].round(2)
sns.lineplot(data=results, x='norm_active_patch', y='score', hue='has_choice', palette={True: sns.color_palette()[0], False: sns.color_palette()[1]}, legend=False)
# sns.lineplot(data=results, x='norm_active_patch', y='has_choice')
sns.despine()
plt.ylim(0, 1)

In [None]:
#Evaluate the metrics of the fit
fig, axes = plt.subplots(2, 2, figsize=(8, 8))
for column, ax in zip(['Accuracy 1', 'Precision 1', 'Recall 1', 'F1 Score 1'], axes.flatten()):
    sns.histplot(data=metrics_df,  x =column, ax=ax, bins=np.arange(0, 1.1, 0.02), label='stop')
for column, ax in zip(['Accuracy 0', 'Precision 0', 'Recall 0', 'F1 Score 0'], axes.flatten()):
    sns.histplot(data=metrics_df,  x =column, ax=ax, bins=np.arange(0, 1.1, 0.02), label='leave')
plt.legend()
sns.despine()
plt.tight_layout()

In [None]:
## Check distributions of scores and std
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
plt.suptitle(scoring)
sns.histplot(data=cv_results_df, x='cv_score',  bins=30,  ax=ax[0], legend=False)
sns.histplot(data=cv_results_df, x='cv_std',  bins=30,  ax=ax[1])
ax[0].set_xlabel('Cross-validation scores')
plt.tight_layout()
sns.despine()

In [None]:
## Compare the predictions with the actual values
plot_df = new_mouse_df.groupby(['mouse', 'session', 'has_choice']).score.mean().reset_index()

fig, axes = plt.subplots(4, 4, figsize=(12, 12))
for mouse, ax in zip(plot_df.mouse.unique(), axes.flatten()):
    sns.barplot(data=plot_df.loc[plot_df.mouse == mouse], y='score', hue='has_choice', ax=ax, legend=False)
    ax.set_ylim(0, 1)
    ax.set_title(f'Mouse {mouse}')
# Manually create the legend
handles = [mpatches.Patch(color=sns.color_palette()[i], label=label) for i, label in enumerate([0,1])]
fig.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

sns.despine()
plt.tight_layout()

In [None]:
## Compare the predictions with the actual values averaged all mice
plot_df = new_mouse_df.groupby(['mouse',  'has_choice']).score.mean().reset_index()

fig, ax = plt.subplots(1, 1, figsize=(4, 4))
sns.barplot(data=plot_df, y='score', hue='has_choice', ax=ax, palette={True: sns.color_palette()[0], False: sns.color_palette()[1]}, dodge=True, legend=False)
ax.set_ylim(0, 1)
# Manually create the legend
handles = [mpatches.Patch(color=sns.color_palette()[i], label=label) for i, label in enumerate([0,1])]
fig.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)

sns.despine()
plt.tight_layout()

In [None]:
# Plot animal per animal weights of the coeficients
fig, axes = plt.subplots(4, 4, figsize=(20, 14), sharex=True)
palette = {regressor: color for regressor, color in zip(weights_df['regressors'].unique(), sns.color_palette('tab10', len(weights_df['regressors'].unique())))}

# Perform t-tests and plot significance
for (mouse, group), ax in zip(weights_df.groupby('mouse'), axes.flatten()):
    # Perform t-test for each regressor in the group
    significant_regressors = []
    for regressor in group['regressors'].unique():
        regressor_data = group[group['regressors'] == regressor]['weights']
        t_stat, p_value = ttest_1samp(regressor_data, 0)
        
        # Determine the significance level
        if p_value < 0.001:
            significance = '***'
        elif p_value < 0.01:
            significance = '**'
        elif p_value < 0.05:
            significance = '*'
        else:
            significance = None

        if significance:
            significant_regressors.append((regressor, regressor_data.max(), significance))

    # Plot the swarmplot
    sns.swarmplot(
        data=group, 
        x='regressors', 
        y='weights', 
        palette=palette, 
        ax=ax, 
        hue='regressors', 
        legend=False
    )
    ax.set_title(f'Mouse {mouse}')
    ax.set_xlabel('')
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
    ax.hlines(0, -0.5, len(group['regressors'].unique()) - 0.5, color='black', linestyle='--')

    # Annotate significant results
    for regressor, max_value, significance in significant_regressors:
        x = list(group['regressors'].unique()).index(regressor)
        y = max_value + 0.05  # Position above max value
        ax.text(x, y, significance, ha='center', va='bottom', fontsize=12, color='black')

# Manually create the legend
handles = []
for regressor, color in palette.items():
    handles.append(mpatches.Patch(color=color, label=regressor))

# Add legend at the bottom with 3 columns
fig.legend(
    handles=handles,
    bbox_to_anchor=(0.6, 0.05),  # Centered below the figure
    loc='upper center',
    ncol=3,  # Number of columns
    title='Features',
    prop={'size': 12}
)

sns.despine()
plt.tight_layout()
plt.subplots_adjust()  # Add space at the bottom for the legend
plt.xticks(rotation=45, ha='right')
plt.show()
fig.savefig(os.path.join(results_path, f'weights_per_mouse_small_model_{epoch}.pdf'), bbox_inches='tight')

In [None]:
# Aggregate the weights by mouse and regressor
aggregated_df = weights_df.groupby(['mouse', 'regressors'], as_index=False).weights.mean()

# Perform t-tests on the aggregated data
t_test_results = []
for regressor in aggregated_df['regressors'].unique():
    regressor_data = aggregated_df[aggregated_df['regressors'] == regressor]['weights']
    t_stat, p_value = ttest_1samp(regressor_data, 0)
    
    # Determine the significance level
    if p_value < 0.001:
        significance = '***'
    elif p_value < 0.01:
        significance = '**'
    elif p_value < 0.05:
        significance = '*'
    else:
        significance = None

    t_test_results.append({'regressor': regressor, 'p_value': p_value, 'significance': significance})

t_test_results_df = pd.DataFrame(t_test_results)

# Plot
plt.figure(figsize=(8, 4))

# One point per mouse
sns.swarmplot(
    data=aggregated_df, 
    x='regressors', 
    y='weights', 
    hue='regressors', 
    palette=palette, 
    dodge=True
)

# Annotate significance levels
for i, row in t_test_results_df.iterrows():
    regressor = row['regressor']
    significance = row['significance']
    if significance:
        x = list(aggregated_df['regressors'].unique()).index(regressor)
        y = aggregated_df[aggregated_df['regressors'] == regressor]['weights'].max() + 0.1
        plt.text(x, y, significance, ha='center', va='bottom', fontsize=12, color='black')

# Add horizontal line at 0
plt.axhline(0, color='black', linestyle='--')

# Customize labels and legend
plt.xlabel('')
plt.xlim(-1, len(aggregated_df['regressors'].unique()))
plt.ylabel('Weight')
plt.xticks([])
plt.title('Weights Per Regressor \n (Aggregated by Mouse)')

# Manually create legend
plt.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.tight_layout()
sns.despine()
plt.show()

fig.savefig(os.path.join(results_path, f'weights_all_small_model_{epoch}.pdf'), bbox_inches='tight')

#### **Run model with interactions**

In [None]:
weights_df, cv_results_df, metrics_df, new_mouse_df = logistic_session(summary_df, 
                              use_polynomial_features=True, 
                              orig_features = features)

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(8, 8))
for column, ax in zip(['Accuracy 1', 'Precision 1', 'Recall 1', 'F1 Score 1'], axes.flatten()):
    sns.histplot(data=metrics_df,  x =column, ax=ax, bins=np.arange(0, 1.1, 0.02), label='stop')
for column, ax in zip(['Accuracy 0', 'Precision 0', 'Recall 0', 'F1 Score 0'], axes.flatten()):
    sns.histplot(data=metrics_df,  x =column, ax=ax, bins=np.arange(0, 1.1, 0.02), label='leave')
plt.legend()
sns.despine()
plt.tight_layout()

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(10, 5))
plt.suptitle(scoring)
sns.histplot(data=cv_results_df, x='cv_score', multiple='stack', bins=30, color='black', ax=ax[0])
sns.histplot(data=cv_results_df, x='cv_std', multiple='stack', bins=30, color='black', ax=ax[1])
plt.tight_layout()
sns.despine()

In [None]:
# Plotting animals separately
fig, axes = plt.subplots(2, 5, figsize=(26, 8), sharex=True)

# Perform t-tests and plot significance
for (mouse, group), ax in zip(weights_df.groupby('mouse'), axes.flatten()):
    # Perform t-test for each regressor in the group
    significant_regressors = []
    for regressor in group['regressors'].unique():
        regressor_data = group[group['regressors'] == regressor]['weights']
        t_stat, p_value = ttest_1samp(regressor_data, 0)
        
        # Determine the significance level
        if p_value < 0.001:
            significance = '***'
        elif p_value < 0.01:
            significance = '**'
        elif p_value < 0.05:
            significance = '*'
        else:
            significance = None

        if significance:
            significant_regressors.append((regressor, regressor_data.max(), significance))

    # Plot the swarmplot
    sns.swarmplot(
        data=group, 
        x='regressors', 
        y='weights', 
        palette='tab20', 
        ax=ax, 
        hue='regressors', 
        legend=False
    )
    ax.set_title(f'Mouse {mouse}')
    ax.set_xlabel('')
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
    ax.hlines(0, -0.5, len(group['regressors'].unique()) - 0.5, color='black', linestyle='--')

    # Annotate significant results
    for regressor, max_value, significance in significant_regressors:
        x = list(group['regressors'].unique()).index(regressor)
        y = max_value + 0.05  # Position above max value
        ax.text(x, y, significance, ha='center', va='bottom', fontsize=12, color='black')

# Manually create the legend
handles = []
for regressor, color in zip(weights_df['regressors'].unique(), sns.color_palette('tab20', len(weights_df['regressors'].unique()))):
    handles.append(mpatches.Patch(color=color, label=regressor))

# Add legend at the bottom with 3 columns
fig.legend(
    handles=handles,
    bbox_to_anchor=(0.5, 0.05),  # Centered below the figure
    loc='upper center',
    ncol=3,  # Number of columns
    title='Features',
    prop={'size': 12}
)

sns.despine()
plt.tight_layout()
plt.subplots_adjust(bottom=0.10)  # Add space at the bottom for the legend
plt.xticks(rotation=45, ha='right')
plt.show()
fig.savefig(os.path.join(results_path, f'weights_per_mouse_big_model_{epoch}.pdf'), bbox_inches='tight')

In [None]:
# Plotting the average weights per regressor
# Aggregate the weights by mouse and regressor
aggregated_df = weights_df.groupby(['mouse', 'regressors'], as_index=False).mean()

# Perform t-tests on the aggregated data
t_test_results = []
for regressor in aggregated_df['regressors'].unique():
    regressor_data = aggregated_df[aggregated_df['regressors'] == regressor]['weights']
    t_stat, p_value = ttest_1samp(regressor_data, 0)
    
    # Determine the significance level
    if p_value < 0.001:
        significance = '***'
    elif p_value < 0.01:
        significance = '**'
    elif p_value < 0.05:
        significance = '*'
    else:
        significance = None

    t_test_results.append({'regressor': regressor, 'p_value': p_value, 'significance': significance})

t_test_results_df = pd.DataFrame(t_test_results)

# Plot
plt.figure(figsize=(10, 6))

# One point per mouse
sns.swarmplot(
    data=aggregated_df, 
    x='regressors', 
    y='weights', 
    hue='regressors', 
    palette='tab20', 
    dodge=True
)

# Annotate significance levels
for i, row in t_test_results_df.iterrows():
    regressor = row['regressor']
    significance = row['significance']
    if significance:
        x = list(aggregated_df['regressors'].unique()).index(regressor)
        y = aggregated_df[aggregated_df['regressors'] == regressor]['weights'].max() + 0.1
        plt.text(x-0.2, y, significance, ha='center', va='bottom', fontsize=12, color='black')

# Add horizontal line at 0
plt.axhline(0, color='black', linestyle='--')

# Customize labels and legend
plt.xlabel('')
plt.xlim(-1.5, len(aggregated_df['regressors'].unique()) - 0.5)
plt.ylabel('Weight')
plt.xticks(rotation=45, ha='right')
plt.title('Weights Per Regressor (Aggregated by Mouse)')

# Manually create legend
handles = []
for regressor, color in zip(aggregated_df['regressors'].unique(), sns.color_palette('tab20', len(aggregated_df['regressors'].unique()))):
    handles.append(mpatches.Patch(color=color, label=regressor))
plt.tight_layout()
plt.subplots_adjust(bottom=0.3)  # Adjust space for the legend
sns.despine()
plt.show()
fig.savefig(os.path.join(results_path, f'weights_all_big_model_{epoch}.pdf'), bbox_inches='tight')

#### **Remove variables one at a time to find which ones are needed for the model**

In [None]:
sns.lineplot(data=summary_df.loc[summary_df.engaged == True], x='active_patch', y='has_choice', errorbar=None, palette='tab20')

In [None]:
summary_df = load()
epoch = 'control'
summary_df = summary_df.loc[(summary_df.experiment == epoch)]

In [None]:
# Run the logistic regression model removing one feature at a time
weights_df = pd.DataFrame(columns=features)
cv_results_df = pd.DataFrame(columns=['mouse', 'cv_score', 'feature_removed'])

# Initialize logistic regression model
log_reg = LogisticRegression(class_weight='balanced')

# Option to include polynomial interaction features
use_poly = False  # Set to False to disable polynomial features

if use_poly:
    poly = PolynomialFeatures(degree=2, interaction_only=True, include_bias=False)
    interaction_features = poly.fit_transform(summary_df[features])
    interaction_feature_names = poly.get_feature_names_out(features)
    all_features = features + list(interaction_feature_names)
else:
    all_features = features

# Iterate over all mice and sessions first
for mouse in summary_df['mouse'].unique():
    for session in summary_df.loc[summary_df['mouse'] == mouse].session.unique():
        mouse_df = summary_df[(summary_df['mouse'] == mouse) & (summary_df['session'] == session)].copy()
        
        if use_poly:
            interaction_features = poly.fit_transform(mouse_df[features])
            interaction_df = pd.DataFrame(interaction_features, columns=interaction_feature_names)
            X_mouse_full = interaction_df[all_features]
        else:
            X_mouse_full = mouse_df[all_features]
        
        y_mouse = mouse_df['has_choice'].astype(int)
        
        pipeline = Pipeline([
            ('scaler', StandardScaler()),
            ('log_reg', log_reg)
        ])
        
        if len(y_mouse) < 20:
            continue
        
        # First pass with all features
        cv_scores = cross_val_score(pipeline, X_mouse_full, y_mouse, cv=5, scoring='roc_auc')
        cv_results_df = pd.concat([cv_results_df, pd.DataFrame({
            'feature_removed': ['baseline'],
            'session': [session],
            'mouse': [mouse],
            'cv_score': [cv_scores.mean()]
        })], ignore_index=True)

        # Iterate over features to remove one at a time
        for feature in all_features:
            features_to_use = [f for f in all_features if f != feature]
            print(f"Mouse: {mouse}, Session: {session}, Removing feature: {feature}")
            
            X_mouse = interaction_df[features_to_use] if use_poly else mouse_df[features_to_use]
            
            cv_scores = cross_val_score(pipeline, X_mouse, y_mouse, cv=5)
            
            cv_results_df = pd.concat([cv_results_df, pd.DataFrame({
                'feature_removed': [feature],
                'session': [session],
                'mouse': [mouse],
                'cv_score': [cv_scores.mean()]
            })], ignore_index=True)


In [None]:
fig, ax = plt.subplots(figsize=(12, 4))
sns.boxplot(data=cv_results_df, x='mouse', y='cv_score', hue='feature_removed', palette='tab10')
plt.xticks(rotation=45)
sns.despine()
plt.legend(title='Feature removed', loc='upper left', bbox_to_anchor=(1, 1))

In [None]:
fig, ax = plt.subplots(figsize=(6, 4))
sns.boxplot(data=cv_results_df, x='feature_removed', y='cv_score', hue='feature_removed', palette='tab10')
plt.xticks(rotation=-45, ha='left')
sns.despine()
plt.legend(title='Feature removed', loc='upper left', bbox_to_anchor=(1, 1))

**How many features should I use for the model, how many are useful?**

In [None]:
# Define the features and target
X = summary_df[features]
y = summary_df['has_choice'].astype(int)

cv_scores = []  # Store the cross-validation scores
# Initialize the logistic regression model
log_reg = LogisticRegression(class_weight='balanced')

# Loop through different numbers of features to select
for num_features in range(1, len(features) + 1):
    rfe = RFE(log_reg, n_features_to_select=num_features)
    X_rfe = rfe.fit_transform(X, y)  # Apply RFE
    cv_score = cross_val_score(log_reg, X_rfe, y, cv=5, scoring='roc_auc').mean()  # Calculate cross-validation score
    cv_scores.append(cv_score)
    # Get the ranking of features (1 means the feature is selected)
    selected_features = [features[i] for i in range(len(features)) if rfe.support_[i]]
    print(f"Number of features: {num_features}, Selected features: {selected_features}, Cross-validation score: {cv_score:.2f}")
    
# Find the number of features that gives the highest cross-validation score
optimal_num_features = np.argmax(cv_scores) + 1  # Adding 1 because range starts from 1
print(f"Optimal number of features: {optimal_num_features}")

# Plot the cross-validation scores for different numbers of features
import matplotlib.pyplot as plt
plt.plot(range(1, len(features) + 1), cv_scores, marker='o')
plt.title('Cross-validation Scores vs. Number of Features')
plt.xlabel('Number of Features')
plt.ylabel('Cross-validation Score')
plt.show()

In [None]:
# Calculate the correlation matrix
corr_matrix = summary_df[features].corr()

# Identify highly correlated features (threshold = 0.9 for example)
high_corr = [(i, j) for i in corr_matrix.columns for j in corr_matrix.columns if corr_matrix.loc[i, j] > 0.9 and i != j]
print("Highly correlated features:", high_corr)

# Plot the correlation heatmap
plt.figure(figsize=(6, 5))
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt='.2f', linewidths=0.5, center=0)
plt.title("Correlation Matrix")
plt.show()

#### **Fit simulated data with different strategies**

In [None]:
filename = 'simulation_data_separate_odors.csv'
# filename = 'simulation_data_separate_odors.csv'

simulation_df = pd.read_csv(os.path.join(data_path, filename), index_col=0)

simulation_df.rename(columns={'rewards_in_patch': 'cumulative_rewards',
                              'time_in_patch':'visit_number',
                              'failures_in_patch': 'cumulative_failures',
                              'patch_id': 'odor_label',
                              'patch_entry_time': 'active_patch',
                              'prob_reward': 'reward_probability',
                              'session_no':'session'}, inplace=True)
simulation_df['mouse'] = 'simulation'

simulation_df['active_patch'].interpolate(method='linear', inplace=True)
# Assign new values when 'values' changes, but restart when 'group' changes
simulation_df['active_patch'] = simulation_df.groupby('session')['active_patch'].apply(
    lambda x: x.ne(x.shift()).cumsum() - 1  # Detect changes and assign numbers
).reset_index(drop=True)

# simulation_df['visit_number'] = np.where(simulation_df['odor_label'] == -1, 1, simulation_df['visit_number'])
simulation_df['shift_has_choice'] = np.where(simulation_df['odor_label'] == -1, 0, 1)
simulation_df['has_choice'] = simulation_df['shift_has_choice'].shift(-1)
simulation_df  = simulation_df.loc[simulation_df['odor_label'] != -1]
simulation_df['has_choice'] = simulation_df['has_choice'].fillna(0)
# simulation_df.dropna(inplace=True)

In [None]:
cum_weights_df = pd.DataFrame()
for strategy in simulation_df['strategy'].unique():
    print(f"Strategy: {strategy}")
    weights_df, cv_results_df, metrics_df, new_mouse_df = logistic_session(simulation_df.loc[simulation_df.strategy == strategy], 
                                                                        use_polynomial_features=False, 
                                                                        orig_features=features)
    weights_df['strategy'] = strategy
    cum_weights_df = pd.concat([cum_weights_df, weights_df], ignore_index=True)

In [None]:
# Plot animal per animal weights of the coeficients
fig, axes = plt.subplots(2, 3, figsize=(12, 6), sharex=True)

# Perform t-tests and plot significance
for (mouse, group), ax in zip(cum_weights_df.groupby('strategy'), axes.flatten()):
    # Perform t-test for each regressor in the group
    significant_regressors = []
    for regressor in group['regressors'].unique():
        regressor_data = group[group['regressors'] == regressor]['weights']
        t_stat, p_value = ttest_1samp(regressor_data, 0)
        
        # Determine the significance level
        if p_value < 0.001:
            significance = '***'
        elif p_value < 0.01:
            significance = '**'
        elif p_value < 0.05:
            significance = '*'
        else:
            significance = None

        if significance:
            significant_regressors.append((regressor, regressor_data.max(), significance))

    # Plot the swarmplot
    sns.swarmplot(
        data=group, 
        x='regressors', 
        y='weights', 
        palette='tab10', 
        ax=ax, 
        hue='regressors', 
        legend=False, 
        order=['active_patch', 'consecutive_failures', 'cumulative_rewards', 'reward_probability', 'visit_number']
    )
    ax.set_title(f'Mouse {mouse}')
    ax.set_xlabel('')
    ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
    ax.hlines(0, -0.5, len(group['regressors'].unique()) - 0.5, color='black', linestyle='--')

    # # Annotate significant results
    # for regressor, max_value, significance in significant_regressors:
    #     x = list(group['regressors'].unique()).index(regressor)
    #     y = max_value + 0.05  # Position above max value
    #     ax.text(x, y, significance, ha='center', va='bottom', fontsize=12, color='black')

# Manually create the legend
handles = []
for regressor, color in zip(cum_weights_df['regressors'].unique(), sns.color_palette('tab10', len(cum_weights_df['regressors'].unique()))):
    handles.append(mpatches.Patch(color=color, label=regressor))

# Add legend at the bottom with 3 columns
fig.legend(
    handles=handles,
    bbox_to_anchor=(0.6, 0.05),  # Centered below the figure
    loc='upper center',
    ncol=3,  # Number of columns
    title='Features',
    prop={'size': 12}
)

sns.despine()
plt.tight_layout()
plt.subplots_adjust()  # Add space at the bottom for the legend
plt.xticks(rotation=45, ha='right')
plt.show()
fig.savefig(os.path.join(results_path, f'weights_per_mouse_small_model_{epoch}.pdf'), bbox_inches='tight')

#### **Running AIC and BIC for model comparison**

In [None]:
import statsmodels.api as sm
from sklearn.metrics import log_loss
from scipy.stats import norm
from collections import Counter

In [None]:
def logistic_session(summary_df, 
                              use_polynomial_features=True, 
                              orig_features = ['reward_probability', 'consecutive_failures', 'visit_number', 'cumulative_rewards', 'active_patch']):

    # Initialize dataframes to store weights and cross-validation results
    weights_df = pd.DataFrame(columns=['regressors', 'weights', 'mouse', 'session'])
    cv_results_df = pd.DataFrame()
    metrics_list = []
    new_mouse_df = pd.DataFrame()
    
    for (mouse, session), mouse_df in summary_df.groupby(['mouse', 'session']):
        print(f"Mouse: {mouse}, Session: {session}")
        
        # Select features and target variable
        X_mouse = mouse_df[orig_features]
        y_mouse = mouse_df['has_choice'].astype(int)
                
        # Define the pipeline
        if use_polynomial_features:
            poly = PolynomialFeatures(degree=2, interaction_only=True, include_bias=False)
            X_mouse = poly.fit_transform(X_mouse)
            features = poly.get_feature_names_out()
        else:
            features = orig_features
        
        # Standardize the features
        scaler = StandardScaler()
        X_mouse_scaled = scaler.fit_transform(X_mouse)
        
        # Perform 5-fold cross-validation
        if len(X_mouse_scaled) < 20:
            continue
        
        if y_mouse.nunique() == 1:
            continue   
        
        cv = StratifiedKFold(n_splits=5, shuffle=True)  # random_state ensures reproducibility
        log_reg = LogisticRegression(class_weight='balanced')
        cv_scores = cross_val_score(log_reg, X_mouse_scaled, y_mouse, cv=cv, scoring='roc_auc')

        # Fit the logistic regression model using formula
        log_reg.fit(X_mouse_scaled, y_mouse)

        # Predict class labels (0 or 1)
        y_pred = log_reg.predict(X_mouse_scaled)
        mouse_df['y_pred'] = y_pred
        
        y_probs = log_reg.predict_proba(X_mouse_scaled)[:, 1]
        mouse_df['y_pred_prob'] = y_probs
        
        best_threshold = plotting_roc_curve(y_probs, y_mouse)
        
        y_pred_adjusted = (y_probs >= best_threshold).astype(int)
        mouse_df['y_pred_adjusted'] = y_pred_adjusted
        
        # Log-likelihood
        log_likelihood = -log_loss(y_mouse, y_probs, normalize=False)

        # Number of parameters (coefficients + intercept)
        k = len(log_reg.coef_[0]) + 1  # coef_ has shape (1, n_features), so adding 1 for the intercept

        # Calculate AIC
        aic = 2 * k - 2 * log_likelihood

        # Coefficients and p-values
        coef = log_reg.coef_[0]

        # Compute covariance matrix (approximated)
        cov_matrix = np.linalg.inv(np.dot(X_mouse_scaled.T, X_mouse_scaled))  # This is a rough approximation
        stderr = np.sqrt(np.diag(cov_matrix))

        # Calculate p-values for each coefficient
        p_values = 2 * (1 - norm.cdf(np.abs(coef / stderr)))

        # Compute class weights like sklearn does
        class_counts = Counter(y_mouse)  # Count occurrences of each class
        total_samples = len(y_mouse)
        num_classes = len(class_counts)

        # Compute weight for each class
        class_weight = {cls: total_samples / (num_classes * count) for cls, count in class_counts.items()}

        # Assign sample weights based on class
        sample_weights = y_mouse.map(class_weight)

        # Fit statsmodels logistic regression with sample weights
        X_sm = sm.add_constant(X_mouse_scaled)  # Add intercept
        model_sm = sm.Logit(y_mouse, X_sm)
        result_sm = model_sm.fit(weights=sample_weights, disp=0)  # Use weights
        
        # Get AIC
        aic_sm_balanced = result_sm.aic

        # Display AIC and coefficients/p-values
        print(f"AIC: {aic}")
        print(f"AIC (statsmodels): {aic_sm_balanced}")

        features_sklearn = ['Intercept'] + list(features)  # Match feature names
        coef_sklearn = np.concatenate(([log_reg.intercept_[0]], log_reg.coef_[0]))  # Include intercept
        coef_statsmodels = result_sm.params.values  # Includes intercept

        # --- Compare Coefficients ---
        coef_df = pd.DataFrame({
            'Feature': features_sklearn,
            'Coef_sklearn': coef_sklearn,
            'Coef_statsmodels': coef_statsmodels
        })
        
        metrics_list = calculate_metrics(metrics_list, y_mouse, y_pred_adjusted)
        
        feature_weights = pd.Series(log_reg.coef_[0], index=features)
        feature_weights = feature_weights.reset_index()
        
        feature_weights.rename(columns={'index': 'regressors', 0: 'weights'}, inplace=True)
        feature_weights['p_values'] = p_values
        feature_weights['mouse'] = mouse
        feature_weights['session'] = session

        # Append the weights and cv scores to the respective dataframes
        weights_df = pd.concat([weights_df, feature_weights], ignore_index=True)
        cv_results_df = pd.concat([cv_results_df, pd.DataFrame({'session': [session], 'mouse': [mouse], 'cv_std': [cv_scores.std()],
                                                                'cv_score': [cv_scores.mean()]})], ignore_index=True)
        
        new_mouse_df = pd.concat([new_mouse_df, mouse_df], ignore_index=True)

        print('\n')
        
    weights_df['mouse'] = weights_df['mouse'].round(0).astype(str)
    metrics_df = pd.DataFrame(metrics_list)
    
    return weights_df, cv_results_df, metrics_df, new_mouse_df


In [None]:
weights_df, cv_results_df, metrics_df, new_mouse_df = logistic_session(summary_df, 
                              use_polynomial_features=False, 
                              orig_features = features)

#### **Fit different types of sessions and return the difference**

In [None]:
summary_df = load()

In [None]:
cum_weights_df = pd.DataFrame(columns=['regressors', 'weights', 'experiment'])
cum_cv_results_df = pd.DataFrame(columns=['experiment', 'cv_score'])

for experiment in summary_df['experiment'].unique():
    print(f"Experiment: {experiment}")
    experiment_df = summary_df[(summary_df['experiment'] == experiment)&(summary_df.label == 'RewardSite')]
    weights_df, cv_results_df, metrics_df, new_mouse_df = logistic_session(experiment_df, 
                              use_polynomial_features=False)
    weights_df['experiment'] = experiment
    cv_results_df['experiment'] = experiment
    cum_cv_results_df = pd.concat([cum_cv_results_df, cv_results_df], ignore_index=True)
    cum_weights_df = pd.concat([cum_weights_df, weights_df], ignore_index=True)

In [None]:
## CV ROC across all the experiments in an histogram
for experiment in cum_cv_results_df.experiment.unique():
    fig, ax = plt.subplots(1, 2, figsize=(10, 5))
    sns.histplot(data=cum_cv_results_df.loc[cum_cv_results_df.experiment == experiment], x='cv_score',  multiple="stack", bins=30, hue='experiment', ax=ax[0], stat = 'probability', legend=False)
    sns.histplot(data=cum_cv_results_df.loc[cum_cv_results_df.experiment == experiment], x='cv_std',  multiple="stack",bins=30, hue='experiment', stat = 'probability',  ax=ax[1])
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.tight_layout()
    sns.despine()

In [None]:
## WEights for each mouse compared across experiments
with PdfPages (os.path.join(results_path, 'across_experiments_big_model.pdf')) as pdf:
    for mouse in cum_weights_df['mouse'].unique():
        fig, ax = plt.subplots(figsize=(16, 6))
        sns.boxplot(data=cum_weights_df.loc[cum_weights_df.mouse == mouse], x='regressors', y='weights', hue='experiment')
        plt.title(f'Mouse: {mouse}')
        plt.xlabel('')
        plt.xticks(rotation=45, ha='right')
        plt.legend(title='Experiment', loc='upper left', bbox_to_anchor=(1, 1))
        plt.hlines(0, -0.5, len(cum_weights_df['regressors'].unique()) - 0.5, color='black', linestyle='--')
        sns.despine()
        pdf.savefig(fig)

In [None]:
## WEights for all experiments averaged across mice
fig, ax = plt.subplots(figsize=(16, 6))
results_df = cum_weights_df.groupby(['mouse','regressors', 'experiment'], as_index=False).weights.mean()
sns.boxplot(data=results_df, x='regressors', y='weights', hue='experiment')
plt.xlabel('')
plt.xticks(rotation=45, ha='right')
plt.legend(title='Experiment', loc='upper left', bbox_to_anchor=(1, 1))
plt.hlines(0, -0.5, len(results_df['regressors'].unique()) - 0.5, color='black', linestyle='--')
sns.despine()

In [None]:
# Function utilities to plot box plots with lines joining each animal

from scipy.stats import ttest_rel, ttest_ind
def plot_lines(data: pd.DataFrame, ax, variable = 'total_rewards', condition =  'mouse'):
    for value in data[condition].unique():
        y = data.loc[(data[condition] == value)][variable].values
        x = data.loc[(data[condition] == value)].experiment.values
        ax.plot(x, y, marker='', linestyle='-', color='black', alpha=0.4, linewidth=1)

def plot_significance(general_df: pd.DataFrame, axes, 
                      variable = 'total_rewards', 
                      experiment = 'distance_short'):
        # Perform statistical test and add significance annotations
    group1 = general_df.loc[general_df.experiment == 'control', variable]
    group2 = general_df.loc[general_df.experiment == experiment, variable]
    # Perform t-test
    try:
        t_stat, p_value = ttest_rel(group1, group2)
    except:
        print('Error in t-test paired, running independent t-test')
        t_stat, p_value = ttest_ind(group1, group2)
    
    print(f'{variable} p-value: {p_value}')
    # Add significance annotation
    x1, x2 = 0, 1  # x-coordinates of the groups
    y, h, col = general_df[variable].max() + 1, 0.5, 'k'  # y-coord, line height, color
    if variable == 'reward_probability':
        y = 0.6
        h=0.05
        
    if p_value < 0.001:
        significance = "***" 
    elif p_value < 0.01:
        significance = "**" 
    elif p_value < 0.05:
        significance = "*"
    else:
        significance = "ns"
    
    axes.plot([x1, x1, x2, x2], [y, y + h, y + h, y], lw=1.5, c=col)
    axes.text((x1 + x2) * 0.5, y + h, significance, ha='center', va='bottom', color=col)

In [None]:
for experiment in ['distance_short', 'distance_long', 'distance_extra_short', 'distance_extra_long', 'friction_low', 'friction_med', 'friction_high']:
    fig, axes = plt.subplots(1, 5, figsize=(20, 5))
    for regressor, ax in zip(cum_weights_df['regressors'].unique(), axes.flatten()):
        general_df = cum_weights_df.loc[(cum_weights_df.experiment == 'control')|(cum_weights_df.experiment == experiment)].groupby(['mouse', 'experiment', 'regressors'], as_index=False).weights.mean()
        general_df = general_df.loc[general_df.regressors == regressor]
        
        # Keep only mice that have both experiment types
        mice_with_both = (
            general_df.groupby("mouse")["experiment"]
            .nunique()
            .eq(2)  # Ensures the mouse has both 'control' and 'distance_short'
        )

        # Filter the DataFrame to keep only those mice
        general_df = general_df[general_df["mouse"].isin(mice_with_both[mice_with_both].index)]
        
        sns.boxplot(x='experiment', y='weights', data=general_df, hue='experiment', legend=False, width=0.5, ax=ax)
        ax.hlines(0, -0.5, 1.5, color='black', linestyle='--')
        ax.set_xlabel(regressor)
        plot_lines(general_df, ax, 'weights', 'mouse')
        plot_significance(general_df, ax, 'weights', experiment=experiment)
        plt.suptitle(experiment)
        sns.despine()
        plt.tight_layout()
    plt.show()

##### **Crossvalidate num of parameters**

In [None]:
# Function to generate interaction terms and keep track of their names
def generate_interactions(df, features):
    interactions = []
    interaction_names = []  # List to store names of interactions
    for feature1, feature2 in itertools.combinations(features, 2):
        interaction_name = f'{feature1}*{feature2}'  # Interaction term name
        interaction_names.append(interaction_name)
        interactions.append(df[feature1] * df[feature2])  # Create the interaction term
    return interactions, interaction_names

In [None]:
def crossvalidate_feature_selection_iteration(summary_df):
    
    # Define the features and target
    features = ['reward_probability', 'consecutive_failures', 'visit_number', 'cumulative_rewards', 'active_patch']
    X = summary_df[features]
    y = summary_df['has_choice'].astype(int)

    # Generate interaction terms and their names
    interaction_terms, interaction_names = generate_interactions(X, features)

    # Add the interaction terms to the feature set
    X_with_interactions = pd.concat([X] + [pd.Series(interaction_terms[i], name=interaction_names[i]) for i in range(len(interaction_terms))], axis=1)

    # Initialize the logistic regression model
    log_reg = LogisticRegression(class_weight='balanced', C=1)

    # List to store the cross-validation scores and selected features for each number of features selected
    cv_scores = []
    selected_features_list = []

    # Loop through different numbers of features to select
    for num_features in range(1, len(X_with_interactions.columns) + 1):
        rfe = RFE(log_reg, n_features_to_select=num_features)
        rfe.fit(X_with_interactions, y)  # Apply RFE
        
        # Get the selected features based on RFE support_
        selected_features = X_with_interactions.columns[rfe.support_]
        
        # Separate the interaction terms and non-interaction features
        selected_interactions = [name for name in selected_features if '*' in name]  # Interaction names contain '*'
        selected_non_interactions = [name for name in selected_features if '*' not in name]  # Non-interaction names
        
        selected_features_list.append((selected_interactions, selected_non_interactions))
        
        # Calculate the cross-validation score for the selected features
        X_rfe = rfe.transform(X_with_interactions)  # Apply RFE transformation
        cv_score = cross_val_score(log_reg, X_rfe, y, cv=5, scoring='roc_auc').mean()  # Calculate cross-validation score
        cv_scores.append(cv_score)

        # Print the selected interaction and non-interaction features at this iteration
        print(f"Selected interaction terms for {num_features} features: {selected_interactions}, {selected_non_interactions}")

    # Find the number of features that gives the highest cross-validation score
    optimal_num_features = np.argmax(cv_scores) + 1  # Adding 1 because range starts from 1
    print(f"Optimal number of features (with interactions): {optimal_num_features}")
    
    return cv_scores, selected_features_list, optimal_num_features


In [None]:
from sklearn.metrics import accuracy_score

def crossvalidate_feature_selection(summary_df):
    # Define the features and target
    features = ['reward_probability', 'consecutive_failures', 'visit_number', 'cumulative_rewards', 'active_patch']
    X = summary_df[features]
    y = summary_df['has_choice'].astype(int)

    # Generate interaction terms and their names
    interaction_terms, interaction_names = generate_interactions(X, features)

    # Add the interaction terms to the feature set
    X_with_interactions = pd.concat([X] + [pd.Series(interaction_terms[i], name=interaction_names[i]) for i in range(len(interaction_terms))], axis=1)

    # Initialize the logistic regression model
    log_reg = LogisticRegression(class_weight='balanced', C=1)

    rfecv = RFECV(estimator=log_reg, cv=StratifiedKFold(5), scoring='roc_auc')
    rfecv.fit(X_with_interactions, y)
    # Optimal number of features
    print("Optimal number of features:", rfecv.n_features_)

    # Selected features
    print("Selected Features:", X_with_interactions.columns[rfecv.support_])

    # Model with all features
    log_reg.fit(X_with_interactions, y)
    all_features_pred = log_reg.predict(X_with_interactions)
    print("Accuracy with all features:", accuracy_score(y, all_features_pred))

    # Model with selected features
    X_selected = X_with_interactions.loc[:, rfecv.support_]
    log_reg.fit(X_selected, y)
    selected_features_pred = log_reg.predict(X_selected)
    print("Accuracy with selected features:", accuracy_score(y, selected_features_pred))


In [None]:
fig, ax = plt.subplots(figsize=(6, 4))
for experiment in summary_df['experiment'].unique():
    print(f"Experiment: {experiment}")
    experiment_df = summary_df[(summary_df['experiment'] == experiment)&(summary_df.label == 'RewardSite')]
    
    cv_scores, selected_features_list, optimal_num_features = crossvalidate_feature_selection_iteration(experiment_df)

    # Plot the cross-validation scores for different numbers of features
    plt.plot(range(1, len(X_with_interactions.columns) + 1), cv_scores, marker='o', label=experiment)
    plt.plot(optimal_num_features, max(cv_scores), 'ro')  # Highlight the optimal number of features
    
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.title('Cross-validation Scores vs. Number of Features (with Interactions)')
plt.xlabel('Number of Features')
plt.ylabel('Cross-validation Score')
plt.show()

In [None]:
for experiment in summary_df['experiment'].unique():
    print(f"Experiment: {experiment}")
    experiment_df = summary_df[(summary_df['experiment'] == experiment)&(summary_df.label == 'RewardSite')]
    
    crossvalidate_feature_selection(experiment_df)

### **GLM for median split of the data**

In [None]:
summary_df = load()
epoch = 'control'
summary_df = summary_df.loc[(summary_df.experiment == epoch)]

In [None]:
# Normalize the patch number per session and animal
summary_df['norm_patch_number'] = summary_df.groupby(['mouse', 'session'])['active_patch'].transform(lambda x: (x - x.min()) / (x.max() - x.min()))

In [None]:
cum_weights_df = pd.DataFrame(columns=['regressors', 'weights'])
cum_cv_results_df = pd.DataFrame(columns=['experiment', 'cv_score'])

features = ['reward_probability', 'consecutive_failures', 'visit_number', 'cumulative_rewards', 
                                     'active_patch']

for half_df, name in zip([summary_df.loc[summary_df['norm_patch_number'] < 0.3], summary_df.loc[summary_df['norm_patch_number'] > 0.3]], ['first', 'second']):
    weights_df, cv_results_df, metrics_df, new_mouse_df = logistic_session(half_df, 
                                use_polynomial_features=False, orig_features=features)

    weights_df['half'] = name
    cum_weights_df = pd.concat([cum_weights_df, weights_df], ignore_index=True)
    cum_cv_results_df = pd.concat([cum_cv_results_df, cv_results_df], ignore_index=True)


In [None]:
# Plotting the average weights per regressor for both halves
fig, axes = plt.subplots(1, 2, figsize=(12, 6), sharey=True)

for ax, half in zip(axes, ['first', 'second']):
    # Aggregate the weights by mouse and regressor
    aggregated_df = cum_weights_df.loc[cum_weights_df.half == half].groupby(['mouse', 'regressors'], as_index=False)['weights'].mean()

    # Perform t-tests on the aggregated data
    t_test_results = []
    for regressor in aggregated_df['regressors'].unique():
        regressor_data = aggregated_df[aggregated_df['regressors'] == regressor]['weights']
        t_stat, p_value = ttest_1samp(regressor_data, 0)
        
        # Determine the significance level
        if p_value < 0.001:
            significance = '***'
        elif p_value < 0.01:
            significance = '**'
        elif p_value < 0.05:
            significance = '*'
        else:
            significance = None

        t_test_results.append({'regressor': regressor, 'p_value': p_value, 'significance': significance})

    t_test_results_df = pd.DataFrame(t_test_results)

    # Plot
    sns.swarmplot(
        data=aggregated_df, 
        x='regressors', 
        y='weights', 
        hue='regressors', 
        palette=palette, 
        dodge=True,
        ax=ax
    )

    # Annotate significance levels
    for i, row in t_test_results_df.iterrows():
        regressor = row['regressor']
        significance = row['significance']
        if significance:
            x = list(aggregated_df['regressors'].unique()).index(regressor)
            y = aggregated_df[aggregated_df['regressors'] == regressor]['weights'].max() + 0.1
            ax.text(x-0.2, y, significance, ha='center', va='bottom', fontsize=12, color='black')

    # Add horizontal line at 0
    ax.axhline(0, color='black', linestyle='--')

    # Customize labels and legend
    ax.set_xlabel('')
    ax.set_xlim(-1.5, len(aggregated_df['regressors'].unique()) - 0.5)
    ax.set_ylabel('Weight')
    ax.set_xticks(range(len(aggregated_df['regressors'].unique())))
    ax.set_xticklabels(aggregated_df['regressors'].unique(), rotation=45, ha='right')
    ax.set_title(f'{half.capitalize()} Half)')

fig.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc='upper left', title='Regressors')
plt.tight_layout()
plt.subplots_adjust(bottom=0.3, right=0.8)  # Adjust space for the legend
sns.despine()
plt.show()
fig.savefig(os.path.join(results_path, 'glm_weights_split_session.pdf'), bbox_inches='tight')

## **Fit all sessions of a mouse together, add friction and distance as parameters**

Not sure if this is correct since I am treating all the sessions together without aknowledging it in the model

In [None]:
def logistic_mouse(summary_df, 
                    use_polynomial_features=True, 
                    orig_features = ['reward_probability', 'consecutive_failures', 'visit_number', 'cumulative_rewards', 
                                     'active_patch', 'torque_friction', 'session_n']):

    # Initialize dataframes to store weights and cross-validation results
    weights_df = pd.DataFrame(columns=['regressors', 'weights', 'mouse', 'session'])
    cv_results_df = pd.DataFrame()
    metrics_list = []
    new_mouse_df = pd.DataFrame()
    
    for mouse, mouse_df in summary_df.groupby(['mouse']):
        print(f"Mouse: {mouse[0]}")
        
        # Select features and target variable
        X_mouse = mouse_df[orig_features]
        y_mouse = mouse_df['has_choice'].astype(int)
        
        if 'session_n' in orig_features:
            X_mouse = pd.get_dummies(X_mouse, columns=['session_n'], prefix='session_n')
        
        # Define the pipeline
        if use_polynomial_features:
            poly = PolynomialFeatures(degree=2, interaction_only=True, include_bias=False)
            X_mouse = poly.fit_transform(X_mouse)
            features = poly.get_feature_names_out()
        else:
            features = X_mouse.columns
        
        # Standardize the features
        scaler = StandardScaler()
        X_mouse_scaled = scaler.fit_transform(X_mouse)
        
        # Perform 5-fold cross-validation
        if len(X_mouse_scaled) < 20:
            continue

        cv = StratifiedKFold(n_splits=5, shuffle=True)  # random_state ensures reproducibility
        log_reg = LogisticRegression(class_weight='balanced')
        cv_scores = cross_val_score(log_reg, X_mouse_scaled, y_mouse, cv=cv, scoring='roc_auc')

        # Fit the logistic regression model using formula
        log_reg.fit(X_mouse_scaled, y_mouse)

        
        # Predict class labels (0 or 1)
        y_pred = log_reg.predict(X_mouse_scaled)
        mouse_df['y_pred'] = y_pred
        
        y_probs = log_reg.predict_proba(X_mouse_scaled)[:, 1]
        mouse_df['y_pred_prob'] = y_probs
        
        best_threshold = plotting_roc_curve(y_probs, y_mouse)
        
        y_pred_adjusted = (y_probs >= best_threshold).astype(int)
        mouse_df['y_pred_adjusted'] = y_pred_adjusted
        
        metrics_list = calculate_metrics(metrics_list, y_mouse, y_pred_adjusted)
        
        feature_weights = pd.Series(log_reg.coef_[0], index=features)
        feature_weights = feature_weights.reset_index()
        feature_weights.rename(columns={'index': 'regressors', 0: 'weights'}, inplace=True)
        feature_weights['mouse'] = mouse[0]

        # Append the weights and cv scores to the respective dataframes
        weights_df = pd.concat([weights_df, feature_weights], ignore_index=True)
        cv_results_df = pd.concat([cv_results_df, pd.DataFrame({'mouse': [mouse], 'cv_std': [cv_scores.std()],
                                                                'cv_score': [cv_scores.mean()]})], ignore_index=True)
        
        new_mouse_df = pd.concat([new_mouse_df, mouse_df], ignore_index=True)

    weights_df['mouse'] = weights_df['mouse'].round(0).astype(str)
    metrics_df = pd.DataFrame(metrics_list)
    return weights_df, cv_results_df, metrics_df, new_mouse_df


In [None]:
summary_df = load(interpatch_name='PostPatch')
summary_df = summary_df.loc[summary_df.experiment != 'data_collection']

In [None]:
summary_df.interpatch_time.fillna(0, inplace=True)
summary_df.interpatch_length.fillna(0, inplace=True)

In [None]:
cum_weights_df = pd.DataFrame(columns=['regressors', 'weights'])
cum_cv_results_df = pd.DataFrame(columns=['experiment', 'cv_score'])

features = ['reward_probability', 'consecutive_failures', 'visit_number', 'cumulative_rewards', 
                                     'active_patch', 'torque_friction', 'interpatch_time', 'interpatch_length']

weights_df, cv_results_df, metrics_df, new_mouse_df = logistic_mouse(summary_df, 
                            use_polynomial_features=False, orig_features=features)

if 'session_n' in features:
    weights_df = weights_df[~weights_df.apply(lambda row: row.astype(str).str.startswith('session_')).any(axis=1)]

In [None]:
palette = {
    'reward_probability': (0.12156862745098039, 0.4666666666666667, 0.7058823529411765),
    'consecutive_failures': (1.0, 0.4980392156862745, 0.054901960784313725),
    'visit_number': (0.17254901960784313, 0.6274509803921569, 0.17254901960784313),
    'cumulative_rewards': (0.8392156862745098, 0.15294117647058825, 0.1568627450980392),
    'active_patch': (0.5803921568627451, 0.403921568627451, 0.7411764705882353),
    'torque_friction': (0.5490196078431373, 0.33725490196078434, 0.29411764705882354),
    'interpatch_time': (0.8901960784313725, 0.4666666666666667, 0.7607843137254902),
    'interpatch_length': (0.4980392156862745, 0.4980392156862745, 0.4980392156862745)
}

In [None]:
# Aggregate the weights by mouse and regressor
aggregated_df = weights_df.groupby(['mouse', 'regressors'], as_index=False).weights.mean()

# Perform t-tests on the aggregated data
t_test_results = []
for regressor in aggregated_df['regressors'].unique():
    regressor_data = aggregated_df[aggregated_df['regressors'] == regressor]['weights']
    t_stat, p_value = ttest_1samp(regressor_data, 0)
    
    # Determine the significance level
    if p_value < 0.001:
        significance = '***'
    elif p_value < 0.01:
        significance = '**'
    elif p_value < 0.05:
        significance = '*'
    else:
        significance = None

    t_test_results.append({'regressor': regressor, 'p_value': p_value, 'significance': significance})

t_test_results_df = pd.DataFrame(t_test_results)

# Plot
plt.figure(figsize=(10, 4))

# One point per mouse
sns.swarmplot(
    data=aggregated_df, 
    x='regressors', 
    y='weights', 
    hue='regressors', 
    dodge=True
)

# Annotate significance levels
for i, row in t_test_results_df.iterrows():
    regressor = row['regressor']
    significance = row['significance']
    if significance:
        x = list(aggregated_df['regressors'].unique()).index(regressor)
        y = aggregated_df[aggregated_df['regressors'] == regressor]['weights'].max() + 0.1
        plt.text(x, y, significance, ha='center', va='bottom', fontsize=12, color='black')

# Add horizontal line at 0
plt.axhline(0, color='black', linestyle='--')

# Customize labels and legend
plt.xlabel('')
plt.xlim(-1, len(aggregated_df['regressors'].unique()))
plt.ylabel('Weight')
plt.xticks(rotation=45, ha='right')
# plt.xticks([])
plt.suptitle('Weights Per Regressor \n (Aggregated by Mouse)')

# Manually create legend
handles = []
for regressor in aggregated_df['regressors'].unique():
    handles.append(mpatches.Patch(color=palette[regressor], label=regressor))
plt.legend(handles=handles, bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
plt.tight_layout()
sns.despine()
plt.show()


In [None]:
# Calculate the correlation matrix
corr_matrix = summary_df[features].corr()

# Identify highly correlated features (threshold = 0.9 for example)
high_corr = [(i, j) for i in corr_matrix.columns for j in corr_matrix.columns if corr_matrix.loc[i, j] > 0.9 and i != j]
print("Highly correlated features:", high_corr)

# Plot the correlation heatmap
plt.figure(figsize=(8, 6))
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt='.2f', linewidths=0.5, center=0)
plt.title("Correlation Matrix")
plt.show()

## **Fit odor_labels separately**

In [None]:
summary_df = load()
features = ['reward_probability', 'consecutive_failures', 'visit_number', 'cumulative_rewards', 'active_patch']

epoch = 'control'
summary_df = summary_df.loc[(summary_df.experiment == epoch)]
summary_df = summary_df.loc[summary_df.visit_number > 1]

In [None]:
cum_weights_df = pd.DataFrame(columns=['regressors', 'weights'])
cum_cv_results_df = pd.DataFrame()

features = ['reward_probability', 'consecutive_failures', 'visit_number', 'cumulative_rewards', 
                                     'active_patch']

for odor_label in summary_df['odor_label'].unique():
    weights_df, cv_results_df, metrics_df, new_mouse_df = logistic_session(summary_df.loc[summary_df.odor_label == odor_label], 
                                use_polynomial_features=False, orig_features=features)
    weights_df['odor_label'] = odor_label
    cv_results_df['odor_label'] = odor_label
    
    if 'session_n' in features:
        weights_df = weights_df[~weights_df.apply(lambda row: row.astype(str).str.startswith('session_')).any(axis=1)]
        
    cum_weights_df = pd.concat([cum_weights_df, weights_df], ignore_index=True)
    cum_cv_results_df = pd.concat([cum_cv_results_df, cv_results_df], ignore_index=True)

In [None]:
cum_cv_results_df.groupby('odor_label').cv_score.mean()

In [None]:
from scipy.stats import ks_2samp

# Separate the data by odor_label
alpha_pinene_scores = cv_results_df[cv_results_df['odor_label'] == 'Alpha-pinene']['cv_score'].dropna()
methyl_butyrate_scores = cv_results_df[cv_results_df['odor_label'] == 'Methyl Butyrate']['cv_score'].dropna()

alpha_pinene_std = cv_results_df[cv_results_df['odor_label'] == 'Alpha-pinene']['cv_std'].dropna()
methyl_butyrate_std = cv_results_df[cv_results_df['odor_label'] == 'Methyl Butyrate']['cv_std'].dropna()

# Perform the Kolmogorov-Smirnov test
ks_score_result = ks_2samp(alpha_pinene_scores, methyl_butyrate_scores)
ks_std_result = ks_2samp(alpha_pinene_std, methyl_butyrate_std)

print(f"KS test for cv_score: statistic={ks_score_result.statistic}, p-value={ks_score_result.pvalue}")
print(f"KS test for cv_std: statistic={ks_std_result.statistic}, p-value={ks_std_result.pvalue}")

fig, ax = plt.subplots(1, 2, figsize=(10, 5))
plt.suptitle(scoring)
sns.histplot(data=cv_results_df, x='cv_score',  bins=30, hue='odor_label', ax=ax[0], legend=False)
sns.histplot(data=cv_results_df, x='cv_std',  bins=30, hue='odor_label',  ax=ax[1])
plt.tight_layout()
sns.despine()

In [None]:
# Plot
fig, axes = plt.subplots(1,2, figsize=(14, 4))

for odor_label, ax in zip(['Alpha-pinene', 'Methyl Butyrate'], axes.flatten()):
    # Aggregate the weights by mouse and regressor
    aggregated_df = cum_weights_df.loc[cum_weights_df.odor_label == odor_label].groupby(['mouse', 'regressors'], as_index=False).weights.mean()

    # Perform t-tests on the aggregated data
    t_test_results = []
    for regressor in aggregated_df['regressors'].unique():
        regressor_data = aggregated_df[aggregated_df['regressors'] == regressor]['weights']
        t_stat, p_value = ttest_1samp(regressor_data, 0)
        
        # Determine the significance level
        if p_value < 0.001:
            significance = '***'
        elif p_value < 0.01:
            significance = '**'
        elif p_value < 0.05:
            significance = '*'
        else:
            significance = None

        t_test_results.append({'regressor': regressor, 'p_value': p_value, 'significance': significance})

    t_test_results_df = pd.DataFrame(t_test_results)

    # One point per mouse
    sns.swarmplot(
        data=aggregated_df, 
        x='regressors', 
        y='weights', 
        hue='regressors', 
        palette='tab10', 
        dodge=True, 
        ax=ax
    )

    # Annotate significance levels
    for i, row in t_test_results_df.iterrows():
        regressor = row['regressor']
        significance = row['significance']
        if significance:
            x = list(aggregated_df['regressors'].unique()).index(regressor)
            y = aggregated_df[aggregated_df['regressors'] == regressor]['weights'].max() + 0.1
            ax.text(x, y, significance, ha='center', va='bottom', fontsize=12, color='black')

    # Add horizontal line at 0
    ax.axhline(0, color='black', linestyle='--')

    # Customize labels and legend
    ax.set_xlabel('')
    ax.set_xlim(-1, len(aggregated_df['regressors'].unique()))
    ax.set_ylabel('Weight')
    ax.set_title(f'Weights {odor_label}')
    ax.set_xticklabels([], rotation=45, ha='right')
# Manually create legend
handles = []
for regressor, color in zip(aggregated_df['regressors'].unique(), sns.color_palette('tab10', len(aggregated_df['regressors'].unique()))):
    handles.append(mpatches.Patch(color=color, label=regressor))
plt.legend(handles=handles, title='Features', loc='upper left', bbox_to_anchor=(1, 1))
plt.tight_layout()
sns.despine()
plt.show()

fig.savefig(os.path.join(results_path, f'weights_all_small_model_{epoch}.pdf'), bbox_inches='tight')

In [None]:
# Plot using the collapsed data
for odor_label in ['Alpha-pinene', 'Methyl Butyrate']:
    fig, axes = plt.subplots(4, 4, figsize=(20, 14), sharex=True)

    # Perform t-tests and plot significance
    for (mouse, group), ax in zip(weights_df.loc[weights_df.odor_label == odor_label].groupby('mouse'), axes.flatten()):
        # Perform t-test for each regressor in the group
        significant_regressors = []
        for regressor in group['regressors'].unique():
            regressor_data = group[group['regressors'] == regressor]['weights']
            t_stat, p_value = ttest_1samp(regressor_data, 0)
            
            # Determine the significance level
            if p_value < 0.001:
                significance = '***'
            elif p_value < 0.01:
                significance = '**'
            elif p_value < 0.05:
                significance = '*'
            else:
                significance = None

            if significance:
                significant_regressors.append((regressor, regressor_data.max(), significance))

        # Plot the swarmplot
        sns.swarmplot(
            data=group, 
            x='regressors', 
            y='weights', 
            palette='tab10', 
            ax=ax, 
            hue='regressors',
            dodge=True,
            legend=False
        )
        ax.set_title(f'Mouse {mouse}')
        ax.set_xlabel('')
        ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='right')
        ax.hlines(0, -0.5, len(group['regressors'].unique()) - 0.5, color='black', linestyle='--')

        # Annotate significant results
        for regressor, max_value, significance in significant_regressors:
            x = list(group['regressors'].unique()).index(regressor)
            y = max_value + 0.05  # Position above max value
            ax.text(x, y, significance, ha='center', va='bottom', fontsize=12, color='black')

    # Manually create the legend
    handles = []
    for regressor, color in zip(weights_df['regressors'].unique(), sns.color_palette('tab10', len(weights_df['regressors'].unique()))):
        handles.append(mpatches.Patch(color=color, label=regressor))

    # Add legend at the bottom with 3 columns
    fig.legend(
        handles=handles,
        bbox_to_anchor=(0.6, 0.05),  # Centered below the figure
        loc='upper center',
        ncol=3,  # Number of columns
        title='Features',
        prop={'size': 12}
    )

    sns.despine()
    plt.tight_layout()
    plt.subplots_adjust()  # Add space at the bottom for the legend
    plt.xticks(rotation=45, ha='right')
    plt.show()
    fig.savefig(os.path.join(results_path, f'weights_per_mouse_small_model_{epoch}_{odor_label}.pdf'), bbox_inches='tight')