# Notebook 01: Data Loading and Exploration

This notebook demonstrates:
- Loading real waveform data from CSV files (CAEN digitizer format)
- Loading spectrum data from N42 files
- Visualizing raw pulse shapes and spectra
- Data quality checks
- Summary statistics
- Creating processed datasets for ML

## Data Formats

### CSV Files (Waveforms)
- Format: Semicolon-delimited
- Header: BOARD;CHANNEL;TIMETAG;ENERGY;ENERGYSHORT;FLAGS;PROBE_CODE;SAMPLES
- Each row: One event with 7 scalar fields + variable-length waveform
- Sampling rate: 250 MS/s (CAEN DT5720D)

### N42 Files (Spectra)
- Format: ANSI N42.42 XML
- Contains: Histogram counts per channel
- Metadata: Acquisition time, instrument info, calibration (if present)

In [None]:
# Setup and imports
import sys
sys.path.append('../')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from tqdm.auto import tqdm

# Local imports - use new CAEN parsers
from src.io.caen_parsers import import_waveforms_csv, import_n42_spectrum, convert_csv_to_waveform_objects
from src.io.waveform_loader import Waveform, WaveformLoader

# Configure plotting
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 8)
%matplotlib inline

# Sampling rate for CAEN DT5720D
SAMPLING_RATE_MHZ = 250.0
DT_NS = 1000.0 / SAMPLING_RATE_MHZ  # 4 ns

print("✓ Imports successful")
print(f"Sampling rate: {SAMPLING_RATE_MHZ} MS/s")
print(f"Time step: {DT_NS} ns")

## 1.1 Data Directory Structure

Expected structure for real CAEN data:
```
data/raw/
    ├── DataR_CH0@DT5720D_*_LYSO_*.CSV
    ├── DataR_CH0@DT5720D_*_BGO_*.CSV  
    ├── CH0@DT5720D_*_Espectrum_*_LYSO_*.n42
    └── CH0@DT5720D_*_Espectrum_*_BGO_*.n42
```

In [None]:
# Scan for available data files
data_dir = Path("../data/raw")

csv_files = []
n42_files = []

if data_dir.exists():
    csv_files = sorted(data_dir.glob('*.CSV')) + sorted(data_dir.glob('*.csv'))
    n42_files = sorted(data_dir.glob('*.n42'))
    
    print(f"Found {len(csv_files)} CSV waveform files")
    print(f"Found {len(n42_files)} N42 spectrum files\n")
    
    if csv_files:
        print("Sample CSV files:")
        for f in csv_files[:3]:
            print(f"  {f.name}")
    
    if n42_files:
        print("\nSample N42 files:")
        for f in n42_files[:3]:
            print(f"  {f.name}")
else:
    print(f"Data directory not found: {data_dir}")
    print("Creating example directory structure...")
    data_dir.mkdir(parents=True, exist_ok=True)

## 1.2 Load Waveform Data from CSV

In [None]:
# Load CSV waveforms for each scintillator
scintillators = ['LYSO', 'BGO', 'NaI', 'Plastic']
sources = ['Cs137', 'Co60', 'Na22']

waveforms_dict = {}

if len(csv_files) > 0:
    print("Loading waveforms from CSV files...\n")
    
    for csv_file in csv_files[:8]:  # Load first 8 files
        filename = csv_file.name.lower()
        
        # Parse scintillator and source from filename
        scint = None
        source = None
        
        for s in scintillators:
            if s.lower() in filename:
                scint = s
                break
        
        for src in sources:
            if src.lower() in filename:
                source = src
                break
        
        if scint:
            print(f"Loading {csv_file.name}...")
            
            # Load events (limit to 500 for exploration)
            events = import_waveforms_csv(str(csv_file), max_events=500)
            
            # Convert to Waveform objects
            waveforms = convert_csv_to_waveform_objects(
                events, 
                scintillator=scint,
                source=source or 'Unknown',
                sampling_rate_MHz=SAMPLING_RATE_MHZ
            )
            
            if scint not in waveforms_dict:
                waveforms_dict[scint] = []
            waveforms_dict[scint].extend(waveforms)
            
            print(f"  Loaded {len(waveforms)} waveforms from {scint}")
            if len(waveforms) > 0:
                print(f"  Waveform length: {len(waveforms[0].waveform)} samples")
                print(f"  Duration: {len(waveforms[0].waveform) * DT_NS:.1f} ns\n")

