# Stretch and Squeeze Psychophysics Data Analysis and Figure Generation

This notebook generates figure panels and performs statistical analyses for the psychophysics experiments in the Stretch and Squeeze paper. 

A de-identified version of all experimental data is being released for this project. The variable DEIDENTIFIED_DATA is "True" when using the de-identified dataset for the analysis. Setting it to "False" allows authors who have access to the raw source data to generate demographic tables. 

In [None]:
DEIDENTIFIED_DATA = True

In [None]:
import xarray as xr
import pandas as pd
import numpy as np
import scipy.stats as stats
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib
import re
import os

# Setting fonttype to the meaning of life makes text editable in exported pdfs
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42

def get_df_from_xarray(data_paths, drop_columns=None):
  start_pt_idx = 0
  dfs = []
  for data_path in data_paths: 
    assert os.path.isfile(data_path), f"File {data_path} does not exist"
    ds = xr.open_dataset(data_path)
    raw_df = ds.to_dataframe().reset_index()
    df = raw_df[(raw_df['choice_slot'] == raw_df['i_choice']) | ((raw_df['i_choice'].isna()) & (raw_df['choice_slot'] == 0))]
    df["participant"] = start_pt_idx + df["participant"]
    start_pt_idx = start_pt_idx + df["participant"].nunique()
    dfs.append(df)

  combined_df = pd.concat(dfs, axis=0).reset_index(drop=True)

  if drop_columns is not None:
      combined_df = combined_df.drop(drop_columns, axis=1, errors='ignore')

  # Sort dataframe such that trials for each participant appear in order
  combined_df = combined_df.sort_values(by=['participant', 'obs'])

  # Sort columns in a logical order
  ordered_cols = ['participant', 'condition_idx', 'block', 'obs', 'trial_type', 'class', 'stimulus_image_url', 'stimulus_name', 'choice_name', 'i_correct_choice', 'i_choice', 'perf', 'reaction_time_msec', 'rel_timestamp_response', 'timestamp_start', 'monitor_width_px', 'monitor_height_px', 'stimulus_width_px', 'choice_width_px', 'stimulus_duration_msec', 'post_stimulus_delay_duration_msec', 'pre_choice_lockout_delay_duration_msec']
  other_cols = [col for col in combined_df.columns if col not in ordered_cols]
  combined_df = combined_df[ordered_cols + other_cols]

  return combined_df


# Function to calculate bootstrap confidence intervals
def bootstrap_ci(data, num_bootstrap_samples=10000, confidence_level=0.95):
    bootstrap_means = np.array([np.mean(np.random.choice(data, size=len(data), replace=True)) 
                                for _ in range(num_bootstrap_samples)])
    return np.percentile(bootstrap_means, [(1 - confidence_level) / 2 * 100, (1 + confidence_level) / 2 * 100])


def perform_chi_square(df, condition, control=0):
    # Filter the dataframe for only the control and the specific condition
    df_filtered = df[df['condition_idx'].isin([control, condition])]
    
    # Create a contingency table for the specific condition and control
    contingency_table = pd.crosstab(df_filtered['condition_idx'], df_filtered['perf'])
    
    # Perform the Chi-square test
    chi2, p, dof, expected = stats.chi2_contingency(contingency_table)
    
    return chi2, p, dof, expected, contingency_table


def chi_square_comparisons(df_trials_test, condition_idx_ordering, condition_labels, control_condition_idx=0):

    control_group_name = condition_labels[condition_idx_ordering.index(control_condition_idx)]

    # Perform chi-square tests for each condition compared to control
    for label_idx, condition_idx in enumerate(condition_idx_ordering):
        if condition_idx != control_condition_idx:
            chi2, p, dof, expected, contingency_table = perform_chi_square(df_trials_test, condition_idx, control=control_condition_idx)
            
            print(f"\nComparing {condition_labels[label_idx]} to {control_group_name}:")
            print(f"Chi-square value: {chi2:.4f}")
            print(f"P-value: {p:.4f}")
            print(f"Degrees of freedom: {dof}")
            
            # Interpret the results
            if p < 0.05:
                print(f"There is a significant difference in performance between {condition_labels[label_idx]} and {control_group_name}.")
            else:
                print(f"There is no significant difference in performance between {condition_labels[label_idx]} and {control_group_name}.")
            
            # Display the contingency table
            print("\nContingency Table:")
            print(contingency_table)
            
            # Display expected frequencies
            print("\nExpected frequencies:")
            print(pd.DataFrame(expected, index=contingency_table.index, columns=contingency_table.columns))
            
            print("\n" + "="*50)


