# Core Inference Function
This is the main function for prediction. It orchestrates the entire process:

1. Imports and Feature extraction are the same as the train script

1. Loads the checkpoint file.

1. Unpacks the trained model, scaler, label encoder, and all feature extraction assets.

1. Loads a single EEGdata.csv file.

1. Loops 10 times to process each trial within the file.

1. For each trial, it extracts the relevant 4-second window, preprocesses it, and generates the full feature vector.

1. It then uses the loaded model to predict the class label and confidence score.

1. Finally, it prints a summary of the results and the total processing time.

**Important note:** Change the path of the checkpoint variable to the correct one

In [None]:
import pickle
import numpy as np
import pandas as pd
from scipy.signal import cheby1, filtfilt, welch, hilbert, butter
from scipy.linalg import eigh
from sklearn.cross_decomposition import CCA
import warnings
import time

# --- Configuration & Constants ---
warnings.filterwarnings('ignore')

# ==> IMPORTANT: CHANGE THIS PATH to the trained SSVEP Checkpoint <==
CHECKPOINT_PATH = ' '

# EEG & Trial Constants
EEG_CHANNELS = ['FZ', 'C3', 'CZ', 'C4', 'PZ', 'PO7', 'OZ', 'PO8']
SAMPLING_RATE = 250
SAMPLES_PER_TRIAL_FULL = 1750  # 7s * 250Hz
NUM_HARMONICS = 5
SSVEP_FREQUENCIES_LIST = [7, 8, 10, 13]


# Time Windowing Constants
SKIP_DURATION = 2.0
DATA_DURATION = 4.0
SAMPLES_TO_SKIP = int(SKIP_DURATION * SAMPLING_RATE)
SAMPLES_TO_PROCESS = int(DATA_DURATION * SAMPLING_RATE)

# --- ========================================================= ---
# --- COMPLETE FEATURE EXTRACTION & PREPROCESSING FUNCTIONS ---
# --- ========================================================= ---

def apply_car_preprocessing(eeg_data):
    """Apply Common Average Reference."""
    return eeg_data - np.mean(eeg_data, axis=1, keepdims=True)

def apply_filterbank(eeg_data, filters):
    """Apply a bank of filters to the EEG data."""
    return np.array([filtfilt(b, a, eeg_data, axis=0) for b, a in filters])

def get_reference_signals(duration_samples, frequencies):
    """Generate sinusoidal reference signals for CCA."""
    t = np.arange(duration_samples) / SAMPLING_RATE
    reference_signals = {}
    for freq in frequencies:
        refs = []
        for h in range(1, NUM_HARMONICS + 1):
            harmonic_freq = freq * h
            refs.extend([np.sin(2 * np.pi * harmonic_freq * t), np.cos(2 * np.pi * harmonic_freq * t)])
        reference_signals[freq] = np.array(refs).T
    return reference_signals

def get_template_correlation_features(eeg_trial, subject_templates, global_templates, subject_id):
    """Extract correlation features with subject-specific templates."""
    correlations = []
    templates_to_use = subject_templates.get(subject_id, global_templates)
    class_labels = list(global_templates.keys())

    for class_label in class_labels:
        if class_label in templates_to_use:
            template = templates_to_use[class_label]
            corr_measures = []
            
            channel_corrs = [np.corrcoef(eeg_trial[:, ch], template[:, ch])[0, 1] for ch in range(eeg_trial.shape[1])]
            corr_measures.append(np.mean([c for c in channel_corrs if np.isfinite(c)]))
            
            global_corr = np.corrcoef(eeg_trial.flatten(), template.flatten())[0, 1]
            corr_measures.append(global_corr if np.isfinite(global_corr) else 0)
            
            trial_power = np.mean(eeg_trial**2, axis=0)
            template_power = np.mean(template**2, axis=0)
            power_corr = np.corrcoef(trial_power, template_power)[0, 1]
            corr_measures.append(power_corr if np.isfinite(power_corr) else 0)

            phase_corrs = []
            for ch in range(eeg_trial.shape[1]):
                trial_phase = np.angle(hilbert(eeg_trial[:, ch]))
                template_phase = np.angle(hilbert(template[:, ch]))
                phase_corr = np.abs(np.mean(np.exp(1j * (trial_phase - template_phase))))
                phase_corrs.append(phase_corr if np.isfinite(phase_corr) else 0)
            corr_measures.append(np.mean(phase_corrs))
            
            correlations.extend(corr_measures)
        else:
            correlations.extend([0, 0, 0, 0])
    return np.array(correlations)

