# Adaptive LDA Online Classification

This notebook runs real-time EEG classification with Adaptive LDA:
- Loads pre-trained model and artifact rejection thresholds
- Connects to LSL streams (EEG and Markers)
- Performs online classification with sliding windows
- Adapts model parameters after each trial based on true labels
- Sends commands to game via UDP
- Tracks and visualizes accuracy over time

**Note**: This requires active EEG and marker streams from LSL.

## 1. Imports and Setup

In [2]:
# Install missing dependencies if needed
try:
    from pylsl import StreamInlet, resolve_streams
except ImportError:
    import subprocess
    import sys
    print("Installing pylsl...")
    subprocess.check_call([sys.executable, "-m", "pip", "install", "pylsl"])
    from pylsl import StreamInlet, resolve_streams
    print("‚úì pylsl installed successfully")

import pickle
import socket
import sys
import time
from pathlib import Path
import importlib

import numpy as np
import matplotlib.pyplot as plt

# Add src directory to Python path
src_dir = Path.cwd() / "src"
if str(src_dir) not in sys.path:
    sys.path.insert(0, str(src_dir))

# Preprocessing
from bci.Preprocessing.filters import Filter

# Transfer function for game control
from bci.transfer.transfer import BCIController

# Models
from bci.Models.AdaptiveLDA import AdaptiveLDA

# Utils
from bci.utils.bci_config import load_config

# Configure plotting
%matplotlib inline
plt.rcParams['figure.figsize'] = (12, 6)

print("‚úì All imports successful!")

Installing pylsl...
Collecting pylsl
  Downloading pylsl-1.18.1-py2.py3-none-macosx_11_0_universal2.whl.metadata (6.1 kB)