if len(waveforms_dict) == 0:
    print("No CSV files found. Generating synthetic data for demonstration...\n")
    
    # Import synthetic generator from old notebook
    def generate_synthetic_waveforms(scintillator, n_waveforms=500):
        """Generate synthetic waveforms matching CAEN format"""
        properties = {
            'LYSO': {'decay': 40, 'amplitude_range': (200, 2000)},
            'BGO': {'decay': 300, 'amplitude_range': (100, 1000)},
            'NaI': {'decay': 230, 'amplitude_range': (300, 2500)},
            'Plastic': {'decay': 2.4, 'amplitude_range': (150, 1500)}
        }
        
        props = properties[scintillator]
        waveforms = []
        n_samples = 1000
        
        for i in range(n_waveforms):
            t_ns = np.arange(n_samples) * DT_NS
            baseline = 3100
            amplitude = np.random.uniform(*props['amplitude_range'])
            peak_time = 20.0
            
            # Gaussian rise + exponential decay
            rise_time = 2.0
            rise = np.exp(-0.5 * ((t_ns - peak_time) / rise_time) ** 2)
            decay = np.exp(-(t_ns - peak_time) / props['decay'])
            decay[t_ns < peak_time] = 0
            
            pulse = baseline + amplitude * (rise + decay)
            pulse[t_ns < peak_time] = baseline + amplitude * rise[t_ns < peak_time]
            pulse += np.random.normal(0, 5, n_samples)
            
            waveforms.append(Waveform(
                waveform=pulse.astype(int),
                timestamp=i * 1e-3,
                baseline=baseline,
                amplitude=amplitude,
                scintillator=scintillator,
                source='Cs137'
            ))
        
        return waveforms
    
    for scint in scintillators:
        waveforms_dict[scint] = generate_synthetic_waveforms(scint)
        print(f"  Generated {len(waveforms_dict[scint])} waveforms for {scint}")

print(f"\n{'='*60}")
print(f"Total scintillators loaded: {len(waveforms_dict)}")
for scint, waveforms in waveforms_dict.items():
    print(f"  {scint}: {len(waveforms)} waveforms")

## 1.3 Load Spectrum Data from N42

In [None]:
# Load N42 spectrum files
spectra_dict = {}

if len(n42_files) > 0:
    print("Loading spectra from N42 files...\n")
    
    for n42_file in n42_files[:8]:
        filename = n42_file.name.lower()
        
        # Parse scintillator and source
        scint = None
        source = None
        
        for s in scintillators:
            if s.lower() in filename:
                scint = s
                break
        
        for src in sources:
            if src.lower() in filename:
                source = src
                break
        
        if scint:
            print(f"Loading {n42_file.name}...")
            
            spectrum_data = import_n42_spectrum(str(n42_file))
            
            if scint not in spectra_dict:
                spectra_dict[scint] = {}
            
            spectra_dict[scint][source or 'Unknown'] = spectrum_data
            
            counts = spectrum_data['counts']
            print(f"  Channels: {len(counts)}")
            print(f"  Total counts: {sum(counts):,}")
            if spectrum_data.get('live_time'):
                print(f"  Live time: {spectrum_data['live_time']}\n")
else:
    print("No N42 files found\n")

print(f"Loaded spectra for {len(spectra_dict)} scintillators")

## 1.4 Waveform Visualization

In [None]:
# Plot representative waveforms
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

colors = {'LYSO': '#D55E00', 'BGO': '#0072B2', 'NaI': '#009E73', 'Plastic': '#CC79A7'}

