# Notebook 01: Data Loading and Exploration

This notebook demonstrates:
- Loading waveform data from HDF5 files
- Visualizing raw pulse shapes
- Data quality checks
- Summary statistics
- Creating processed datasets for ML

In [None]:
# Setup and imports
import sys
sys.path.append('../')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

# Local imports
from src.io import WaveformLoader
from src.visualization import plot_waveform, plot_waveform_grid, plot_spectrum

# Configure plotting
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
%matplotlib inline

print("✓ Imports successful")

## 1.1 Data Directory Structure

Expected structure:
```
data/raw/
    ├── LYSO/
    │   ├── Cs137/
    │   ├── Co60/
    │   └── Na22/
    ├── BGO/
    ├── NaI/
    └── Plastic/
```

In [None]:
# Initialize loader
data_dir = Path("../data/raw")
loader = WaveformLoader(data_dir, sampling_rate_MHz=125)

# Get available data
available_data = loader.get_available_data()
print(f"Found {len(available_data)} data files:\n")
display(available_data)

## 1.2 Load Example Waveforms

In [None]:
# Load waveforms from each scintillator (Cs-137 source)
scintillators = ['LYSO', 'BGO', 'NaI', 'Plastic']
n_waveforms = 1000

waveforms_dict = {}

for scint in scintillators:
    try:
        waveforms = loader.load_waveforms(
            scintillator=scint,
            source='Cs137',
            n_waveforms=n_waveforms
        )
        waveforms_dict[scint] = waveforms
        print(f"✓ Loaded {len(waveforms)} waveforms from {scint}")
    except FileNotFoundError:
        print(f"✗ Data not found for {scint}")
        # Generate synthetic data for demonstration
        print(f"  Generating synthetic data for {scint}...")
        waveforms_dict[scint] = generate_synthetic_waveforms(scint, n_waveforms)

print(f"\nTotal scintillators loaded: {len(waveforms_dict)}")

In [None]:
# Helper function to generate synthetic data if needed
def generate_synthetic_waveforms(scintillator, n_waveforms=1000):
    """Generate synthetic waveforms for demonstration"""
    from src.io.waveform_loader import Waveform
    
    # Scintillator properties
    properties = {
        'LYSO': {'decay': 40, 'rise': 5, 'amplitude_range': (200, 2000)},
        'BGO': {'decay': 300, 'rise': 20, 'amplitude_range': (100, 1000)},
        'NaI': {'decay': 230, 'rise': 10, 'amplitude_range': (300, 2500)},
        'Plastic': {'decay': 2.4, 'rise': 1, 'amplitude_range': (150, 1500)}
    }
    
    props = properties[scintillator]
    waveforms = []
    
    for i in range(n_waveforms):
        # Generate synthetic pulse
        t = np.arange(1024) * 8  # 8 ns sampling
        baseline = 100
        amplitude = np.random.uniform(*props['amplitude_range'])
        peak_time = 200 + np.random.normal(0, 10)
        
        # Exponential pulse
        pulse = np.zeros(1024)
        peak_idx = int(peak_time / 8)
        
        # Rise
        rise_samples = int(props['rise'] / 0.8)
        pulse[:peak_idx] = amplitude * (1 - np.exp(-(np.arange(peak_idx) - peak_idx + rise_samples) / rise_samples))
        
        # Decay
        decay_samples = int(props['decay'] / 8)
        pulse[peak_idx:] = amplitude * np.exp(-(np.arange(1024 - peak_idx)) / decay_samples)
        
        # Add noise and baseline
        noise = np.random.normal(0, 5, 1024)
        waveform_data = pulse + baseline + noise
        
        waveforms.append(Waveform(
            waveform=waveform_data,
            timestamp=i * 10.0,
            baseline=baseline,
            amplitude=amplitude,
            scintillator=scintillator,
            source='Cs137'
        ))
    
    return waveforms

## 1.3 Waveform Visualization

In [None]:
# Plot representative waveforms from each scintillator
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

colors = {'LYSO': 'blue', 'BGO': 'green', 'NaI': 'red', 'Plastic': 'purple'}

for idx, (scint, waveforms) in enumerate(waveforms_dict.items()):
    # Plot first waveform
    waveform = waveforms[0].waveform
    t_ns = np.arange(len(waveform)) * 8  # Time in ns
    
    axes[idx].plot(t_ns, waveform, color=colors[scint], linewidth=1.5)
    axes[idx].axhline(waveforms[0].baseline, color='gray', linestyle='--', label='Baseline')
    axes[idx].set_xlabel('Time (ns)', fontsize=12)
    axes[idx].set_ylabel('Amplitude (ADC)', fontsize=12)
    axes[idx].set_title(f'{scint} Pulse', fontsize=14)
    axes[idx].grid(True, alpha=0.3)
    axes[idx].legend()

plt.tight_layout()
plt.savefig('../results/figures/individual_pulses.pdf', dpi=300, bbox_inches='tight')
plt.show()

print("Figure saved to results/figures/")

In [None]:
# Overlay normalized pulses for comparison
plt.figure(figsize=(12, 6))

for scint, waveforms in waveforms_dict.items():
    # Average first 10 waveforms
    avg_waveform = np.mean([w.baseline_corrected for w in waveforms[:10]], axis=0)
    
    # Normalize to peak
    avg_waveform_norm = avg_waveform / np.max(avg_waveform)
    
    t_ns = np.arange(len(avg_waveform_norm)) * 8
    plt.plot(t_ns, avg_waveform_norm, label=scint, linewidth=2, color=colors[scint])