def get_plv_features(eeg_trial):
    """Calculate Phase Locking Value between channels."""
    plv_features = []
    for target_freq in SSVEP_FREQUENCIES_LIST:
        low, high = max(1, target_freq - 1), min(125, target_freq + 1)
        b, a = butter(4, [low/(0.5*SAMPLING_RATE), high/(0.5*SAMPLING_RATE)], btype='band')
        filtered_channels = [np.angle(hilbert(filtfilt(b, a, eeg_trial[:, ch]))) for ch in range(eeg_trial.shape[1])]
        for i in range(len(filtered_channels)):
            for j in range(i+1, len(filtered_channels)):
                plv = np.abs(np.mean(np.exp(1j * (filtered_channels[i] - filtered_channels[j]))))
                plv_features.append(plv)
    return np.array(plv_features)

def get_enhanced_harmonic_features(eeg_trial):
    """Extract power and phase features for all harmonics."""
    harmonic_features = []
    for target_freq in SSVEP_FREQUENCIES_LIST:
        for h in range(1, NUM_HARMONICS + 1):
            harmonic_freq = target_freq * h
            if harmonic_freq > 60: continue
            low, high = max(1, harmonic_freq - 0.5), min(125, harmonic_freq + 0.5)
            b, a = butter(4, [low/(0.5*SAMPLING_RATE), high/(0.5*SAMPLING_RATE)], btype='band')
            channel_powers = [np.mean(filtfilt(b, a, eeg_trial[:, ch])**2) for ch in range(eeg_trial.shape[1])]
            max_p, mean_p = np.max(channel_powers), np.mean(channel_powers)
            harmonic_features.extend([max_p, mean_p, max_p / (mean_p + 1e-10)])
    return np.array(harmonic_features)

def get_enhanced_cca_features(filtered_eeg_bank, reference_signals):
    """Enhanced CCA with multiple components and better correlation measures."""
    enhanced_features = []
    for filtered_eeg in filtered_eeg_bank:
        rho_k = []
        for freq in SSVEP_FREQUENCIES_LIST:
            ref_sig = reference_signals[freq]
            cca = CCA(n_components=min(3, filtered_eeg.shape[1], ref_sig.shape[1]))
            cca.fit(filtered_eeg, ref_sig)
            X_c, Y_c = cca.transform(filtered_eeg, ref_sig)
            component_corrs = [np.corrcoef(X_c[:, c], Y_c[:, c])[0, 1] for c in range(X_c.shape[1])]
            component_corrs = [c for c in component_corrs if np.isfinite(c)]
            if not component_corrs: component_corrs = [0]
            rho_k.extend([np.max(component_corrs), np.mean(component_corrs), np.sum([w * c for w, c in zip([0.5, 0.3, 0.2], component_corrs)])])
        enhanced_features.extend(rho_k)
    return np.array(enhanced_features)

def extract_trca_features(eeg_trial, spatial_filters):
    """Enhanced TRCA features."""
    features = []
    class_labels = list(spatial_filters.keys())
    templates = {label: eeg_trial @ w if len(w) == eeg_trial.shape[1] else np.zeros(eeg_trial.shape[0]) for label, w in spatial_filters.items()}
    for label in class_labels:
        w = spatial_filters[label]
        if len(w) == eeg_trial.shape[1]:
            proj = eeg_trial @ w
            corr = np.corrcoef(proj, templates[label])[0, 1]
            p_corr = np.corrcoef(proj**2, templates[label]**2)[0, 1]
            features.extend([corr if np.isfinite(corr) else 0, p_corr if np.isfinite(p_corr) else 0])
        else:
            features.extend([0, 0])
    return np.array(features)

