# Occlusion robustness analysis

Robustness of learned feature selectivity (left-convex boundary element) to progressive visual occlusion.

**Overview**

- Load inference recordings with occlusion applied at multiple levels
- Compute stimulus-specific information measures for L4 neurons regarding left-convex features
- Visualise degradation of feature selectivity as occlusion increases

**Key metrics:**

- Specific information $I(S; R)$ for left-convex boundary elements
- Number of informative neurons (exceeding threshold) per occlusion level
- Firing rate statistics across occlusion conditions

**Dependencies**

Long runtime: sweeping over multiple occlusion levels

```bash
for trial in 3 7 15; do
    python scripts/analysis/inference_occlusion.py ./experiments/n4p2/train_n4p2_lrate_0_02_181023 $trial
done
```

**Plots**


In [None]:
from pprint import pprint

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import xarray as xr

from hsnn import analysis, viz
from hsnn.analysis import measures
from hsnn.utils import handler, io

OUTPUT_DIR = io.BASE_DIR / "out/figures/fig13"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

viz.setup_journal_env()

### 1) Load occlusion inference results

In [None]:
dataset_name = "n4p2"
model_type = "ALL"
logdir = f"{dataset_name}/train_{dataset_name}_lrate_0_02_181023"

# Get experiment and trials from metadata
expt = handler.ExperimentHandler(logdir)
trials_dict = expt.metadata.get_trials_dict(model_type)

if "inference" not in trials_dict:
    raise ValueError("No 'inference' trials found in metadata")

trial_names = trials_dict["inference"]
trials = [expt[name] for name in trial_names]
print(f"Found {len(trials)} inference trials for {model_type}: {trial_names}")

In [None]:
# Load occlusion inference results from trials
def load_occlusion_results(
    trial: handler.TrialView, subdir: str = "occlusion"
) -> xr.DataArray:
    """Load occlusion inference results for a trial."""
    results_path = handler.get_results_path(trial, chkpt=-1, subdir=subdir)
    if not results_path.exists():
        raise FileNotFoundError(f"Results not found: {results_path}")
    return io.load_pickle(results_path)


# Load results from all trials
records_list = []
for trial in trials:
    records = load_occlusion_results(trial)
    records_list.append(records)
    print(f"Loaded {trial.name}: {records.dims}")

# Load metadata
metadata = records.attrs.copy()
pprint(metadata)

# Load config and labels
cfg = trials[0].config
imageset, labels = io.get_dataset(cfg["training"]["data"], return_annotations=True)
print(f"Loaded {len(imageset)} images")

In [None]:
# Common parameters
layer = 4
offset = 50.0
duration = metadata["duration"] - offset
target = 1  # Convex
side = "left"  # Left-convex boundary element
num_classes = 2
max_info = np.log2(num_classes)

# Get occlusion levels from first recording
occlusion_levels = records_list[0].coords["occlusion"].values
print(f"Duration: {duration} ms, Offset: {offset} ms")
print(f"Occlusion levels: {occlusion_levels}")
print(f"Target: {side}-convex (label={target})")

In [None]:
# Get images with left-convex features
pos_ids = np.flatnonzero(labels[side] == target)
neg_ids = np.flatnonzero(labels[side] != target)

print(f"Positive (left-convex) images: {pos_ids}")
print(f"Negative images: {neg_ids}")

