In [None]:
# Machine Learning for Neural Data Analysis
# --- Standard Library ---
import warnings

# --- Scientific Libraries ---
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import pearsonr, zscore
from scipy.io import loadmat

# --- Machine Learning ---
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.linear_model import Ridge, LogisticRegression
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import r2_score, mean_squared_error, accuracy_score, confusion_matrix
from sklearn.pipeline import make_pipeline

warnings.filterwarnings("ignore")

# Set style
sns.set_theme(style="white", context="talk", palette="Set2")
plt.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Computer Modern Roman", "Times New Roman", "DejaVu Serif"],
    "text.usetex": False,
    "axes.labelsize": 13,
    "axes.titlesize": 15,
    "legend.fontsize": 11,
    "xtick.labelsize": 11,
    "ytick.labelsize": 11,
    "axes.linewidth": 0.6,
    "axes.edgecolor": "0.2",
})
print("=== MACHINE LEARNING FOR NEURAL DATA ANALYSIS ===")

# 1. Data Preparation
Load, clean, and visualize neural and behavioral data.

In [None]:
# --- Data Preparation: Load, clean, and preprocess neural and behavioral data ---
print("1. Data Preparation...")

# Set random seed for reproducibility
np.random.seed(42)

from scipy.io import loadmat

# Selector for choosing the dataset
dataset_choice = "population"  # Change to "dendrites" to use the other dataset

if dataset_choice == "dendrites":
    # Load the .mat file for the tree dataset. these were variables directly stored in the .mat file
    mat_path = "Data/M1_dendritic_tree_data.mat"  # Update this path to the location of your .mat file
    mat_data = loadmat(mat_path)

    # Extract the variables
    behaviour = mat_data['behaviour'].squeeze()  # Squeeze to remove single-dimensional entries
    result = mat_data['result']
    tax = np.squeeze(mat_data['tax'].T)
    coords = mat_data['coords']

elif dataset_choice == "population":
    # Load the .mat file for the M1 dataset. # these were variables stored in a structure in the .mat file
    mat_path = "Data/M1_population_data.mat"  # Data exported from matlab
    mat_data = loadmat(mat_path)

    # Extract the variables
    mat_data = mat_data['mat_data']
    behaviour = mat_data['behaviour'][0, 0].squeeze()  # Indexing to access the nested structure and squeeze to remove single-dimensional entries
    result = mat_data['result'][0, 0]  # Indexing to access the nested structure
    tax = mat_data['tax'][0, 0].squeeze()  # Indexing to access the nested structure and squeeze to remove single-dimensional entries

# Get the dimensions
T, N = len(behaviour), result.shape[0]

# Ensure result shape is (Individuals, Timepoints)
print(result.shape) # N x T
print(tax.shape) # T x 1
print(behaviour.shape) # T x 1

# Align with existing variable names in this notebook
behavior = behaviour
neural_data = result
time_axis = tax

if dataset_choice == "dendrites":
    median_ref = np.nanmedian(neural_data, axis=0, keepdims=True)
    neural_data = neural_data - median_ref

# Smooth behavioral data with a Gaussian filter
from scipy.ndimage import gaussian_filter1d
behavior = gaussian_filter1d(behavior, sigma=10)

# Replace non-finite values with NaN and drop invalid neurons/timepoints
neural_data = np.where(np.isfinite(neural_data), neural_data, np.nan)
behavior = np.where(np.isfinite(behavior), behavior, np.nan)
initial_n_neurons, initial_n_timepoints = neural_data.shape
valid_neurons = ~np.all(np.isnan(neural_data), axis=1)
neural_data = neural_data[valid_neurons]
removed_neurons = initial_n_neurons - neural_data.shape[0]
valid_time = ~np.isnan(behavior)
valid_time &= ~np.any(np.isnan(neural_data), axis=0)
behavior = behavior[valid_time]
neural_data = neural_data[:, valid_time]
time_axis = time_axis[valid_time]
removed_timepoints = initial_n_timepoints - neural_data.shape[1]