def print_main_stats(df_trials, condition_idx_ordering, condition_labels, chance_level=0.25, test_blocks=[8,9]):
  condition_accuracies = []
  condition_cis = []

  for cond in condition_idx_ordering:
      cond_df = df_trials[(df_trials["condition_idx"] == cond) & (df_trials["block"].isin(test_blocks))]
      mean_accuracy = cond_df['perf'].mean()
      ci = bootstrap_ci(cond_df['perf'])
      condition_accuracies.append(mean_accuracy)
      condition_cis.append(ci)

  # Create a DataFrame for plotting
  condition_acc_data = pd.DataFrame({
      'Condition': condition_labels,
      'Accuracy': condition_accuracies,
      'CI_lower': [ci[0] for ci in condition_cis],
      'CI_upper': [ci[1] for ci in condition_cis]
  })

  # Calculate the error bars
  condition_acc_data['yerr_lower'] = condition_acc_data['Accuracy'] - condition_acc_data['CI_lower']
  condition_acc_data['yerr_upper'] = condition_acc_data['CI_upper'] - condition_acc_data['Accuracy']

  print("Condition accuracies (mean with 95% CIs):")
  for condition, accuracy, ci in zip(condition_acc_data['Condition'], condition_acc_data['Accuracy'], condition_cis):
      print(f"{condition}: {accuracy:.2f} ({ci[0]:.2f}, {ci[1]:.2f})")

  # Calculate control accuracy (assumed to be the first condition)
  control_accuracy = condition_acc_data['Accuracy'].iloc[0]

  # Calculate and print percentage increase in margin above chance
  print("\nPercentage increase in margin above chance:")
  for condition, accuracy in zip(condition_acc_data['Condition'][1:], condition_acc_data['Accuracy'][1:]):  # Skip the first (control) condition
      margin_control = control_accuracy - chance_level
      margin_condition = accuracy - chance_level
      
      percentage_increase = ((margin_condition - margin_control) / margin_control) * 100

      print("Condition acc:", accuracy, "| Control acc:", control_accuracy)
      print("Margin condition:", margin_condition, "| Margin control:", margin_control)
      print(f"{condition}: {percentage_increase:.1f}%")


  # Calculate training time for each participant
  df_trials_train = df_trials[~df_trials["block"].isin(test_blocks)]
  training_times = df_trials_train.groupby('participant')['rel_timestamp_response'].max().reset_index()
  training_times = training_times.merge(df_trials_train[['participant', 'condition_idx']], on='participant', how='left')

  # Function to calculate mean with bootstrap CI
  def mean_with_ci(data, num_bootstrap_samples=10000, ci=0.95):
      bootstrap_means = np.array([np.mean(np.random.choice(data, size=len(data), replace=True)) 
                                  for _ in range(num_bootstrap_samples)])
      mean = np.mean(data)
      ci_lower, ci_upper = np.percentile(bootstrap_means, [(1-ci)/2 * 100, (1+ci)/2 * 100])
      return mean, ci_lower, ci_upper

  # Calculate mean and CI for each condition
  training_time_results = []
  for c_idx, condition in enumerate(condition_idx_ordering):
      if condition == 0 or condition:
        condition_data = training_times[training_times['condition_idx'] == condition]['rel_timestamp_response']
        mean, ci_lower, ci_upper = mean_with_ci(condition_data)
        training_time_results.append({
            'condition': condition_labels[c_idx],
            'mean_training_time': round(mean/(1000*60), 4),
            'ci_lower':  round(ci_lower/(1000*60), 4),
            'ci_upper':  round(ci_upper/(1000*60), 4),
        })

  # Create a DataFrame with the results
  training_time_results_df = pd.DataFrame(training_time_results)
  print("Training times (minutes):")
  print(training_time_results_df)


  # Calculate completion time for each participant
  completion_times = df_trials.groupby('participant')['rel_timestamp_response'].max().reset_index()
  completion_times = completion_times.merge(df_trials[['participant', 'condition_idx']], on='participant', how='left')

  # Calculate mean and CI for each condition
  completion_time_results = []
  for c_idx, condition in enumerate(condition_idx_ordering):
      if condition == 0 or condition:
        condition_data = completion_times[completion_times['condition_idx'] == condition]['rel_timestamp_response']
        mean, ci_lower, ci_upper = mean_with_ci(condition_data)
        completion_time_results.append({
            'condition': condition_labels[c_idx],
            'mean_completion_time': round(mean/(1000*60), 4),
            'ci_lower':  round(ci_lower/(1000*60), 4),
            'ci_upper':  round(ci_upper/(1000*60), 4),
        })

  # Create a DataFrame with the results
  completion_time_results_df = pd.DataFrame(completion_time_results)
  print("Completion times (minutes):")
  print(completion_time_results_df)

  return condition_acc_data, training_time_results_df, completion_time_results_df