def get_snr_features(eeg_trial):
    """Calculate SNR at target frequencies vs. neighboring frequencies."""
    snr_features = []
    for ch in range(eeg_trial.shape[1]):
        freqs, psd = welch(eeg_trial[:, ch], fs=SAMPLING_RATE, nperseg=min(512, eeg_trial.shape[0]))
        for target_freq in SSVEP_FREQUENCIES_LIST:
            freq_idx = np.argmin(np.abs(freqs - target_freq))
            if 2 < freq_idx < len(psd) - 2:
                signal_power = psd[freq_idx]
                noise_indices = [freq_idx-2, freq_idx-1, freq_idx+1, freq_idx+2]
                noise_power = np.mean(psd[[i for i in noise_indices if 0 <= i < len(psd)]])
                snr_features.append(signal_power / (noise_power + 1e-10))
            else:
                snr_features.append(0)
    return np.array(snr_features)

def get_psd_features(eeg_trial):
    """Enhanced PSD features with better frequency resolution."""
    psd_features = []
    for ch_data in eeg_trial.T:
        nperseg = min(1000, eeg_trial.shape[0])
        freqs, psd = welch(ch_data, fs=SAMPLING_RATE, nperseg=nperseg, noverlap=nperseg//2)
        for target_freq in SSVEP_FREQUENCIES_LIST:
            for h in range(1, NUM_HARMONICS + 1):
                harmonic_freq = target_freq * h
                if harmonic_freq <= freqs[-1]:
                    idx = np.argmin(np.abs(freqs - harmonic_freq))
                    band_indices = range(max(0, idx-1), min(len(psd), idx+2))
                    psd_features.append(np.mean(psd[band_indices]))
                else:
                    psd_features.append(0)
    return np.array(psd_features)

def extract_all_features(eeg_trial, filters, reference_signals, spatial_filters, subject_templates, global_templates, subject_id):
    """Main feature extraction function that calls all other helpers."""
    filtered_eeg_bank = apply_filterbank(eeg_trial, filters)
    fbcca_feat = get_enhanced_cca_features(filtered_eeg_bank, reference_signals)
    trca_feat = extract_trca_features(eeg_trial, spatial_filters)
    plv_feat = get_plv_features(eeg_trial)
    harmonic_feat = get_enhanced_harmonic_features(eeg_trial)
    snr_feat = get_snr_features(eeg_trial)
    psd_feat = get_psd_features(eeg_trial)
    template_corr_feat = get_template_correlation_features(eeg_trial, subject_templates, global_templates, subject_id)
    return np.concatenate([trca_feat, fbcca_feat, plv_feat, harmonic_feat, snr_feat, psd_feat, template_corr_feat])

# --- ============================================ ---
# --- MAIN PREDICTION FUNCTION AND EXECUTION BLOCK ---
# --- ============================================ ---

def predict_eeg_file(checkpoint_path, eeg_file_path):
    """
    Loads a checkpoint and predicts 10 trials from a single EEGdata.csv file.
    """
    # 1. Load the checkpoint
    print(f"Loading checkpoint from: {checkpoint_path}")
    try:
        with open(checkpoint_path, 'rb') as f:
            checkpoint = pickle.load(f)
        print("✅ Checkpoint loaded successfully.")
    except FileNotFoundError:
        print(f"❌ ERROR: Checkpoint file not found at '{checkpoint_path}'. Please check the path.")
        return

    # 2. Unpack assets from the checkpoint
    model = checkpoint['model']
    scaler = checkpoint['scaler']
    le = checkpoint['label_encoder']
    filters = checkpoint['filter_bank']
    reference_signals = checkpoint['ref_signals']
    spatial_filters = checkpoint['spatial_filters']
    global_templates = checkpoint['global_templates']
    
    # 3. Load the multi-trial EEG data file
    print(f"\nLoading EEG data from: {eeg_file_path}")
    try:
        eeg_df_full = pd.read_csv(eeg_file_path)
        print("✅ EEG data loaded.")
    except FileNotFoundError:
        print(f"❌ ERROR: EEG file not found at '{eeg_file_path}'. Please check the path.")
        return
        
    # 4. Process each of the 10 trials
    print("\n--- Processing 10 Trials ---")
    results = []
    
    # Start timer for the prediction loop
    start_time = time.time()
    
    for i in range(10):
        trial_num = i + 1
        
        # Segment the data for the current trial
        start_idx = i * SAMPLES_PER_TRIAL_FULL
        end_idx = start_idx + SAMPLES_PER_TRIAL_FULL
        trial_df = eeg_df_full.iloc[start_idx:end_idx]
        
        # Apply time windowing
        eeg_data_windowed = trial_df.iloc[SAMPLES_TO_SKIP : SAMPLES_TO_SKIP + SAMPLES_TO_PROCESS]
        
        # Select channels, convert to numpy, and apply CAR
        eeg_array = apply_car_preprocessing(eeg_data_windowed[EEG_CHANNELS].values)
        
        # Extract features (using global templates as subject is unknown)
        features = extract_all_features(
            eeg_array, filters, reference_signals, spatial_filters,
            {}, global_templates, 'unknown_subject'
        ).reshape(1, -1)
        
        
        # Scale features and get probabilities
        features_scaled = scaler.transform(features)
        probabilities = model.predict_proba(features_scaled)[0]
        
        # Store prediction and confidence
        confidence = np.max(probabilities)
        prediction_idx = np.argmax(probabilities)
        prediction_label = le.inverse_transform([prediction_idx])[0]
        
        results.append({
            "trial": trial_num,
            "prediction": prediction_label,
            "confidence": confidence
        })
    
    # End timer for the prediction loop
    end_time = time.time()
    prediction_time = end_time - start_time

    # 5. Print the final results
    print("\n--- Prediction Results ---")
    for res in results:
        print(f"Trial {res['trial']:<2}:  Prediction = {res['prediction']:<10} | Confidence = {res['confidence']:.2%}")
    print("--------------------------")
    print(f"Total prediction time for 10 trials: {prediction_time:.4f} seconds")


# Running the Prediction
This is the main execution block. To run a prediction, **change the `EEG_FILE_TO_TEST` variable** to the path of the `EEGdata.csv` file you want to analyze.

In [17]:
if __name__ == '__main__':
    # ==> IMPORTANT: CHANGE THIS PATH to your EEGdata.csv file <==
    EEG_FILE_TO_TEST = '/kaggle/input/mtcaic3-phase-ii/SSVEP/test/S44/1/EEGdata.csv'
    
    predict_eeg_file(CHECKPOINT_PATH, EEG_FILE_TO_TEST)

Loading checkpoint from: /kaggle/input/ssvep-checkpoint-no-outliers-finetuned-on-abdo/scikitlearn/default/1/ssvep_checkpoint_filtered.pkl
✅ Checkpoint loaded successfully.

Loading EEG data from: /kaggle/input/mtcaic3-phase-ii/SSVEP/test/S44/1/EEGdata.csv
✅ EEG data loaded.

--- Processing 10 Trials ---

--- Prediction Results ---
Trial 1 :  Prediction = Right      | Confidence = 88.99%
Trial 2 :  Prediction = Backward   | Confidence = 87.61%
Trial 3 :  Prediction = Backward   | Confidence = 91.42%
Trial 4 :  Prediction = Right      | Confidence = 98.68%
Trial 5 :  Prediction = Backward   | Confidence = 88.12%
Trial 6 :  Prediction = Backward   | Confidence = 99.20%
Trial 7 :  Prediction = Forward    | Confidence = 95.01%
Trial 8 :  Prediction = Forward    | Confidence = 98.44%
Trial 9 :  Prediction = Forward    | Confidence = 62.78%
Trial 10:  Prediction = Left       | Confidence = 72.56%
--------------------------
Total prediction time for 10 trials: 4.1625 seconds
