# Notebook 3: Advanced Feature Extraction

## Introduction

### Beyond Simple PSD Ratio

While the basic tail-to-total PSD ratio works well, we can extract **100+ timing features** from each waveform for improved discrimination:

**Feature Categories:**
1. **Multiple Charge Ratios**: Different gate pairs (fast, medium, slow)
2. **Rise Time Features**: 10-90%, 20-80%, CFD timing
3. **Shape Moments**: Mean, variance, skewness, kurtosis
4. **Cumulative Charge**: Time to reach 10%, 20%, ..., 90% of total charge
5. **Decay Parameters**: Bi-exponential fit (fast/slow components)
6. **Frequency Domain**: FFT, power spectral density
7. **Time-over-Threshold**: Duration above various thresholds
8. **Template Matching**: Correlation with neutron/gamma templates
9. **Gatti Filter**: Optimal matched filter
10. **Wavelet Features**: Multi-resolution decomposition

### Why So Many Features?

**Advantages:**
- Machine learning can find optimal combinations
- Different features work better at different energies
- Redundancy provides robustness to noise
- Physics-informed features improve interpretability

**Physics Motivation:**
- Scintillation process is complex (multiple decay components)
- Different features capture different physical phenomena
- Energy-dependent quenching affects pulse shape

### Learning Objectives

1. Extract timing features from waveforms
2. Calculate cumulative charge timestamps
3. Perform bi-exponential decay fits
4. Extract frequency-domain features
5. Implement template matching and Gatti filter
6. Analyze feature importance
7. Handle quality control (saturation, pile-up)

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy import signal, optimize, stats
from scipy.stats import skew, kurtosis
import pywt  # pip install PyWavelets

plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['figure.figsize'] = (14, 6)
np.random.seed(42)

print("✓ Libraries imported")

## 1. Generate Synthetic Waveforms

In [None]:
def generate_waveform(particle_type, energy_kev, 
                     sampling_rate_mhz=250, num_samples=368):
    """Generate bi-exponential scintillation pulse"""
    dt = 1000.0 / sampling_rate_mhz
    time = np.arange(num_samples) * dt
    
    tau_fast = 3.2
    tau_slow = 32.0
    
    if particle_type == 'gamma':
        fast_fraction = 0.75
    else:  # neutron
        fast_fraction = 0.55
    
    amplitude = energy_kev * 3.0
    t0 = 200
    
    pulse = np.zeros_like(time)
    active_time = time - t0
    valid = active_time >= 0
    
    pulse[valid] = amplitude * (
        fast_fraction * np.exp(-active_time[valid] / tau_fast) +
        (1 - fast_fraction) * np.exp(-active_time[valid] / tau_slow)
    )
    
    baseline = 8192
    waveform = baseline - pulse
    waveform += np.random.normal(0, 10, num_samples)
    waveform = np.clip(waveform, 0, 16383)
    
    return waveform.astype(int)

# Generate example waveforms
gamma_wf = generate_waveform('gamma', 500)
neutron_wf = generate_waveform('neutron', 500)

print(f"✓ Generated example waveforms")

## 2. Feature Extractor Class