# Print summary of cleaning
n_neurons, n_timepoints = neural_data.shape
print(f"Removed {removed_neurons} neurons and {removed_timepoints} timepoints with NaN")
print(f"Dataset cleaned: {n_neurons} neurons, {n_timepoints} timepoints")

# Verify no NaNs remain
assert not np.isnan(neural_data).any()
assert not np.isnan(behavior).any()

# 2. Data Overview
Visualize behavior and neural activity.

In [None]:
# --- Data Overview: Visualize behavior and neural activity ---

# Plot behavior and neural activity heatmap
fig, (ax_beh, ax_heat) = plt.subplots(
    2, 1, figsize=(14, 6), gridspec_kw={"height_ratios": [1, 3]}, sharex=True
)

ax_beh.plot(time_axis, behavior, color="tab:red", linewidth=1.5)
ax_beh.set_ylabel("Behavior")
sns.despine(ax=ax_beh, bottom=True)
ax_beh.grid(False)

neural_z = zscore(neural_data, axis=1)
im = ax_heat.imshow(
    neural_z,
    aspect="auto",
    cmap="viridis",
    extent=[time_axis[0], time_axis[-1], 0, n_neurons],
    vmin=-2,
    vmax=2,
    interpolation="none",
)
ax_heat.set_ylabel("Neuron #")
ax_heat.set_xlabel("Time (s)")
ax_heat.set_title("Neural activity (z-scored)")
sns.despine(ax=ax_heat)
ax_heat.grid(False)
cbar = fig.colorbar(im, ax=ax_heat, label="z-score", orientation="horizontal", pad=0.2, fraction=0.05)
cbar.outline.set_visible(False)

fig.suptitle("Can we predict behavior from neural activity?", y=1.02)
fig.tight_layout()
plt.show()

# Plot example neuron activity traces
fig, ax = plt.subplots(figsize=(14, 4))
offset = np.nanmax(np.abs(neural_data)) * 1.2
palette = sns.color_palette("Set2", n_colors=min(5, n_neurons))
for i, color in enumerate(palette):
    ax.plot(time_axis, neural_data[i] + i * offset, label=f"Neuron {i}", color=color, lw=1)
ax.set_xlabel("Time (s)")
ax.set_ylabel("Activity (offset)")
ax.set_title("Example neuron activity")
ax.legend(ncol=3, frameon=False)
sns.despine(ax=ax)
ax.grid(False)
fig.tight_layout()
plt.show()


# 3. Linear Regression

### Fit a linear model to predict behavior from neural activity.

This cell builds a linear decoder that predicts behavior from population neural activity. The features matrix X is set to neural_data.T so that each row is a time point and each column is a neuron, while y is the corresponding behavioral readout at each time. The data are split into a training and test set (70/30) for an unbiased generalization estimate.

The model is a Pipeline with StandardScaler followed by Ridge, which standardizes each feature then fits a linear regression with L2 regularization. Standardization makes the penalty comparable across features; Ridge solves for weights by minimizing $||y - Xw||_2^2 + \alpha ||w||_2^2$ with $\alpha=1.0$, which stabilizes estimates in the presence of multicollinearity and reduces overfitting. The pipeline is fit only on the training data to avoid leakage, then used to predict on train, test, and the full series (the latter for visualization). Other common approaches include L1 (Lasso) to promote sparse weights, Elastic Net to balance sparsity and stability, or RidgeCV to choose α via cross-validation.

Performance is summarized with test-set metrics. The coefficient of determination $R^2 = 1 - \frac{\sum_i (y_i-\hat{y}_i)^2}{\sum_i (y_i-\bar{y})^2}$ reports variance explained relative to a constant-mean predictor, while RMSE quantifies average prediction error as $\sqrt{\frac{1}{n}\sum_i (y_i-\hat{y}_i)^2}$. Pearson’s $r$ measures linear association between predicted and true test targets and is scale-invariant; $r$ can be high even when predictions are biased in scale or offset, whereas $R^2$ penalizes such mismatches.

