# 🧠 BCI Competition: Motor Imagery (MI) Game Control Template

Welcome, teams\! This notebook is your template for the real-time Motor Imagery (MI) competition.

### Your Task

Your goal is to classify a user's intended "left" or "right" hand movement to control a game. You will only modify two functions: `load_model` and `predict`. The surrounding infrastructure for data recording and game communication is fixed.

1.  Add any necessary library imports in **Part 1**.
2.  Complete the `load_model` function in **Part 3**.
3.  Complete the `predict` function in **Part 4**.
4.  Submit this completed notebook and your model file(s).

## Part 1: Team-Specific Imports <span style="color:green">*(✍️ EDITABLE)*</span>.

In the cell below, add any libraries required to load your model and process the data (e.g., `tensorflow`, `sklearn`, `mne`, `scipy`).


In [1]:
# ==> YOUR CODE HERE <==
# Example:
# import mne
# from scipy.signal import butter, lfilter
# import joblib
import pandas as pd
import numpy as np
import os
import tensorflow as tf # Assuming you've loaded your model using tf.keras.models.load_model
import mne
SAMPLING_RATE = 250
EEG_CHANNELS = ['FZ', 'C3', 'CZ', 'C4', 'PZ', 'PO7', 'OZ', 'PO8']
T_MIN, T_MAX = 3.5, 7.5
LOW_CUT, HIGH_CUT, NOTCH_FREQ = 1.0, 40.0, 50.0

print("Team-specific libraries would be imported here.")


Team-specific libraries would be imported here.


## Part 2: Data Recording Infrastructure <span style="color:red">*(❌ DO NOT EDIT or ADD anything)*</span>.

The system uses a dedicated function to record each trial of raw EEG data from the headset. This function captures the data and passes it to your `predict` function as a pandas DataFrame. You do not need to modify this.

```python
def record_trial(inlet):
    # This function's internal logic is fixed.
    # It records 9 seconds of data and returns a DataFrame.
    ...
```

The structure of the pandas DataFrame given to `predict()` will contain **2250 rows**, which is the total number of samples for a single 9-second trial.

```
              Time            FZ            C3           CZ  ...     Gyro3    Battery  Counter  Validation
0     1.664401e+06  332287.40625  357817.12500  640024.6250  ...  0.366211  66.666672  19733.0         1.0
1     1.664401e+06  335842.93750  361377.65625  647339.8750  ...  0.274658  66.666672  19734.0         1.0
...            ...           ...           ...          ...  ...       ...        ...      ...         ...
2248  1.664410e+06  335227.37500  358243.93750  639073.8750  ...  0.732422  66.666672  21981.0         1.0
2249  1.664410e+06  331707.93750  355082.68750  631451.1250  ...  0.823975  66.666672  21982.0         1.0
```

## Part 3: Model Loading <span style="color:green">*(✍️ EDITABLE)*</span>.

Complete this function to load your trained model from the provided file path.

**Instructions:**

  * The function takes a `model_path` string as input.
  * It must load your model and return the model object.
  * Ensure the model is in evaluation/inference mode (e.g., `model.eval()`).

In [2]:
def load_model(model_path):
    """
    Load the trained MI model from the given file path.

    Args:
        model_path (str): The path to the model file.

    Returns:
        model: The loaded model object.
    """
    # ==> YOUR CODE HERE <==
    
    # --- Example for a PyTorch Model ---
    # import torch
    # # The model architecture must be defined here or imported.
    # # For MI, the model should have 2 outputs (left/right).
    # # model = YourMIModelClass(n_outputs=2, ...etc) 
    # model.load_state_dict(torch.load(model_path, map_location="cpu"))
    # model.eval()
    # return model
    loaded_model = tf.keras.models.load_model(model_path, compile=False)
    
    # Remove the 'pass' statement and add your code.
    return loaded_model

## Part 4: Preprocessing and Prediction  <span style="color:green">*(✍️ EDITABLE)*</span>.

