# QTCR-Net: Quantum Temporal Convolutional Reservoir Network

## End-to-End Demo for DVS128 Event Camera Classification

This notebook demonstrates the complete workflow for training and evaluating QTCR-Net, a novel quantum-hybrid neural network architecture for event-based vision.

**Architecture Highlights:**
- Fully convolutional spatio-temporal feature extraction (NO MLP encoders)
- Temporal Convolutional Network (TCN) with multi-scale dilations
- Multiple quantum temporal reservoirs (PennyLane QNodes)
- Dual-head classification (waveform + voltage)
- Hybrid quantum-classical causal temporal modeling

**Author:** QTCR-Net Research Team  
**Date:** 2025

## 1. Setup and Imports

In [None]:
import os
import sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import yaml
import torch
import pennylane as qml
from tqdm.auto import tqdm

# Set plotting style
sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

# Check GPU availability
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"CUDA Version: {torch.version.cuda}")

## 2. Load Configuration

In [None]:
# Load configuration
with open('config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("Configuration loaded successfully!")
print(f"\nExperiment: {config['experiment']['name']}")
print(f"Description: {config['experiment']['description']}")
print(f"\nData Configuration:")
print(f"  Window duration: {config['data']['window']['duration_sec']}s")
print(f"  Temporal bins: {config['data']['window']['temporal_bins']}")
print(f"  Spatial patch size: {config['data']['spatial']['patch_size']}x{config['data']['spatial']['patch_size']}")
print(f"\nModel Configuration:")
print(f"  Quantum groups: {config['model']['quantum_reservoir']['num_groups']}")
print(f"  Qubits per group: {config['model']['quantum_reservoir']['qubits_per_group']}")
print(f"  Quantum layers: {config['model']['quantum_reservoir']['circuit']['num_layers']}")

## 3. Data Preprocessing

Convert raw CSV event streams to voxel grid representations.

In [None]:
# Check if preprocessing is needed
manifest_path = Path(config['data']['manifest_path'])

if not manifest_path.exists():
    print("Manifest not found. Running preprocessing...")
    print("\nNote: This may take a while for large datasets.")
    print("For testing, you can limit the number of files/windows in preprocess.py\n")
    
    # Run preprocessing
    !python preprocess.py --config config.yaml
else:
    print(f"Manifest already exists at: {manifest_path}")
    print("Skipping preprocessing. To reprocess, delete the manifest file.")

In [None]:
# Load and inspect manifest
manifest = pd.read_csv(manifest_path)

print(f"Total samples: {len(manifest)}")
print(f"\nFirst few samples:")
display(manifest.head())

print(f"\nLabel distributions:")
print("\nWaveform:")
print(manifest['waveform_label'].value_counts())
print("\nVoltage:")
print(manifest['voltage_label'].value_counts())

## 4. Visualize Voxel Grids

Load and visualize some example voxel grids.

In [None]:
def visualize_voxel_grid(voxel_path, title="Voxel Grid"):
    """
    Visualize a voxel grid from a .npy file.
    
    Args:
        voxel_path: Path to .npy file
        title: Plot title
    """
    voxel = np.load(voxel_path)  # [C, T, H, W]
    C, T, H, W = voxel.shape
    
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    
    # Show 4 time slices for each polarity
    time_indices = np.linspace(0, T-1, 4, dtype=int)
    
    for i, t_idx in enumerate(time_indices):
        # ON events (polarity 0)
        axes[0, i].imshow(voxel[0, t_idx], cmap='hot', interpolation='nearest')
        axes[0, i].set_title(f'ON Events - t={t_idx}/{T}')
        axes[0, i].axis('off')
        
        # OFF events (polarity 1)
        axes[1, i].imshow(voxel[1, t_idx], cmap='cool', interpolation='nearest')
        axes[1, i].set_title(f'OFF Events - t={t_idx}/{T}')
        axes[1, i].axis('off')
    
    plt.suptitle(title, fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.show()
    
    print(f"Voxel shape: {voxel.shape}")
    print(f"Value range: [{voxel.min():.3f}, {voxel.max():.3f}]")
    print(f"Non-zero elements: {(voxel != 0).sum()} / {voxel.size}")

In [None]:
# Visualize a few random samples
num_samples_to_viz = 2
sample_indices = np.random.choice(len(manifest), num_samples_to_viz, replace=False)

for idx in sample_indices:
    sample = manifest.iloc[idx]
    voxel_path = sample['npy_path']
    waveform = sample['waveform_label']
    voltage = sample['voltage_label']
    
    title = f"Waveform: {waveform}, Voltage: {voltage}"
    visualize_voxel_grid(voxel_path, title=title)

## 5. Create Dataset and DataLoaders

In [None]:
from dataset import create_dataloaders

# Create data loaders
train_loader, val_loader, test_loader = create_dataloaders(config, train_augment=True)

print(f"Train batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")
print(f"Test batches: {len(test_loader)}")

In [None]:
# Inspect a batch
batch_voxels, batch_labels = next(iter(train_loader))

print(f"Batch voxel shape: {batch_voxels.shape}")
print(f"Batch voxel dtype: {batch_voxels.dtype}")
print(f"Waveform labels: {batch_labels['waveform']}")
print(f"Voltage labels: {batch_labels['voltage']}")

## 6. Build QTCR-Net Model

Construct the quantum-hybrid architecture.

In [None]:
from qtcr_model import QTCRNet

# Create model
model = QTCRNet(config).to(device)

# Print model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Summary:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Memory (approx): {total_params * 4 / 1024**2:.2f} MB")

In [None]:
# Test forward pass
with torch.no_grad():
    test_input = batch_voxels[:2].to(device)  # Test with 2 samples
    waveform_logits, voltage_logits = model(test_input)

print(f"Test input shape: {test_input.shape}")
print(f"Waveform logits shape: {waveform_logits.shape}")
print(f"Voltage logits shape: {voltage_logits.shape}")
print("\nForward pass successful!")

## 7. Visualize Quantum Circuit

Visualize one of the quantum reservoir circuits.

In [None]:
# Access the first quantum reservoir
quantum_layer = model.quantum_layer
first_reservoir = quantum_layer.reservoirs[0]

print(f"First Quantum Reservoir:")
print(f"  Qubits: {first_reservoir.num_qubits}")
print(f"  Layers: {first_reservoir.num_layers}")
print(f"  Entanglement: {first_reservoir.entanglement}")
print(f"  Trainable: {first_reservoir.trainable}")

# Draw circuit (requires pennylane[matplotlib])
try:
    # Create dummy input
    dummy_features = torch.randn(first_reservoir.num_qubits)
    dummy_params = first_reservoir.quantum_params.detach()
    
    # Draw circuit
    fig, ax = qml.draw_mpl(first_reservoir.qnode)(dummy_features, dummy_params)
    plt.title("Quantum Reservoir Circuit (Single Reservoir)", fontweight='bold')
    plt.tight_layout()
    plt.show()
except Exception as e:
    print(f"Could not draw circuit: {e}")
    print("Install pennylane[matplotlib] for circuit visualization.")

## 8. Train QTCR-Net

Train the model with dual-head losses, AMP, and early stopping.

**Note:** For a full training run, use `train.py` from the command line. This notebook shows a short training demo.

In [None]:
from train import QTCRNetTrainer

# Option 1: Train in notebook (short demo)
# Reduce epochs for quick demo
config_demo = config.copy()
config_demo['training']['num_epochs'] = 5  # Short demo
config_demo['training']['early_stopping']['enabled'] = False

print("Training QTCR-Net (demo with 5 epochs)...")
print("For full training, use: python train.py --config config.yaml\n")

trainer = QTCRNetTrainer(config_demo)
trainer.train()

In [None]:
# Option 2: Run full training from command line
# Uncomment to run full training

# !python train.py --config config.yaml

## 9. Visualize Training Progress

Load and plot training metrics from TensorBoard logs.

In [None]:
# Load TensorBoard logs
from tensorboard.backend.event_processing import event_accumulator

log_dir = Path(config['training']['logging']['tensorboard_dir'])

# Find most recent run
runs = sorted(log_dir.glob('*'))
if len(runs) > 0:
    latest_run = runs[-1]
    print(f"Loading logs from: {latest_run}")
    
    # Load event file
    event_file = list(latest_run.glob('events.out.tfevents.*'))[0]
    ea = event_accumulator.EventAccumulator(str(event_file))
    ea.Reload()
    
    # Extract metrics
    train_loss = ea.Scalars('Epoch/Train_Loss')
    val_loss = ea.Scalars('Epoch/Val_Loss')
    train_wave_acc = ea.Scalars('Epoch/Train_Waveform_Acc')
    val_wave_acc = ea.Scalars('Epoch/Val_Waveform_Acc')
    
    # Plot
    fig, axes = plt.subplots(1, 2, figsize=(16, 5))
    
    # Loss
    axes[0].plot([x.step for x in train_loss], [x.value for x in train_loss], label='Train', linewidth=2)
    axes[0].plot([x.step for x in val_loss], [x.value for x in val_loss], label='Val', linewidth=2)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Loss')
    axes[0].set_title('Training and Validation Loss', fontweight='bold')
    axes[0].legend()
    axes[0].grid(alpha=0.3)
    
    # Accuracy
    axes[1].plot([x.step for x in train_wave_acc], [x.value for x in train_wave_acc], label='Train', linewidth=2)
    axes[1].plot([x.step for x in val_wave_acc], [x.value for x in val_wave_acc], label='Val', linewidth=2)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Accuracy (%)')
    axes[1].set_title('Waveform Classification Accuracy', fontweight='bold')
    axes[1].legend()
    axes[1].grid(alpha=0.3)
    
    plt.tight_layout()
    plt.show()
else:
    print("No training logs found. Train the model first.")

## 10. Evaluate Trained Model

Load the best checkpoint and evaluate on the test set.

In [None]:
from eval import QTCRNetEvaluator

# Load best checkpoint
checkpoint_dir = Path(config['training']['checkpoint']['save_dir'])
best_checkpoint = checkpoint_dir / 'best_model.pth'

if best_checkpoint.exists():
    print(f"Loading best model from: {best_checkpoint}")
    
    # Create evaluator
    evaluator = QTCRNetEvaluator(config, str(best_checkpoint))
    
    # Evaluate
    results = evaluator.evaluate()
    evaluator.results = results
    
    # Print results
    evaluator.print_results(results)
else:
    print(f"Checkpoint not found at: {best_checkpoint}")
    print("Train the model first using train.py")

## 11. Confusion Matrices

In [None]:
if best_checkpoint.exists():
    # Plot confusion matrices
    evaluator.plot_confusion_matrices(save=False)
else:
    print("Train and evaluate the model first.")

## 12. Per-Class Performance

In [None]:
if best_checkpoint.exists():
    # Plot per-class metrics
    evaluator.plot_per_class_metrics(save=False)
else:
    print("Train and evaluate the model first.")

## 13. Feature Visualization

Extract and visualize intermediate feature maps.

In [None]:
if best_checkpoint.exists():
    # Load model
    model_eval = evaluator.model
    model_eval.eval()
    
    # Get a test sample
    test_voxel, test_labels = next(iter(test_loader))
    test_voxel = test_voxel[:1].to(device)  # Single sample
    
    # Extract feature maps
    with torch.no_grad():
        feature_maps = model_eval.get_feature_maps(test_voxel)
    
    print("Feature map shapes:")
    for name, feat in feature_maps.items():
        print(f"  {name}: {feat.shape}")
    
    # Visualize quantum features
    quantum_features = feature_maps['quantum'].cpu().numpy()[0]
    
    plt.figure(figsize=(14, 4))
    plt.bar(range(len(quantum_features)), quantum_features)
    plt.xlabel('Quantum Feature Index')
    plt.ylabel('Value')
    plt.title('Quantum Reservoir Output Features', fontweight='bold')
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.show()
else:
    print("Train the model first.")

## 14. Summary and Next Steps

### QTCR-Net Architecture Summary

**Novel Contributions:**
1. **Fully Convolutional Feature Extraction**: Unlike traditional MLP-based approaches, QTCR-Net uses pure convolutional layers for spatio-temporal feature extraction, preserving spatial structure.

2. **Temporal Convolutional Network (TCN)**: Multi-scale dilated convolutions (1, 2, 4, 8) capture temporal patterns at different timescales, crucial for event-based data.

3. **Multi-Reservoir Quantum Architecture**: Instead of one large quantum circuit (prone to barren plateaus), QTCR-Net uses multiple small quantum reservoirs (4-8 qubits each), each processing different feature groups.

4. **Quantum Temporal Reservoirs**: Frozen random quantum circuits act as high-dimensional nonlinear feature extractors (reservoir computing paradigm), avoiding trainability issues.

5. **Hybrid Quantum-Classical Learning**: Classical TCN learns temporal patterns, quantum reservoirs provide nonlinear transformations, enabling synergistic learning.

### Next Steps

1. **Hyperparameter Tuning**: Experiment with different numbers of quantum groups, qubits, and TCN dilations.

2. **Fine-tuning Quantum Layers**: Enable partial training of quantum parameters (set `trainable_quantum: true` in config).

3. **Comparison Studies**: Compare QTCR-Net against classical CNNs, TCNs, and MLP+FFT baselines.

4. **Hardware Deployment**: Test on quantum hardware (IBM Quantum, IonQ) for real quantum advantage.

5. **Transfer Learning**: Pre-train on large event-camera datasets and fine-tune on your specific task.

6. **Architecture Ablation**: Study the contribution of each component (TCN, quantum reservoirs, number of groups).

### Citation

If you use QTCR-Net in your research, please cite:

```
@misc{qtcrnet2025,
  title={QTCR-Net: Quantum Temporal Convolutional Reservoir Network for Event-Based Vision},
  author={QTCR-Net Research Team},
  year={2025},
  note={Novel quantum-hybrid architecture for DVS event camera classification}
}
```

---

## End of Demo

For questions or issues, please refer to the README.md or contact the research team.