# Stretch and Squeeze Psychophysics Data Analysis and Figure Generation

This notebook generates figure panels and performs statistical analyses for the psychophysics experiments in the Stretch and Squeeze paper. 

A de-identified version of all experimental data is being released for this project. The variable DEIDENTIFIED_DATA is "True" when using the de-identified dataset for the analysis. Setting it to "False" allows authors who have access to the raw source data to generate demographic tables. 

In [1]:
DEIDENTIFIED_DATA = True

In [2]:
import xarray as xr
import pandas as pd
import numpy as np
import seaborn as sns
import matplotlib
import re
import os
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches

# Setting fonttype to the meaning of life makes text editable in exported pdfs
matplotlib.rcParams['pdf.fonttype'] = 42
matplotlib.rcParams['ps.fonttype'] = 42


# Function to calculate bootstrap confidence intervals
def bootstrap_ci(data, num_bootstrap_samples=10000, confidence_level=0.95):
    bootstrap_means = np.array([np.mean(np.random.choice(data, size=len(data), replace=True)) 
                                for _ in range(num_bootstrap_samples)])
    return np.percentile(bootstrap_means, [(1 - confidence_level) / 2 * 100, (1 + confidence_level) / 2 * 100])

In [3]:
# Set the current working directory to the parent directory (which contains the "notebooks" directory among others)
changed_dir = False
if not changed_dir and os.path.exists("./make_figs.ipynb"):
  os.chdir(os.path.dirname(os.getcwd()))
  changed_dir = True
assert os.path.exists("./notebooks/make_figs.ipynb"), "Make sure your working directory starts in 'notebooks'"

os.makedirs("notebooks/fig_outputs", exist_ok=True)

In [4]:
drop_columns = ["stimulus_image_url_l", "stimulus_image_url_r", "class_l", "class_r", "mask_duration_msec", "mask_image_url", "choice_slot", "choice_image_urls", "keep_stimulus_on", "query_string", "platform", "bonus_usd_if_correct"]
if DEIDENTIFIED_DATA:
  drop_columns.extend(["assignment_id", "worker_id"])

## Data loading and preprocessing

In [5]:
def load_imagenet12_dataset(data_path, RUN_TESTS=True, drop_columns=None):
  assert os.path.isfile(data_path), "Data file not found: " + data_path

  # Load dataset and convert to a dataframe with 1 row per trial
  ds = xr.open_dataset(data_path)

  raw_df = ds.to_dataframe().reset_index()

  # Filter rows where choice_slot equals i_choice
  df = raw_df[raw_df['choice_slot'] == raw_df['i_choice']].copy()

  # Sort dataframe such that trials for each participant appear in order
  df = df.sort_values(by=['participant', 'obs'])

  # Sort the columns in a logical order
  ordered_cols = ['participant', 'condition_idx', 'block', 'obs', 'trial_type', 'class', 'stimulus_image_url', 'stimulus_name', 'choice_name', 'i_correct_choice', 'i_choice', 'perf', 'reaction_time_msec', 'rel_timestamp_response', 'timestamp_start', 'monitor_width_px', 'monitor_height_px', 'stimulus_width_px', 'choice_width_px', 'stimulus_duration_msec', 'post_stimulus_delay_duration_msec', 'pre_choice_lockout_delay_duration_msec']
  other_cols = [col for col in df.columns if col not in ordered_cols]
  df = df[ordered_cols + other_cols]

  # Recover info about whether each image was from train, val, etc.
  df["split"] = df.apply(lambda row: row['stimulus_image_url'].split(".s3.amazonaws.com/")[1].split("/")[0], axis=1)

  if RUN_TESTS:
    # Sanity check that will almost certainly fail if stimulus_name and choice_name are calculated incorrectly
    for _, row in df.iterrows():
      if row["i_choice"] == row["i_correct_choice"]:
        assert(row["stimulus_name"] == row["choice_name"]), "stim=" + row["stimulus_name"] + ", choice=" + row["choice_name"]
      else:
        assert(row["stimulus_name"] != row["choice_name"]), "stim=" + row["stimulus_name"] + ", choice=" + row["choice_name"]

  if drop_columns is not None:
    df = df.drop(drop_columns, axis=1, errors='ignore')
  
  return df