In [None]:
def plot_occlusion_examples(
    image: np.ndarray,
    occlusion_levels: np.ndarray,
    bump_start: int,
    bump_width: int,
    ax: plt.Axes | None = None,
    fig_width: float = 3.5,
    title_fontsize: int = 8,
    show_titles: bool = True,
    # Inset positioning parameters (in axes coordinates [0, 1])
    inset_x: float = 0.55,
    inset_y: float = 0.55,
    inset_width: float = 0.4,
    inset_height: float | None = None,
    inset_spacing: float = 0.05,
    # Standalone figure parameters
    standalone_wspace: float = 0.15,
) -> tuple[plt.Figure | None, np.ndarray]:
    """Plot example images with progressive occlusion overlay.

    Args:
        image: Input image array (H, W).
        occlusion_levels: Array of occlusion fractions [0.0, 1.0].
        bump_start: Pixel column where the feature starts.
        bump_width: Width of the feature in pixels.
        ax: If provided, use this axes for an inset (creates sub-axes).
        fig_width: Figure width in inches (ignored if ax provided).
        title_fontsize: Font size for panel titles.
        show_titles: Whether to show occlusion level titles.
        inset_x: X position of inset (left edge) in axes coordinates.
        inset_y: Y position of inset (bottom edge) in axes coordinates.
        inset_width: Total width of inset in axes coordinates.
        inset_height: Height of inset in axes coordinates. If None, computed
            from image aspect ratio.
        inset_spacing: Spacing between image panels as fraction of panel width.
        standalone_wspace: Width spacing for standalone figure.

    Returns:
        Tuple of (figure, axes). Figure is None if ax was provided.
    """
    n_levels = len(occlusion_levels)
    aspect = image.shape[0] / image.shape[1]

    if ax is not None:
        # Create inset axes within provided axes
        fig = ax.get_figure()

        # Calculate panel dimensions for inset
        panel_width = inset_width / n_levels
        panel_width_with_spacing = panel_width * (1 - inset_spacing)

        # Compute height from aspect if not specified
        if inset_height is None:
            inset_height = panel_width_with_spacing * aspect

        axes = []
        for i in range(n_levels):
            # Position in axes coordinates
            inset_ax = ax.inset_axes(
                [
                    inset_x + i * panel_width,
                    inset_y,
                    panel_width_with_spacing,
                    inset_height,
                ],
            )
            axes.append(inset_ax)
        axes = np.array(axes)
        created_fig = None
    else:
        # Create standalone figure
        panel_width = fig_width / n_levels
        panel_height = panel_width * aspect
        created_fig, axes = plt.subplots(
            1,
            n_levels,
            figsize=(fig_width, panel_height + 0.3),
            gridspec_kw={"wspace": standalone_wspace},
        )

    for i, occ_level in enumerate(occlusion_levels):
        panel_ax = axes[i]

        # Show original image
        panel_ax.imshow(image, cmap="gray", vmin=0, vmax=255)

        # Overlay semi-transparent gray mask for occluded region
        mask_width = int(bump_start + bump_width * occ_level) if occ_level > 0 else 0
        if mask_width > 0:
            panel_ax.axvspan(0, mask_width, color="gray", alpha=0.85)

        if show_titles:
            panel_ax.set_title(
                f"{occ_level * 100:.0f}%", fontsize=title_fontsize, pad=2
            )
        panel_ax.set_xticks([])
        panel_ax.set_yticks([])

        # Bounding box around each panel
        for spine in panel_ax.spines.values():
            spine.set_visible(True)
            spine.set_linewidth(0.5)
            spine.set_color("black")

    return created_fig, axes

### 2) Compute information measures per occlusion level

Compute stimulus-specific information for L4 neurons regarding left-convex features at each occlusion level.

In [None]:
# Cache path for pre-computed information measures (per trial, before averaging)
measures_cache_path = (
    OUTPUT_DIR / f"specific_measures_occ_{model_type}_{side}_per_trial.pkl"
)


def compute_measures_for_trial(
    records: xr.DataArray,
    occlusion_levels: np.ndarray,
    labels: pd.DataFrame,
    target: int,
    duration: float,
    offset: float,
    layer: int,
) -> dict[float, pd.DataFrame]:
    """Compute information measures for all occlusion levels in a trial."""
    measures_dict = {}
    for occ_level in occlusion_levels:
        records_occ = records.sel(occlusion=occ_level)
        ISR = measures.get_combined_measures(
            records_occ,
            labels=labels,
            target=target,
            duration=duration,
            offset=offset,
            layer=layer,
        )
        measures_dict[occ_level] = ISR
    return measures_dict


if measures_cache_path.exists():
    print(f"Loading cached measures from: {measures_cache_path}")
    all_trial_measures = io.load_pickle(measures_cache_path)
