# Notebook 03: Pulse Shape Analysis

This notebook performs comprehensive pulse shape discrimination (PSD) analysis using waveform data from CSV files.

## Overview

Pulse shape analysis extracts temporal features from scintillator waveforms to:
1. Load event waveforms from CSV files (CAEN digitizer format)
2. Apply energy calibration from Notebook 02
3. Extract 15+ pulse features (rise time, decay constant, etc.)
4. Visualize characteristic waveforms for each scintillator
5. Compare pulse shapes across scintillator types
6. Prepare feature dataset for ML classification

## Key Pulse Features

### Temporal Features:
- **Rise time (10-90%)**: Time from 10% to 90% of peak
- **Fall time (90-10%)**: Time from 90% to 10% after peak
- **Peak position**: Time of maximum amplitude
- **Decay constant (τ)**: Exponential decay time constant

### Amplitude Features:
- **Amplitude**: Peak height above baseline
- **Baseline**: Average of pre-trigger samples
- **Baseline RMS**: Noise level

### Charge Features:
- **Total charge**: Integral of entire pulse
- **Tail charge**: Integral of late-time region
- **Tail-to-total ratio**: PSD parameter for particle discrimination

### Shape Features:
- **FWHM**: Full width at half maximum
- **Skewness**: Asymmetry of pulse shape
- **Kurtosis**: "Tailedness" of distribution

## Expected Decay Times

| Scintillator | Primary Decay (ns) | Secondary (ns) | Rise Time (ns) |
|--------------|-------------------|----------------|----------------|
| LYSO         | 40                | 53             | ~1             |
| BGO          | 300               | -              | ~2             |
| NaI(Tl)      | 230               | -              | ~2             |
| Plastic      | 2.4               | 14.2           | ~1             |

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from scipy import signal
from scipy.optimize import curve_fit
from typing import Dict, List, Tuple, Optional
import json
import pandas as pd
from tqdm.auto import tqdm

# Import our package modules
import sys
sys.path.append('..')

from src.io.caen_parsers import import_waveforms_csv, convert_csv_to_waveform_objects
from src.pulse_analysis.feature_extraction import PulseFeatureExtractor
from src.pulse_analysis.pulse_fitting import fit_exponential_decay

# Configure plotting
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)
plt.rcParams['font.size'] = 10

# CAEN DT5720D sampling rate
SAMPLING_RATE_MHZ = 250.0  # 250 MS/s
DT_NS = 1000.0 / SAMPLING_RATE_MHZ  # Time step in nanoseconds

print("Pulse Shape Analysis Notebook - Ready")
print(f"Sampling rate: {SAMPLING_RATE_MHZ} MS/s")
print(f"Time step: {DT_NS:.3f} ns")

## 1. Load Energy Calibration

Load the calibration parameters from Notebook 02 to convert ADC values to energy.

In [None]:
# Load energy calibration from previous notebook
calibration_path = Path('../data/processed/energy_calibration.json')

calibrations = {}
if calibration_path.exists():
    with open(calibration_path, 'r') as f:
        calibrations = json.load(f)
    print("Energy calibrations loaded:")
    for scint, cal in calibrations.items():
        print(f"  {scint}: E = {cal['slope']:.4f} × ch + {cal['offset']:.2f} keV")
else:
    print(f"Calibration file not found: {calibration_path}")
    print("Using default calibrations...")
    calibrations = {
        'LYSO': {'slope': 0.5, 'offset': 0.0},
        'BGO': {'slope': 0.4, 'offset': 0.0},
        'NaI': {'slope': 0.6, 'offset': 0.0},
        'Plastic': {'slope': 0.3, 'offset': 0.0}
    }

def channel_to_energy(channel, scintillator):
    """Convert ADC channel to energy using calibration"""
    if scintillator in calibrations:
        cal = calibrations[scintillator]
        return cal['slope'] * channel + cal['offset']
    return channel  # Return raw channel if no calibration

## 2. Load Waveform Data

Load CSV files containing event waveforms for each scintillator and source.

In [None]:
# Define data directory
data_dir = Path('../data/raw')  # Update this path

# Define scintillators and sources
scintillators = ['LYSO', 'BGO', 'NaI', 'Plastic']
sources = ['Cs137', 'Co60', 'Na22', 'Am241', 'Background']

# Load waveform data
waveforms = {}