plt.xlabel('Time (ns)', fontsize=14)
plt.ylabel('Normalized Amplitude', fontsize=14)
plt.title('Normalized Pulse Shapes Comparison (Cs-137)', fontsize=16)
plt.legend(fontsize=12)
plt.grid(True, alpha=0.3)
plt.xlim([0, 2000])
plt.tight_layout()
plt.savefig('../results/figures/pulse_comparison.pdf', dpi=300, bbox_inches='tight')
plt.show()

## 1.4 Statistical Summary

In [None]:
# Calculate statistics for each scintillator
statistics = []

for scint, waveforms in waveforms_dict.items():
    amplitudes = [w.amplitude for w in waveforms]
    baselines = [w.baseline for w in waveforms]
    
    stats = {
        'Scintillator': scint,
        'N Events': len(waveforms),
        'Amp Mean': np.mean(amplitudes),
        'Amp Std': np.std(amplitudes),
        'Amp Min': np.min(amplitudes),
        'Amp Max': np.max(amplitudes),
        'Baseline Mean': np.mean(baselines),
        'Baseline Std': np.std(baselines)
    }
    statistics.append(stats)

stats_df = pd.DataFrame(statistics)
display(stats_df)

# Save to CSV
stats_df.to_csv('../results/tables/data_summary_statistics.csv', index=False)
print("\n✓ Statistics saved to results/tables/")

In [None]:
# Amplitude distributions
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

for idx, (scint, waveforms) in enumerate(waveforms_dict.items()):
    amplitudes = [w.amplitude for w in waveforms]
    
    axes[idx].hist(amplitudes, bins=50, alpha=0.7, color=colors[scint], edgecolor='black')
    axes[idx].axvline(np.mean(amplitudes), color='red', linestyle='--', linewidth=2, label=f'Mean: {np.mean(amplitudes):.1f}')
    axes[idx].set_xlabel('Amplitude (ADC)', fontsize=12)
    axes[idx].set_ylabel('Counts', fontsize=12)
    axes[idx].set_title(f'{scint} Amplitude Distribution', fontsize=14)
    axes[idx].grid(True, alpha=0.3)
    axes[idx].legend()

plt.tight_layout()
plt.savefig('../results/figures/amplitude_distributions.pdf', dpi=300, bbox_inches='tight')
plt.show()

## 1.5 Data Quality Checks

In [None]:
def check_data_quality(waveforms):
    """Comprehensive quality checks"""
    
    quality_metrics = {}
    
    # Baseline stability
    baselines = [w.baseline for w in waveforms]
    quality_metrics['baseline_mean'] = np.mean(baselines)
    quality_metrics['baseline_std'] = np.std(baselines)
    quality_metrics['baseline_stable'] = np.std(baselines) < 5  # Good if std < 5 ADC
    
    # Amplitude range
    amplitudes = [w.amplitude for w in waveforms]
    quality_metrics['amplitude_min'] = np.min(amplitudes)
    quality_metrics['amplitude_max'] = np.max(amplitudes)
    quality_metrics['amplitude_range'] = np.max(amplitudes) - np.min(amplitudes)
    
    # Saturation check (assume max ADC = 4096)
    saturated = sum(1 for w in waveforms if w.amplitude > 4000)
    quality_metrics['saturated_count'] = saturated
    quality_metrics['saturation_fraction'] = saturated / len(waveforms)
    
    # Anomaly detection (very low amplitude)
    anomalies = sum(1 for w in waveforms if w.amplitude < 50)
    quality_metrics['anomaly_count'] = anomalies
    quality_metrics['anomaly_fraction'] = anomalies / len(waveforms)
    
    return quality_metrics

# Check quality for each scintillator
quality_results = []

for scint, waveforms in waveforms_dict.items():
    metrics = check_data_quality(waveforms)
    metrics['Scintillator'] = scint
    quality_results.append(metrics)

quality_df = pd.DataFrame(quality_results)
display(quality_df[['Scintillator', 'baseline_stable', 'saturation_fraction', 'anomaly_fraction']])

print("\n✓ Data quality: All scintillators PASS" if all(quality_df['baseline_stable']) else "⚠ Warning: Check baseline stability")

## 1.6 Save Processed Data for ML

Create datasets ready for machine learning

In [None]:
# Combine all waveforms into ML-ready format
from src.io import save_processed_data

# Create waveform array and labels
all_waveforms = []
all_labels = []
label_map = {'LYSO': 0, 'BGO': 1, 'NaI': 2, 'Plastic': 3}

for scint, waveforms in waveforms_dict.items():
    for w in waveforms:
        all_waveforms.append(w.baseline_corrected)
        all_labels.append(label_map[scint])

waveform_array = np.array(all_waveforms)
labels_array = np.array(all_labels)

print(f"Dataset shape: {waveform_array.shape}")
print(f"Labels shape: {labels_array.shape}")
print(f"Class distribution: {np.bincount(labels_array)}")

# Save
np.save('../data/processed/waveforms.npy', waveform_array)
np.save('../data/processed/labels.npy', labels_array)

print("\n✓ Processed data saved to data/processed/")

## 1.7 Summary Report

In [None]:
print("="*60)
print("DATA LOADING AND EXPLORATION SUMMARY")
print("="*60)
print(f"\nScintillators analyzed: {len(waveforms_dict)}")
print(f"Total waveforms loaded: {len(all_waveforms)}")
print(f"\nWaveforms per scintillator:")
for scint, waveforms in waveforms_dict.items():
    print(f"  {scint}: {len(waveforms)}")

print(f"\nData quality: {'PASS' if all(quality_df['baseline_stable']) else 'CHECK'}")
print(f"\nSaved outputs:")
print("  - Figures: results/figures/")
print("  - Tables: results/tables/")
print("  - Processed data: data/processed/")
print("\n✓ Notebook complete!")