In [None]:
!pip install torch
!pip install mne
!pip install tensorflow





[notice] A new release of pip is available: 24.2 -> 25.1.1
[notice] To update, run: python.exe -m pip install --upgrade pip


In [22]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from mne.io import read_raw_edf

from datetime import datetime
from sklearn.model_selection import train_test_split


In [28]:
# ===================== Paths & Labels =======================
data_path = r"C:\Users\Admin\Desktop\Data_Preprocessing\data"

# Replace this with your real seizure start/end times (wall-clock HH:MM:SS)
# Updated Seizure Times and Filenames
seizure_files = {
    "PN00-1.edf": ("19:58:36", "19:59:46"),
    "PN00-2.edf": ("02:38:37", "02:39:31"),
    "PN00-3.edf": ("18:28:29", "19:29:29"),
    "PN00-4.edf": ("21:08:29", "21:09:43"),
    "PN00-5.edf": ("22:37:08", "22:38:15"),
}

# Updated Baseline Start Times
baseline_start_times = {
    "PN00-1.edf": "19:39:33",
    "PN00-2.edf": "02:18:17",
    "PN00-3.edf": "18:15:44",
    "PN00-4.edf": "20:51:43",
    "PN00-5.edf": "22:22:04",
}

channels_to_use = list(range(18))  # First 18 channels

# =============== Time Conversion =====================
def relative_seconds(timestr, meas_date):
    """Convert wall-clock HH:MM:SS to seconds from EDF start."""
    target = datetime.strptime(timestr, "%H:%M:%S").time()
    start = meas_date.time()
    target_dt = datetime.combine(datetime.today(), target)
    start_dt = datetime.combine(datetime.today(), start)
    return (target_dt - start_dt).total_seconds()

# =============== Window Extraction ====================
def extract_windows(raw, seizure_start, seizure_end, baseline_start):
    window_size = 30  # seconds
    sampling_rate = int(raw.info['sfreq'])
    total_duration = raw.times[-1]

    # Use maximum possible preictal window
    max_preictal_duration = seizure_start
    preictal_window = min(30 * 60, max_preictal_duration)
    preictal_start = seizure_start - preictal_window

    print(f"EDF Duration: {total_duration:.2f}s | Seizure Start: {seizure_start:.2f}s | Preictal Start: {preictal_start:.2f}s | Baseline Start: {baseline_start:.2f}s")

    if seizure_start > total_duration:
        print("[SKIP] Seizure start is after end of recording.")
        return [], []

    if baseline_start + 30 * 60 > total_duration:
        print("[SKIP] Baseline + 30min is beyond recording duration.")
        return [], []

    try:
        preictal_data = raw.copy().crop(tmin=preictal_start, tmax=seizure_start).get_data()[channels_to_use]
        interictal_data = raw.copy().crop(tmin=baseline_start, tmax=baseline_start + 30 * 60).get_data()[channels_to_use]
    except Exception as e:
        print(f"[ERROR] Cropping failed: {e}")
        return [], []

    def segment(data):
        n_samples = data.shape[1]
        samples_per_win = sampling_rate * window_size
        segments = []
        for start in range(0, n_samples - samples_per_win, samples_per_win):
            segment = data[:, start:start + samples_per_win]
            segments.append(segment)
        return segments

    return segment(preictal_data), segment(interictal_data)


# =============== Data Processing ====================
X, y = [], []

for file, (sz_start, sz_end) in seizure_files.items():
    reg_start = baseline_start_times[file]
    file_path = os.path.join(data_path, file)
    raw = read_raw_edf(file_path, preload=True)
    meas_date = raw.info['meas_date']
    if meas_date is None:
        print(f"[WARNING] {file} has no meas_date metadata — skipping.")
        continue

    try:
        sz_start_sec = relative_seconds(sz_start, meas_date)
        sz_end_sec = relative_seconds(sz_end, meas_date)
        reg_start_sec = relative_seconds(reg_start, meas_date)

        preictal, interictal = extract_windows(raw, sz_start_sec, sz_end_sec, reg_start_sec)
        print(f"{file} — Preictal: {len(preictal)} | Interictal: {len(interictal)}")

        X.extend(preictal)
        y.extend([1] * len(preictal))

        X.extend(interictal)
        y.extend([0] * len(interictal))
    except Exception as e:
        print(f"[ERROR] {file} — {e}")