def assert_constant_counts(df):
    # Print unique trial types for debugging
    print("Unique trial types in the dataset:", df['trial_type'].unique())
    
    # Group by trialset_id and get value counts for trial_type
    counts = df.groupby(['experiment_id', 'trialset_id'])['trial_type'].value_counts().unstack(fill_value=0)
    
    # Get all unique trial types in the dataset
    all_trial_types = ['calibration', 'repeat_stimulus']

    for trial_type in all_trial_types:
        if trial_type in counts.columns:
            count_unique = counts[trial_type].nunique()
            if count_unique != 1:
                print(f"\nWarning: Count of {trial_type} is not constant across all trialset_ids")
                print(f"Unique counts for {trial_type}: {counts[trial_type].unique()}")
            else:
                print(f"\nCount of {trial_type} is constant ({counts[trial_type].iloc[0]}) across all trialset_ids")
        else:
            print(f"\nWarning: Trial type '{trial_type}' is not present in counts DataFrame")
            print("This might indicate an issue with data processing")

    print("\nAssertion check completed.")


def reassign_blocks(df, verbose=True):
    """
    Reassigns block values for specific participants in the dataframe based on observation numbers.
    Only affects participants who have trials with 'shuffle' in their trial_type.
    Includes verification of block sizes.
    
    Parameters:
    df (pd.DataFrame): Input dataframe containing columns 'participant', 'trial_type', 'obs', and 'block'
    
    Returns:
    pd.DataFrame: DataFrame with updated block values
    """
    # Create a copy of the dataframe to avoid modifying the original
    df_copy = df.copy()
    
    # Find participants who have 'shuffle' in any of their trial_type values
    shuffle_participants = df_copy[df_copy['trial_type'].str.contains('shuffle', na=False)]['participant'].unique()
    
    # Define the block structure
    block_structure = {
        0: 18,    # Block 0 has 18 trials
        **{i: 19 for i in range(1, 8)},    # Blocks 1-7 have 19 trials each
        **{i: 25 for i in range(8, 10)}    # Blocks 8-9 have 25 trials each
    }

    print(block_structure)
    
    # Calculate cumulative trial counts for block boundaries
    cumulative_trials = [sum(block_structure[i] for i in range(k)) for k in range(len(block_structure) + 1)]
    
    # Function to assign block based on observation number
    def get_block(obs):
        for block_num, trial_boundary in enumerate(cumulative_trials[1:]):
            if obs < trial_boundary:
                return block_num
        return len(block_structure) - 1  # Return last block number if beyond all boundaries
    
    # Process each participant who needs block reassignment
    for participant in shuffle_participants:
        # Get participant's data
        mask = df_copy['participant'] == participant
        participant_data = df_copy[mask].copy()
        
        # Sort by observation number
        participant_data = participant_data.sort_values('obs')
        
        # Assign new block values based on observation position
        df_copy.loc[mask, 'block'] = participant_data['obs'].apply(get_block)
    
    # Verify block sizes for each participant
    print("\nVerifying block sizes for each participant:")
    if verbose:
        print("------------------------------------------")
    
    all_correct = True
    for participant in shuffle_participants:
        participant_data = df_copy[df_copy['participant'] == participant]
        
        for block, expected_trials in block_structure.items():
            block_trials = len(participant_data[participant_data['block'] == block])
            
            if block_trials != expected_trials:
                print(f"WARNING: Participant {participant} has {block_trials} trials in block {block} (expected {expected_trials})")
                all_correct = False
            elif verbose:
                print(f"Participant {participant} has correct number of trials ({expected_trials}) in block {block}")
    
    if all_correct:
        print("\nVERIFICATION PASSED: All participants have the correct number of trials in each block!")
    else:
        print("\nVERIFICATION FAILED: Some participants have incorrect numbers of trials in certain blocks.")
        
    # Additional summary across all affected participants
    if verbose:
        print("\nSummary across all participants with shuffled trials:")
        print("--------------------------------------------------")
    for block, expected_trials in block_structure.items():
        total_trials = sum(len(df_copy[(df_copy['participant'] == p) & (df_copy['block'] == block)]) 
                          for p in shuffle_participants)
        num_participants = len(shuffle_participants)
        if total_trials == expected_trials * num_participants:
            if verbose:
                print(f"✓ Block {block}: All participants have exactly {expected_trials} trials")
        else:
            print(f"✗ Block {block}: Expected {expected_trials} trials per participant, "
                  f"found {total_trials/num_participants:.1f} on average")
    
    return df_copy