Downloading pylsl-1.18.1-py2.py3-none-macosx_11_0_universal2.whl (663 kB)
[2K   [90m‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ‚îÅ[0m [32m663.6/663.6 kB[0m [31m12.6 MB/s[0m  [33m0:00:00[0m
[?25hInstalling collected packages: pylsl
Successfully installed pylsl-1.18.1
‚úì pylsl installed successfully
‚úì All imports successful!


## 2. Load Configuration and Model

In [3]:
# Load configuration
current_wd = Path.cwd()

try:
    config_path = current_wd / "resources" / "configs" / "bci_config.yaml"
    print(f"Loading configuration from: {config_path}")
    config = load_config(config_path)
    print("‚úì Configuration loaded successfully!")
    
    print(f"\nKey Parameters:")
    print(f"  Sampling frequency: {config.fs} Hz")
    print(f"  Frequency range: {config.frequencies} Hz")
    print(f"  Window size: {config.window_size} samples ({config.window_size/config.fs:.2f} seconds)")
    print(f"  Step size: {config.step_size} samples ({config.step_size/config.fs:.2f} seconds)")
    print(f"  Online mode: {config.online}")
except Exception as e:
    print(f"‚ùå Error loading config: {e}")
    raise

# Marker definitions
markers = {
    0: "unknown",
    1: "rest",
    2: "left_hand",
    3: "right_hand"
}

print(f"\n‚úì Marker definitions:")
for marker_id, marker_name in markers.items():
    print(f"  {marker_id}: {marker_name}")

# Initialize random seed
np.random.seed(config.random_state)

Loading configuration from: /Users/amalbenslimen/BCI Challenge /BCI-Challenge/resources/configs/bci_config.yaml
‚úì Configuration loaded successfully!

Key Parameters:
  Sampling frequency: 160.0 Hz
  Frequency range: [8.0, 30.0] Hz
  Window size: 250 samples (1.56 seconds)
  Step size: 32 samples (0.20 seconds)
  Online mode: dino

‚úì Marker definitions:
  0: unknown
  1: rest
  2: left_hand
  3: right_hand


In [4]:
# Load trained model
model_path = current_wd / "resources" / "models" / "adaptivelda_model.pkl"
artefact_rejection_path = current_wd / "resources" / "models" / "adaptivelda_artefact_removal.pkl"

print(f"Loading Adaptive LDA model from: {model_path}")
if not model_path.exists():
    raise FileNotFoundError(f"Model not found: {model_path}\nPlease train the model first using main_offline_AdaptiveLDA.ipynb")

clf = AdaptiveLDA.load(model_path)
print("‚úì Model loaded successfully!")
print(f"  Classes: {clf.classes}")
print(f"  Number of features: {clf.n_features}")

# Load artifact rejection thresholds
print(f"\nLoading artifact rejection from: {artefact_rejection_path}")
if not artefact_rejection_path.exists():
    print("‚ö†Ô∏è  Warning: Artifact rejection file not found. Continuing without it.")
    ar = None
else:
    ar = pickle.load(open(artefact_rejection_path, "rb"))
    print("‚úì Artifact rejection thresholds loaded!")

Loading Adaptive LDA model from: /Users/amalbenslimen/BCI Challenge /BCI-Challenge/resources/models/adaptivelda_model.pkl
‚úì Model loaded successfully!
  Classes: [0 1 2]
  Number of features: 176

Loading artifact rejection from: /Users/amalbenslimen/BCI Challenge /BCI-Challenge/resources/models/adaptivelda_artefact_removal.pkl
‚úì Artifact rejection thresholds loaded!


## 3. Initialize Preprocessing and Controller

In [5]:
# Initialize filter for online processing
filter = Filter(config, online=True)
print("‚úì Filter initialized (online mode)")

# Initialize transfer function for sending commands to game
controller = BCIController(config)
print("‚úì BCI Controller initialized")

# Initialize data buffers
buffer = np.zeros((len(config.channels), int(config.window_size)), dtype=np.float32)
label_buffer = np.zeros((1, int(config.window_size)), dtype=np.int32)

print(f"‚úì Data buffers initialized:")
print(f"  Buffer shape: {buffer.shape}")
print(f"  Label buffer shape: {label_buffer.shape}")

# Statistics tracking
avg_time_per_classification = 0.0
number_of_classifications = 0
total_fails = 0
total_successes = 0
total_predictions = 0
total_rejected = 0
total_adaptations = 0

# Probability threshold for accepting predictions
probability_threshold = config.classification_threshold if hasattr(config, 'classification_threshold') else 0.6
print(f"\n‚úì Statistics initialized")
print(f"  Probability threshold: {probability_threshold}")

# For visualization: track accuracy over time
accuracy_history = []
window_accuracies = []
window_size_viz = 20  # Calculate rolling accuracy every 20 predictions

print("\n‚úì Preprocessing and model objects initialized!")

‚úì Filter initialized (online mode)
‚úì BCI Controller initialized
‚úì Data buffers initialized:
  Buffer shape: (16, 250)
  Label buffer shape: (1, 250)

‚úì Statistics initialized
  Probability threshold: 0.6

‚úì Preprocessing and model objects initialized!


## 4. Connect to LSL Streams

In [6]:
# Find the EEG stream from LSL and establish connection
print("Looking for EEG and Markers streams...")
streams = resolve_streams(wait_time=5.0)

eeg_streams = [s for s in streams if s.type() == "EEG"]
if config.online == "dino":
    label_streams = [
        s for s in streams 
        if s.type() == "Markers" and s.name() == "MyDinoGameMarkerStream"
    ]
else:
    label_streams = [
        s for s in streams 
        if s.type() == "Markers" and s.name() == "Labels_Stream"
    ]

if not eeg_streams:
    raise RuntimeError("‚ùå Could not find EEG stream. Make sure your EEG stream is running.")
if not label_streams:
    raise RuntimeError("‚ùå Could not find Markers stream. Make sure your marker stream is running.")

inlet = StreamInlet(eeg_streams[0], max_chunklen=32)
inlet_labels = StreamInlet(label_streams[0], max_chunklen=32)

print(f"‚úì Connected to EEG stream: {eeg_streams[0].name()}")
print(f"‚úì Connected to Labels stream: {label_streams[0].name()}")
print(f"  EEG channels: {inlet.info().channel_count()}")
print(f"  Sampling rate: {inlet.info().nominal_srate()} Hz")

print("\n" + "="*60)
print("READY FOR ONLINE ADAPTIVE LDA CLASSIFICATION")
print("="*60)
print("The model will adapt its parameters after each trial!")
print("="*60 + "\n")

Looking for EEG and Markers streams...


2026-01-22 00:10:13.261 (  24.461s) [          2FD95D]      netinterfaces.cpp:91    INFO| netif 'lo0' (status: 1, multicast: 32768, broadcast: 0)
2026-01-22 00:10:13.261 (  24.461s) [          2FD95D]      netinterfaces.cpp:91    INFO| netif 'lo0' (status: 1, multicast: 32768, broadcast: 0)
2026-01-22 00:10:13.261 (  24.461s) [          2FD95D]      netinterfaces.cpp:102   INFO| 	IPv4 addr: 7f000001
2026-01-22 00:10:13.261 (  24.461s) [          2FD95D]      netinterfaces.cpp:91    INFO| netif 'lo0' (status: 1, multicast: 32768, broadcast: 0)
2026-01-22 00:10:13.261 (  24.461s) [          2FD95D]      netinterfaces.cpp:105   INFO| 	IPv6 addr: ::1
2026-01-22 00:10:13.261 (  24.461s) [          2FD95D]      netinterfaces.cpp:91    INFO| netif 'lo0' (status: 1, multicast: 32768, broadcast: 0)
2026-01-22 00:10:13.261 (  24.461s) [          2FD95D]      netinterfaces.cpp:105   INFO| 	IPv6 addr: fe80::1%lo0
2026-01-22 00:10:13.261 (  24.461s) [          2FD95D]      netinterfaces.cpp:91    I

RuntimeError: ‚ùå Could not find EEG stream. Make sure your EEG stream is running.

## 5. Online Classification Loop

**‚ö†Ô∏è Warning**: The cell below will run continuously until interrupted. Use the **Stop** button or `Ctrl+C` to stop.

In [7]:
# Reset statistics
avg_time_per_classification = 0.0
number_of_classifications = 0
total_fails = 0
total_successes = 0
total_predictions = 0
total_rejected = 0
total_adaptations = 0
accuracy_history = []
window_accuracies = []

previous_label = 0  # Track previous label to detect trial boundaries
trial_buffer = None  # Store data for one trial
trial_true_label = None  # Store true label for adaptation

# Create UDP socket for game communication
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)

print("üü¢ Starting online classification...")
print("Press Stop button or interrupt kernel to stop.\n")

try:
    while True:
        start_classification_time = time.time() * 1000  # in milliseconds
        eeg_chunk, timestamp = inlet.pull_chunk()
        labels_chunk, label_timestamp = inlet_labels.pull_chunk()
        crt_label = None

        # Check if sample and labels are valid and non-empty
        if eeg_chunk:
            # Convert to numpy arrays and transpose to (n_channels, n_samples)
            eeg_chunk = np.array(eeg_chunk).T  # shape (n_channels, n_samples)
            n_new_samples = eeg_chunk.shape[1]

            # Safety: If new data is larger than the buffer, just take the end of it
            if n_new_samples >= config.window_size:
                buffer = eeg_chunk[:, -config.window_size :]
            else:
                # Update the buffers with the new chunks of data
                buffer[:, :-n_new_samples] = buffer[:, n_new_samples:]
                buffer[:, -n_new_samples:] = eeg_chunk

        if labels_chunk:
            labels_chunk = np.array(labels_chunk).T  # shape (1, n_samples)
            n_new_labels = labels_chunk.shape[1]

            if n_new_labels >= config.window_size:
                label_buffer = labels_chunk[:, -config.window_size :]
            else:
                label_buffer[:, :-n_new_labels] = label_buffer[:, n_new_labels:]
                label_buffer[:, -n_new_labels:] = labels_chunk

        # Extract the current label (most present in the buffer)
        unique, counts = np.unique(label_buffer, return_counts=True)
        if len(unique) > 0:
            label_counts = dict(zip(unique, counts))
            crt_label = max(label_counts, key=lambda k: label_counts[k])
        else:
            crt_label = 0  # fallback to unknown

        # Detect trial boundary (label changed from non-zero to different non-zero)
        if previous_label != 0 and crt_label != previous_label and crt_label != 0:
            # Trial just ended! Adapt the model with previous trial data
            if trial_buffer is not None and trial_true_label is not None and trial_true_label != 0:
                try:
                    # Adapt model parameters based on completed trial
                    clf.update(trial_buffer, trial_true_label)
                    total_adaptations += 1
                    print(f"üîÑ Adapted model (Trial ended: {markers.get(trial_true_label, 'unknown')} ‚Üí {markers.get(crt_label, 'unknown')})")
                except Exception as e:
                    print(f"‚ö†Ô∏è  Adaptation failed: {e}")

            # Reset trial buffer for new trial
            trial_buffer = None
            trial_true_label = None

        # Store current trial data
        if crt_label != 0:
            trial_buffer = buffer.copy()
            trial_true_label = crt_label

        previous_label = crt_label

        # Filter the data
        filtered_data = filter.apply_filter_online(buffer)

        # Reshape for prediction: (1, n_channels, n_samples)
        filtered_data = filtered_data[np.newaxis, :, :]

        # Create the features and classify
        probabilities = clf.predict_proba(filtered_data)

        if probabilities is None:
            print("‚ö†Ô∏è  Warning: Model returned None for probability.")
            continue  # skip this iteration

        # Send command to game
        controller.send_command(probabilities, sock)

        # Get prediction
        prediction = np.argmax(probabilities, axis=1)[0]

        # Print classification result
        print(
            f"Label: {crt_label} ({markers.get(crt_label, 'unknown')}) | "
            f"Predicted: {prediction} ({markers.get(prediction, 'unknown')}) | "
            f"Conf: {probabilities[0][prediction]:.2%} | "
            f"Adaptations: {total_adaptations}"
        )

        total_predictions += 1

        # Track accuracy for non-unknown labels
        if crt_label != 0:
            is_correct = (prediction) == crt_label  # Note: labels are now 0-indexed
            accuracy_history.append(1 if is_correct else 0)

            if probabilities[0][prediction] < probability_threshold:
                total_rejected += 1
            else:
                total_successes += int(is_correct)
                total_fails += int(not is_correct)

            # Calculate rolling accuracy
            if len(accuracy_history) >= window_size_viz:
                rolling_acc = np.mean(accuracy_history[-window_size_viz:])
                window_accuracies.append(rolling_acc)
                if len(window_accuracies) % 5 == 0:  # Print every 5 windows
                    print(f"üìä Rolling accuracy (last {window_size_viz}): {rolling_acc:.2%}")

        number_of_classifications += 1

        end_classification_time = time.time() * 1000  # in milliseconds
        avg_time_per_classification += (
            end_classification_time - start_classification_time
        )

except KeyboardInterrupt:
    print("\n" + "="*60)
    print("STOPPING ONLINE PROCESSING")
    print("="*60)
finally:
    sock.close()
    print("‚úì Socket closed")

üü¢ Starting online classification...
Press Stop button or interrupt kernel to stop.

‚úì Socket closed


NameError: name 'inlet' is not defined

## 6. Display Results and Statistics

In [None]:
# Display final statistics
print("="*60)
print("FINAL STATISTICS")
print("="*60)
print(f"Avg time per loop: {avg_time_per_classification / max(1, number_of_classifications):.2f} ms")
print(f"Total Predictions: {total_predictions}")
print(f"  Rejected: {total_rejected}")
print(f"  Accepted Successes: {total_successes}")
print(f"  Accepted Fails: {total_fails}")
print(f"Total Adaptations: {total_adaptations}")

if total_successes + total_fails > 0:
    final_accuracy = total_successes / (total_successes + total_fails)
    print(f"\nFinal Accuracy (accepted predictions): {final_accuracy:.2%}")

if len(accuracy_history) > 0:
    overall_accuracy = np.mean(accuracy_history)
    print(f"Overall Accuracy (all predictions): {overall_accuracy:.2%}")
    print(f"Total labeled trials: {len(accuracy_history)}")

print("="*60)

## 7. Visualize Accuracy Over Time

In [None]:
# Create accuracy visualization
if len(accuracy_history) > 0:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))

    # Plot 1: Raw accuracy over trials
    axes[0].plot(accuracy_history, 'o-', alpha=0.6, markersize=3)
    mean_acc = np.mean(accuracy_history)
    axes[0].axhline(y=mean_acc, color='r', linestyle='--',
                   label=f'Mean: {mean_acc:.2%}')
    axes[0].set_xlabel('Trial Number', fontweight='bold')
    axes[0].set_ylabel('Correct (1) / Incorrect (0)', fontweight='bold')
    axes[0].set_title('Classification Accuracy Over Time', fontweight='bold')
    axes[0].legend()
    axes[0].grid(alpha=0.3)
    axes[0].set_ylim([-0.1, 1.1])

    # Plot 2: Rolling accuracy
    if len(window_accuracies) > 0:
        axes[1].plot(window_accuracies, 'o-', color='steelblue', linewidth=2)
        mean_rolling = np.mean(window_accuracies)
        axes[1].axhline(y=mean_rolling, color='r', linestyle='--',
                       label=f'Mean: {mean_rolling:.2%}')
        axes[1].set_xlabel(f'Window Number (size={window_size_viz})', fontweight='bold')
        axes[1].set_ylabel('Rolling Accuracy', fontweight='bold')
        axes[1].set_title(f'Rolling Accuracy (Window={window_size_viz} trials)', fontweight='bold')
        axes[1].legend()
        axes[1].grid(alpha=0.3)
        axes[1].set_ylim([0, 1.0])
    else:
        axes[1].text(0.5, 0.5, f'Need at least {window_size_viz} trials\nfor rolling accuracy',
                    ha='center', va='center', transform=axes[1].transAxes,
                    fontsize=12, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
        axes[1].set_title('Rolling Accuracy', fontweight='bold')

    plt.tight_layout()
    
    # Save plot
    plot_path = current_wd / "adaptive_lda_online_accuracy.png"
    plt.savefig(plot_path, dpi=150, bbox_inches='tight')
    print(f"‚úì Accuracy plot saved: {plot_path}")
    
    plt.show()
else:
    print("‚ö†Ô∏è  No accuracy data collected. Make sure labels were received during classification.")