In [None]:
# Import required libraries
import os
import pandas as pd
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report
from sklearn.model_selection import train_test_split
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
import numpy as np

# Read data
recording_location = './'
all_features_csv = os.path.join(recording_location, './Data/FeatureFiles/feature_list_all.csv')
df = pd.read_csv(all_features_csv)

# Define most significant features based on your analysis
significant_features = [
    "freqDisPerSec",
    "meanFix",
    "maxFix", 
    "varFix",
    "fixDensPerBB",
    "blinkRate",
    "meanDis",
    "minDis"
]

# Prepare features and labels
features = df[significant_features]
labels = df['label']

# Scale features
scaler = StandardScaler()
scaler.fit(features)
scaled = scaler.transform(features)
scaled_features = pd.DataFrame(scaled, columns=features.columns)

# Split data
feature_train, feature_test, label_train, label_test = train_test_split(
    scaled_features, 
    labels, 
    train_size=0.8, 
    random_state=42, 
    stratify=labels
)

# Train ExtraTrees classifier
extra_trees = ExtraTreesClassifier(
    n_estimators=100,
    random_state=42,
    n_jobs=-1
).fit(feature_train, label_train)

# Make predictions
predictions = extra_trees.predict(feature_test)
probabilities = extra_trees.predict_proba(feature_test)

# Calculate accuracy
accuracy = extra_trees.score(feature_test, label_test)
print(f"Accuracy: {accuracy:.3f}")

# Display confusion matrix
cm = confusion_matrix(label_test, predictions)
ConfusionMatrixDisplay.from_estimator(extra_trees, feature_test, label_test)
plt.show()

# Display classification report
print("\nClassification Report:")
print(classification_report(label_test, predictions))

# Save the model and scaler for later use
import joblib
joblib.dump(extra_trees, './Models/extra_trees_classifier.joblib')
joblib.dump(scaler, './Models/feature_scaler.joblib')