The left plot compares test predictions to ground truth with an identity line: tight clustering around this line indicates accurate predictions and well-calibrated scale. The right plot overlays the model’s predictions (trained on the training split) across the entire time axis to illustrate how closely dynamics are tracked; interpret this as qualitative since it includes the training portion. For temporally ordered data, consider time-aware splits (e.g., TimeSeriesSplit or blocked splits) to respect autocorrelation and prevent optimistic estimates.

In [None]:
# --- Linear Regression: Predict behavior from neural activity ---

print("\n2. Linear Regression...")

# Prepare features and target
X = neural_data.T
y = behavior

# Split data into training and test sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Build and fit linear model with Ridge regularization
model = make_pipeline(StandardScaler(), Ridge(alpha=1.0))
model.fit(X_train, y_train)

# Predict on train, test, and full data
y_pred_train = model.predict(X_train)
y_pred_test = model.predict(X_test)
y_pred_full = model.predict(X)

# Evaluate model performance
r2_train = r2_score(y_train, y_pred_train)
r2_test = r2_score(y_test, y_pred_test)
rmse_test = np.sqrt(mean_squared_error(y_test, y_pred_test))
corr_test, _ = pearsonr(y_test, y_pred_test)

# --- Plot actual vs predicted behavior ---
plt.figure(figsize=(14, 5))
plt.subplot(1,2,1)
plt.scatter(y_test, y_pred_test, alpha=0.6, s=40)
plt.plot([y_test.min(), y_test.max()],[y_test.min(), y_test.max()], "r--", linewidth=2)
plt.xlabel("Actual Behavior")
plt.ylabel("Predicted Behavior")
plt.title(f"Linear model captures behavioral trends\nR² = {r2_test:.3f}")
plt.text(0.05, 0.95, f"r = {corr_test:.2f}", transform=plt.gca().transAxes, va="top")
plt.grid(True, alpha=0.3)

plt.subplot(1,2,2)
plt.plot(time_axis, behavior, 'r-', linewidth=0.5, label='Actual', alpha=0.8)
plt.plot(time_axis, y_pred_full, 'b-', linewidth=0.5, label='Predicted', alpha=0.8)
plt.xlabel('Time (s)')
plt.ylabel('Behavior')
plt.title('Full Time Series Prediction')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# Print metrics
print(f"Training R²: {r2_train:.3f}")
print(f"Test R²: {r2_test:.3f}")
print(f"Test RMSE: {rmse_test:.3f}")

# 4. Model Validation

### Cross-validation and regularization analysis.

This section evaluates generalization with 5-fold cross-validation using the coefficient of determination. The call to cross_val_score(model, X, y, cv=5, scoring='r2') returns one $R^2$ per fold, where $R^2 = 1 - \frac{\sum_i (y_i-\hat y_i)^2}{\sum_i (y_i-\bar y)^2}$. Values near 1 indicate strong predictive fit, 0 matches a mean-only baseline, and negatives mean the model underperforms the baseline. The check for any score below −1 flags extreme misfit that often points to data or setup issues (e.g., scaling, leakage, or mismatched targets).

The regularization sweep tests Ridge models across preset $\alpha$ values by building a Pipeline with StandardScaler and Ridge for each $\alpha$. Standardizing inside the pipeline ensures each feature contributes comparably and that scaling is learned only on the training fold, preventing leakage. Ridge minimizes $||y - Xw||_2^2 + \alpha ||w||_2^2$, so increasing $\alpha$ typically reduces variance at the cost of bias. For each $\alpha$, cross_val_score yields fold scores whose mean and standard deviation summarize central performance and stability.

The plots visualize these results. The left boxplot shows the distribution of $R^2$ across folds for the current model, with the title showing mean ± standard deviation. The right panel plots mean $R^2$ versus $\alpha$ on a log scale and shades $\pm$1 standard deviation to indicate variability; the peak suggests a good regularization strength, while drops for very small or large $\alpha$ indicate under- or over-regularization. If data are temporally ordered, prefer time-aware splits; for broader tuning and automation, consider a denser log-spaced grid with RidgeCV or GridSearchCV.