In [None]:
class FeatureExtractor:
    """
    Comprehensive feature extraction for PSD analysis
    Extracts 100+ features from each waveform
    """
    
    def __init__(self, sampling_rate_mhz=250, baseline_samples=50):
        self.sampling_rate = sampling_rate_mhz
        self.dt = 1000.0 / sampling_rate_mhz  # ns per sample
        self.baseline_samples = baseline_samples
        
        # Will be set during template building
        self.neutron_template = None
        self.gamma_template = None
        self.gatti_weights = None
    
    def extract_all_features(self, waveform):
        """
        Extract complete feature set from waveform
        
        Returns:
        --------
        features : dict
            Dictionary of all extracted features
        """
        features = {}
        
        # 1. Quality control
        features.update(self._qc_features(waveform))
        
        if features.get('saturated', False):
            return features  # Skip bad pulses
        
        # 2. Baseline characterization
        baseline = np.mean(waveform[:self.baseline_samples])
        baseline_rms = np.std(waveform[:self.baseline_samples])
        features['baseline_mean'] = baseline
        features['baseline_rms'] = baseline_rms
        
        # 3. Baseline-subtract and normalize
        pulse = baseline - waveform
        pulse[pulse < 0] = 0
        amplitude = np.max(pulse)
        
        if amplitude < 10:
            return features  # Too small
        
        pulse_norm = pulse / amplitude
        features['amplitude'] = amplitude
        
        # 4. Multiple charge ratios (KEY FEATURES)
        features.update(self._charge_ratio_features(pulse))
        
        # 5. Rise time features
        features.update(self._rise_time_features(pulse_norm))
        
        # 6. Cumulative charge timestamps
        features.update(self._cumulative_charge_features(pulse))
        
        # 7. Shape moments
        features.update(self._shape_moments(pulse_norm))
        
        # 8. Time-over-threshold
        features.update(self._tot_features(pulse_norm))
        
        # 9. Decay fit
        features.update(self._decay_fit_features(pulse))
        
        # 10. Frequency domain
        features.update(self._frequency_features(pulse_norm))
        
        # 11. Template matching (if templates available)
        if self.neutron_template is not None:
            features.update(self._template_features(pulse_norm))
        
        # 12. Wavelet features
        features.update(self._wavelet_features(pulse))
        
        return features
    
    def _qc_features(self, waveform):
        """Quality control features"""
        features = {}
        
        # Saturation check
        saturated = (waveform <= 10) | (waveform >= 16373)
        features['saturated'] = saturated.any()
        features['n_saturated_samples'] = saturated.sum()
        
        # Pile-up detection (count peaks)
        deriv = np.diff(waveform)
        sign_changes = np.diff(np.sign(deriv))
        n_peaks = (sign_changes < 0).sum()
        features['n_peaks'] = n_peaks
        features['pile_up_likely'] = n_peaks > 2
        
        return features
    
    def _charge_ratio_features(self, pulse):
        """Multiple charge integration ratios"""
        features = {}
        
        # Define gate pairs (in ns)
        gate_pairs = [
            (0, 20),    # Very fast
            (0, 60),    # Fast
            (0, 200),   # Medium (traditional short)
            (0, 800),   # Long (traditional long)
            (20, 60),   # Early tail
            (60, 200),  # Mid tail
            (200, 800), # Late tail
        ]
        
        total_integral = pulse.sum()
        
        for start_ns, end_ns in gate_pairs:
            start_idx = int(start_ns / self.dt)
            end_idx = min(int(end_ns / self.dt), len(pulse))
            
            if end_idx > start_idx:
                gate_integral = pulse[start_idx:end_idx].sum()
                ratio = gate_integral / total_integral if total_integral > 0 else 0
                features[f'Q_ratio_{start_ns}_{end_ns}ns'] = ratio
        
        # Traditional PSD
        short_gate = int(200 / self.dt)
        long_gate = int(800 / self.dt)
        Q_short = pulse[:short_gate].sum()
        Q_long = pulse[:long_gate].sum()
        features['psd_traditional'] = (Q_long - Q_short) / Q_long if Q_long > 0 else 0
        
        return features
    
    def _rise_time_features(self, pulse_norm):
        """Rise time calculations"""
        features = {}
        
        peak_idx = np.argmax(pulse_norm)
        
        # Find crossing times
        thresholds = {'10': 0.10, '20': 0.20, '50': 0.50, '80': 0.80, '90': 0.90}
        crossing_times = {}
        
        for name, thresh in thresholds.items():
            idx = np.where(pulse_norm[:peak_idx] >= thresh)[0]
            if len(idx) > 0:
                crossing_times[name] = idx[0] * self.dt
            else:
                crossing_times[name] = 0
        
        # Rise time combinations
        features['rise_10_90'] = crossing_times['90'] - crossing_times['10']
        features['rise_20_80'] = crossing_times['80'] - crossing_times['20']
        features['rise_10_50'] = crossing_times['50'] - crossing_times['10']
        features['time_to_peak'] = peak_idx * self.dt
        
        # Peak position (fraction of total trace)
        features['peak_position_frac'] = peak_idx / len(pulse_norm)
        
        return features
    
    def _cumulative_charge_features(self, pulse):
        """Cumulative charge timestamps"""
        features = {}
        
        cumsum = np.cumsum(pulse)
        total_charge = cumsum[-1]
        
        if total_charge > 0:
            for pct in range(10, 100, 10):
                threshold = pct / 100.0 * total_charge
                idx = np.searchsorted(cumsum, threshold)
                
                if idx < len(cumsum):
                    # Linear interpolation
                    if idx > 0:
                        frac = (threshold - cumsum[idx-1]) / (cumsum[idx] - cumsum[idx-1] + 1e-10)
                        time = (idx - 1 + frac) * self.dt
                    else:
                        time = 0
                else:
                    time = len(cumsum) * self.dt
                
                features[f'charge_t{pct}pct'] = time
            
            # Charge collection speed
            t10 = features['charge_t10pct']
            t50 = features['charge_t50pct']
            t90 = features['charge_t90pct']
            
            if t90 > t10:
                features['charge_speed_10_50'] = 0.4 / (t50 - t10) if t50 > t10 else 0
                features['charge_speed_50_90'] = 0.4 / (t90 - t50) if t90 > t50 else 0
                features['charge_asymmetry'] = (t50 - t10) / (t90 - t10)
            else:
                features['charge_speed_10_50'] = 0
                features['charge_speed_50_90'] = 0
                features['charge_asymmetry'] = 0.5
        
        return features
    
    def _shape_moments(self, pulse_norm):
        """Statistical shape moments"""
        features = {}
        
        # Treat pulse as probability distribution
        if pulse_norm.sum() > 0:
            features['skewness'] = skew(pulse_norm)
            features['kurtosis'] = kurtosis(pulse_norm)
            
            # Weighted mean time
            time = np.arange(len(pulse_norm)) * self.dt
            features['mean_time'] = np.average(time, weights=pulse_norm)
            features['std_time'] = np.sqrt(np.average((time - features['mean_time'])**2, 
                                                      weights=pulse_norm))
        else:
            features['skewness'] = 0
            features['kurtosis'] = 0
            features['mean_time'] = 0
            features['std_time'] = 0
        
        return features
    
    def _tot_features(self, pulse_norm):
        """Time-over-threshold features"""
        features = {}
        
        thresholds = [0.1, 0.3, 0.5, 0.7, 0.9]
        
        for thresh in thresholds:
            above = pulse_norm > thresh
            if above.any():
                tot = above.sum() * self.dt
                first_cross = np.where(above)[0][0] * self.dt
                last_cross = np.where(above)[0][-1] * self.dt
            else:
                tot = 0
                first_cross = 0
                last_cross = 0
            
            features[f'tot_{int(thresh*100)}pct'] = tot
            features[f'tot_start_{int(thresh*100)}pct'] = first_cross
        
        return features
    
    def _decay_fit_features(self, pulse):
        """Bi-exponential decay fit"""
        features = {}
        
        peak_idx = np.argmax(pulse)
        
        if peak_idx < len(pulse) - 50:
            tail = pulse[peak_idx:peak_idx+150]
            x = np.arange(len(tail)) * self.dt
            
            def biexp(t, A_fast, tau_fast, A_slow, tau_slow):
                return A_fast * np.exp(-t/tau_fast) + A_slow * np.exp(-t/tau_slow)
            
            try:
                p0 = [tail[0]*0.7, 10, tail[0]*0.3, 50]
                popt, _ = optimize.curve_fit(biexp, x, tail, p0=p0, maxfev=5000,
                                            bounds=([0,1,0,10], [np.inf,100,np.inf,500]))
                
                A_fast, tau_fast, A_slow, tau_slow = popt
                
                features['decay_tau_fast'] = tau_fast
                features['decay_tau_slow'] = tau_slow
                features['decay_A_ratio'] = A_slow / (A_fast + A_slow + 1e-10)
                features['decay_tau_ratio'] = tau_slow / (tau_fast + 1e-10)
                
                # Fit quality
                y_pred = biexp(x, *popt)
                r2 = 1 - np.sum((tail - y_pred)**2) / (np.sum((tail - np.mean(tail))**2) + 1e-10)
                features['decay_fit_r2'] = r2
                
            except:
                features['decay_tau_fast'] = 0
                features['decay_tau_slow'] = 0
                features['decay_A_ratio'] = 0
                features['decay_tau_ratio'] = 0
                features['decay_fit_r2'] = 0
        
        return features
    
    def _frequency_features(self, pulse_norm):
        """Frequency domain features"""
        features = {}
        
        # FFT
        fft = np.fft.rfft(pulse_norm)
        power = np.abs(fft)**2
        freqs = np.fft.rfftfreq(len(pulse_norm), d=self.dt/1000)  # MHz
        
        # Spectral centroid
        if power.sum() > 0:
            features['spectral_centroid'] = np.average(freqs, weights=power)
            features['spectral_variance'] = np.average((freqs - features['spectral_centroid'])**2, 
                                                       weights=power)
        else:
            features['spectral_centroid'] = 0
            features['spectral_variance'] = 0
        
        # Power in frequency bands
        bands = [(0, 10), (10, 50), (50, 100)]  # MHz
        for low, high in bands:
            mask = (freqs >= low) & (freqs < high)
            band_power = power[mask].sum()
            features[f'power_{low}_{high}MHz'] = band_power
        
        return features
    
    def _template_features(self, pulse_norm):
        """Template matching features"""
        features = {}
        
        min_len = min(len(pulse_norm), len(self.neutron_template))
        pulse_trunc = pulse_norm[:min_len]
        template_n = self.neutron_template[:min_len]
        template_g = self.gamma_template[:min_len]
        
        # Correlation
        corr_n = np.corrcoef(pulse_trunc, template_n)[0, 1]
        corr_g = np.corrcoef(pulse_trunc, template_g)[0, 1]
        
        features['template_n_corr'] = corr_n if not np.isnan(corr_n) else 0
        features['template_g_corr'] = corr_g if not np.isnan(corr_g) else 0
        
        # L2 distance
        features['template_n_l2'] = np.linalg.norm(pulse_trunc - template_n)
        features['template_g_l2'] = np.linalg.norm(pulse_trunc - template_g)
        
        # Discrimination score
        features['template_score'] = (corr_n - corr_g) + \
                                     (features['template_g_l2'] - features['template_n_l2'])
        
        # Gatti filter
        if self.gatti_weights is not None:
            features['gatti_score'] = np.dot(self.gatti_weights[:min_len], pulse_trunc)
        
        return features
    
    def _wavelet_features(self, pulse):
        """Wavelet decomposition features"""
        features = {}
        
        try:
            # Discrete wavelet transform
            coeffs = pywt.wavedec(pulse, 'db4', level=4)
            
            # Energy in each level
            for i, c in enumerate(coeffs):
                features[f'wavelet_energy_L{i}'] = np.sum(c**2)
            
            # Wavelet entropy
            energies = [np.sum(c**2) for c in coeffs]
            total_energy = sum(energies)
            if total_energy > 0:
                probs = [e / total_energy for e in energies]
                entropy = -sum([p * np.log2(p + 1e-10) for p in probs if p > 0])
                features['wavelet_entropy'] = entropy
            else:
                features['wavelet_entropy'] = 0
        except:
            for i in range(5):
                features[f'wavelet_energy_L{i}'] = 0
            features['wavelet_entropy'] = 0
        
        return features
    
    def set_templates(self, neutron_waveforms, gamma_waveforms):
        """Build templates from training data"""
        self.neutron_template = self._build_template(neutron_waveforms)
        self.gamma_template = self._build_template(gamma_waveforms)
        
        # Gatti optimal filter
        diff = self.neutron_template - self.gamma_template
        noise_var = self._estimate_noise_variance(neutron_waveforms, gamma_waveforms)
        self.gatti_weights = diff / (noise_var + 1e-10)
        
        print("✓ Templates and Gatti filter computed")
    
    def _build_template(self, waveforms):
        """Build average template from waveforms"""
        templates = []
        for wf in waveforms[:500]:  # Use subset
            baseline = np.mean(wf[:self.baseline_samples])
            pulse = baseline - wf
            if np.max(pulse) > 100:
                pulse_norm = pulse / np.max(pulse)
                templates.append(pulse_norm)
        
        if templates:
            return np.median(templates, axis=0)
        else:
            return np.zeros(waveforms.shape[1])
    
    def _estimate_noise_variance(self, wf_n, wf_g):
        """Estimate noise from baseline"""
        baselines_n = wf_n[:, :self.baseline_samples]
        baselines_g = wf_g[:, :self.baseline_samples]
        return np.mean([baselines_n.var(), baselines_g.var()])