X = np.array(X)
y = np.array(y)


Extracting EDF parameters from C:\Users\Admin\Desktop\Data_Preprocessing\data\PN00-1.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 1343999  =      0.000 ...  2624.998 secs...


  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)


EDF Duration: 2625.00s | Seizure Start: 1143.00s | Preictal Start: 0.00s | Baseline Start: 0.00s
PN00-1.edf — Preictal: 38 | Interictal: 60
Extracting EDF parameters from C:\Users\Admin\Desktop\Data_Preprocessing\data\PN00-2.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 1178623  =      0.000 ...  2301.998 secs...


  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)


EDF Duration: 2302.00s | Seizure Start: 1220.00s | Preictal Start: 0.00s | Baseline Start: 0.00s
PN00-2.edf — Preictal: 40 | Interictal: 60
Extracting EDF parameters from C:\Users\Admin\Desktop\Data_Preprocessing\data\PN00-3.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 1284607  =      0.000 ...  2508.998 secs...


  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)


EDF Duration: 2509.00s | Seizure Start: 765.00s | Preictal Start: 0.00s | Baseline Start: 0.00s
PN00-3.edf — Preictal: 25 | Interictal: 60
Extracting EDF parameters from C:\Users\Admin\Desktop\Data_Preprocessing\data\PN00-4.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 1076223  =      0.000 ...  2101.998 secs...


  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)


EDF Duration: 2102.00s | Seizure Start: 1006.00s | Preictal Start: 0.00s | Baseline Start: 0.00s
PN00-4.edf — Preictal: 33 | Interictal: 60
Extracting EDF parameters from C:\Users\Admin\Desktop\Data_Preprocessing\data\PN00-5.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 1097215  =      0.000 ...  2142.998 secs...


  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)


EDF Duration: 2143.00s | Seizure Start: 904.00s | Preictal Start: 0.00s | Baseline Start: 0.00s
PN00-5.edf — Preictal: 30 | Interictal: 60


In [33]:
summary = []  # ✅ Initialize

for file, (sz_start, sz_end) in seizure_files.items():
    print(f"\nProcessing {file}")
    reg_start = baseline_start_times[file]
    file_path = os.path.join(data_path, file)
    
    raw = read_raw_edf(file_path, preload=True)

    meas_date = raw.info['meas_date']
    if meas_date is None:
        print(f"[WARNING] {file} has no meas_date metadata — skipping.")
        continue

    try:
        sz_start_sec = relative_seconds(sz_start, meas_date)
        sz_end_sec = relative_seconds(sz_end, meas_date)
        reg_start_sec = relative_seconds(reg_start, meas_date)

        preictal, interictal = extract_windows(raw, sz_start_sec, sz_end_sec, reg_start_sec)
        print(f"{file} — Preictal: {len(preictal)} | Interictal: {len(interictal)}")

        X.extend(preictal)
        y.extend([1] * len(preictal))

        X.extend(interictal)
        y.extend([0] * len(interictal))

        # ✅ Append summary
        summary.append({
            "file": file,
            "preictal_segments": len(preictal),
            "interictal_segments": len(interictal),
            "seizure_start_sec": sz_start_sec,
            "baseline_start_sec": reg_start_sec
        })

    except Exception as e:
        print(f"[ERROR] {file} — {e}")



Processing PN00-1.edf
Extracting EDF parameters from C:\Users\Admin\Desktop\Data_Preprocessing\data\PN00-1.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 1343999  =      0.000 ...  2624.998 secs...


  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)


EDF Duration: 2625.00s | Seizure Start: 1143.00s | Preictal Start: 0.00s | Baseline Start: 0.00s
PN00-1.edf — Preictal: 38 | Interictal: 60
[ERROR] PN00-1.edf — 'numpy.ndarray' object has no attribute 'extend'