# Try to find CSV files in the data directory
if data_dir.exists():
    csv_files = list(data_dir.glob('*.CSV')) + list(data_dir.glob('*.csv'))
    print(f"Found {len(csv_files)} CSV files in {data_dir}\n")
    
    for csv_file in csv_files:
        filename = csv_file.name
        print(f"Loading: {filename}")
        
        # Parse scintillator and source from filename
        scint = None
        source = None
        
        for s in scintillators:
            if s.lower() in filename.lower():
                scint = s
                break
        
        for src in sources:
            if src.lower() in filename.lower():
                source = src
                break
        
        if scint and source:
            # Load waveforms (limit to first 1000 events for speed)
            events = import_waveforms_csv(str(csv_file), max_events=1000)
            
            # Convert to Waveform objects
            wf_objects = convert_csv_to_waveform_objects(
                events, 
                scintillator=scint,
                source=source,
                sampling_rate_MHz=SAMPLING_RATE_MHZ
            )
            
            # Store
            if scint not in waveforms:
                waveforms[scint] = {}
            waveforms[scint][source] = wf_objects
            
            print(f"  Scintillator: {scint}")
            print(f"  Source: {source}")
            print(f"  Events loaded: {len(wf_objects)}")
            if len(wf_objects) > 0:
                print(f"  Waveform length: {len(wf_objects[0].waveform)} samples")
                print(f"  Duration: {len(wf_objects[0].waveform) * DT_NS:.1f} ns\n")
        else:
            print(f"  Could not parse scintillator/source\n")
else:
    print(f"Data directory not found: {data_dir}")
    print("Creating synthetic waveforms for demonstration...\n")
    
    # Create synthetic waveforms
    from src.io.waveform_loader import Waveform
    
    def create_synthetic_pulse(decay_time_ns, amplitude, baseline=3100, noise_level=5):
        """Create a synthetic scintillator pulse"""
        n_samples = 1000
        time_ns = np.arange(n_samples) * DT_NS
        
        # Rising edge (Gaussian)
        rise_time = 2.0  # ns
        peak_time = 20.0  # ns
        rise = np.exp(-0.5 * ((time_ns - peak_time) / rise_time) ** 2)
        
        # Falling edge (exponential decay)
        decay = np.exp(-(time_ns - peak_time) / decay_time_ns)
        decay[time_ns < peak_time] = 0
        
        # Combine
        pulse = baseline + amplitude * (rise + decay)
        pulse[time_ns < peak_time] = baseline + amplitude * rise[time_ns < peak_time]
        
        # Add noise
        pulse += np.random.normal(0, noise_level, n_samples)
        
        return pulse.astype(int)
    
    # Decay times and amplitudes for each scintillator
    decay_params = {
        'LYSO': (40, 200),
        'BGO': (300, 150),
        'NaI': (230, 250),
        'Plastic': (2.4, 180)
    }
    
    for scint in scintillators:
        waveforms[scint] = {}
        decay_ns, amp = decay_params[scint]
        
        for source in ['Cs137', 'Co60']:
            wf_list = []
            for _ in range(500):
                # Vary amplitude by ±20%
                pulse_amp = amp * np.random.uniform(0.8, 1.2)
                samples = create_synthetic_pulse(decay_ns, pulse_amp)
                
                wf = Waveform(
                    waveform=samples,
                    timestamp=0.0,
                    baseline=3100,
                    amplitude=pulse_amp,
                    scintillator=scint,
                    source=source
                )
                wf_list.append(wf)
            
            waveforms[scint][source] = wf_list

print(f"\n{'='*60}")
print(f"Loaded waveforms for {len(waveforms)} scintillators")
for scint, sources_dict in waveforms.items():
    total_events = sum(len(wf_list) for wf_list in sources_dict.values())
    print(f"  {scint}: {total_events} events from {list(sources_dict.keys())}")

## 3. Visualize Representative Waveforms

Plot example waveforms from each scintillator to visualize characteristic pulse shapes.

In [None]:
# Plot representative waveforms for each scintillator
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

colors = plt.cm.tab10(np.linspace(0, 1, 10))