# Initialize extractor
extractor = FeatureExtractor()

print("✓ FeatureExtractor class defined")

## 3. Extract Features from Example Waveforms

In [None]:
# Extract features
features_gamma = extractor.extract_all_features(gamma_wf)
features_neutron = extractor.extract_all_features(neutron_wf)

print("Feature extraction results:\n")
print(f"Total features extracted: {len(features_gamma)}")
print(f"\nKey discriminating features:")
print(f"{'Feature':<30} {'Gamma':<12} {'Neutron':<12}")
print("-" * 60)

key_features = [
    'psd_traditional',
    'Q_ratio_200_800ns',
    'charge_t50pct',
    'charge_t90pct',
    'decay_A_ratio',
    'tot_50pct',
    'rise_10_90'
]

for feat in key_features:
    if feat in features_gamma and feat in features_neutron:
        print(f"{feat:<30} {features_gamma[feat]:<12.4f} {features_neutron[feat]:<12.4f}")

print(f"\n✓ Features extracted successfully")

## 4. Build Templates and Extract Dataset

In [None]:
# Generate training dataset
print("Generating training dataset...")

n_train = 1000  # Per particle type

gamma_waveforms = []
neutron_waveforms = []

for i in range(n_train):
    energy = np.random.exponential(400) + 50
    energy = min(energy, 2000)
    
    gamma_waveforms.append(generate_waveform('gamma', energy))
    neutron_waveforms.append(generate_waveform('neutron', energy))