else:
    print(f"Computing information measures for {len(trials)} trials...")

    # Compute measures for each trial
    all_trial_measures: list[dict[float, pd.DataFrame]] = []
    for i, records in enumerate(records_list):
        print(f"  Processing trial {i + 1}/{len(trials)}...")
        trial_measures = compute_measures_for_trial(
            records, occlusion_levels, labels, target, duration, offset, layer
        )
        all_trial_measures.append(trial_measures)

        # Print max info per occlusion level for this trial
        for occ_level in occlusion_levels:
            max_info_val = trial_measures[occ_level][side].max()
            print(
                f"    Occlusion {occ_level * 100:.0f}%: max I(S;R) = {max_info_val:.3f} bits"
            )

    # Save to cache (per-trial, before averaging)
    io.save_pickle(all_trial_measures, measures_cache_path, parents=True)
    print(f"Saved per-trial measures to: {measures_cache_path}")

print(f"\nLoaded measures for {len(all_trial_measures)} trials")

In [None]:
# Extract ranked measures for each occlusion level PER TRIAL
# Then average across ranked series

ranked_per_trial: dict[float, list[np.ndarray]] = {occ: [] for occ in occlusion_levels}

for trial_measures in all_trial_measures:
    for occ_level in occlusion_levels:
        ISR = trial_measures[occ_level]
        ranked = ISR[side].sort_values(ascending=False).values
        ranked_per_trial[occ_level].append(ranked)

# Stack and compute mean/std across trials for each occlusion level
ranked_mean: dict[float, np.ndarray] = {}
ranked_std: dict[float, np.ndarray] = {}

for occ_level in occlusion_levels:
    stacked = np.vstack(ranked_per_trial[occ_level])  # (n_trials, n_neurons)
    ranked_mean[occ_level] = stacked.mean(axis=0)
    ranked_std[occ_level] = stacked.std(axis=0)

# Convert to DataFrame for easier plotting
ranked_df = pd.DataFrame(ranked_mean)
ranked_df.index.name = "neuron_rank"
ranked_std_df = pd.DataFrame(ranked_std)
ranked_std_df.index.name = "neuron_rank"

print(f"Averaged ranked measures across {len(all_trial_measures)} trials")
ranked_df.head(10)

### 3) Visualise information degradation with occlusion

**Figure 1:** Rank-ordered specific information across occlusion levels

In [None]:
f, ax = plt.subplots(figsize=(4, 2))

nrn_ids = np.arange(1, len(ranked_df) + 1)
cmap = plt.cm.copper
colors = [cmap(i / (len(occlusion_levels) - 1)) for i in range(len(occlusion_levels))]

for i, occ_level in enumerate(occlusion_levels):
    ax.plot(
        nrn_ids,
        ranked_df[occ_level].values,
        color=colors[i],
        label=f"{occ_level * 100:.0f}%",
        linewidth=1.5,
    )

# ax.axhline(max_info, color="black", linestyle=":", label="Max info")
ax.set_xlim(1, len(nrn_ids))
ax.set_ylim(0, 1.05)
ax.set_xscale("log")
ax.set_xlabel("Neuron rank #")
# ax.set_ylabel("Specific Information [bits]")
ax.set_ylabel(r"$\mathcal{I}\; (s, \vec{R})$")
# ax.set_title("L4 Neuron Selectivity vs. Occlusion Level")
ax.grid()
ax.legend(title="Occlusion", loc="upper right", title_fontsize="small")
f.tight_layout()
plt.show()

In [None]:
# Save informative neurons figure
output_path_ranked = OUTPUT_DIR / "fig_information_ranked_occlusion.pdf"
viz.save_figure(f, output_path_ranked, overwrite=True, dpi=300)

**Figure 2:** Number of informative neurons vs. occlusion level

In [None]:
threshold = 2 / 3
n_neurons = len(ranked_df)

# Compute informative neuron counts PER TRIAL, then average
counts_per_trial: dict[float, list[int]] = {occ: [] for occ in occlusion_levels}