In [None]:
# --- Model Validation: Cross-validation and regularization analysis ---

print("\n3. Cross-Validation...")

# Cross-validation for model performance
cv_scores = cross_val_score(model, X, y, cv=5, scoring='r2')

# Warn if scores are suspiciously low
if np.any(cv_scores < -1):
    print('Warning: Some cross-validation scores are very low. Check data scaling and model setup.')

# Regularization analysis: test different alpha values
alphas = [0.01, 0.1, 1.0, 10.0, 100.0]
mean_scores = []
std_scores = []
for alpha in alphas:
    model_alpha = make_pipeline(StandardScaler(), Ridge(alpha=alpha))
    scores = cross_val_score(model_alpha, X, y, cv=5, scoring='r2')
    mean_scores.append(scores.mean())
    std_scores.append(scores.std())

# --- Plot cross-validation results ---
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
sns.boxplot(y=cv_scores, color="lightgray")
plt.ylabel("R² Score")
plt.title(f"Consistent prediction accuracy\nMean R² = {cv_scores.mean():.3f} ± {cv_scores.std():.3f}")
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
mean_scores = np.array(mean_scores)
std_scores = np.array(std_scores)
plt.semilogx(alphas, mean_scores, "o-", linewidth=2)
plt.fill_between(alphas, mean_scores - std_scores, mean_scores + std_scores, alpha=0.2)
plt.xlabel("Regularization Strength (alpha)")
plt.ylabel("Cross-Validation R²")
plt.title("Regularization Effect")
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

### How to deal with poor cross validation

This section evaluates generalization with 5-fold cross-validation using the coefficient of determination. The call to cross_val_score(model, X, y, cv=5, scoring='r2') returns one $R^2$ per fold, where $R^2 = 1 - \frac{\sum_i (y_i-\hat y_i)^2}{\sum_i (y_i-\bar y)^2}$. Values near 1 indicate strong predictive fit, 0 matches a mean-only baseline, and negatives mean the model underperforms the baseline. The check for any score below −1 flags extreme misfit that often points to data or setup issues (e.g., scaling, leakage, or mismatched targets).

The regularization sweep tests Ridge models across preset $\alpha$ values by building a Pipeline with StandardScaler and Ridge for each $\alpha`. Standardizing inside the pipeline ensures each feature contributes comparably and that scaling is learned only on the training fold, preventing leakage. 

The plots visualize these results. The left boxplot shows the distribution of $R^2$ across folds for the current model, with the title showing mean ± standard deviation. The right panel plots mean $R^2$ versus $\alpha$ on a log scale and shades $\pm$1 standard deviation to indicate variability; the peak suggests a good regularization strength, while drops for very small or large $\alpha` indicate under- or over-regularization. If data are temporally ordered, prefer time-aware splits; for broader tuning and automation, consider a denser log-spaced grid with RidgeCV or GridSearchCV.

In [None]:
# Improve cross-validation scores ---
# If CV scores are very low, try tuning alpha and using a robust scaler
# and shuffled KFold to better distribute samples across folds.

print("\nAttempting to improve cross-validation scores...")

from sklearn.preprocessing import RobustScaler
from sklearn.linear_model import RidgeCV
from sklearn.model_selection import KFold

# Shuffled KFold can help when data ordering biases the folds
cv_shuffled = KFold(n_splits=5, shuffle=True, random_state=42)

# Tune alpha over a wide log-scale range
alphas_tune = np.logspace(-3, 3, 13)

# Build updated model: RobustScaler + RidgeCV
model = make_pipeline(RobustScaler(), RidgeCV(alphas=alphas_tune, cv=cv_shuffled, scoring='r2'))

# Fit on training data (pipeline handles scaling internally)
model.fit(X_train, y_train)

# Re-evaluate cross-validation performance with the shuffled CV
cv_scores_fixed = cross_val_score(model, X, y, cv=cv_shuffled, scoring='r2')
print(f"New CV Mean R²: {cv_scores_fixed.mean():.3f} ± {cv_scores_fixed.std():.3f}")