Complete this function to process the raw trial data and make a final prediction for game control.

**Instructions:**

  * The function receives the `model` object and a `df` (a pandas DataFrame) from the `record_trial` function.
  * You must perform all necessary **preprocessing** (e.g., filtering, feature extraction) inside this function.
  * The function must return one of three specific strings: `"left"`, `"right"`, or `"?"`.
  * Returning `"?"` is crucial for preventing incorrect moves in the game when the model is not confident.

In [None]:
def predict(model, df_trial_raw):
    """
    Preprocess the raw MI data and make a prediction.

    Args:
        model: The loaded model object (TensorFlow Keras model).
        df_trial_raw (pd.DataFrame): The raw trial data, expected to be a DataFrame
                                    representing a single trial's EEG readings.

    Returns:
        str: The prediction, which must be "left" or "right". Returns "?" if an error occurs.
    """
    # ==> YOUR CODE HERE <==
    if df_trial_raw.empty:
        print("Warning: Input DataFrame for prediction is empty.")
        return "?"

    try:

        df_trial_raw.columns = df_trial_raw.columns.str.strip()
        eeg_data_subset = df_trial_raw[EEG_CHANNELS].copy()
        eeg_data_subset = eeg_data_subset.astype(np.float32)
    except KeyError as e:
        print(f"Error: Missing expected EEG channel column: {e}")
        return "?"
    except Exception as e:
        print(f"Error during initial data selection and type conversion: {e}")
        return "?"


    try:
        eeg_data_np_transposed = eeg_data_subset.values.T
        
        info = mne.create_info(ch_names=EEG_CHANNELS, sfreq=SAMPLING_RATE, ch_types='eeg', verbose=False)
        raw_trial_mne = mne.io.RawArray(eeg_data_np_transposed * 1e-6, info, verbose=False) # Assuming data is in uV

        raw_trial_mne.filter(l_freq=LOW_CUT, h_freq=HIGH_CUT, fir_design='firwin', verbose=False)
        raw_trial_mne.notch_filter(freqs=NOTCH_FREQ, verbose=False)

        epoch_data = raw_trial_mne.get_data(tmin=T_MIN, tmax=T_MAX)
        

        epoch_data_transposed = epoch_data.T
        
        mean_val = np.mean(epoch_data_transposed, axis=0)
        std_val = np.std(epoch_data_transposed, axis=0)

        normalized_epoch_data = (epoch_data_transposed - mean_val) / (std_val + 1e-8)
        
        normalized_epoch_data = normalized_epoch_data.T
        

        model_input_data = normalized_epoch_data[np.newaxis, :, :, np.newaxis]

    except Exception as e:
        print(f"Error during MNE processing or normalization: {e}")
        return "?"

    try:

        predictions = model.predict(model_input_data, verbose=0)
    except Exception as e:
        print(f"Error during model prediction: {e}")
        return "?"


    if predictions.shape[1] == 2: # Ensure it's a 2-class output
        predicted_class_index = np.argmax(predictions, axis=1)[0]
        
        if predicted_class_index == 0:
            return "left"
        elif predicted_class_index == 1:
            return "right"
        else:
            # This case should not be reached for a 2-class output if argmax works correctly.
            print(f"Warning: Unexpected prediction index: {predicted_class_index}")
            return "?"
    else:
        print(f"Error: Model returned unexpected number of classes: {predictions.shape[1]}")
        return "?"

## Part 5: Game Interface <span style="color:red">*(❌ DO NOT EDIT or ADD anything)*</span>.

After your `predict` function returns a command (e.g., `"left"`), the system passes this string to a game interface function. This function sends your prediction over the network to the game engine.

```python
def send_to_game(socket, prediction):
    # This function's internal logic is fixed.
    # It sends your prediction string ("left", "right", or "?")
    # over a ZeroMQ socket to the game.
    ...