gamma_waveforms = np.array(gamma_waveforms)
neutron_waveforms = np.array(neutron_waveforms)

# Build templates
extractor.set_templates(neutron_waveforms, gamma_waveforms)

# Extract features for all waveforms
print("Extracting features for all waveforms...")

all_features = []

for i, wf in enumerate(gamma_waveforms):
    features = extractor.extract_all_features(wf)
    features['particle'] = 'gamma'
    all_features.append(features)
    
    if (i+1) % 200 == 0:
        print(f"  Processed {i+1}/{n_train} gamma events")

for i, wf in enumerate(neutron_waveforms):
    features = extractor.extract_all_features(wf)
    features['particle'] = 'neutron'
    all_features.append(features)
    
    if (i+1) % 200 == 0:
        print(f"  Processed {i+1}/{n_train} neutron events")

# Create DataFrame
df_features = pd.DataFrame(all_features)

print(f"\n✓ Feature extraction complete")
print(f"  Dataset shape: {df_features.shape}")
print(f"  Total features: {len(df_features.columns) - 1}")

## 5. Feature Correlation Analysis

In [None]:
# Select numeric features
numeric_features = df_features.select_dtypes(include=[np.number]).columns.tolist()

# Remove QC flags
numeric_features = [f for f in numeric_features if not f.startswith('n_') and 
                   f not in ['saturated', 'pile_up_likely']]