In [None]:
if DEIDENTIFIED_DATA and os.path.isfile("psych_data/df_main_6ani_6nonani_v1_2.csv") and os.path.isfile("psych_data/df_main_6ani_6nonani_v1_4.csv") and os.path.isfile("psych_data/df_main_6ani_6nonani_v1_6.csv") and os.path.isfile("psych_data/df_main_6ani_6nonani_v2_0.csv") and os.path.isfile("psych_data/df_main_6ani_6nonani_v2_1.csv"):
    print("Reading datasets from saved .csv")
    df_main_6ani_6nonani_v1_2 = pd.read_csv("psych_data/df_main_6ani_6nonani_v1_2.csv")
    df_main_6ani_6nonani_v1_4 = pd.read_csv("psych_data/df_main_6ani_6nonani_v1_4.csv")
    df_main_6ani_6nonani_v1_6 = pd.read_csv("psych_data/df_main_6ani_6nonani_v1_6.csv")
    df_main_6ani_6nonani_v2_0 = pd.read_csv("psych_data/df_main_6ani_6nonani_v2_0.csv")
    df_main_6ani_6nonani_v2_1 = pd.read_csv("psych_data/df_main_6ani_6nonani_v2_1.csv")
else: # Load from raw .h5
    print("Reading datasets from .h5 files")

    df_main_6ani_6nonani_v1_2 = load_imagenet12_dataset("./psych_data/6ani_6nonani_v1_2_combined_dataset.h5", drop_columns=drop_columns)
    df_main_6ani_6nonani_v1_2["experiment_id"] = "6ani_6nonani_v1_2"

    df_main_6ani_6nonani_v1_4 = load_imagenet12_dataset("./psych_data/6ani_6nonani_v1_4_combined_dataset.h5", drop_columns=drop_columns)
    df_main_6ani_6nonani_v1_4["experiment_id"] = "6ani_6nonani_v1_4"

    df_main_6ani_6nonani_v1_6 = load_imagenet12_dataset("./psych_data/6ani_6nonani_v1_6_combined_dataset.h5", drop_columns=drop_columns)
    df_main_6ani_6nonani_v1_6["experiment_id"] = "6ani_6nonani_v1_6"

    df_main_6ani_6nonani_v2_0 = load_imagenet12_dataset("./psych_data/6ani_6nonani_v2_0_combined_dataset.h5", drop_columns=drop_columns)
    df_main_6ani_6nonani_v2_0["experiment_id"] = "6ani_6nonani_v2_0"

    df_main_6ani_6nonani_v2_1 = load_imagenet12_dataset("./psych_data/6ani_6nonani_v2_1_combined_dataset.h5", drop_columns=drop_columns)
    df_main_6ani_6nonani_v2_1["experiment_id"] = "6ani_6nonani_v2_1"

    if DEIDENTIFIED_DATA:
        print("Saving de-identified version of the dataset")
        df_main_6ani_6nonani_v1_2.to_csv("psych_data/df_main_6ani_6nonani_v1_2.csv", index=False)
        df_main_6ani_6nonani_v1_4.to_csv("psych_data/df_main_6ani_6nonani_v1_4.csv", index=False)
        df_main_6ani_6nonani_v1_6.to_csv("psych_data/df_main_6ani_6nonani_v1_6.csv", index=False)
        df_main_6ani_6nonani_v2_0.to_csv("psych_data/df_main_6ani_6nonani_v2_0.csv", index=False)
        df_main_6ani_6nonani_v2_1.to_csv("psych_data/df_main_6ani_6nonani_v2_1.csv", index=False)