In [None]:
# Set the current working directory to the parent directory (which contains the "notebooks" directory among others)
changed_dir = False
if not changed_dir and os.path.exists("./make_figs.ipynb"):
  os.chdir(os.path.dirname(os.getcwd()))
  changed_dir = True
assert os.path.exists("./notebooks/make_figs.ipynb"), "Make sure your working directory starts in 'notebooks'"

os.makedirs("notebooks/fig_outputs", exist_ok=True)

In [None]:
drop_columns = ["stimulus_image_url_l", "stimulus_image_url_r", "class_l", "class_r", "mask_duration_msec", "mask_image_url", "choice_slot", "choice_image_urls", "keep_stimulus_on", "query_string", "platform", "bonus_usd_if_correct"]
if DEIDENTIFIED_DATA:
  drop_columns.extend(["assignment_id", "worker_id"])

## ImageNet 12-way object classification experiments

In [None]:
def load_imagenet12_dataset(data_path, RUN_TESTS=True, drop_columns=None):
  # Load dataset and convert to a dataframe with 1 row per trial
  ds = xr.open_dataset(data_path)

  raw_df = ds.to_dataframe().reset_index()

  # Filter rows where choice_slot equals i_choice
  df = raw_df[raw_df['choice_slot'] == raw_df['i_choice']].copy()

  # Sort dataframe such that trials for each participant appear in order
  df = df.sort_values(by=['participant', 'obs'])

  # Sort the columns in a logical order
  ordered_cols = ['participant', 'condition_idx', 'block', 'obs', 'trial_type', 'class', 'stimulus_image_url', 'stimulus_name', 'choice_name', 'i_correct_choice', 'i_choice', 'perf', 'reaction_time_msec', 'rel_timestamp_response', 'timestamp_start', 'monitor_width_px', 'monitor_height_px', 'stimulus_width_px', 'choice_width_px', 'stimulus_duration_msec', 'post_stimulus_delay_duration_msec', 'pre_choice_lockout_delay_duration_msec']
  other_cols = [col for col in df.columns if col not in ordered_cols]
  df = df[ordered_cols + other_cols]

  # Recover info about whether each image was from train, val, etc.
  df["split"] = df.apply(lambda row: row['stimulus_image_url'].split(".s3.amazonaws.com/")[1].split("/")[0], axis=1)

  if RUN_TESTS:
    # Sanity check that will almost certainly fail if stimulus_name and choice_name are calculated incorrectly
    for _, row in df.iterrows():
      if row["i_choice"] == row["i_correct_choice"]:
        assert(row["stimulus_name"] == row["choice_name"]), "stim=" + row["stimulus_name"] + ", choice=" + row["choice_name"]
      else:
        assert(row["stimulus_name"] != row["choice_name"]), "stim=" + row["stimulus_name"] + ", choice=" + row["choice_name"]

  if drop_columns is not None:
    df = df.drop(drop_columns, axis=1, errors='ignore')
  
  return df