for trial_measures in all_trial_measures:
    for occ_level in occlusion_levels:
        ISR = trial_measures[occ_level]
        count = np.sum(ISR[side].values >= threshold)
        counts_per_trial[occ_level].append(count)

# Compute mean and std for counts and percentages
num_selective_mean = []
num_selective_std = []
pct_selective_mean = []
pct_selective_std = []

for occ_level in occlusion_levels:
    counts = np.array(counts_per_trial[occ_level])
    num_selective_mean.append(counts.mean())
    num_selective_std.append(counts.std())
    pct_selective_mean.append(counts.mean() / n_neurons * 100)
    pct_selective_std.append(counts.std() / n_neurons * 100)

# Create summary DataFrame with errors
informative_summary_df = pd.DataFrame(
    {
        "occlusion": occlusion_levels,
        "count_mean": num_selective_mean,
        "count_std": num_selective_std,
        "pct_mean": pct_selective_mean,
        "pct_std": pct_selective_std,
    }
)
print(informative_summary_df.to_string(index=False))

# Plot
f, ax = plt.subplots(figsize=(4, 2.5))

x_pos = np.arange(len(occlusion_levels))
bars = ax.bar(
    x_pos,
    num_selective_mean,
    yerr=num_selective_std,
    color="C0",
    edgecolor="black",
    linewidth=0.5,
    capsize=3,
)

ax.set_xticks(x_pos)
ax.set_xticklabels([f"{int(occ * 100)}%" for occ in occlusion_levels])
ax.set_ylim(0, max(num_selective_mean) * 1.3)
ax.set_xlabel("Occlusion Level")
ax.set_ylabel("# informative neurons")
ax.set_axisbelow(True)
ax.grid(axis="y")

# Add value labels on bars (mean ± std as percentage)
for i, (bar, mean_pct, std_pct) in enumerate(
    zip(bars, pct_selective_mean, pct_selective_std)
):
    if mean_pct > 0:
        ax.annotate(
            f"{mean_pct:.2g}±{std_pct:.1g}%",
            xy=(
                bar.get_x() + bar.get_width() / 2,
                bar.get_height() + num_selective_std[i] + 1,
            ),
            ha="center",
            va="bottom",
            fontsize=8,
        )

# Add occlusion examples as inset (upper right)
img_idx = pos_ids[2]  # Left-convex image
image = imageset[img_idx]
bump_start = metadata["bump_start"]
bump_width = metadata["bump_width"]

plot_occlusion_examples(
    image,
    occlusion_levels,
    bump_start,
    bump_width,
    ax=ax,
    title_fontsize=7,
    show_titles=True,
    inset_x=0.45,
    inset_y=0.68,
    inset_width=0.55,
    inset_height=0.3,
    inset_spacing=0.2,
)

f.tight_layout()
plt.show()

In [None]:
# Save informative neurons figure
output_path_informative = OUTPUT_DIR / "fig_informative_neurons_occlusion.pdf"
viz.save_figure(f, output_path_informative, overwrite=True, dpi=300)

**Figure 3:** Summary statistics across occlusion levels

In [None]:
# Compute summary statistics per trial, then average
summary_per_trial: list[list[dict]] = []

for trial_measures in all_trial_measures:
    trial_stats = []
    for occ_level in occlusion_levels:
        ISR = trial_measures[occ_level][side]
        trial_stats.append(
            {
                "occlusion": occ_level,
                "max_info": ISR.max(),
                "mean_info": ISR.mean(),
                "median_info": ISR.median(),
                "num_selective_90": np.sum(ISR >= 0.9),
                "num_selective_67": np.sum(ISR >= 2 / 3),
            }
        )
    summary_per_trial.append(trial_stats)