for idx, scint in enumerate(scintillators):
    if scint not in waveforms_dict or len(waveforms_dict[scint]) == 0:
        axes[idx].text(0.5, 0.5, f'{scint}\nNo data', 
                      ha='center', va='center', transform=axes[idx].transAxes)
        axes[idx].axis('off')
        continue
    
    # Plot first waveform
    waveform = waveforms_dict[scint][0].waveform
    t_ns = np.arange(len(waveform)) * DT_NS
    
    axes[idx].plot(t_ns, waveform, color=colors[scint], linewidth=1.5)
    
    baseline = waveforms_dict[scint][0].baseline
    axes[idx].axhline(baseline, color='gray', linestyle='--', alpha=0.7, label='Baseline')
    
    axes[idx].set_xlabel('Time (ns)', fontsize=11)
    axes[idx].set_ylabel('ADC Value', fontsize=11)
    axes[idx].set_title(f'{scint} Pulse', fontsize=12, fontweight='bold')
    axes[idx].grid(True, alpha=0.3)
    axes[idx].legend(fontsize=9)

plt.tight_layout()
plt.show()

print("Individual pulse shapes plotted")

In [None]:
# Overlay normalized pulses for comparison
plt.figure(figsize=(12, 6))

for scint in scintillators:
    if scint not in waveforms_dict or len(waveforms_dict[scint]) == 0:
        continue
    
    # Average first 10 waveforms
    avg_waveform = np.mean(
        [w.waveform - w.baseline for w in waveforms_dict[scint][:10]], 
        axis=0
    )
    
    # Normalize to peak
    avg_waveform_norm = avg_waveform / np.max(avg_waveform)
    
    t_ns = np.arange(len(avg_waveform_norm)) * DT_NS
    plt.plot(t_ns, avg_waveform_norm, label=scint, linewidth=2.5, color=colors[scint])