# Calculate correlation with particle type
df_features['particle_label'] = (df_features['particle'] == 'neutron').astype(int)

correlations = []
for feat in numeric_features:
    if df_features[feat].std() > 0:  # Exclude constant features
        corr = df_features[['particle_label', feat]].corr().iloc[0, 1]
        correlations.append({'feature': feat, 'correlation': abs(corr)})

corr_df = pd.DataFrame(correlations).sort_values('correlation', ascending=False)

print("Top 20 most discriminating features:\n")
print(corr_df.head(20).to_string(index=False))

# Plot top features
fig, ax = plt.subplots(figsize=(12, 8))

top_features = corr_df.head(20)
ax.barh(range(len(top_features)), top_features['correlation'], 
        color='steelblue', edgecolor='black', linewidth=1.2)
ax.set_yticks(range(len(top_features)))
ax.set_yticklabels(top_features['feature'])
ax.set_xlabel('Absolute Correlation with Particle Type', fontsize=12, fontweight='bold')
ax.set_title('Top 20 Discriminating Features', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='x')
ax.invert_yaxis()

plt.tight_layout()
plt.show()

print("\n✓ Feature importance analysis complete")

## 6. Visualize Key Feature Distributions

In [None]:
# Plot distributions of top features
top_4_features = corr_df.head(4)['feature'].tolist()

fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.ravel()

for i, feat in enumerate(top_4_features):
    ax = axes[i]
    
    gamma_vals = df_features[df_features['particle'] == 'gamma'][feat]
    neutron_vals = df_features[df_features['particle'] == 'neutron'][feat]
    
    # Remove outliers for better visualization
    vmin = np.percentile(df_features[feat], 1)
    vmax = np.percentile(df_features[feat], 99)
    
    bins = np.linspace(vmin, vmax, 50)
    
    ax.hist(gamma_vals, bins=bins, alpha=0.6, label='Gamma', 
           color='blue', edgecolor='black', linewidth=0.5)
    ax.hist(neutron_vals, bins=bins, alpha=0.6, label='Neutron', 
           color='red', edgecolor='black', linewidth=0.5)
    
    ax.set_xlabel(feat, fontsize=11, fontweight='bold')
    ax.set_ylabel('Counts', fontsize=11, fontweight='bold')
    ax.set_title(f'{feat}\n(corr = {corr_df[corr_df["feature"]==feat]["correlation"].values[0]:.3f})',
                fontsize=12, fontweight='bold')
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3, axis='y')

plt.tight_layout()
plt.show()

print("✓ Feature distributions plotted")

## Summary

### What We Learned

1. **Comprehensive Feature Extraction**: Extracted 100+ timing features from each waveform
   - Multiple charge ratios capture energy-dependent behavior
   - Cumulative charge timestamps reveal charge collection dynamics
   - Bi-exponential fits measure fast/slow component ratio
   - Frequency domain features capture oscillatory behavior
   - Wavelet features provide multi-resolution analysis

2. **Template Matching**: Built particle-specific templates
   - Correlation measures shape similarity
   - Gatti filter is the optimal linear discriminator
   - Physics-informed approach improves low-energy performance

3. **Feature Importance**: Identified most discriminating features
   - Charge ratio features consistently rank highest
   - Tail-related features (cumulative charge times) are critical
   - Redundancy provides robustness

4. **Quality Control**: Detect problematic events
   - Saturation detection prevents calibration errors
   - Pile-up detection removes multi-event pulses

### Best Practices

- **Feature Selection**: Use correlation analysis to identify top features
- **Normalization**: Essential before machine learning (next notebook)
- **Template Building**: Use large, clean dataset (1000+ events)
- **Energy Dependence**: Features may perform differently at different energies
- **Computational Cost**: For real-time applications, select subset of features

### Next Steps

In Notebook 4, we'll use these features to train machine learning classifiers for improved neutron/gamma discrimination.