In [1]:
import importlib
from typing import List, Tuple
import pandas as pd
import numpy as np

from constants import *
import utils
importlib.reload(utils)

<module 'utils' from '/Users/adam/workspace/Classifier-Builder/lol_classifier/utils.py'>

In [2]:
from glob import glob
import os

datasets = glob(os.path.join(DATASET_PATH, f"*/{DATASET_POSTFIX}"))
print(f"Found {len(datasets)} datasets")

Found 1 datasets


In [15]:
utils.clear_directories([TRAINING_DATA_PATH, TESTING_DATA_PATH])
for path in datasets:
    filename = path.split('/')[-2]

    df = utils.read_eeg_data(path)
    df.drop(df[df.iloc[:, -1] == 'd'].index, inplace=True)  # drop flashes
    df.iloc[:SAMPLING_RATE * 3, -1] = 'neutral'

    # Apply notch filter at 50 Hz and adjust band-pass filter based on Nyquist frequency
    import mne

    ch_names = df.columns[:-1].tolist()
    ch_names = [str(ch) for ch in ch_names]  # Ensure all channel names are strings
    info = mne.create_info(ch_names=ch_names, sfreq=SAMPLING_RATE, ch_types="eeg")
    raw = mne.io.RawArray(df.iloc[:, :-1].T.values, info)

    # Apply notch filter at 50 Hz
    # raw.notch_filter(freqs=50) # Notch filter is not necessary since SAMPLING_RATE is 25 

    # Determine safe band-pass filter range
    nyquist_freq = SAMPLING_RATE / 2
    low_freq = 1.0
    high_freq = 10.0

    # Apply band-pass filter
    raw.filter(l_freq=low_freq, h_freq=high_freq)

    # Convert back to DataFrame
    df.iloc[:, :-1] = raw.get_data().T

    markers = df.iloc[:, -1]
    print(f"From {filename} with shape:{df.shape} and markers:{len(markers[markers != '0'])} kills:{len(markers[markers == 'kill'])}, deaths:{len(markers[markers == 'death'])}")

    prev_marker_and_countdown = ['0', 0]
    df.iloc[:, -1] = markers.apply(lambda x: utils.propagate_events(x, prev_marker_and_countdown))

    df.drop(df[markers == '0'].index, inplace=True)  
    df = df.sample(frac=1, random_state=42).reset_index(drop=True)
    utils.split_and_save(df, f"{filename}.csv")


Creating RawArray with float64 data, n_channels=40, n_times=41331
    Range : 0 ... 41330 =      0.000 ...  1653.200 secs
Ready.
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 10 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 10.00 Hz
- Upper transition bandwidth: 2.50 Hz (-6 dB cutoff frequency: 11.25 Hz)
- Filter length: 83 samples (3.320 s)

From untitled folder with shape:(41331, 41) and markers:83 kills:1, deaths:7


[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


In [4]:
train, test = utils.load_datasets()
train.shape, test.shape 

Number of kills=60; deaths=420; neutral=119


((599, 41), (150, 41))

In [5]:
train[np.isnan(train).any(axis=1)]


array([], shape=(0, 41), dtype=float64)

In [7]:
from sklearn.svm import SVC

X_train, y_train = train[:, :-1], train[:, -1]
X_test, y_test = test[:, :-1], test[:, -1]

svm = SVC(kernel='poly', degree=3, C=1, random_state=42, decision_function_shape='ovr', probability=True)
svm.fit(X_train, y_train)

In [8]:
print('accuracy:', svm.score(X_test, y_test))

from sklearn.metrics import classification_report, confusion_matrix
y_pred = svm.predict(X_test)

utils.present_confusion_matrix(confusion_matrix(y_test, y_pred))
utils.present_metrics(classification_report(y_test, y_pred, target_names=STATES, output_dict=True))

accuracy: 1.0
Confusion Matrix:
+-----------------+-----------------+------------------+--------------------+
|                 | Predicted: kill | Predicted: death | Predicted: neutral |
+-----------------+-----------------+------------------+--------------------+
|  Actual: kill   |       15        |        0         |         0          |
|  Actual: death  |        0        |       105        |         0          |
| Actual: neutral |        0        |        0         |         30         |
+-----------------+-----------------+------------------+--------------------+

Classification Report:
+--------------+-----------+--------+----------+---------+
|              | precision | recall | f1-score | support |
+--------------+-----------+--------+----------+---------+
|     kill     |    1.0    |  1.0   |   1.0    |  15.0   |
|    death     |    1.0    |  1.0   |   1.0    |  105.0  |
|   neutral    |    1.0    |  1.0   |   1.0    |  30.0   |
|   accuracy   |    1.0    |  1.0   |   1.0 

In [16]:
from skl2onnx import to_onnx
onx = to_onnx(svm, X_train[:1].astype(np.float32), target_opset=12)
with open("output/svm-without-preprocessing-90p.onnx", "wb") as f:
    f.write(onx.SerializeToString())