Processing PN00-2.edf
Extracting EDF parameters from C:\Users\Admin\Desktop\Data_Preprocessing\data\PN00-2.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 1178623  =      0.000 ...  2301.998 secs...


  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)


EDF Duration: 2302.00s | Seizure Start: 1220.00s | Preictal Start: 0.00s | Baseline Start: 0.00s
PN00-2.edf — Preictal: 40 | Interictal: 60
[ERROR] PN00-2.edf — 'numpy.ndarray' object has no attribute 'extend'

Processing PN00-3.edf
Extracting EDF parameters from C:\Users\Admin\Desktop\Data_Preprocessing\data\PN00-3.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 1284607  =      0.000 ...  2508.998 secs...


  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)


EDF Duration: 2509.00s | Seizure Start: 765.00s | Preictal Start: 0.00s | Baseline Start: 0.00s
PN00-3.edf — Preictal: 25 | Interictal: 60
[ERROR] PN00-3.edf — 'numpy.ndarray' object has no attribute 'extend'

Processing PN00-4.edf
Extracting EDF parameters from C:\Users\Admin\Desktop\Data_Preprocessing\data\PN00-4.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 1076223  =      0.000 ...  2101.998 secs...


  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)


EDF Duration: 2102.00s | Seizure Start: 1006.00s | Preictal Start: 0.00s | Baseline Start: 0.00s
PN00-4.edf — Preictal: 33 | Interictal: 60
[ERROR] PN00-4.edf — 'numpy.ndarray' object has no attribute 'extend'

Processing PN00-5.edf
Extracting EDF parameters from C:\Users\Admin\Desktop\Data_Preprocessing\data\PN00-5.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 1097215  =      0.000 ...  2142.998 secs...


  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)
  raw = read_raw_edf(file_path, preload=True)


EDF Duration: 2143.00s | Seizure Start: 904.00s | Preictal Start: 0.00s | Baseline Start: 0.00s
PN00-5.edf — Preictal: 30 | Interictal: 60
[ERROR] PN00-5.edf — 'numpy.ndarray' object has no attribute 'extend'


In [34]:
import pandas as pd
pd.DataFrame(summary).to_csv("segmentation_summary.csv", index=False)


In [35]:
# (samples, channels, time_steps) → (samples, time_steps, channels)
X = np.transpose(X, (0, 2, 1))  # Now shape is (samples, 30*sampling_rate, channels)
print("X shape for LSTM:", X.shape)


X shape for LSTM: (466, 15360, 18)


In [36]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, stratify=y, test_size=0.2, random_state=42)


In [41]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import numpy as np

# Convert numpy to torch tensors
X_tensor = torch.tensor(X, dtype=torch.float32)  # shape: (samples, time_steps, channels)
y_tensor = torch.tensor(y, dtype=torch.float32)

# Dataset and dataloader
dataset = TensorDataset(X_tensor, y_tensor)
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# Define LSTM model
class SeizureLSTM(nn.Module):
    def __init__(self, input_size, hidden_size=64):
        super(SeizureLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        _, (h_n, _) = self.lstm(x)
        out = self.fc(h_n[-1])
        return self.sigmoid(out)

model = SeizureLSTM(input_size=X.shape[2])
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)



In [42]:

# Training loop
for epoch in range(10):
    for batch_X, batch_y in train_loader:
        outputs = model(batch_X).squeeze()
        loss = criterion(outputs, batch_y)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

Epoch 1, Loss: 0.6752
Epoch 2, Loss: 0.7059
Epoch 3, Loss: 0.6168
Epoch 4, Loss: 0.6046
Epoch 5, Loss: 0.6695
Epoch 6, Loss: 0.6692
Epoch 7, Loss: 0.7051
Epoch 8, Loss: 0.7010
Epoch 9, Loss: 0.6385
Epoch 10, Loss: 0.7320


In [44]:
from torch.utils.data import DataLoader, TensorDataset

# Create dataset and dataloader
dataset = TensorDataset(X_tensor, y_tensor)
dataloader = DataLoader(dataset, batch_size=64, shuffle=False)