# Average across trials
summary_stats = []
for i, occ_level in enumerate(occlusion_levels):
    vals = {key: [] for key in summary_per_trial[0][0].keys() if key != "occlusion"}
    for trial_stats in summary_per_trial:
        for key in vals:
            vals[key].append(trial_stats[i][key])

    summary_stats.append(
        {
            "occlusion": occ_level,
            "max_info_mean": np.mean(vals["max_info"]),
            "max_info_std": np.std(vals["max_info"]),
            "mean_info_mean": np.mean(vals["mean_info"]),
            "mean_info_std": np.std(vals["mean_info"]),
            "num_selective_67_mean": np.mean(vals["num_selective_67"]),
            "num_selective_67_std": np.std(vals["num_selective_67"]),
        }
    )

summary_df = pd.DataFrame(summary_stats)
summary_df

In [None]:
f, axes = plt.subplots(1, 2, figsize=(10, 4))

# Panel A: Max and mean information
ax = axes[0]
ax.plot(summary_df["occlusion"] * 100, summary_df["max_info_mean"], "o-", label="Max")
ax.plot(summary_df["occlusion"] * 100, summary_df["mean_info_mean"], "s-", label="Mean")
ax.axhline(max_info, color="black", linestyle=":")
ax.set_xlabel("Occlusion Level (%)")
ax.set_ylabel("Specific Information (bits)")
ax.set_title("Information vs. Occlusion")
ax.legend()
ax.grid(True, alpha=0.3)

# Panel B: Normalised selective neuron count
ax = axes[1]
baseline = summary_df["num_selective_67_mean"].iloc[0]
ax.plot(
    summary_df["occlusion"] * 100,
    summary_df["num_selective_67_mean"] / baseline * 100,
    "o-",
    color="C1",
)
ax.set_xlabel("Occlusion Level (%)")
ax.set_ylabel("Selective Neurons (% of baseline)")
ax.set_title("Robustness to Occlusion")
ax.axhline(100, color="black", linestyle=":")
ax.grid(True, alpha=0.3)

f.tight_layout()
plt.show()

### 4) Firing rate analysis across occlusion levels

Examine how neuronal firing rates change with occlusion.

In [None]:
# Compute firing rates per occlusion level (averaged across trials)
rates_occ: dict[float, xr.DataArray] = {}

for occ_level in occlusion_levels:
    # Compute rates for each trial and average
    trial_rates = []
    for records in records_list:
        records_occ = records.sel(occlusion=occ_level, layer=layer, nrn_cls="EXC")
        rates = analysis.infer_rates(records_occ, duration, offset)
        trial_rates.append(rates)

    # Average across trials
    rates_occ[occ_level] = sum(trial_rates) / len(trial_rates)

In [None]:
# Compute mean firing rates for positive vs negative stimuli
rate_stats = []

for occ_level in occlusion_levels:
    rates = rates_occ[occ_level]

    mean_pos = float(rates.sel(img=pos_ids).mean())
    mean_neg = float(rates.sel(img=neg_ids).mean())
    mean_all = float(rates.mean())
    max_rate = float(rates.max())

    rate_stats.append(
        {
            "occlusion": occ_level,
            "mean_rate_pos": mean_pos,
            "mean_rate_neg": mean_neg,
            "mean_rate_all": mean_all,
            "max_rate": max_rate,
        }
    )

rate_stats_df = pd.DataFrame(rate_stats)
rate_stats_df

In [None]:
f, ax = plt.subplots(figsize=(5, 4))

ax.plot(
    rate_stats_df["occlusion"] * 100,
    rate_stats_df["mean_rate_pos"],
    "o-",
    label="Left-convex stimuli",
)
ax.plot(
    rate_stats_df["occlusion"] * 100,
    rate_stats_df["mean_rate_neg"],
    "s-",
    label="Other stimuli",
)

ax.set_xlabel("Occlusion Level (%)")
ax.set_ylabel("Mean Firing Rate (Hz)")
ax.set_title("L4 Firing Rates vs. Occlusion")
ax.legend()
ax.grid(True, alpha=0.3)
f.tight_layout()
plt.show()

### 5) Track individual selective neurons

Follow the most selective neurons across occlusion levels.

In [None]:
# Get top-N neurons at baseline (0% occlusion)
topN = 10
baseline_ISR = all_trial_measures[0][0.0][side]
top_neurons = baseline_ISR.sort_values(ascending=False).head(topN).index.tolist()