In [7]:
# Choose which version of the dataset to use

df_main_6ani_6nonani = df_main_6ani_6nonani_v2_1  # v2_1 is the version with 25 participants

In [8]:
# Preprocessing: Remove participants with less than 500 trials (e.g., those who didn't pass the screening phase)

df_main_6ani_6nonani = df_main_6ani_6nonani[df_main_6ani_6nonani['block'] > 0]
unique_obs_counts_per_row = df_main_6ani_6nonani.groupby('participant')['obs'].transform('nunique')
df_main_6ani_6nonani = df_main_6ani_6nonani[unique_obs_counts_per_row >= 500]

In [9]:
# Preprocessing: Reconstruct split names

def reconstruct_split_name(url):
    """
    Extracts parts from the stimulus_image_url to create a new split name.
    Expected format: "base_folder-filename_part"
    Example Input: "https://.../vanilla resnet50/unit [805]/Stretch in conv25/unit_[805]_Stretch in conv25_1.png"
    Example Output: "vanilla_resnet50-Stretch_in_conv25"
    """
    if pd.isna(url) or not isinstance(url, str):
        return None # Return None for missing or non-string URLs

    try:
        # Split the URL by '/'
        parts = url.split('/')

        # --- Part 1: Base-level folder name ---
        # Assuming it's the first folder after the domain name (e.g., index 3)
        if len(parts) > 3:
            base_folder_raw = parts[3]
            # Replace spaces with underscores
            part1 = base_folder_raw.replace(' ', '_')
        else:
            # Handle unexpected URL structure
            # print(f"Warning: Could not extract base folder from URL: {url}")
            return None

        # --- Part 2: Filename part ---
        filename = parts[-1] # Get the last part (filename)

        # Use regex to extract the part between unit_[<int>]_ and the final _<int>.png
        # Pattern explanation:
        # unit_\[\d+\]_  : Matches "unit_[", one or more digits, "]_" literally
        # (.*?)          : Captures any characters non-greedily (this is our target)
        # _\d+\.png$     : Matches "_", one or more digits, ".png" at the end of the string
        match = re.search(r'unit_\[\d+\]_(.*?)_\d+\.png$', filename)

        if match:
            filename_part_raw = match.group(1) # Get the captured group
            # Replace spaces with underscores
            part2 = filename_part_raw.replace(' ', '_')
        else:
            # Handle cases where the filename doesn't match the expected pattern
            # print(f"Warning: Could not extract filename part from: {filename} in URL: {url}")
            return None # Or potentially try alternative extraction logic if needed

        # --- Combine the parts ---
        return f"{part1}-{part2}"

    except Exception as e:
        # Catch any other unexpected errors during processing
        print(f"Error processing URL: {url} - Error: {e}")
        return None

condition_not_natural = df_main_6ani_6nonani['split'] != 'natural'

#  This creates a Series containing the reconstructed names, indexed like the original non-natural rows
reconstructed_names = df_main_6ani_6nonani.loc[condition_not_natural, 'stimulus_image_url'].apply(reconstruct_split_name)

# Initialize the 'split_recon' column with 'natural' as the default for ALL rows
df_main_6ani_6nonani['split_recon'] = 'natural'

# Use .loc with the condition again to overwrite the 'split_recon' values
#    ONLY for the non-natural rows, using the 'reconstructed_names' Series.
df_main_6ani_6nonani.loc[condition_not_natural, 'split_recon'] = reconstructed_names

In [10]:
# Preprocessing: Handle NaNs in 'perf' column

# Method 1: Replace NaNs with 0
# print("Original 'perf' value counts (including NaN):")
# print(df_main_6ani_6nonani['perf'].value_counts(dropna=False))
# df_main_6ani_6nonani['perf_cleaned'] = df_main_6ani_6nonani['perf'].fillna(0)
# print("\n'perf_cleaned' value counts (NaN replaced with 0):")
# print(df_main_6ani_6nonani['perf_cleaned'].value_counts(dropna=False))