if np.any(cv_scores_fixed < 0):
    print("Note: Some folds still have negative R². Consider domain-specific preprocessing (e.g., detrending, denoising) or TimeSeriesSplit.")

In [None]:
# --- Refresh metrics and plots using the tuned model ---

# Recompute predictions with the (optionally) tuned model
y_pred_train = model.predict(X_train)
y_pred_test = model.predict(X_test)
y_pred_full = model.predict(X)

# Recompute metrics
from sklearn.metrics import r2_score, mean_squared_error
r2_train = r2_score(y_train, y_pred_train)
r2_test = r2_score(y_test, y_pred_test)
rmse_test = np.sqrt(mean_squared_error(y_test, y_pred_test))
from scipy.stats import pearsonr, zscore
corr_test, _ = pearsonr(y_test, y_pred_test)

print(f"Updated Training R²: {r2_train:.3f}")
print(f"Updated Test R²: {r2_test:.3f}")
print(f"Updated Test RMSE: {rmse_test:.3f}")
print(f"Updated Test r (Pearson): {corr_test:.2f}")

# Optional: quick overlay plot with updated predictions
plt.figure(figsize=(10, 4))
plt.plot(time_axis, y, 'r-', linewidth=0.5, label='Actual', alpha=0.8)
plt.plot(time_axis, y_pred_full, 'b-', linewidth=0.5, label='Predicted (tuned)', alpha=0.8)
plt.xlabel('Time (s)')
plt.ylabel('Behavior')
plt.title('Full Time Series Prediction (Updated)')
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

# 5. Feature Importance

### Analyze which neurons are most predictive for behavior.

This cell extracts feature importance from a linear Ridge model by fitting it on standardized neural features and visualizing the resulting coefficients per neuron. Standardizing with StandardScaler (fit on X_train only) removes mean and scales each feature to unit variance so that coefficients are comparable across neurons; without this, large-scale features would dominate. The scaler’s statistics are learned on the training set and then applied to both train and test splits to avoid leakage and keep the scale consistent.

Ridge is then fit on the scaled training data to obtain coefficients coefs. In a linear model, each coefficient reflects the signed contribution of a neuron to the predicted behavior, conditional on the other neurons. Positive values indicate a direct association and negative values an inverse association. Because Ridge minimizes $||y - Xw||_2^2 + \alpha ||w||_2^2$, the L2 penalty shrinks weights and tends to distribute them across correlated neurons; as a result, magnitudes are stabilized but may be diluted when features are collinear. With standardized inputs, coefficient magnitudes are directly comparable and serve as a sensible proxy for feature importance (though not a causal measure).

The bar plot maps each neuron’s index to its coefficient, using a width that adapts to the number of features to keep bars visible. Interpret larger absolute bars as more influential neurons under the fitted linear decoder. If the target were multi-output, coefs would be 2D and you’d typically select one target or aggregate across targets before plotting. Minor notes: ColumnTransformer is imported but unused here; it can be removed. Also, X_test_scaled is computed but not used in this cell. If your features are sparse, set StandardScaler(with_mean=False) to avoid errors.

In [None]:
# --- Feature Importance: Ridge Regression coefficients (with scaling) ---

# Use the same scaling as the model for fair coefficients
scaler = StandardScaler()
ridge = Ridge(alpha=1.0)

# Fit scaler on training data and transform
X_train_scaled = scaler.fit_transform(X_train)
ridge.fit(X_train_scaled, y_train)
coefs = ridge.coef_

# --- Plot coefficients ---
plt.figure(figsize=(12, 5))

# Color by sign: positive = red, negative = blue
colors = ['#d62728' if c > 0 else '#1f77b4' for c in coefs]

# Create bars
x_positions = np.arange(len(coefs))
bars = plt.bar(x_positions, coefs, color=colors, alpha=0.7, edgecolor='white', linewidth=0.5)