print(f"Top {topN} neurons at baseline:")
for nrn in top_neurons:
    print(f"  Neuron {nrn}: I(S;R) = {baseline_ISR[nrn]:.3f} bits")

In [None]:
# Track these neurons across occlusion levels
tracking_data = []

for nrn in top_neurons:
    for occ_level in occlusion_levels:
        ISR = all_trial_measures[0][occ_level][side]
        tracking_data.append(
            {
                "neuron": nrn,
                "occlusion": occ_level,
                "info": ISR[nrn],
            }
        )

tracking_df = pd.DataFrame(tracking_data)
tracking_pivot = tracking_df.pivot(index="occlusion", columns="neuron", values="info")
tracking_pivot

In [None]:
f, ax = plt.subplots(figsize=(6, 4))

for nrn in top_neurons[:5]:  # Plot top 5
    ax.plot(
        tracking_pivot.index * 100,
        tracking_pivot[nrn],
        "o-",
        label=f"Neuron {nrn}",
        markersize=4,
    )

ax.axhline(max_info, color="black", linestyle=":")
ax.set_xlabel("Occlusion Level (%)")
ax.set_ylabel("Specific Information (bits)")
ax.set_title("Top Selective Neurons vs. Occlusion")
ax.legend(loc="lower left", fontsize="small")
ax.grid(True, alpha=0.3)
f.tight_layout()
plt.show()

### 6) Save analysis results

In [None]:
# Save summary statistics
output_dir = OUTPUT_DIR / dataset_name
output_dir.mkdir(parents=True, exist_ok=True)

summary_df.to_csv(output_dir / "occlusion_summary_stats.csv", index=False)
informative_summary_df.to_csv(
    output_dir / "occlusion_informative_counts.csv", index=False
)
rate_stats_df.to_csv(output_dir / "occlusion_rate_stats.csv", index=False)
tracking_pivot.to_csv(output_dir / "occlusion_neuron_tracking.csv")

print(f"Results saved to: {output_dir}")
print("  - Summary statistics: occlusion_summary_stats.csv")
print("  - Informative neuron counts: occlusion_informative_counts.csv")
print("  - Rate statistics: occlusion_rate_stats.csv")
print("  - Neuron tracking: occlusion_neuron_tracking.csv")

## Appendix

**Visualise example images with occlusion mask overlay**

In [None]:
# Standalone plot (use first trial's data for visualisation)
img_idx = pos_ids[2]  # Left-convex image
image = imageset[img_idx]
bump_start = metadata["bump_start"]
bump_width = metadata["bump_width"]

f, axes = plot_occlusion_examples(
    image, occlusion_levels, bump_start, bump_width, fig_width=3.5
)
plt.show()

In [None]:
# Save informative neurons figure
output_path_inset = OUTPUT_DIR / "occluded_images.pdf"
viz.save_figure(f, output_path_inset, overwrite=True, dpi=300)

In [None]:
# Topographic firing rates at different occlusion levels (from first trial)
img_idx = pos_ids[0]
rep = 0

# Use first trial's records for visualisation
records_first = records_list[0]

f, axes = plt.subplots(1, len(occlusion_levels), figsize=(3 * len(occlusion_levels), 3))

for i, occ_level in enumerate(occlusion_levels):
    ax = axes[i]
    records_occ = records_first.sel(occlusion=occ_level, layer=layer, nrn_cls="EXC")
    rates = analysis.infer_rates(records_occ, duration, offset)
    rates_img = rates.sel(img=img_idx, rep=rep)

    # Reshape to spatial
    layer_shape = (64, 64)  # L4 EXC shape
    rates_spatial = rates_img.values.reshape(layer_shape)

    im = ax.imshow(rates_spatial, cmap="inferno", vmin=0)
    ax.set_title(f"{occ_level * 100:.0f}%")
    ax.axis("off")

f.suptitle(f"L4 Firing Rates: Image {img_idx}", y=1.02)
f.tight_layout()
plt.show()