# Method 2: Drop rows with NaNs
df_main_6ani_6nonani = df_main_6ani_6nonani.dropna(subset=['perf'])
df_main_6ani_6nonani["perf_cleaned"] = df_main_6ani_6nonani["perf"]


In [11]:
# Save cleaned dataset
df_main_6ani_6nonani.to_csv("psych_data/df_main_6ani_6nonani_v2_1_cleaned.csv", index=False)

## Plotting (boxplots)

In [None]:
split_col = "split_recon"

df_aggregated = df_main_6ani_6nonani.groupby(['participant', split_col])['perf_cleaned'].mean().reset_index()
df_aggregated.rename(columns={'perf_cleaned': 'participant_accuracy'}, inplace=True)

def map_split_to_plot_group_and_label(split_name):
    if split_name == 'natural':
        return "natural_group", "natural"
    elif split_name == 'robust_resnet50-mXDREAM_-_l2robust':
        return "Robust_MEI_group", "Robust MEI"
    elif split_name == 'robust_resnet50-Stretch_in_pixelspace':
        return "Robust_Pixel_space_group", "Pixel space"
    elif split_name == 'robust_resnet50-Stretch_in_conv25':
        return "Robust_Layer3_conv1_group", "Layer3_conv1"
    elif split_name == 'robust_resnet50-Stretch_in_conv51':
        return "Robust_Layer4_conv7_group", "Layer4_conv7"
    elif split_name == 'vanilla_resnet50-mXDREAM_-_vanilla':
        return "Standard_MEI_group", "Standard MEI"
    elif split_name == 'vanilla_resnet50-Stretch_in_pixelspace':
        return "Standard_Pixel_Space_group", "Pixel Space"
    elif split_name == 'vanilla_resnet50-Stretch_in_conv25':
        return "Standard_Layer3_conv1_group", "Layer3_conv1"
    elif split_name == 'vanilla_resnet50-Stretch_in_conv51':
        return "Standard_Layer4_conv7_group", "Layer4_conv7"
    else:
        return f"Unknown_{split_name}", f"Unknown: {split_name}"

mapped_values = df_aggregated[split_col].apply(map_split_to_plot_group_and_label)
df_aggregated['plot_group_id'] = [item[0] for item in mapped_values]

plot_group_id_order = [
    "natural_group",
    "Robust_MEI_group", "Robust_Pixel_space_group", "Robust_Layer3_conv1_group", "Robust_Layer4_conv7_group",
    "Standard_MEI_group", "Standard_Pixel_Space_group", "Standard_Layer3_conv1_group", "Standard_Layer4_conv7_group"
]

x_tick_display_labels = [
    'natural',
    'Robust MEI', 'Pixel space', 'Layer3_conv1', 'Layer4_conv7',
    'Standard MEI', 'Pixel Space', 'Layer3_conv1', 'Layer4_conv7'
]

intra_group_spacing = 1.0
inter_group_spacing_val = 1.8

x_positions = []
current_x = 0.0
x_positions.append(round(current_x, 5))

current_x += inter_group_spacing_val
num_robust_items = 4
for i in range(num_robust_items):
    x_positions.append(round(current_x + i * intra_group_spacing, 5))

current_x = round((current_x + (num_robust_items - 1) * intra_group_spacing) + inter_group_spacing_val, 5)
num_standard_items = 4
for i in range(num_standard_items):
    x_positions.append(round(current_x + i * intra_group_spacing, 5))

# Colors
natural_color = 'dodgerblue'
robust_mei_color = '#92edd1'
robust_general_color = '#1b9e77'
standard_mei_color = '#ffcdd2'
standard_color = '#e57373'