for idx, scint in enumerate(scintillators):
    ax = axes[idx]
    plt.sca(ax)
    
    if scint not in waveforms or len(waveforms[scint]) == 0:
        plt.text(0.5, 0.5, f'{scint}\nNo data', ha='center', va='center', transform=ax.transAxes)
        plt.axis('off')
        continue
    
    # Get first available source
    source = list(waveforms[scint].keys())[0]
    wf_list = waveforms[scint][source]
    
    # Plot first 5 waveforms
    for i, wf in enumerate(wf_list[:5]):
        samples = wf.waveform
        time_ns = np.arange(len(samples)) * DT_NS
        
        plt.plot(time_ns, samples, linewidth=1.5, alpha=0.7, color=colors[i],
                label=f'Event {i+1}')
    
    plt.xlabel('Time (ns)', fontsize=11)
    plt.ylabel('ADC Value', fontsize=11)
    plt.title(f'{scint} - {source}', fontsize=12, fontweight='bold')
    plt.grid(True, alpha=0.3)
    plt.legend(fontsize=9)

plt.tight_layout()
plt.show()

print("Representative waveforms plotted")

## 4. Extract Pulse Features

Use the `PulseFeatureExtractor` to compute comprehensive features for each waveform.

In [None]:
# Initialize feature extractor
feature_extractor = PulseFeatureExtractor(sampling_rate_MHz=SAMPLING_RATE_MHZ)

# Extract features for all waveforms
all_features = []

print("Extracting pulse features...\n")

for scint in scintillators:
    if scint not in waveforms:
        continue
    
    for source, wf_list in waveforms[scint].items():
        print(f"Processing {scint} - {source}: {len(wf_list)} events")
        
        for wf in tqdm(wf_list, desc=f"{scint}-{source}", leave=False):
            try:
                # Extract features
                features = feature_extractor.extract_features(wf.waveform)
                
                # Add metadata
                features['scintillator'] = scint
                features['source'] = source
                features['timestamp'] = wf.timestamp if hasattr(wf, 'timestamp') else 0
                
                # Convert amplitude to energy if calibration available
                if 'amplitude' in features:
                    features['energy_keV'] = channel_to_energy(features['amplitude'], scint)
                
                all_features.append(features)
                
            except Exception as e:
                # Skip waveforms that fail feature extraction
                continue

# Convert to DataFrame
df_features = pd.DataFrame(all_features)

print(f"\n{'='*60}")
print(f"Extracted features for {len(df_features)} waveforms")
print(f"Feature columns: {len(df_features.columns)}")
print(f"\nFeature list:")
feature_cols = [col for col in df_features.columns if col not in ['scintillator', 'source', 'timestamp']]
for i, col in enumerate(feature_cols, 1):
    print(f"  {i:2d}. {col}")

# Display summary statistics
print(f"\n{'='*60}")
print("Feature Summary Statistics:")
print(df_features.groupby('scintillator').size())

## 5. Compare Decay Times

Analyze the measured decay time constants for each scintillator and compare to expected values.

In [None]:
# Expected decay times
expected_decay = {
    'LYSO': 40,
    'BGO': 300,
    'NaI': 230,
    'Plastic': 2.4
}

# Plot decay time distributions
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

for idx, scint in enumerate(scintillators):
    ax = axes[idx]
    plt.sca(ax)
    
    # Get decay times for this scintillator
    scint_data = df_features[df_features['scintillator'] == scint]
    
    if len(scint_data) > 0 and 'decay_constant' in scint_data.columns:
        decay_times = scint_data['decay_constant'].dropna()
        
        # Remove outliers (> 3 sigma)
        mean_decay = decay_times.mean()
        std_decay = decay_times.std()
        decay_times_filtered = decay_times[
            (decay_times > mean_decay - 3*std_decay) & 
            (decay_times < mean_decay + 3*std_decay)
        ]
        
        # Plot histogram
        plt.hist(decay_times_filtered, bins=50, alpha=0.7, color='blue', 
                edgecolor='black', label='Measured')
        
        # Plot expected value
        expected = expected_decay[scint]
        plt.axvline(expected, color='red', linestyle='--', linewidth=2,
                   label=f'Expected ({expected} ns)')
        
        # Annotate
        measured_mean = decay_times_filtered.mean()
        measured_std = decay_times_filtered.std()
        plt.text(0.6, 0.95, 
                f"Measured: {measured_mean:.1f} ± {measured_std:.1f} ns\n"
                f"Expected: {expected} ns\n"
                f"Difference: {abs(measured_mean - expected):.1f} ns",
                transform=ax.transAxes, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8),
                fontsize=9)
        
        plt.xlabel('Decay Time (ns)', fontsize=11)
        plt.ylabel('Count', fontsize=11)
        plt.title(f'{scint} Decay Time', fontsize=12, fontweight='bold')
        plt.legend(fontsize=10)
        plt.grid(True, alpha=0.3)
    else:
        plt.text(0.5, 0.5, f'{scint}\nNo data', ha='center', va='center', transform=ax.transAxes)
        plt.axis('off')