correct = 0
total = 0

model.eval()
with torch.no_grad():
    for X_batch, y_batch in dataloader:
        outputs = model(X_batch).squeeze()
        predicted = (outputs >= 0.5).float()
        correct += (predicted == y_batch).sum().item()
        total += y_batch.size(0)

accuracy = correct / total
print(f"Accuracy: {accuracy:.4f}")


Accuracy: 0.6438


In [45]:
# ======================= SAVE MODEL =======================
torch.save(model.state_dict(), "seizure_lstm_model.pth")
print("Model saved as seizure_lstm_model.pth")


Model saved as seizure_lstm_model.pth


In [46]:
# Re-initialize the model class with same input_size and architecture
loaded_model = SeizureLSTM(input_size=X.shape[2])
loaded_model.load_state_dict(torch.load("seizure_lstm_model.pth"))
loaded_model.eval()  # Set to evaluation mode


SeizureLSTM(
  (lstm): LSTM(18, 64, batch_first=True)
  (fc): Linear(in_features=64, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [47]:
#Rebuild the LSTM Class & Load Model
import torch
import torch.nn as nn

# Same architecture as used in training
class SeizureLSTM(nn.Module):
    def __init__(self, input_size, hidden_size=64):
        super(SeizureLSTM, self).__init__()
        self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        _, (h_n, _) = self.lstm(x)
        out = self.fc(h_n[-1])
        return self.sigmoid(out)

# Initialize and load trained weights
model = SeizureLSTM(input_size=18)
model.load_state_dict(torch.load("seizure_lstm_model.pth"))
model.eval()


SeizureLSTM(
  (lstm): LSTM(18, 64, batch_first=True)
  (fc): Linear(in_features=64, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [48]:
import mne
import numpy as np

# Load the EDF file
raw = mne.io.read_raw_edf(r"C:\Users\Admin\Desktop\Data_Preprocessing\data\PN05-2.edf", preload=True)

# Get the first 30 seconds of selected 18 channels
sfreq = int(raw.info['sfreq'])  # e.g., 256 Hz
window_sec = 30
n_samples = sfreq * window_sec

# Select only the first 18 channels (used during training)
raw.pick_channels(raw.ch_names[:18])  # or channels_to_use if predefined

# Get data from the first 30 seconds
data, _ = raw[:, :n_samples]  # shape: (channels, time_steps)
data = data.T[np.newaxis, ...]  # shape: (1, time_steps, channels)

# Convert to PyTorch tensor
input_tensor = torch.tensor(data, dtype=torch.float32)


Extracting EDF parameters from C:\Users\Admin\Desktop\Data_Preprocessing\data\PN05-2.edf...
EDF file detected
Setting channel info structure...
Creating raw.info structure...
Reading 0 ... 4733439  =      0.000 ...  9244.998 secs...


  raw = mne.io.read_raw_edf(r"C:\Users\Admin\Desktop\Data_Preprocessing\data\PN05-2.edf", preload=True)
  raw = mne.io.read_raw_edf(r"C:\Users\Admin\Desktop\Data_Preprocessing\data\PN05-2.edf", preload=True)
  raw = mne.io.read_raw_edf(r"C:\Users\Admin\Desktop\Data_Preprocessing\data\PN05-2.edf", preload=True)


NOTE: pick_channels() is a legacy function. New code should use inst.pick(...).


In [None]:
with torch.no_grad():
    output = model(input_tensor).item()
    print(f"Predicted Seizure Probability: {output:.4f}")
    
    if output >= 0.5:
        print(" Seizure likely within 30 minutes (Preictal)")
    else:
        print(" Normal brain activity (Interictal)")
 

Predicted Seizure Probability: 0.3624
✅ Normal brain activity (Interictal)


In [None]:
#Class 0 (Interictal): 300 samples
#Class 1 (Preictal): 166 samples
import numpy as np

unique, counts = np.unique(y, return_counts=True)
for label, count in zip(unique, counts):
    print(f"Class {label} — Count: {count}")


Class 0 — Count: 300
Class 1 — Count: 166