# Highlight top features by absolute value
top_indices = np.argsort(np.abs(coefs))[-3:]
for idx in top_indices:
    bars[idx].set_edgecolor('black')
    bars[idx].set_linewidth(1.5)
    bars[idx].set_alpha(1.0)

plt.xlabel('Neuron Index')
plt.ylabel('Ridge Coefficient')
plt.title('Feature Importance: Ridge Coefficients')
plt.axhline(y=0, color='black', linestyle='-', alpha=0.3)

# Simple legend
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor='#d62728', label='Positive weights'),
    Patch(facecolor='#1f77b4', label='Negative weights')
]
plt.legend(handles=legend_elements, loc='upper right')

plt.grid(True, alpha=0.2, axis='y')
plt.tight_layout()
plt.show()

print(f"Feature importance: {len(coefs)} neurons, range [{coefs.min():.3f}, {coefs.max():.3f}]")

### Visualize top predictors

This cell visualizes which individual neurons the fitted linear decoder relies on most, by overlaying their raw activity traces with the behavioral signal. It first extracts the model’s coefficients. Because the decoder is linear, each coefficient reflects the signed contribution of a neuron’s standardized activity to the predicted behavior: positive weights increase the prediction when the neuron is active, negative weights decrease it.

The code then ranks neurons by these weights. Using np.argsort on the 1D coefficient vector, it selects the three largest positive coefficients (top predictors) and the three most negative coefficients (top suppressors). The positive set is reversed to present the strongest first.

Notes:
- Coefficients come from a model trained on standardized features; magnitudes are comparable across neurons but do not imply causality. With correlated neurons, Ridge spreads weight across features, so “top 3” reflects the regularized solution rather than unique contributions.
- Ensure y (behavior) and each neural trace share the same time index; here, earlier cleaning aligned time_axis, behavior, and neural_data across valid timepoints.

In [None]:
# Plot top 3 positive and top 3 negative predictor signals, aligned with behavior

# Get coefficients from the trained model
if hasattr(model, "named_steps"):
    if "ridge" in model.named_steps:
        coefs = model.named_steps["ridge"].coef_
    elif "ridgecv" in model.named_steps:
        coefs = model.named_steps["ridgecv"].coef_
else:
    # Fallback: fit quick model
    _sc = StandardScaler()
    _rd = Ridge(alpha=1.0)
    _rd.fit(_sc.fit_transform(X_train), y_train)
    coefs = _rd.coef_

# Select top 3 positive and negative neurons
pos_indices = np.argsort(coefs)[-3:][::-1]  # Top 3 positive (descending)
neg_indices = np.argsort(coefs)[:3]         # Top 3 negative (ascending)

# Create single plot with stacked traces
fig, ax = plt.subplots(figsize=(14, 8))

# Plot behavior at top
ax.plot(time_axis, behavior + 25, 'k-', lw=2, label='Behavior')

# Plot neural traces with offset and color coding
offset = np.std(neural_data) * 10  # Use 3x std for good spacing

for i, idx in enumerate(pos_indices):
    color = 'green' if coefs[idx] > 0 else 'red'
    ax.plot(time_axis, neural_data[idx] + offset * (len(pos_indices) - i), color=color, lw=1.5, 
            label=f'Neuron {idx} ({coefs[idx]:.3f})')

for i, idx in enumerate(neg_indices):
    color = 'green' if coefs[idx] > 0 else 'red'
    ax.plot(time_axis, neural_data[idx] + offset * (len(neg_indices) - i - len(pos_indices)), color=color, lw=1.5,
            label=f'Neuron {idx} ({coefs[idx]:.3f})')

ax.set_xlabel('Time (s)')
ax.set_ylabel('Signal (offset)')
ax.set_title('Behavior and Top Predictor Neurons')
ax.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()

print("Top positive neurons:", pos_indices.tolist())
print("Top negative neurons:", neg_indices.tolist())

# 6. Classifier

This section builds a minimal baseline classifier and then a corrected, time-aware version for imbalanced time-series labels.