plt.tight_layout()
plt.show()

print("Decay time analysis complete")

## 6. Visualize Feature Correlations

Create correlation matrices to understand relationships between pulse features.

In [None]:
# Select key features for correlation analysis
key_features = [
    'amplitude', 'rise_time_10_90', 'fall_time_90_10', 'decay_constant',
    'total_charge', 'tail_total_ratio', 'width_fwhm', 'peak_position'
]

# Filter to available features
available_features = [f for f in key_features if f in df_features.columns]

if len(available_features) >= 4:
    # Compute correlation matrix
    corr_matrix = df_features[available_features].corr()
    
    # Plot heatmap
    plt.figure(figsize=(12, 10))
    sns.heatmap(corr_matrix, annot=True, fmt='.2f', cmap='coolwarm', 
                center=0, square=True, linewidths=1,
                cbar_kws={'label': 'Correlation Coefficient'})
    plt.title('Pulse Feature Correlation Matrix', fontsize=14, fontweight='bold', pad=20)
    plt.tight_layout()
    plt.show()
    
    print("\nHighly correlated feature pairs (|r| > 0.7):")
    for i in range(len(corr_matrix)):
        for j in range(i+1, len(corr_matrix)):
            if abs(corr_matrix.iloc[i, j]) > 0.7:
                print(f"  {corr_matrix.index[i]} <-> {corr_matrix.columns[j]}: "
                      f"{corr_matrix.iloc[i, j]:.3f}")
else:
    print("Insufficient features for correlation analysis")

## 7. Compare Pulse Shapes Across Scintillators

Create overlay plots showing characteristic differences between scintillator types.

In [None]:
# Plot average pulse shape for each scintillator
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 6))

colors_scint = {'LYSO': 'red', 'BGO': 'blue', 'NaI': 'green', 'Plastic': 'purple'}

# Left plot: Linear scale
plt.sca(ax1)
for scint in scintillators:
    if scint not in waveforms or len(waveforms[scint]) == 0:
        continue
    
    # Get waveforms from first source
    source = list(waveforms[scint].keys())[0]
    wf_list = waveforms[scint][source][:100]  # Use first 100 events
    
    # Average waveforms (normalize to peak)
    wf_arrays = []
    for wf in wf_list:
        samples = wf.waveform - wf.waveform[:50].mean()  # Baseline subtract
        if samples.max() > 10:  # Valid pulse
            samples_norm = samples / samples.max()  # Normalize to 1
            wf_arrays.append(samples_norm)
    
    if len(wf_arrays) > 0:
        # Average and plot
        avg_wf = np.mean(wf_arrays, axis=0)
        time_ns = np.arange(len(avg_wf)) * DT_NS
        
        plt.plot(time_ns, avg_wf, linewidth=2.5, label=scint, color=colors_scint[scint])

