In [1]:
import json
import torch
import numpy as np
import mne
from scipy.stats import zscore


In [7]:
import os
import json
import numpy as np
import mne
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import seaborn as sns
# import shap

from scipy.stats import zscore
from scipy.signal import welch
from sklearn.model_selection import GroupShuffleSplit
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
from torch.utils.data import Dataset, DataLoader


In [14]:
# Global configuration
DATASET_PATH = "dataset"
SFREQ = 250
WINDOW_SIZE = 100
STEP_SIZE = 50

CHANNELS = ['Fp1','Fp2','F3','F4','T7','T8','P3','P4']
N_CHANNELS = len(CHANNELS)


In [15]:
class ChannelAttention(nn.Module):
    def __init__(self, channels, ratio=8):
        super().__init__()
        self.fc1 = nn.Linear(channels, channels//ratio)
        self.fc2 = nn.Linear(channels//ratio, channels)

    def forward(self, x):
        avg = x.mean(dim=1)
        attn = torch.sigmoid(self.fc2(F.relu(self.fc1(avg))))
        return x * attn.unsqueeze(1)


In [16]:
class EEG_ASD_Model(nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv1d(8, 32, 15, padding=7)
        self.bn1 = nn.BatchNorm1d(32)

        self.ms3 = nn.Conv1d(32,32,3,padding=1)
        self.ms5 = nn.Conv1d(32,32,5,padding=2)
        self.ms7 = nn.Conv1d(32,32,7,padding=3)
        self.bn2 = nn.BatchNorm1d(32)

        self.attn = ChannelAttention(32)
        self.dw = nn.Conv1d(32,32,3,padding=1,groups=32)
        self.bn3 = nn.BatchNorm1d(32)

        self.pool = nn.AvgPool1d(4)
        self.lstm1 = nn.LSTM(32,64,batch_first=True,bidirectional=True)
        self.lstm2 = nn.LSTM(128,32,batch_first=True,bidirectional=True)

        self.fc1 = nn.Linear(64,64)
        self.fc2 = nn.Linear(64,2)

    def forward(self, x):
        x = x.permute(0,2,1)
        x = F.elu(self.bn1(self.conv1(x)))
        x = self.bn2(self.ms3(x)+self.ms5(x)+self.ms7(x))
        x = self.attn(x.permute(0,2,1)).permute(0,2,1)
        x = F.elu(self.bn3(self.dw(x)))
        x = self.pool(x)
        x,_ = self.lstm1(x.permute(0,2,1))
        x,_ = self.lstm2(x)
        x = F.elu(self.fc1(x[:,-1]))
        return self.fc2(x)


In [17]:
model = EEG_ASD_Model()
model.load_state_dict(torch.load("artifacts/best_model.pth", map_location="cpu"))
model.eval()


EEG_ASD_Model(
  (conv1): Conv1d(8, 32, kernel_size=(15,), stride=(1,), padding=(7,))
  (bn1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (ms3): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,))
  (ms5): Conv1d(32, 32, kernel_size=(5,), stride=(1,), padding=(2,))
  (ms7): Conv1d(32, 32, kernel_size=(7,), stride=(1,), padding=(3,))
  (bn2): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (attn): ChannelAttention(
    (fc1): Linear(in_features=32, out_features=4, bias=True)
    (fc2): Linear(in_features=4, out_features=32, bias=True)
  )
  (dw): Conv1d(32, 32, kernel_size=(3,), stride=(1,), padding=(1,), groups=32)
  (bn3): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (pool): AvgPool1d(kernel_size=(4,), stride=(4,), padding=(0,))
  (lstm1): LSTM(32, 64, batch_first=True, bidirectional=True)
  (lstm2): LSTM(128, 32, batch_first=True, bidirectional=True)
  (fc1): Linear(i

In [18]:
def preprocess_unseen_eeg(file_path):
    raw = mne.io.read_raw_eeglab(file_path, preload=True, verbose=False)
    raw.pick_types(eeg=True)
    raw.filter(1., 40.)

    data = raw.get_data()
    ch_names = raw.ch_names

    fixed = np.zeros((len(CHANNELS), data.shape[1]))

    for i, ch in enumerate(CHANNELS):
        if ch in ch_names:
            fixed[i] = data[ch_names.index(ch)]

    fixed = np.where(
        fixed.std(axis=1, keepdims=True) == 0,
        fixed,
        zscore(fixed, axis=1)
    )

    return fixed


In [19]:
def create_windows_unseen(data):
    windows = []
    for start in range(0, data.shape[1] - WINDOW_SIZE, STEP_SIZE):
        window = data[:, start:start + WINDOW_SIZE]
        windows.append(window.T)   # (time, channels)
    return np.array(windows, dtype=np.float32)


In [None]:
file_path = "dataset/2Abby_Resting.set"
##33 normal 
##55 normal 
##2 ASD

data = preprocess_unseen_eeg(file_path)
X_unseen = create_windows_unseen(data)

print("Unseen windows:", X_unseen.shape)


NOTE: pick_types() is a legacy function. New code should use inst.pick(...).
Filtering raw data in 1 contiguous segment
Setting up band-pass filter from 1 - 40 Hz

FIR filter parameters
---------------------
Designing a one-pass, zero-phase, non-causal bandpass filter:
- Windowed time-domain design (firwin) method
- Hamming window with 0.0194 passband ripple and 53 dB stopband attenuation
- Lower passband edge: 1.00
- Lower transition bandwidth: 1.00 Hz (-6 dB cutoff frequency: 0.50 Hz)
- Upper passband edge: 40.00 Hz
- Upper transition bandwidth: 10.00 Hz (-6 dB cutoff frequency: 45.00 Hz)
- Filter length: 1691 samples (3.303 s)



  raw = mne.io.read_raw_eeglab(file_path, preload=True, verbose=False)
  raw = mne.io.read_raw_eeglab(file_path, preload=True, verbose=False)
[Parallel(n_jobs=1)]: Done  17 tasks      | elapsed:    0.0s


Unseen windows: (1637, 100, 8)


In [32]:
with torch.no_grad():
    outputs = model(torch.tensor(X_unseen))
    probs = torch.softmax(outputs, dim=1).numpy()
    preds = probs.argmax(axis=1)


In [33]:
import json

with open("artifacts/label_map.json", "r") as f:
    LABEL_MAP = json.load(f)

print(LABEL_MAP)


{'0': 'Control', '1': 'ASD'}


In [34]:
unique, counts = np.unique(preds, return_counts=True)
majority_class = unique[np.argmax(counts)]

print("Window-wise counts:", dict(zip(unique, counts)))
print("Final Prediction:", LABEL_MAP[str(majority_class)])


Window-wise counts: {1: 1637}
Final Prediction: ASD


In [30]:
mean_prob = probs.mean(axis=0)
final_class = np.argmax(mean_prob)

print("Mean probability:", mean_prob)
print("Final Prediction:", LABEL_MAP[str(final_class)])


Mean probability: [0.99619144 0.00381087]
Final Prediction: Control