color_palette_categorical = {
    "natural_group": natural_color,
    "Robust_MEI_group": robust_mei_color,
    "Robust_Pixel_space_group": robust_general_color,
    "Robust_Layer3_conv1_group": robust_general_color,
    "Robust_Layer4_conv7_group": robust_general_color,
    "Standard_MEI_group": standard_mei_color,
    "Standard_Pixel_Space_group": standard_color,
    "Standard_Layer3_conv1_group": standard_color,
    "Standard_Layer4_conv7_group": standard_color,
}

# Plotting
plt.style.use('seaborn-v0_8-ticks') 
fig, ax = plt.subplots(figsize=(17, 10))

fontsize_axis_labels = 28
fontsize_tick_labels = 22
fontsize_legend = 28
linewidth_spines = 3.5
linewidth_ticks = 3.5
tick_length_major = 15.0
linewidth_box = 1.5
linewidth_median = 2.0
linewidth_whisker = 1.5
linewidth_cap = 1.5
markeredgewidth_flier = 1.0
markersize_flier = 6

box_width = 0.7 * intra_group_spacing

# Loop through groups and plot each box using ax.boxplot()
for i, group_id in enumerate(plot_group_id_order):
    group_df = df_aggregated[df_aggregated['plot_group_id'] == group_id]
    participant_accuracies = group_df['participant_accuracy'].dropna() # Ensure no NaNs for plt.boxplot

    current_x_position = x_positions[i]
    current_color = color_palette_categorical[group_id]

    bp = ax.boxplot(
        participant_accuracies,
        positions=[current_x_position], # Sets the x-coordinate for this specific box
        widths=box_width,
        patch_artist=True,
        manage_ticks=False, # will handle all tick configurations manually
        boxprops={'edgecolor': 'black', 'linewidth': linewidth_box, 'facecolor': current_color},
        medianprops={'color': 'black', 'linewidth': linewidth_median},
        whiskerprops={'color': 'black', 'linewidth': linewidth_whisker},
        capprops={'color': 'black', 'linewidth': linewidth_cap},
        flierprops={'marker': 'o', 'markersize': markersize_flier,
                    'markerfacecolor': 'white', 'markeredgecolor': 'black',
                    'markeredgewidth': markeredgewidth_flier}
    )

# Style the plot
ax.set_xlabel('Representation space', fontsize=fontsize_axis_labels, labelpad=25)
ax.set_ylabel('Accuracy', fontsize=fontsize_axis_labels, labelpad=20)

ax.set_xticks(x_positions) # Set tick marks AT the positions of our boxes
ax.set_xticklabels(x_tick_display_labels, rotation=45, ha='center', fontsize=fontsize_tick_labels)

ax.tick_params(axis='y', labelsize=fontsize_tick_labels)
ax.set_ylim(-0.03, 1.0)

# Adjust x-axis limits to give some padding around the boxes
ax.set_xlim(x_positions[0] - intra_group_spacing * 0.7, x_positions[-1] + intra_group_spacing * 0.7)

ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['left'].set_linewidth(linewidth_spines)
ax.spines['bottom'].set_linewidth(linewidth_spines)

ax.tick_params(axis='both', which='major', direction='out',
               width=linewidth_ticks, length=tick_length_major)

# Custom legend
legend_patches = [
    mpatches.Patch(color=robust_general_color, label='Robust'),
    mpatches.Patch(color=standard_color, label='Standard')
]
ax.legend(handles=legend_patches,
          loc='lower left',
          frameon=False,
          fontsize=fontsize_legend,
          ncol=1,
          handlelength=2.0,
          handleheight=0.8,
          labelspacing=0.7
         )

ax.grid(False)

plt.tight_layout()
plt.savefig("notebooks/fig_outputs/boxplot_accuracy_by_split.pdf", format='pdf', bbox_inches='tight', dpi=600)
plt.show()

## Plotting (means with bootstrap confidence intervals)

In [None]:
# Calculate Mean Performance per Split
mean_perf = df_main_6ani_6nonani.groupby(split_col)['perf_cleaned'].mean()