plt.xlabel('Time (ns)', fontsize=12)
plt.ylabel('Normalized Amplitude', fontsize=12)
plt.title('Average Pulse Shapes (Linear)', fontsize=13, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.xlim(0, 500)

# Right plot: Log scale (emphasizes tail)
plt.sca(ax2)
for scint in scintillators:
    if scint not in waveforms or len(waveforms[scint]) == 0:
        continue
    
    source = list(waveforms[scint].keys())[0]
    wf_list = waveforms[scint][source][:100]
    
    wf_arrays = []
    for wf in wf_list:
        samples = wf.waveform - wf.waveform[:50].mean()
        if samples.max() > 10:
            samples_norm = samples / samples.max()
            wf_arrays.append(samples_norm)
    
    if len(wf_arrays) > 0:
        avg_wf = np.mean(wf_arrays, axis=0)
        time_ns = np.arange(len(avg_wf)) * DT_NS
        
        # Plot with small offset to avoid log(0)
        plt.plot(time_ns, avg_wf + 1e-4, linewidth=2.5, label=scint, color=colors_scint[scint])

plt.xlabel('Time (ns)', fontsize=12)
plt.ylabel('Normalized Amplitude (log)', fontsize=12)
plt.title('Average Pulse Shapes (Log)', fontsize=13, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.yscale('log')
plt.xlim(0, 1000)
plt.ylim(1e-3, 2)

plt.tight_layout()
plt.show()

print("Pulse shape comparison complete")

## 8. Feature Distributions by Scintillator

Visualize distributions of key features for each scintillator type.

In [None]:
# Select features to plot
plot_features = ['rise_time_10_90', 'fall_time_90_10', 'tail_total_ratio', 'width_fwhm']
plot_features = [f for f in plot_features if f in df_features.columns]

if len(plot_features) >= 2:
    n_features = len(plot_features)
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()
    
    for idx, feature in enumerate(plot_features[:4]):
        ax = axes[idx]
        plt.sca(ax)
        
        # Plot violin plots for each scintillator
        scint_data = []
        scint_labels = []
        
        for scint in scintillators:
            data = df_features[df_features['scintillator'] == scint][feature].dropna()
            if len(data) > 0:
                # Remove outliers
                q1, q3 = data.quantile([0.25, 0.75])
                iqr = q3 - q1
                data_filtered = data[(data >= q1 - 1.5*iqr) & (data <= q3 + 1.5*iqr)]
                scint_data.append(data_filtered)
                scint_labels.append(scint)
        
        if len(scint_data) > 0:
            parts = plt.violinplot(scint_data, positions=range(len(scint_data)),
                                  showmeans=True, showmedians=True)
            
            # Color violins
            for i, pc in enumerate(parts['bodies']):
                scint = scint_labels[i]
                pc.set_facecolor(colors_scint[scint])
                pc.set_alpha(0.6)
            
            plt.xticks(range(len(scint_labels)), scint_labels)
            plt.ylabel(feature.replace('_', ' ').title(), fontsize=11)
            plt.title(f'{feature.replace("_", " ").title()} Distribution', 
                     fontsize=12, fontweight='bold')
            plt.grid(True, alpha=0.3, axis='y')
    
    # Hide unused subplots
    for idx in range(len(plot_features), 4):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.show()
else:
    print("Insufficient features for distribution plots")

print("Feature distribution visualization complete")

## 9. Save Feature Dataset

Save the extracted features to CSV for use in ML classification (Notebook 04).

In [None]:
# Save feature dataset
output_path = Path('../data/processed/pulse_features.csv')
output_path.parent.mkdir(parents=True, exist_ok=True)

df_features.to_csv(output_path, index=False)

print(f"Feature dataset saved to: {output_path}")
print(f"  Shape: {df_features.shape}")
print(f"  Scintillators: {df_features['scintillator'].unique().tolist()}")
print(f"  Features: {len([c for c in df_features.columns if c not in ['scintillator', 'source', 'timestamp']])}")

# Display sample
print("\nSample rows:")
print(df_features.head())

## Summary

This notebook successfully performed comprehensive pulse shape analysis:

1. **Loaded waveform data** from CSV files (CAEN digitizer format)
2. **Applied energy calibration** from Notebook 02
3. **Extracted 15+ pulse features** including temporal, amplitude, charge, and shape parameters
4. **Visualized characteristic waveforms** for each scintillator type
5. **Analyzed decay time constants** and compared to expected values
6. **Examined feature correlations** to understand parameter relationships
7. **Compared pulse shapes** across all four scintillators
8. **Saved feature dataset** for machine learning classification

### Key Observations:

- **LYSO**: Fast rise and decay (~40 ns), high amplitude
- **BGO**: Slow decay (~300 ns), broad pulse shape
- **NaI(Tl)**: Medium decay (~230 ns), excellent energy resolution
- **Plastic**: Very fast (~2.4 ns), significant tail for PSD

### Discriminating Features:

The most important features for scintillator classification are:
1. **Decay constant** - Primary discriminator (40 vs 230 vs 300 ns)
2. **Rise time** - Fast vs slow scintillators
3. **Tail-to-total ratio** - Particle discrimination capability
4. **Pulse width (FWHM)** - Overall temporal extent

### Next Steps:

The feature dataset is ready for:
- **Notebook 04**: ML classification using traditional and deep learning models
- **Notebook 06**: SiPM characterization using pulse features
- **Notebook 07**: Comprehensive performance comparison