In [None]:
if DEIDENTIFIED_DATA and os.path.isfile("psych_data/df_main_6ani_6nonani_v1_2.csv") and os.path.isfile("psych_data/df_main_6ani_6nonani_v1_4.csv"):
    print("Reading datasets from saved .csv")
    df_main_6ani_6nonani_v1_2 = pd.read_csv("psych_data/df_main_6ani_6nonani_v1_2.csv")
    df_main_6ani_6nonani_v1_4 = pd.read_csv("psych_data/df_main_6ani_6nonani_v1_4.csv")
else: # Load from raw .h5
    print("Reading datasets from .h5 files")

    df_main_6ani_6nonani_v1_2 = load_imagenet12_dataset("./psych_data/6ani_6nonani_v1_2_combined_dataset.h5", drop_columns=drop_columns)
    df_main_6ani_6nonani_v1_2["experiment_id"] = "6ani_6nonani_v1_2"

    df_main_6ani_6nonani_v1_4 = load_imagenet12_dataset("./psych_data/6ani_6nonani_v1_4_combined_dataset.h5", drop_columns=drop_columns)
    df_main_6ani_6nonani_v1_4["experiment_id"] = "6ani_6nonani_v1_4"

    if DEIDENTIFIED_DATA:
        print("Saving de-identified version of the dataset")
        df_main_6ani_6nonani_v1_2.to_csv("psych_data/df_main_6ani_6nonani_v1_2.csv", index=False)
        df_main_6ani_6nonani_v1_4.to_csv("psych_data/df_main_6ani_6nonani_v1_4.csv", index=False)

In [None]:
import re

def reconstruct_split_name(url):
    """
    Extracts parts from the stimulus_image_url to create a new split name.
    Expected format: "base_folder-filename_part"
    Example Input: "https://.../vanilla resnet50/unit [805]/Stretch in conv25/unit_[805]_Stretch in conv25_1.png"
    Example Output: "vanilla_resnet50-Stretch_in_conv25"
    """
    if pd.isna(url) or not isinstance(url, str):
        return None # Return None for missing or non-string URLs

    try:
        # Split the URL by '/'
        parts = url.split('/')

        # --- Part 1: Base-level folder name ---
        # Assuming it's the first folder after the domain name (e.g., index 3)
        if len(parts) > 3:
            base_folder_raw = parts[3]
            # Replace spaces with underscores
            part1 = base_folder_raw.replace(' ', '_')
        else:
            # Handle unexpected URL structure
            # print(f"Warning: Could not extract base folder from URL: {url}")
            return None

        # --- Part 2: Filename part ---
        filename = parts[-1] # Get the last part (filename)

        # Use regex to extract the part between unit_[<int>]_ and the final _<int>.png
        # Pattern explanation:
        # unit_\[\d+\]_  : Matches "unit_[", one or more digits, "]_" literally
        # (.*?)          : Captures any characters non-greedily (this is our target)
        # _\d+\.png$     : Matches "_", one or more digits, ".png" at the end of the string
        match = re.search(r'unit_\[\d+\]_(.*?)_\d+\.png$', filename)

        if match:
            filename_part_raw = match.group(1) # Get the captured group
            # Replace spaces with underscores
            part2 = filename_part_raw.replace(' ', '_')
        else:
            # Handle cases where the filename doesn't match the expected pattern
            # print(f"Warning: Could not extract filename part from: {filename} in URL: {url}")
            return None # Or potentially try alternative extraction logic if needed

        # --- Combine the parts ---
        return f"{part1}-{part2}"

    except Exception as e:
        # Catch any other unexpected errors during processing
        print(f"Error processing URL: {url} - Error: {e}")
        return None
    
#df_main_6ani_6nonani_v1_2['split_recon'] = df_main_6ani_6nonani_v1_2['stimulus_image_url'].apply(reconstruct_split_name)

In [None]:
df_main_6ani_6nonani = df_main_6ani_6nonani_v1_4

In [None]:
condition_not_natural = df_main_6ani_6nonani['split'] != 'natural'