# Calculate 95% Confidence Intervals per Split using bootstrap
# Apply the bootstrap function to the 'perf_cleaned' data for each group
ci_perf = df_main_6ani_6nonani.groupby(split_col)['perf_cleaned'].apply(
    lambda x: bootstrap_ci(x, num_bootstrap_samples=10000, confidence_level=0.95)
)

# Combine means and CIs into a summary DataFrame
summary_df = pd.DataFrame({
    'mean': mean_perf,
    'ci': ci_perf
})
# Split the CI tuple into separate lower and upper bound columns
summary_df[['ci_lower', 'ci_upper']] = pd.DataFrame(summary_df['ci'].tolist(), index=summary_df.index)
summary_df = summary_df.drop(columns='ci') # Remove the original tuple column

# Sort the DataFrame according to requirements
# Get unique split names
split_names = summary_df.index.unique().tolist()

# Sort alphabetically, but ensure 'natural' is first if it exists
if 'natural' in split_names:
    split_names.remove('natural')
    sorted_splits = ['natural'] + sorted(split_names)
else:
    sorted_splits = sorted(split_names)

# Reindex the summary DataFrame based on the desired order
summary_df = summary_df.loc[sorted_splits]

# Reset index to make 'split' a regular column for plotting
summary_df = summary_df.reset_index()

print("\nSummary Statistics (Sorted):")
print(summary_df)

# Create the Bar Plot with Confidence Intervals

# Calculate the error values for the plot (distance from the mean to the CI bounds)
# yerr should be in the format [[lower_errors], [upper_errors]]
lower_error = summary_df['mean'] - summary_df['ci_lower']
upper_error = summary_df['ci_upper'] - summary_df['mean']
error_bars = [lower_error, upper_error]

# Create the plot
plt.figure(figsize=(10, 6)) # Adjust figure size as needed
bars = plt.bar(summary_df[split_col], summary_df['mean'], yerr=error_bars, capsize=5, color='skyblue', edgecolor='black')

# Add labels and title
plt.xlabel('Split Condition', fontsize=12)
plt.ylabel('Mean Performance (Accuracy)', fontsize=12)
plt.title('Mean Performance by Split with 95% Bootstrap CI (N=10,000)', fontsize=14)

# Improve readability
plt.xticks(rotation=45, ha='right') # Rotate x-axis labels if they overlap
plt.ylim(0, 1.05) # Set y-axis limits appropriate for accuracy (0 to 1)
plt.grid(axis='y', linestyle='--', alpha=0.7) # Add horizontal grid lines

# Add mean values on top of bars for clarity
for bar, mean_val in zip(bars, summary_df['mean']):
    yval = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2.0, yval + 0.02, f'{mean_val:.3f}', va='bottom', ha='center', fontsize=9)

# Ensure tight layout and display the plot
plt.tight_layout()
plt.show()

## Statistical analysis in R

In [None]:
!Rscript psych_code/scripts/glmm.r

In [None]:
!Rscript psych_code/scripts/alexnet_glmm.r

In [None]:
# Summary of missingness by participant (trials where the participant did not respond)

# Group by 'participant' and then aggregate:
# - Count the number of rows in each group (using 'size')
# - Find the maximum value of 'obs' in each group (using 'max')
summary_stats = df_main_6ani_6nonani.groupby('participant').agg(
    number_of_rows=('participant', 'size'),  # 'size' counts rows in each group
    max_obs_value=('obs', 'max')             # 'max' finds max of 'obs' column per group
)

# Print the results for each participant
print("Summary for each participant:")
for participant_id, data_row in summary_stats.iterrows():
    print(f"Participant ID: {participant_id}")
    print(f"  Number of rows: {data_row['number_of_rows']}")
    print(f"  Max value of obs: {data_row['max_obs_value']}")
    print("-" * 30)