- Label: $y_\text{bin} = 1(|y| > 0.1)$ → "active" vs "inactive".
- Basic run: one train/test split with `StandardScaler → LogisticRegression`.
- Corrected run: block-aware split to limit temporal leakage + `class_weight='balanced'` + proper metrics.

Why two runs? The first is a quick gross baseline. The second reflects best practices for imbalanced, autocorrelated data.

Notes and ideas to go further:
- Tune the amplitude threshold or decision threshold using PR curves (optimize recall/F1).
- Report balanced accuracy, precision/recall/F1; accuracy alone can be misleading under imbalance.
- Prefer time-aware splits (GroupShuffleSplit/TimeSeriesSplit) to avoid temporal leakage.
- If positives are rare, try `class_weight='balanced'` or resampling; calibrate probabilities if needed.

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import make_pipeline
from sklearn.metrics import accuracy_score, balanced_accuracy_score, confusion_matrix, classification_report, matthews_corrcoef
from collections import Counter

# --- Publication-quality plot settings ---
plt.rcParams.update({
    "font.family": "serif",
    "font.serif": ["Computer Modern Roman", "Times New Roman", "DejaVu Serif"],
    "axes.labelsize": 13,
    "axes.titlesize": 15,
    "legend.fontsize": 11,
    "xtick.labelsize": 11,
    "ytick.labelsize": 11,
    "axes.linewidth": 0.7,
    "axes.edgecolor": "0.2",
})

# =============================================================================
# Basic gross classification (single run, no CV)
# =============================================================================

# Binary labels using symmetric threshold ±0.1
y_binary = (np.abs(y) > 0.1).astype(int)

# Visual check: behavior with active intervals and threshold lines
plt.figure(figsize=(14, 4))
plt.plot(time_axis, y, 'k-', linewidth=1.3, label='Behavior')
plt.fill_between(time_axis, y.min(), y.max(), where=(y_binary == 1), color='orange', alpha=0.25, label='Active (|y|>0.1)')
plt.axhline(0.1, color='gray', ls='--', lw=1)
plt.axhline(-0.1, color='gray', ls='--', lw=1)
plt.xlabel('Time (s)'); plt.ylabel('Behavior'); plt.title('Behavior with Binary Labels (±0.1)')
plt.legend(frameon=False); plt.grid(False)
for s in plt.gca().spines.values(): s.set_visible(False)
plt.tight_layout(); plt.show()

# Show class distribution
class_counts = Counter(y_binary)
print(f"Class distribution: {class_counts[0]} vs {class_counts[1]} ({class_counts[0]/len(y_binary)*100:.1f}% vs {class_counts[1]/len(y_binary)*100:.1f}%)")

# Train basic classifier
X_train, X_test, y_train, y_test = train_test_split(X, y_binary, test_size=0.3, random_state=42, stratify=y_binary)
clf_basic = make_pipeline(StandardScaler(), LogisticRegression(max_iter=1000))
clf_basic.fit(X_train, y_train)
y_pred_basic = clf_basic.predict(X_test)

acc = accuracy_score(y_test, y_pred_basic)
bacc = balanced_accuracy_score(y_test, y_pred_basic)
mcc = matthews_corrcoef(y_test, y_pred_basic)
print(f'Basic accuracy: {acc:.3f} | Balanced accuracy: {bacc:.3f} | MCC: {mcc:.3f}')

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Original distribution (darker color, log scale)
axes[0].hist(y, bins=50, alpha=0.85, color='#003366')
axes[0].axhline(0, color='white', lw=0.5)
axes[0].axvline(0.0, color='gray', ls=':')
axes[0].axvline(0.1, color='gray', ls='--'); axes[0].axvline(-0.1, color='gray', ls='--')
axes[0].set_title('Behavior distribution (log y-axis)')
axes[0].set_xlabel('y'); axes[0].set_ylabel('Count (log)')
axes[0].set_yscale('log')
axes[0].grid(False)
for s in axes[0].spines.values(): s.set_visible(False)