# 2. Apply the function ONLY to the 'stimulus_image_url' of these rows
#    This creates a Series containing the reconstructed names, indexed like the original non-natural rows
reconstructed_names = df_main_6ani_6nonani.loc[condition_not_natural, 'stimulus_image_url'].apply(reconstruct_split_name)

# 3. Initialize the 'split_recon' column with 'natural' as the default for ALL rows
df_main_6ani_6nonani['split_recon'] = 'natural'

# 4. Use .loc with the condition again to overwrite the 'split_recon' values
#    ONLY for the non-natural rows, using the 'reconstructed_names' Series.
#    Pandas automatically aligns the assignment based on the index.
df_main_6ani_6nonani.loc[condition_not_natural, 'split_recon'] = reconstructed_names

df_main_6ani_6nonani

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns # Often used for nicer default plot styles

split_col = "split_recon"

# 1. Preprocessing: Handle NaNs in 'perf' column
print("Original 'perf' value counts (including NaN):")
print(df_main_6ani_6nonani['perf'].value_counts(dropna=False))
df_main_6ani_6nonani['perf_cleaned'] = df_main_6ani_6nonani['perf'].fillna(0)
print("\n'perf_cleaned' value counts (NaN replaced with 0):")
print(df_main_6ani_6nonani['perf_cleaned'].value_counts(dropna=False))

# 2. Calculate Mean Performance per Split
mean_perf = df_main_6ani_6nonani.groupby(split_col)['perf_cleaned'].mean()

# 3. Calculate 95% Confidence Intervals per Split using bootstrap
# Apply the bootstrap function to the 'perf_cleaned' data for each group
ci_perf = df_main_6ani_6nonani.groupby(split_col)['perf_cleaned'].apply(
    lambda x: bootstrap_ci(x, num_bootstrap_samples=10000, confidence_level=0.95)
)

# 4. Combine means and CIs into a summary DataFrame
summary_df = pd.DataFrame({
    'mean': mean_perf,
    'ci': ci_perf
})
# Split the CI tuple into separate lower and upper bound columns
summary_df[['ci_lower', 'ci_upper']] = pd.DataFrame(summary_df['ci'].tolist(), index=summary_df.index)
summary_df = summary_df.drop(columns='ci') # Remove the original tuple column

# 5. Sort the DataFrame according to requirements
# Get unique split names
split_names = summary_df.index.unique().tolist()

# Sort alphabetically, but ensure 'natural' is first if it exists
if 'natural' in split_names:
    split_names.remove('natural')
    sorted_splits = ['natural'] + sorted(split_names)
else:
    sorted_splits = sorted(split_names)

# Reindex the summary DataFrame based on the desired order
summary_df = summary_df.loc[sorted_splits]

# Reset index to make 'split' a regular column for plotting
summary_df = summary_df.reset_index()

print("\nSummary Statistics (Sorted):")
print(summary_df)

# 6. Create the Bar Plot with Confidence Intervals

# Calculate the error values for the plot (distance from the mean to the CI bounds)
# yerr should be in the format [[lower_errors], [upper_errors]]
lower_error = summary_df['mean'] - summary_df['ci_lower']
upper_error = summary_df['ci_upper'] - summary_df['mean']
error_bars = [lower_error, upper_error]

# Create the plot
plt.figure(figsize=(10, 6)) # Adjust figure size as needed
bars = plt.bar(summary_df[split_col], summary_df['mean'], yerr=error_bars, capsize=5, color='skyblue', edgecolor='black')

# Add labels and title
plt.xlabel('Split Condition', fontsize=12)
plt.ylabel('Mean Performance (Accuracy)', fontsize=12)
plt.title('Mean Performance by Split with 95% Bootstrap CI (N=10,000)', fontsize=14)

# Improve readability
plt.xticks(rotation=45, ha='right') # Rotate x-axis labels if they overlap
plt.ylim(0, 1.05) # Set y-axis limits appropriate for accuracy (0 to 1)
plt.grid(axis='y', linestyle='--', alpha=0.7) # Add horizontal grid lines

# Add mean values on top of bars for clarity
for bar, mean_val in zip(bars, summary_df['mean']):
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2.0, yval + 0.02, f'{mean_val:.3f}', va='bottom', ha='center', fontsize=9)

# Ensure tight layout and display the plot
plt.tight_layout()
plt.show()