plt.xlabel('Time (ns)', fontsize=12)
plt.ylabel('Normalized Amplitude', fontsize=12)
plt.title('Normalized Pulse Shapes Comparison', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.xlim([0, 500])
plt.tight_layout()
plt.show()

print("Normalized pulse comparison plotted")

## 1.5 Spectrum Visualization (if N42 data available)

In [None]:
# Plot spectra if available
if len(spectra_dict) > 0:
    fig, axes = plt.subplots(len(spectra_dict), 1, 
                            figsize=(12, 4*len(spectra_dict)))
    
    if len(spectra_dict) == 1:
        axes = [axes]
    
    for idx, (scint, sources_dict) in enumerate(spectra_dict.items()):
        for source, spectrum_data in sources_dict.items():
            counts = spectrum_data['counts']
            channels = np.arange(len(counts))
            
            axes[idx].plot(channels, counts, linewidth=0.8, 
                          label=source, color=colors[scint])
        
        axes[idx].set_xlabel('Channel', fontsize=11)
        axes[idx].set_ylabel('Counts', fontsize=11)
        axes[idx].set_title(f'{scint} Spectrum', fontsize=12, fontweight='bold')
        axes[idx].set_yscale('log')
        axes[idx].legend(fontsize=10)
        axes[idx].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print("No spectrum data available to plot")

## 1.6 Statistical Summary

In [None]:
# Calculate statistics
statistics = []

for scint in scintillators:
    if scint not in waveforms_dict or len(waveforms_dict[scint]) == 0:
        continue
    
    waveforms = waveforms_dict[scint]
    amplitudes = [w.amplitude for w in waveforms]
    baselines = [w.baseline for w in waveforms]
    
    stats = {
        'Scintillator': scint,
        'N Events': len(waveforms),
        'Amp Mean': np.mean(amplitudes),
        'Amp Std': np.std(amplitudes),
        'Amp Min': np.min(amplitudes),
        'Amp Max': np.max(amplitudes),
        'Baseline Mean': np.mean(baselines),
        'Baseline Std': np.std(baselines),
        'Waveform Length': len(waveforms[0].waveform)
    }
    statistics.append(stats)

stats_df = pd.DataFrame(statistics)
print("\n" + "="*80)
print("WAVEFORM STATISTICS SUMMARY")
print("="*80)
print(stats_df.to_string(index=False))
print("="*80)

## 1.7 Data Quality Checks

In [None]:
def check_data_quality(waveforms):
    """Comprehensive quality checks for CAEN data"""
    
    quality_metrics = {}
    
    # Baseline stability
    baselines = [w.baseline for w in waveforms]
    quality_metrics['baseline_mean'] = np.mean(baselines)
    quality_metrics['baseline_std'] = np.std(baselines)
    quality_metrics['baseline_stable'] = np.std(baselines) < 10  # Good if std < 10 ADC
    
    # Amplitude range
    amplitudes = [w.amplitude for w in waveforms]
    quality_metrics['amplitude_min'] = np.min(amplitudes)
    quality_metrics['amplitude_max'] = np.max(amplitudes)
    quality_metrics['amplitude_range'] = np.max(amplitudes) - np.min(amplitudes)
    
    # Saturation check (DT5720D: 12-bit ADC, max = 4095)
    saturated = sum(1 for w in waveforms if np.max(w.waveform) >= 4090)
    quality_metrics['saturated_count'] = saturated
    quality_metrics['saturation_fraction'] = saturated / len(waveforms)
    
    # Anomaly detection (very low amplitude)
    anomalies = sum(1 for w in waveforms if w.amplitude < 20)
    quality_metrics['anomaly_count'] = anomalies
    quality_metrics['anomaly_fraction'] = anomalies / len(waveforms)
    
    return quality_metrics

# Check quality
quality_results = []

for scint in scintillators:
    if scint not in waveforms_dict or len(waveforms_dict[scint]) == 0:
        continue
    
    metrics = check_data_quality(waveforms_dict[scint])
    metrics['Scintillator'] = scint
    quality_results.append(metrics)

quality_df = pd.DataFrame(quality_results)

print("\nDATA QUALITY REPORT")
print("="*80)
print(quality_df[['Scintillator', 'baseline_stable', 'saturation_fraction', 
                   'anomaly_fraction']].to_string(index=False))
print("="*80)

all_stable = all(quality_df['baseline_stable']) if len(quality_df) > 0 else False
print(f"\n✓ Data quality: {'PASS' if all_stable else 'CHECK'}")

## 1.8 Save Processed Data for ML

In [None]:
# Combine all waveforms into ML-ready format
processed_dir = Path('../data/processed')
processed_dir.mkdir(parents=True, exist_ok=True)

all_waveforms = []
all_labels = []
label_map = {'LYSO': 0, 'BGO': 1, 'NaI': 2, 'Plastic': 3}

for scint in scintillators:
    if scint not in waveforms_dict:
        continue
    
    for w in waveforms_dict[scint]:
        # Store baseline-corrected waveform
        wf_corrected = w.waveform - w.baseline
        all_waveforms.append(wf_corrected)
        all_labels.append(label_map[scint])

waveform_array = np.array(all_waveforms)
labels_array = np.array(all_labels)

print(f"Dataset shape: {waveform_array.shape}")
print(f"Labels shape: {labels_array.shape}")
print(f"\nClass distribution:")
for scint, label in label_map.items():
    count = np.sum(labels_array == label)
    print(f"  {scint} (label {label}): {count} events")

# Save
np.save(processed_dir / 'waveforms.npy', waveform_array)
np.save(processed_dir / 'labels.npy', labels_array)

# Save label map
import json
with open(processed_dir / 'label_map.json', 'w') as f:
    json.dump(label_map, f, indent=2)

print(f"\n✓ Processed data saved to {processed_dir}/")

## 1.9 Summary Report

In [None]:
print("="*80)
print("DATA LOADING AND EXPLORATION SUMMARY")
print("="*80)
print(f"\nScintillators analyzed: {len(waveforms_dict)}")
print(f"Total waveforms loaded: {len(all_waveforms)}")
print(f"\nWaveforms per scintillator:")
for scint in scintillators:
    if scint in waveforms_dict:
        print(f"  {scint}: {len(waveforms_dict[scint])}")

print(f"\nSpectrum files loaded: {len(spectra_dict)} scintillators")

print(f"\nData quality: {'PASS' if all_stable else 'CHECK'}")

print(f"\nSaved outputs:")
print(f"  - Processed waveforms: {processed_dir / 'waveforms.npy'}")
print(f"  - Labels: {processed_dir / 'labels.npy'}")
print(f"  - Label map: {processed_dir / 'label_map.json'}")

print("\n" + "="*80)
print("✓ Notebook complete! Data ready for:")
print("  - Notebook 02: Energy calibration (N42 spectra)")
print("  - Notebook 03: Pulse shape analysis (CSV waveforms)")
print("  - Notebook 04: ML classification")
print("="*80)