# Class imbalance
axes[1].bar(['Class 0', 'Class 1'], [class_counts[0], class_counts[1]], color=['#D2691E', '#228B22'], alpha=0.85)
axes[1].set_title('Class Distribution')
axes[1].set_ylabel('Count')
axes[1].grid(False)
for s in axes[1].spines.values(): s.set_visible(False)

# Basic confusion matrix
sns.heatmap(confusion_matrix(y_test, y_pred_basic), annot=True, fmt='d', ax=axes[2], cbar=False, cmap='Blues')
axes[2].set_title('Confusion matrix (basic)')
axes[2].set_xlabel('Predicted')
axes[2].set_ylabel('Actual')
axes[2].grid(False)
for s in axes[2].spines.values(): s.set_visible(False)

fig.suptitle('Basic single run (no CV)', y=1.02, fontsize=16)
fig.tight_layout(); plt.show()

# Optional: detailed report (can be long if printed inline)
print(classification_report(y_test, y_pred_basic, digits=3))
fig.tight_layout()
plt.show()



MCC (Matthews Correlation Coefficient) uses all four quadrants of the confusion matrix. It returns a value between -1 and 1 and remains reliable even with imbalanced classes, making it a robust measure of binary classification quality.

In [None]:
# =============================================================================
# Time-aware, class-balanced single run
# =============================================================================
from sklearn.model_selection import GroupShuffleSplit
from sklearn.metrics import precision_recall_curve, average_precision_score, f1_score, balanced_accuracy_score, matthews_corrcoef
from sklearn.pipeline import make_pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression

# Group by time blocks to limit temporal leakage (tune block_sec to your data)
block_sec = 5.0  # e.g., 0.5–10.0 depending on autocorrelation
groups = np.floor((time_axis - time_axis.min()) / block_sec).astype(int)

def _ix(a, idx):
    return a.iloc[idx] if hasattr(a, 'iloc') else a[idx]

gss = GroupShuffleSplit(n_splits=1, test_size=0.30, random_state=42)
tr_idx, te_idx = next(gss.split(X, y_binary, groups=groups))
X_tr, X_te = _ix(X, tr_idx), _ix(X, te_idx)
y_tr, y_te = y_binary[tr_idx], y_binary[te_idx]

# Balanced logistic regression in a scaling pipeline
clf_bal = make_pipeline(StandardScaler(), LogisticRegression(max_iter=2000, class_weight='balanced', random_state=42))
clf_bal.fit(X_tr, y_tr)
y_pred_test = clf_bal.predict(X_te)
proba_te = clf_bal.predict_proba(X_te)[:, 1]

# Metrics that are robust to imbalance
bacc_w = balanced_accuracy_score(y_te, y_pred_test)
f1_w = f1_score(y_te, y_pred_test, zero_division=0)
ap = average_precision_score(y_te, proba_te)
mcc_w = matthews_corrcoef(y_te, y_pred_test)
print(f'Time-aware balanced run → BalAcc: {bacc_w:.3f} | F1(+): {f1_w:.3f} | MCC: {mcc_w:.3f} | PR-AUC: {ap:.3f}')

# Precision-Recall curve
P, R, thr = precision_recall_curve(y_te, proba_te)
plt.figure(figsize=(6,4))
plt.plot(R, P, lw=2, label=f'Balanced (AP={ap:.3f})')
plt.xlabel('Recall (positive)'); plt.ylabel('Precision')
plt.title('PR curve (time-aware split)')
plt.legend(frameon=False); plt.grid(False)
for s in plt.gca().spines.values(): s.set_visible(False)
plt.tight_layout(); plt.show()

# Optional: tune decision threshold for recall (example target)
target_recall = 0.90
mask = R[:-1] >= target_recall
if np.any(mask):
    thr_opt = thr[np.argmax(mask)]
    yhat_opt = (proba_te >= thr_opt).astype(int)
    print(f'Threshold tuned for ~{target_recall:.0%} recall → thr={thr_opt:.3f}, F1(+): {f1_score(y_te, yhat_opt, zero_division=0):.3f}')
else:
    print('Target recall not achievable from PR curve.')
