# Synaptipy — Optogenetic Analysis Example

This notebook demonstrates Synaptipy’s optogenetic synchronization analysis:

1. Extract TTL/digital stimulus epochs from an auxiliary channel
2. Correlate action-potential times with light stimuli
3. Compute optical latency, response probability, and jitter

Because the sample ABF files in `examples/data/` may not contain a TTL
channel, this notebook generates **synthetic data** to illustrate the API.
Replace the synthetic section with your own file loading when applying to
real recordings.

**Requirements:** `pip install synaptipy matplotlib`

In [8]:
import numpy as np
import matplotlib.pyplot as plt

from Synaptipy.core.analysis.optogenetics import (
    extract_ttl_epochs,
    calculate_optogenetic_sync,
)
from Synaptipy.core.analysis.spike_analysis import detect_spikes_threshold

print("Imports OK")

AttributeError: module 'matplotlib' has no attribute 'get_data_path'

## 1. Create Synthetic Optogenetics Data

We simulate:
- A voltage trace with spikes that reliably follow light pulses
  (latency ~ 5 ms ± 1 ms jitter)
- A TTL channel with 10-ms square pulses at 2 Hz

In [None]:
np.random.seed(42)
fs = 20_000  # Hz
duration = 5.0  # seconds
n_samples = int(fs * duration)
time = np.arange(n_samples) / fs

# --- TTL channel: 2 Hz, 10 ms pulses ---
ttl = np.zeros(n_samples)
stim_interval = 0.5  # seconds (2 Hz)
pulse_width = 0.010  # 10 ms
stim_onsets = np.arange(0.1, duration - 0.1, stim_interval)
for onset in stim_onsets:
    i0 = int(onset * fs)
    i1 = int((onset + pulse_width) * fs)
    ttl[i0:i1] = 5.0  # 5 V TTL

# --- Voltage trace: resting at -70 mV with evoked spikes ---
voltage = -70.0 + np.random.normal(0, 0.5, n_samples)
response_probability = 0.8  # 80 % response rate
latency_mean_ms = 5.0
latency_std_ms = 1.0

for onset in stim_onsets:
    if np.random.rand() < response_probability:
        latency_s = (latency_mean_ms + np.random.normal(0, latency_std_ms)) / 1000.0
        spike_time = onset + max(latency_s, 0.001)
        idx = int(spike_time * fs)
        if idx + 20 < n_samples:
            # Simple triangular spike waveform
            voltage[idx:idx + 5] = np.linspace(-70, 30, 5)
            voltage[idx + 5:idx + 10] = np.linspace(30, -75, 5)
            voltage[idx + 10:idx + 20] = np.linspace(-75, -70, 10)

print(f"Created {len(stim_onsets)} stimuli over {duration} s")
print(f"Expected response rate: {response_probability * 100:.0f}%")

In [None]:
# Visualise the raw data
fig, axes = plt.subplots(2, 1, figsize=(14, 5), sharex=True)
axes[0].plot(time, voltage, linewidth=0.4, color="k")
axes[0].set_ylabel("Vm (mV)")
axes[0].set_title("Voltage trace with light-evoked spikes")

axes[1].plot(time, ttl, linewidth=0.8, color="dodgerblue")
axes[1].set_ylabel("TTL (V)")
axes[1].set_xlabel("Time (s)")
axes[1].set_title("TTL stimulus channel")

plt.tight_layout()
plt.show()

## 2. Extract TTL Epochs

`extract_ttl_epochs` detects rising and falling edges of the digital signal
and returns stimulus onset / offset times in seconds.

In [None]:
onsets, offsets = extract_ttl_epochs(ttl, time, threshold=2.5)

print(f"Detected {len(onsets)} stimulus epochs")
print(f"First onset:  {onsets[0]:.4f} s")
print(f"First offset: {offsets[0]:.4f} s")
print(f"Pulse width:  {(offsets[0] - onsets[0]) * 1000:.1f} ms")

## 3. Detect Spikes

Run threshold-based spike detection on the voltage trace.

In [None]:
refractory = int(0.002 * fs)  # 2 ms
spike_result = detect_spikes_threshold(
    voltage, time,
    threshold=-20.0,
    refractory_samples=refractory,
)

ap_times = spike_result.spike_times if spike_result.spike_times is not None else np.array([])
print(f"Detected {len(ap_times)} action potentials")

## 4. Optogenetic Synchronization

`calculate_optogenetic_sync` correlates detected spikes with TTL stimuli
and computes:
- **Optical latency** — median time from stimulus onset to first spike
- **Response probability** — fraction of stimuli that evoke ≥1 spike
- **Jitter** — standard deviation of spike latencies

In [None]:
opto_result = calculate_optogenetic_sync(
    ttl_data=ttl,
    action_potential_times=ap_times,
    time=time,
    ttl_threshold=2.5,
    response_window_ms=20.0,
)

print(f"Stimulus count:       {opto_result.stimulus_count}")
print(f"Response probability: {opto_result.response_probability:.2f}")
print(f"Optical latency:     {opto_result.optical_latency_ms:.2f} ms")
print(f"Spike jitter:        {opto_result.spike_jitter_ms:.2f} ms")

## 5. Peri-Stimulus Time Histogram (PSTH)

Plot the distribution of spike latencies relative to stimulus onset.

In [None]:
# Collect first-spike latencies relative to each stimulus onset
latencies = []
stim_onset_times = opto_result.stimulus_onsets  # array of onset times (s)

for onset, spk_list in zip(stim_onset_times, opto_result.responding_spikes):
    if spk_list is not None and len(spk_list) > 0:
        latency_ms = (spk_list[0] - onset) * 1000.0  # convert s → ms
        latencies.append(latency_ms)

fig, ax = plt.subplots(figsize=(8, 4))
if latencies:
    ax.hist(latencies, bins=20, color="steelblue", edgecolor="white")
    ax.axvline(np.median(latencies), color="red", linestyle="--",
               label=f"Median = {np.median(latencies):.2f} ms")
    ax.legend()
ax.set_xlabel("Latency from stimulus onset (ms)")
ax.set_ylabel("Count")
ax.set_title("Peri-Stimulus Time Histogram")
plt.tight_layout()
plt.show()

## 6. Raster Plot

Each row = one stimulus trial; dots show spike times relative to
stimulus onset.

In [None]:
fig, ax = plt.subplots(figsize=(8, 5))
for trial_idx, (onset, spk_list) in enumerate(
    zip(opto_result.stimulus_onsets, opto_result.responding_spikes)
):
    if spk_list is not None and len(spk_list) > 0:
        # Convert absolute spike times to latencies (ms) relative to onset
        latencies_trial = [(s - onset) * 1000.0 for s in spk_list]
        ax.scatter(latencies_trial, [trial_idx] * len(latencies_trial),
                   marker="|", color="k", s=40)

ax.axvline(0, color="cyan", linewidth=2, label="Stimulus onset")
ax.set_xlabel("Latency (ms)")
ax.set_ylabel("Trial")
ax.set_title("Raster Plot — Spike Latencies Per Stimulus")
ax.legend()
plt.tight_layout()
plt.show()

---

### Applying to Real Data

Replace the synthetic-data section with:

```python
from Synaptipy.infrastructure.file_readers.neo_adapter import NeoAdapter

adapter = NeoAdapter()
recording = adapter.read_recording("path/to/your/file.abf")

# Pick the voltage and TTL channels by name or index
v_ch = recording.channels["Channel 0"]
ttl_ch = recording.channels["Channel 1"]

voltage = v_ch.data_trials[0]
ttl = ttl_ch.data_trials[0]
```