<a href="https://colab.research.google.com/github/Zfeng0207/FIT3199-FYP/blob/dev%2Fryuji/FYP2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
import numpy as np
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split

# Load the dataset
df = pd.read_csv("00_recurrent_stroke_patient.csv")

# Convert categorical target column "Stroke_Y/N" to binary (0 or 1)
df["Stroke_Y/N"] = df["Stroke_Y/N"].astype(int)

# Convert datetime columns to timestamps
if "charttime" in df.columns:
    df["charttime"] = pd.to_datetime(df["charttime"]).astype(int) // 10**9  # Convert to UNIX timestamp

# Drop non-numeric columns
non_numeric_cols = ["subject_id", "stay_id", "icd_code", "icd_title", "rhythm", "gender", "anchor_year_group", "dod"]
df = df.drop(columns=[col for col in non_numeric_cols if col in df.columns])

# Fill missing values (only numeric columns)
df = df.apply(pd.to_numeric, errors='coerce')  # Ensure all columns are numeric
df.fillna(df.median(), inplace=True)

# Normalize numerical features
features = [col for col in df.columns if col != "Stroke_Y/N"]
scaler = MinMaxScaler()
df[features] = scaler.fit_transform(df[features])

# Define input features (X) and target (y)
X = df[features].values  # Ensure it's a NumPy array
y = df["Stroke_Y/N"].values  # Target variable

# Convert to PyTorch tensors
X_tensor = torch.tensor(X, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32).view(-1, 1)

# Split into training, validation, and test sets
X_train, X_test, y_train, y_test = train_test_split(X_tensor, y_tensor, test_size=0.2, random_state=42)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)

# PyTorch DataLoader
batch_size = 32
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(TensorDataset(X_val, y_val), batch_size=batch_size)
test_loader = DataLoader(TensorDataset(X_test, y_test), batch_size=batch_size)

print("Data processing completed successfully!")

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset

# Define CNN Model for Stroke Prediction
class StrokeCNN(nn.Module):
    def __init__(self, input_size, num_filters=64, kernel_size=3, dropout=0.3):
        super(StrokeCNN, self).__init__()
        self.conv1 = nn.Conv1d(in_channels=1, out_channels=num_filters, kernel_size=kernel_size, padding=1)
        self.conv2 = nn.Conv1d(in_channels=num_filters, out_channels=num_filters * 2, kernel_size=kernel_size, padding=1)
        self.pool = nn.MaxPool1d(kernel_size=2)
        self.dropout = nn.Dropout(dropout)
        self.sigmoid = nn.Sigmoid()

        # Calculate output size after convolutions and pooling
        with torch.no_grad():
            sample_input = torch.rand(1, 1, input_size)  # Batch=1, Channels=1, Features=input_size
            sample_output = self.pool(torch.relu(self.conv1(sample_input)))
            sample_output = self.pool(torch.relu(self.conv2(sample_output)))
            self.flattened_size = sample_output.numel()  # Get the total number of features

        self.fc1 = nn.Linear(self.flattened_size, 128)
        self.fc2 = nn.Linear(128, 1)

    def forward(self, x):
        x = x.unsqueeze(1)  # Reshape to (batch, 1, features)
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(x.shape[0], -1)  # Flatten
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return self.sigmoid(x)

# Initialize model
input_size = X_tensor.shape[1]  # Number of features
model = StrokeCNN(input_size)

# Loss function & optimizer
criterion = nn.BCELoss()
optimizer = optim.Adam(model.parameters(), lr=0.0005)

# Training function with validation accuracy
def train_model(model, train_loader, val_loader, criterion, optimizer, epochs=30):
    for epoch in range(epochs):
        model.train()
        total_loss, correct_train, total_train = 0, 0, 0

        for X_batch, y_batch in train_loader:
            optimizer.zero_grad()
            y_pred = model(X_batch)
            loss = criterion(y_pred, y_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

            # Compute training accuracy
            predicted = (y_pred > 0.5).float()
            correct_train += (predicted == y_batch).sum().item()
            total_train += y_batch.size(0)

        # Validation step
        model.eval()
        val_loss, correct_val, total_val = 0, 0, 0
        with torch.no_grad():
            for X_val, y_val in val_loader:
                y_val_pred = model(X_val)
                val_loss += criterion(y_val_pred, y_val).item()
                predicted_val = (y_val_pred > 0.5).float()
                correct_val += (predicted_val == y_val).sum().item()
                total_val += y_val.size(0)

        # Compute accuracies
        train_acc = correct_train / total_train
        val_acc = correct_val / total_val

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {total_loss/len(train_loader):.4f}, "
              f"Train Acc: {train_acc:.4f}, Val Loss: {val_loss/len(val_loader):.4f}, "
              f"Val Acc: {val_acc:.4f}")

# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, epochs=30)

# Evaluation function with test accuracy
def evaluate_model(model, test_loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for X_test, y_test in test_loader:
            y_test_pred = model(X_test)
            predicted = (y_test_pred > 0.5).float()
            correct += (predicted == y_test).sum().item()
            total += y_test.size(0)
    print(f'Test Accuracy: {correct / total:.4f}')

# Evaluate the model
evaluate_model(model, test_loader)



Data processing completed successfully!
Epoch 1/30, Train Loss: 0.3267, Train Acc: 0.8649, Val Loss: 0.3085, Val Acc: 0.8565
Epoch 2/30, Train Loss: 0.2885, Train Acc: 0.8649, Val Loss: 0.2948, Val Acc: 0.8565
Epoch 3/30, Train Loss: 0.2747, Train Acc: 0.8713, Val Loss: 0.2841, Val Acc: 0.8663
Epoch 4/30, Train Loss: 0.2668, Train Acc: 0.8767, Val Loss: 0.2784, Val Acc: 0.8689
Epoch 5/30, Train Loss: 0.2627, Train Acc: 0.8774, Val Loss: 0.2716, Val Acc: 0.8703
Epoch 6/30, Train Loss: 0.2612, Train Acc: 0.8780, Val Loss: 0.2744, Val Acc: 0.8703
Epoch 7/30, Train Loss: 0.2595, Train Acc: 0.8784, Val Loss: 0.2858, Val Acc: 0.8607
Epoch 8/30, Train Loss: 0.2571, Train Acc: 0.8785, Val Loss: 0.2679, Val Acc: 0.8703
Epoch 9/30, Train Loss: 0.2562, Train Acc: 0.8786, Val Loss: 0.2681, Val Acc: 0.8703
Epoch 10/30, Train Loss: 0.2542, Train Acc: 0.8800, Val Loss: 0.2665, Val Acc: 0.8714
Epoch 11/30, Train Loss: 0.2528, Train Acc: 0.8796, Val Loss: 0.2621, Val Acc: 0.8707
Epoch 12/30, Train Loss

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, random_split
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score, roc_auc_score, precision_score, recall_score

# --------------------- Data Preprocessing ---------------------

# Load the dataset
df = pd.read_csv("/content/drive/MyDrive/Colab Notebooks/FIT3199-FYP/ecg_data/records_w_diag_icd10.csv")
df.head()
# # Convert categorical target column "Stroke_Y/N" to binary (0 or 1)
# df["Stroke_Y/N"] = df["Stroke_Y/N"].astype(int)

# # Convert 'charttime' to datetime and sort by subject_id, charttime
# df['charttime'] = pd.to_datetime(df['charttime'], errors='coerce')
# df = df.sort_values(by=['subject_id', 'charttime'])

# # Drop non-relevant columns
# columns_to_drop = ["stay_id_x", "stay_id_y", "charttime", "dod", "icd_title"]
# df = df.drop(columns=columns_to_drop, errors='ignore')

# # Handle missing values by filling with column mean for numeric columns only
# numeric_data = df.select_dtypes(include=np.number)
# df[numeric_data.columns] = numeric_data.fillna(numeric_data.mean())

# # Ensure all remaining columns are numeric
# df = df.apply(pd.to_numeric, errors='coerce')

# # Our Stroke Target Column
# target_column = "Stroke_Y/N"

# # --------------------- Time-Series Sequence Creation ---------------------

# def create_sequences(df, n_previous=3):
#     sequences, labels = [], []

#     patient_groups = df.groupby("subject_id")  # Group by patient
#     for _, group in patient_groups:
#         group = group.drop(columns=["subject_id"])  # Drop ID for training
#         if len(group) < n_previous:
#             continue  # Skip patients with too few records

#         # Ensure only numeric values
#         group = group.apply(pd.to_numeric, errors='coerce')

#         X_patient = group.drop(columns=["Stroke_Y/N"]).values
#         y_patient = group["Stroke_Y/N"].values

#         # Create sequences of length `n_previous`
#         for i in range(len(group) - n_previous + 1):
#             seq_X = X_patient[i:i + n_previous]  # Past admissions
#             seq_y = y_patient[i + n_previous - 1]  # Predict next admission stroke outcome
#             sequences.append(seq_X)
#             labels.append(seq_y)

#     return np.array(sequences, dtype=np.float32), np.array(labels, dtype=np.float32)

# # Generate time-series sequences
# X_seq, y_seq = create_sequences(df, n_previous=3)

# # Replace NaN values with 0
# X_seq = np.nan_to_num(X_seq, nan=0.0)
# y_seq = np.nan_to_num(y_seq, nan=0.0)

# # Convert to PyTorch tensors
# X_tensor = torch.tensor(X_seq, dtype=torch.float32)
# y_tensor = torch.tensor(y_seq, dtype=torch.float32).unsqueeze(1)  # Shape: (N,1)

# # Debugging Output
# print(f"X_tensor shape: {X_tensor.shape}")  # Should be (samples, time_steps, features)
# print(f"y_tensor shape: {y_tensor.shape}")  # Should be (samples, 1)

# # Check if GPU is available
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# print(f"Using device: {device}")

# X_tensor, y_tensor = X_tensor.to(device), y_tensor.to(device)

# # --------------------- Train-Validation-Test Split ---------------------

# train_size = int(0.7 * len(X_tensor))
# val_size = int(0.15 * len(X_tensor))
# test_size = len(X_tensor) - train_size - val_size

# train_data, val_data, test_data = random_split(TensorDataset(X_tensor, y_tensor), [train_size, val_size, test_size])

# train_loader = DataLoader(train_data, batch_size=64, shuffle=True)
# val_loader = DataLoader(val_data, batch_size=64, shuffle=False)
# test_loader = DataLoader(test_data, batch_size=64, shuffle=False)

# # --------------------- Define CNN Model ---------------------

# class StrokeCNN(nn.Module):
#     def __init__(self, num_features, num_filters=64, kernel_size=2, dropout=0.3):
#         super(StrokeCNN, self).__init__()

#         self.conv1 = nn.Conv1d(in_channels=num_features, out_channels=num_filters, kernel_size=kernel_size, padding=1)
#         self.conv2 = nn.Conv1d(in_channels=num_filters, out_channels=num_filters * 2, kernel_size=kernel_size, padding=1)

#         self.pool = nn.MaxPool1d(kernel_size=2)
#         self.dropout = nn.Dropout(dropout)
#         self.sigmoid = nn.Sigmoid()

#         with torch.no_grad():
#             sample_input = torch.rand(1, num_features, 3)
#             sample_output = self.pool(torch.relu(self.conv1(sample_input)))
#             sample_output = self.pool(torch.relu(self.conv2(sample_output)))
#             self.flattened_size = sample_output.numel()

#         self.fc1 = nn.Linear(self.flattened_size, 128)
#         self.fc2 = nn.Linear(128, 1)

#     def forward(self, x):
#         x = x.permute(0, 2, 1)
#         x = self.pool(torch.relu(self.conv1(x)))
#         x = self.pool(torch.relu(self.conv2(x)))
#         x = x.view(x.shape[0], -1)
#         x = torch.relu(self.fc1(x))
#         x = self.dropout(x)
#         x = self.fc2(x)
#         return self.sigmoid(x)

# # --------------------- Training ---------------------

# num_features = X_seq.shape[2]
# model = StrokeCNN(num_features).to(device)
# criterion = nn.BCELoss()
# optimizer = optim.Adam(model.parameters(), lr=0.001)

# num_epochs = 10

# for epoch in range(num_epochs):
#     model.train()
#     train_loss = 0.0
#     train_correct = 0
#     train_total = 0

#     for inputs, targets in train_loader:
#         inputs, targets = inputs.to(device), targets.to(device)  # Move to GPU
#         optimizer.zero_grad()
#         outputs = model(inputs)
#         loss = criterion(outputs, targets)
#         loss.backward()
#         optimizer.step()
#         train_loss += loss.item()
#         predicted = (outputs >= 0.5).float()
#         train_correct += (predicted == targets).sum().item()
#         train_total += targets.size(0)

#     train_loss /= len(train_loader)
#     train_accuracy = train_correct / train_total

#     model.eval()
#     val_loss = 0.0
#     val_correct = 0
#     val_total = 0

#     with torch.no_grad():
#         for inputs, targets in val_loader:
#             inputs, targets = inputs.to(device), targets.to(device)
#             outputs = model(inputs)
#             loss = criterion(outputs, targets)
#             val_loss += loss.item()
#             predicted = (outputs >= 0.5).float()
#             val_correct += (predicted == targets).sum().item()
#             val_total += targets.size(0)

#     val_loss /= len(val_loader)
#     val_accuracy = val_correct / val_total

#     print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f} - "
#           f"Val Loss: {val_loss:.4f}, Val Acc: {val_accuracy:.4f}")

# # --------------------- Evaluation ---------------------

# def evaluate(model, dataloader):
#     model.eval()
#     y_true, y_pred = [], []

#     with torch.no_grad():
#         for inputs, targets in dataloader:
#             inputs, targets = inputs.to(device), targets.to(device)
#             outputs = model(inputs).cpu().numpy()
#             y_pred.extend(outputs)
#             y_true.extend(targets.cpu().numpy())

#     y_pred_binary = (np.array(y_pred) >= 0.5).astype(int)
#     return accuracy_score(y_true, y_pred_binary), precision_score(y_true, y_pred_binary, zero_division=1), recall_score(y_true, y_pred_binary), roc_auc_score(y_true, y_pred)

# print("Final Test Performance:", evaluate(model, test_loader))



Unnamed: 0,filename,study_id,patient_id,ecg_time,ed_stay_id,ed_hadm_id,hosp_hadm_id,ed_diag_ed,ed_diag_hosp,hosp_diag_hosp,...,age,anchor_year,anchor_age,dod,ecg_no_within_stay,ecg_taken_in_ed,ecg_taken_in_hosp,ecg_taken_in_ed_or_hosp,fold,strat_fold
0,mimic-iv-ecg/files/p1000/p10000032/s40689238/4...,40689238,10000032,2180-07-23 08:44:00,,,,[],[],[],...,52.0,2180.0,52.0,2180-09-09,-1,False,False,False,17,19
1,mimic-iv-ecg/files/p1000/p10000032/s44458630/4...,44458630,10000032,2180-07-23 09:54:00,,,,[],[],[],...,52.0,2180.0,52.0,2180-09-09,-1,False,False,False,17,19
2,mimic-iv-ecg/files/p1000/p10000032/s49036311/4...,49036311,10000032,2180-08-06 09:07:00,,,25742920.0,[],[],"['J449', 'E875', 'Z21', 'R188', 'R197', 'E871'...",...,52.0,2180.0,52.0,2180-09-09,0,False,True,True,17,19
3,mimic-iv-ecg/files/p1000/p10000117/s45090959/4...,45090959,10000117,2181-03-04 17:14:00,,,,[],[],[],...,55.0,2174.0,48.0,,-1,False,False,False,18,5
4,mimic-iv-ecg/files/p1000/p10000117/s48446569/4...,48446569,10000117,2183-09-18 13:52:00,,,,[],[],[],...,57.0,2174.0,48.0,,-1,False,False,False,18,5


In [None]:
import numpy as np
import os

def compress_npy(file_path):
    """Compresses a .npy file using numpy.savez_compressed.

    Args:
        file_path (str): The path to the .npy file.
    """
    if not file_path.endswith('.npy') or not os.path.exists(file_path):
        raise ValueError("Invalid file path. Must be an existing .npy file.")

    data = np.load(file_path)
    base_name = os.path.splitext(file_path)[0]
    compressed_file_path = f"{base_name}.npz"
    np.savez_compressed(compressed_file_path, data=data)

    print(f"Compressed file saved to: {compressed_file_path}")
    os.remove(file_path)
    print(f"Original file removed: {file_path}")

# Example usage:
file_path = "memmap (1).npy"
arr = np.random.rand(100, 100)
np.save(file_path, arr)
compress_npy(file_path)

Compressed file saved to: memmap (1).npz
Original file removed: memmap (1).npy


In [None]:
memmap_meta = np.load("memmap_meta.npz", allow_pickle=True)
print(memmap_meta.files)  # Lists all stored keys

# Peek inside each key
for key in memmap_meta.files:
    print(f"{key}:", memmap_meta[key])


['start', 'length', 'shape', 'dtype']
start: [       0     1000     2000 ... 21646000 21647000 21648000]
length: [1000 1000 1000 ... 1000 1000 1000]
shape: [21649000       12]
dtype: float32


In [None]:
import pandas as pd
import numpy as np

# Load your stroke-labeled diagnosis dataframe
diagnosis_df = pd.read_csv("records_w_diag_icd10 (1).csv", low_memory=False, on_bad_lines='warn')

# Load memmap metadata and actual ECG data
memmap_meta = np.load("memmap_meta.npz", allow_pickle=True)
print("Meta keys:", memmap_meta.files)

memmap_data = np.load("memmap (1).npz", allow_pickle=True)
print("ECG data keys:", memmap_data.files)

# Show a sample of your diagnosis data
print(diagnosis_df.head())

# Inspect ECG data shape
ecg_array = memmap_data["data"]
print("ECG shape:", ecg_array.shape)  # Should be (num_samples, length, channels) or similar


ParserError: Error tokenizing data. C error: EOF inside string starting at row 99886

In [None]:
diagnosis_df = diagnosis_df.iloc[:ecg_array.shape[0]].reset_index(drop=True)
def has_stroke_icd10(icd_list):
    stroke_prefixes = [f"I6{i}" for i in range(10)]
    if isinstance(icd_list, str):
        try:
            icd_list = eval(icd_list)
        except:
            return 0
    return int(any(code[:3] in stroke_prefixes for code in icd_list))

diagnosis_df["stroke_label"] = diagnosis_df["hosp_diag_hosp"].apply(has_stroke_icd10)

print("Class balance:\n", diagnosis_df["stroke_label"].value_counts())


Class balance:
 stroke_label
0    99
1     1
Name: count, dtype: int64


In [None]:
#use whole dataset with loss weights
import torch
import torch.nn as nn

# Count positives and negatives
num_pos = diagnosis_df["stroke_label"].sum()
num_neg = len(diagnosis_df) - num_pos

# Compute pos_weight (how much more important a stroke is than a non-stroke)
pos_weight = torch.tensor([num_neg / num_pos], dtype=torch.float32)

# Use in your loss function
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

In [None]:
#use downsampling
# Separate classes
stroke_df = diagnosis_df[diagnosis_df["stroke_label"] == 1]
nonstroke_df = diagnosis_df[diagnosis_df["stroke_label"] == 0]

# Random sample non-stroke rows to match stroke count
nonstroke_sampled = nonstroke_df.sample(n=len(stroke_df), random_state=42)

# Combine
balanced_df = pd.concat([stroke_df, nonstroke_sampled]).sample(frac=1, random_state=42).reset_index(drop=True)

print(f"Balanced dataset size: {len(balanced_df)}")


Balanced dataset size: 2


In [None]:
from torch.utils.data import Dataset

class ECGDataset(Dataset):
    def __init__(self, df, ecg_data):
        self.df = df.reset_index(drop=True)
        self.ecg_data = ecg_data

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        ecg = self.ecg_data[idx]              # shape: [5000, 12] or flattened
        label = float(self.df.loc[idx, "stroke_label"])
        return torch.tensor(ecg, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleECG1DCNN(nn.Module):
    def __init__(self, input_channels=12, sequence_length=5000):
        super(SimpleECG1DCNN, self).__init__()

        self.conv1 = nn.Conv1d(input_channels, 32, kernel_size=7, stride=1, padding=3)
        self.bn1 = nn.BatchNorm1d(32)
        self.pool1 = nn.MaxPool1d(kernel_size=2)

        self.conv2 = nn.Conv1d(32, 64, kernel_size=5, stride=1, padding=2)
        self.bn2 = nn.BatchNorm1d(64)
        self.pool2 = nn.MaxPool1d(kernel_size=2)

        self.conv3 = nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1)
        self.bn3 = nn.BatchNorm1d(128)
        self.pool3 = nn.AdaptiveAvgPool1d(1)  # Output shape: [B, 128, 1]

        self.fc = nn.Linear(128, 1)  # Binary output (logit)

    def forward(self, x):
        x = x.permute(0, 2, 1)  # Change shape: [B, Seq, C] -> [B, C, Seq]
        x = self.pool1(F.relu(self.bn1(self.conv1(x))))
        x = self.pool2(F.relu(self.bn2(self.conv2(x))))
        x = self.pool3(F.relu(self.bn3(self.conv3(x))))
        x = x.view(x.size(0), -1)
        return self.fc(x)  # Logit output (no sigmoid here)


In [None]:
from torch.utils.data import Dataset
import torch

class ECGDataset(Dataset):
    def __init__(self, df, ecg_data):
        self.df = df.reset_index(drop=True)
        self.ecg_data = ecg_data

    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        ecg = self.ecg_data[idx]  # Shape: [5000, 12] or [length, channels]
        label = float(self.df.loc[idx, "stroke_label"])
        return torch.tensor(ecg, dtype=torch.float32), torch.tensor(label, dtype=torch.float32)


In [None]:
diagnosis_df = diagnosis_df.iloc[:len(ecg_array)].reset_index(drop=True)
balanced_df = balanced_df.iloc[:len(ecg_array)].reset_index(drop=True)

# Proceed to create datasets
full_dataset = ECGDataset(diagnosis_df, ecg_array)
balanced_dataset = ECGDataset(balanced_df, ecg_array)

In [None]:
def train(model, dataloader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for X, y in dataloader:
        X, y = X.to(device), y.to(device).unsqueeze(1)  # Make y shape [B, 1]
        optimizer.zero_grad()
        output = model(X)
        loss = criterion(output, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)


In [None]:
# --- Ensure 'stroke_label' column exists ---
if "stroke_label" not in diagnosis_df.columns:
    def has_stroke_icd10(icd_list):
        valid_prefixes = [f"I6{i}" for i in range(10)]  # I60 to I69
        if isinstance(icd_list, str):
            try:
                icd_list = eval(icd_list)
            except:
                return 0
        return int(any(code[:3] in valid_prefixes for code in icd_list))

    diagnosis_df["stroke_label"] = diagnosis_df["hosp_diag_hosp"].apply(has_stroke_icd10)


In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = SimpleECG1DCNN().to(device)

# For Strategy 1 (weighted loss)
num_pos = diagnosis_df["stroke_label"].sum()
num_neg = len(diagnosis_df) - num_pos
pos_weight = torch.tensor([num_neg / num_pos], dtype=torch.float32).to(device)
criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# Training loop (example for 5 epochs)
for epoch in range(5):
    loss = train(model, full_loader, criterion, optimizer, device)
    print(f"Epoch {epoch+1} - Loss: {loss:.4f}")


KeyError: 'stroke_label'

In [None]:
import numpy as np
import os
from tqdm import tqdm

# 1. VERIFY FILE INTEGRITY
file_path = "memmap (1).npz"
file_size = os.path.getsize(file_path)
expected_size = 21649000 * 12 * 4  # 1.04GB for float32

print(f"File size: {file_size/1e6:.2f}MB | Expected: {expected_size/1e6:.2f}MB")

# 2. ATTEMPT PROPER LOADING
try:
    # Load with mmap for large files
    with np.load(file_path, mmap_mode='r') as data:
        if 'data' in data:
            ecg_data = data['data']
            print(f"Loaded shape: {ecg_data.shape}")
        else:
            print("No 'data' array found in file. Available keys:", list(data.keys()))
except Exception as e:
    print(f"Load error: {e}")

# 3. DATA RECOVERY STRATEGY
if 'ecg_data' not in locals() or ecg_data.shape[0] < 1000:
    print("\n⚠️ Using partial data recovery approach")

    # Try loading what we can
    try:
        with np.load(file_path, allow_pickle=True) as data:
            all_arrays = {k: data[k] for k in data.files}
            print("Found arrays:", list(all_arrays.keys()))

            # Find the largest array
            ecg_data = max(all_arrays.values(), key=lambda x: x.size)
            print(f"Using largest array found: {ecg_data.shape}")
    except:
        print("Could not recover any data. Creating synthetic placeholder...")
        ecg_data = np.zeros((100, 12), dtype=np.float32)

File size: 0.08MB | Expected: 1039.15MB
Loaded shape: (100, 100)

⚠️ Using partial data recovery approach
Found arrays: ['data']
Using largest array found: (100, 100)


In [None]:
# -------------------- COMPLETE ECG STROKE PREDICTION SYSTEM --------------------
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import copy
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import classification_report, roc_auc_score, f1_score
from sklearn.preprocessing import StandardScaler
from collections import Counter
from imblearn.combine import SMOTEENN
from imblearn.over_sampling import SMOTE
import os
import re

# -------------------- 1. DATA LOADING --------------------
def load_ecg_data():
    file_path = "memmap (1).npy"
    try:
        data = np.fromfile(file_path, dtype=np.float32)
        ecg_array = data.reshape(-1, 3000)
        print("✅ Loaded ECG data from .npy file")
        print("Shape:", ecg_array.shape)

        scaler = StandardScaler()
        ecg_array = scaler.fit_transform(ecg_array)
        ecg_array = np.nan_to_num(ecg_array, nan=0.0, posinf=0.0, neginf=0.0)

        return ecg_array
    except Exception as e:
        raise ValueError(f"Could not load ECG data: {e}")

def load_diagnosis_data(csv_path, expected_rows=None):
    print(f"\n🔍 Loading diagnosis data from {csv_path}")
    try:
        df = pd.read_csv(csv_path, low_memory=False)
        if expected_rows:
            df = df.iloc[:expected_rows].reset_index(drop=True)
            print(f"Aligned to {expected_rows} rows to match ECG data")
        return df
    except Exception as e:
        print(f"Failed to load diagnosis data: {e}")
        return pd.DataFrame()

# -------------------- 2. LABEL CREATION --------------------
def create_stroke_labels(df):
    def has_stroke(icd_list):
        pattern = re.compile(r"^I6[0-9]")
        if isinstance(icd_list, str):
            try:
                icd_list = eval(icd_list)
            except:
                return 0
        if isinstance(icd_list, list):
            return int(any(pattern.match(code) for code in icd_list if isinstance(code, str)))
        return 0

    df["stroke_label"] = df["hosp_diag_hosp"].apply(has_stroke)
    return df

# -------------------- 3. DATASET + MODEL --------------------
class ECGDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return torch.tensor(self.X[idx], dtype=torch.float32), torch.tensor(self.y[idx], dtype=torch.float32)

class ECGCNNClassifier(nn.Module):
    def __init__(self, input_length=3000):
        super().__init__()
        self.conv_layers = nn.Sequential(
            nn.Conv1d(1, 16, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),

            nn.Conv1d(16, 32, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),

            nn.Conv1d(32, 64, kernel_size=5, stride=1, padding=2),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),

            nn.Conv1d(64, 128, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.MaxPool1d(kernel_size=2),

            nn.AdaptiveAvgPool1d(1)
        )

        self.fc_layers = nn.Sequential(
            nn.Linear(128, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 1)
        )

    def forward(self, x):
        # Reshape for CNN (batch_size, channels, sequence_length)
        x = x.view(x.size(0), 1, -1)
        x = self.conv_layers(x)
        x = x.view(x.size(0), -1)
        return self.fc_layers(x)

def extract_ecg_features(ecg_array):
    """Extract relevant features from ECG data"""
    features = []

    for ecg in ecg_array:
        # Simple statistical features
        mean = np.mean(ecg)
        std = np.std(ecg)
        min_val = np.min(ecg)
        max_val = np.max(ecg)
        p2p = max_val - min_val

        # Simple frequency domain features using FFT
        fft_vals = np.abs(np.fft.rfft(ecg))
        fft_freq = np.fft.rfftfreq(len(ecg))
        dominant_freq = fft_freq[np.argmax(fft_vals)]

        # Combine features
        feature_vector = np.array([
            mean, std, min_val, max_val, p2p,
            dominant_freq, np.mean(fft_vals), np.std(fft_vals)
        ])

        features.append(feature_vector)

    return np.array(features)

# -------------------- 4. LOSS + TRAINING --------------------
class FocalLoss(nn.Module):
    def __init__(self, gamma=0.5, pos_weight=None):
        super().__init__()
        self.gamma = gamma
        self.pos_weight = pos_weight

    def forward(self, inputs, targets):
        # Clip inputs to prevent extreme values
        inputs = torch.clamp(inputs, -50, 50)

        # Using BCE with logits for numerical stability
        bce = nn.functional.binary_cross_entropy_with_logits(
            inputs, targets,
            pos_weight=self.pos_weight,
            reduction='none'
        )

        # Safe exponential calculation
        pt = torch.exp(-torch.clamp(bce, max=50))

        # Calculate focal term with safety checks
        focal_term = (1 - pt + 1e-7) ** self.gamma

        # Return mean of the loss
        loss = focal_term * bce

        # Check for NaN and replace with zero
        loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)

        return loss.mean()

def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss = 0
    for X, y in loader:
        X, y = X.to(device), y.to(device).unsqueeze(1)
        optimizer.zero_grad()
        outputs = model(X)
        loss = criterion(outputs, y)
        loss.backward()

        # Add gradient clipping here
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(loader)

def evaluate(model, loader, criterion, device, threshold=0.15):
    model.eval()
    all_preds, all_probs, all_labels = [], [], []
    total_loss = 0
    with torch.no_grad():
        for X, y in loader:
            X, y = X.to(device), y.to(device).unsqueeze(1)
            outputs = model(X)
            loss = criterion(outputs, y)
            probs = torch.sigmoid(outputs)
            preds = (probs > threshold).cpu().numpy()
            all_preds.extend(preds)
            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
            total_loss += loss.item()
    auc = roc_auc_score(all_labels, all_probs)
    return total_loss / len(loader), np.array(all_preds), np.array(all_labels), auc, np.array(all_probs)

def find_best_threshold(probs, labels):
    """Find threshold that maximizes F1 score"""
    best_f1 = 0
    best_threshold = 0.5

    for threshold in np.arange(0.05, 0.95, 0.05):
        preds = (probs > threshold).astype(int)
        f1 = f1_score(labels, preds)
        if f1 > best_f1:
            best_f1 = f1
            best_threshold = threshold

    print(f"✅ Best threshold: {best_threshold:.2f} with F1 score: {best_f1:.4f}")
    return best_threshold

# -------------------- 5. MAIN TRAINING --------------------
def run_training(X, y, n_splits=5, epochs=20, batch_size=128):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    results = []
    sampler = SMOTE(random_state=42)
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

    all_val_probs = []
    all_val_labels = []

    for fold, (train_idx, test_idx) in enumerate(skf.split(X, y)):
        print(f"\n=== FOLD {fold+1}/{n_splits} ===")
        X_train, X_test = X[train_idx], X[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        # Resample with SMOTE
        X_train, y_train = sampler.fit_resample(X_train, y_train)
        print(f"✅ Resampled training set — Class balance: {Counter(y_train)}")

        # Clean
        X_train = np.nan_to_num(X_train, nan=0.0, posinf=0.0, neginf=0.0)
        X_test = np.nan_to_num(X_test, nan=0.0, posinf=0.0, neginf=0.0)

        # Optional re-normalization
        scaler = StandardScaler()
        X_train = scaler.fit_transform(X_train)
        X_test = scaler.transform(X_test)

        # Dataloaders
        train_loader = DataLoader(ECGDataset(X_train, y_train), batch_size=batch_size, shuffle=True)
        test_loader = DataLoader(ECGDataset(X_test, y_test), batch_size=batch_size)

        model = ECGCNNClassifier(X.shape[1]).to(device)
        pos_count = y_train.sum()
        neg_count = len(y_train) - pos_count
        pos_weight = torch.tensor([neg_count / pos_count if pos_count > 0 else 1.0]).to(device)

        criterion = FocalLoss(gamma=0.5, pos_weight=pos_weight)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

        # Use cosine annealing scheduler with restarts
        scheduler = CosineAnnealingLR(optimizer, T_max=epochs//2, eta_min=1e-6)

        # Add validation tracking
        best_auc = 0
        best_model = None
        best_val_probs = None
        best_val_labels = None

        for epoch in range(epochs):
            train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
            scheduler.step()

            # Validate every epoch
            val_loss, val_preds, val_labels, val_auc, val_probs = evaluate(
                model, test_loader, criterion, device, threshold=0.15
            )

            print(f"Epoch {epoch+1}/{epochs} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | AUC: {val_auc:.4f}")

            # Save best model
            if val_auc > best_auc:
                best_auc = val_auc
                best_model = copy.deepcopy(model)
                best_val_probs = val_probs
                best_val_labels = val_labels
                print(f"✅ New best model saved! AUC: {val_auc:.4f}")

        # Store validation probabilities for threshold optimization
        all_val_probs.extend(best_val_probs)
        all_val_labels.extend(best_val_labels)

        # Find best threshold for final evaluation
        if fold == n_splits - 1:  # On last fold
            best_threshold = find_best_threshold(np.array(all_val_probs), np.array(all_val_labels))
        else:
            best_threshold = 0.15  # Default threshold

        # Use best model for final evaluation
        test_loss, preds, labels, auc, probs = evaluate(best_model, test_loader, criterion, device, threshold=best_threshold)

        print(f"→ Predicted Positives: {int(preds.sum())} / {len(preds)}")
        print(f"→ Actual Positives: {int(labels.sum())} / {len(labels)}")
        print(f"→ ROC-AUC: {auc:.4f}")

        unique_preds = np.unique(preds, return_counts=True)
        print(f"→ Unique predictions: {dict(zip(unique_preds[0], unique_preds[1]))}")

        # Check if preds array has been modified
        print(f"→ Shape of preds array: {preds.shape}")
        print(f"→ Shape of labels array: {labels.shape}")

        report = classification_report(labels, preds, target_names=['Non-Stroke', 'Stroke'], output_dict=True, zero_division=0)

        results.append({
            'fold': fold+1,
            'test_loss': test_loss,
            'roc_auc': auc,
            'best_threshold': best_threshold,
            'report': report,
            'model': best_model
        })

        # Save model checkpoint
        torch.save({
            'model_state_dict': best_model.state_dict(),
            'threshold': best_threshold,
            'auc': auc,
        }, f'ecg_stroke_model_fold{fold+1}.pt')
        print(f"✅ Model saved to ecg_stroke_model_fold{fold+1}.pt")

    return results

# -------------------- 6. EXECUTION --------------------
if __name__ == "__main__":
    ecg_data = load_ecg_data()
    diagnosis_df = load_diagnosis_data("records_w_diag_icd10 (1).csv", len(ecg_data))
    diagnosis_df = create_stroke_labels(diagnosis_df)

    print("\nClass Distribution:")
    print(diagnosis_df["stroke_label"].value_counts())

    results = run_training(
        ecg_data,
        diagnosis_df["stroke_label"].values,
        n_splits=5,
        epochs=5,
        batch_size=128
    )

    print("\n📊 Final Results:")
    avg_auc = 0
    for res in results:
        print(f"\nFold {res['fold']}:")
        print(f"Test Loss: {res['test_loss']:.4f}")
        print(f"ROC AUC: {res['roc_auc']:.4f}")
        print(f"Best Threshold: {res['best_threshold']:.4f}")
        avg_auc += res['roc_auc']

        # Check both key formats that might be used in the report
        if '1' in res['report']:
            print(f"Stroke Precision: {res['report']['1']['precision']:.4f}")
            print(f"Stroke Recall: {res['report']['1']['recall']:.4f}")
            print(f"Stroke F1 Score: {res['report']['1']['f1-score']:.4f}")
        elif 'Stroke' in res['report']:
            print(f"Stroke Precision: {res['report']['Stroke']['precision']:.4f}")
            print(f"Stroke Recall: {res['report']['Stroke']['recall']:.4f}")
            print(f"Stroke F1 Score: {res['report']['Stroke']['f1-score']:.4f}")
        else:
            # Check the actual predictions count
            preds_count = res.get('preds_sum', 'unknown')
            print(f"⚠️ No stroke predictions found in report. Predicted positives: {preds_count}")

    print(f"\n✅ Average ROC AUC across all folds: {avg_auc / len(results):.4f}")

def predict_with_ensemble(ecg_data, model_paths, threshold=None):
    """
    Make predictions using an ensemble of models

    Args:
        ecg_data: ECG data to predict on (already preprocessed)
        model_paths: List of paths to model checkpoints
        threshold: Optional prediction threshold (if None, uses saved threshold)

    Returns:
        Average probabilities and binary predictions
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    all_probs = []

    for model_path in model_paths:
        checkpoint = torch.load(model_path, map_location=device)
        model = ECGCNNClassifier(ecg_data.shape[1]).to(device)
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()

        # Use the model's saved threshold if none provided
        model_threshold = threshold if threshold is not None else checkpoint.get('threshold', 0.5)

        # Prepare data
        dataset = ECGDataset(ecg_data, np.zeros(len(ecg_data)))
        loader = DataLoader(dataset, batch_size=128)

        # Get predictions
        probs = []
        with torch.no_grad():
            for X, _ in loader:
                X = X.to(device)
                outputs = model(X)
                batch_probs = torch.sigmoid(outputs).cpu().numpy()
                probs.extend(batch_probs)

        all_probs.append(np.array(probs))

    # Average probabilities from all models
    avg_probs = np.mean(np.array(all_probs), axis=0)

    # Use threshold to get binary predictions
    final_threshold = threshold if threshold is not None else 0.5
    binary_preds = (avg_probs > final_threshold).astype(int)

    return avg_probs, binary_preds

✅ Loaded ECG data from .npy file
Shape: (86596, 3000)

🔍 Loading diagnosis data from records_w_diag_icd10 (1).csv
Aligned to 86596 rows to match ECG data

Class Distribution:
stroke_label
0    83366
1     3230
Name: count, dtype: int64
Using device: cuda

=== FOLD 1/5 ===
✅ Resampled training set — Class balance: Counter({np.int64(0): 66692, np.int64(1): 66692})
Epoch 1/5 | Train Loss: 0.3785 | Val Loss: 0.3457 | AUC: 0.5451
✅ New best model saved! AUC: 0.5451
Epoch 2/5 | Train Loss: 0.3404 | Val Loss: 0.3061 | AUC: 0.5665
✅ New best model saved! AUC: 0.5665
Epoch 3/5 | Train Loss: 0.3294 | Val Loss: 0.3897 | AUC: 0.5670
✅ New best model saved! AUC: 0.5670
Epoch 4/5 | Train Loss: 0.3262 | Val Loss: 0.3262 | AUC: 0.5763
✅ New best model saved! AUC: 0.5763
Epoch 5/5 | Train Loss: 0.3139 | Val Loss: 0.3147 | AUC: 0.5943
✅ New best model saved! AUC: 0.5943
→ Predicted Positives: 13892 / 17320
→ Actual Positives: 646 / 17320
→ ROC-AUC: 0.5943
→ Unique predictions: {np.False_: np.int64(3428)

In [None]:
# -------------------- ENHANCED ECG STROKE PREDICTION SYSTEM --------------------
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import copy
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import classification_report, roc_auc_score, f1_score, precision_recall_curve, average_precision_score
from sklearn.preprocessing import StandardScaler
from collections import Counter
from imblearn.over_sampling import SMOTE
from imblearn.under_sampling import RandomUnderSampler
from imblearn.pipeline import Pipeline
import os
import re
!pip install PyWavelets
import pywt
from tqdm import tqdm
import seaborn as sns

# -------------------- 1. DATA LOADING AND PREPROCESSING --------------------
def load_ecg_data():
    file_path = "memmap (1).npy"
    try:
        data = np.fromfile(file_path, dtype=np.float32)
        ecg_array = data.reshape(-1, 3000)
        print("✅ Loaded ECG data from .npy file")
        print("Shape:", ecg_array.shape)
        return ecg_array
    except Exception as e:
        raise ValueError(f"Could not load ECG data: {e}")

def load_diagnosis_data(csv_path, expected_rows=None):
    print(f"\n🔍 Loading diagnosis data from {csv_path}")
    try:
        df = pd.read_csv(csv_path, low_memory=False)
        if expected_rows:
            df = df.iloc[:expected_rows].reset_index(drop=True)
            print(f"Aligned to {expected_rows} rows to match ECG data")
        return df
    except Exception as e:
        print(f"Failed to load diagnosis data: {e}")
        return pd.DataFrame()

def preprocess_ecg(ecg_array, wavelet_transform=True):
    """Apply advanced preprocessing to ECG data"""
    print("Preprocessing ECG data...")

    # 1. Remove baseline wander (high-pass filter via wavelet)
    if wavelet_transform:
        print("Applying wavelet denoising...")
        # Process in batches to avoid memory issues
        batch_size = 1000
        processed_data = []

        for i in range(0, len(ecg_array), batch_size):
            batch = ecg_array[i:i+batch_size]
            processed_batch = []

            for signal in batch:
                # Handle invalid values before wavelet transform
                signal = np.nan_to_num(signal, nan=0.0, posinf=0.0, neginf=0.0)

                # Determine appropriate wavelet level based on signal length
                level = min(5, pywt.dwt_max_level(len(signal), 'db4'))
                coeffs = pywt.wavedec(signal, 'db4', level=level)

                # Handle coefficients with epsilon to prevent division by zero
                for j in range(1, len(coeffs)):
                    std_val = np.std(coeffs[j]) + 1e-10
                    coeffs[j] = pywt.threshold(coeffs[j], std_val/2, mode='soft')

                # Reconstruct signal
                reconstructed = pywt.waverec(coeffs, 'db4')

                # Adjust length to match original if needed
                if len(reconstructed) > len(signal):
                    reconstructed = reconstructed[:len(signal)]
                else:
                    reconstructed = np.pad(reconstructed, (0, len(signal) - len(reconstructed)))

                processed_batch.append(reconstructed)

            processed_data.extend(processed_batch)

        ecg_array = np.array(processed_data)

    # 2. Handle outliers and invalid values before standardization
    ecg_array = np.nan_to_num(ecg_array, nan=0.0, posinf=0.0, neginf=0.0)

    # 3. Standardize each signal carefully to avoid division by zero
    print("Standardizing signals...")
    standardized_signals = []

    for signal in ecg_array:
        # Calculate mean and std safely
        mean = np.mean(signal)
        std = np.std(signal)

        # If std is zero or very close to zero, use a small constant instead
        if std < 1e-10:
            standardized = signal - mean  # Just center the data if no variance
        else:
            standardized = (signal - mean) / std

        standardized_signals.append(standardized)

    ecg_array = np.array(standardized_signals)

    # 4. Clip extreme values (beyond 5 std)
    for i in range(len(ecg_array)):
        # Re-calculate mean and std for the standardized signal
        mean, std = np.mean(ecg_array[i]), np.std(ecg_array[i])
        # Only clip if std is not near zero
        if std > 1e-10:
            ecg_array[i] = np.clip(ecg_array[i], mean - 5*std, mean + 5*std)

    # 5. Final check for any remaining invalid values
    ecg_array = np.nan_to_num(ecg_array, nan=0.0, posinf=0.0, neginf=0.0)

    print("✅ ECG preprocessing complete")
    return ecg_array

# -------------------- 2. LABEL CREATION --------------------
def create_stroke_labels(df):
    def has_stroke(icd_list):
        pattern = re.compile(r"^I6[0-9]")
        if isinstance(icd_list, str):
            try:
                icd_list = eval(icd_list)
            except:
                return 0
        if isinstance(icd_list, list):
            return int(any(pattern.match(code) for code in icd_list if isinstance(code, str)))
        return 0

    df["stroke_label"] = df["hosp_diag_hosp"].apply(has_stroke)
    return df

# -------------------- 3. DATASET AND FEATURE EXTRACTION --------------------
class ECGDataset(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return torch.tensor(self.X[idx], dtype=torch.float32), torch.tensor(self.y[idx], dtype=torch.float32)

def extract_ecg_features(ecg_array):
    """Extract clinically relevant features from ECG data"""
    print("Extracting ECG features...")
    features = []

    for ecg in tqdm(ecg_array):
        # Time domain features
        mean = np.mean(ecg)
        std = np.std(ecg)
        min_val = np.min(ecg)
        max_val = np.max(ecg)
        p2p = max_val - min_val
        rms = np.sqrt(np.mean(np.square(ecg)))

        # Detect peaks for QRS complex approximation
        # Simple peak detection - in practice use more robust methods
        from scipy.signal import find_peaks
        peaks, _ = find_peaks(ecg, distance=50)  # Adjust distance based on sampling rate

        # Heart rate approximation
        if len(peaks) > 1:
            rr_intervals = np.diff(peaks)
            hr_feature = 60 / (np.mean(rr_intervals) / 250)  # Assuming 250Hz sampling rate
            hr_variability = np.std(rr_intervals) if len(rr_intervals) > 1 else 0
        else:
            hr_feature = 0
            hr_variability = 0

        # Frequency domain features
        fft_vals = np.abs(np.fft.rfft(ecg))
        fft_freq = np.fft.rfftfreq(len(ecg))

        # Power in different frequency bands
        lf_power = np.sum(fft_vals[(fft_freq >= 0.04) & (fft_freq < 0.15)])
        hf_power = np.sum(fft_vals[(fft_freq >= 0.15) & (fft_freq < 0.4)])
        lf_hf_ratio = lf_power / hf_power if hf_power > 0 else 0

        # Dominant frequency
        dominant_freq = fft_freq[np.argmax(fft_vals)]

        # Wavelet features - energy in different sub-bands
        coeffs = pywt.wavedec(ecg, 'db4', level=5)
        wavelet_energy = [np.sum(np.square(c)) for c in coeffs]

        # Combine all features
        feature_vector = np.array([
            mean, std, min_val, max_val, p2p, rms,
            hr_feature, hr_variability,
            lf_power, hf_power, lf_hf_ratio, dominant_freq,
            *wavelet_energy
        ])

        features.append(feature_vector)

    features = np.array(features)
    # Normalize features
    scaler = StandardScaler()
    features = scaler.fit_transform(features)
    print(f"✅ Extracted {features.shape[1]} features for each ECG")

    return features

# -------------------- 4. IMPROVED MODEL ARCHITECTURE --------------------
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=15, stride=stride, padding=7, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=15, padding=7, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv1d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm1d(out_channels)
            )

    def forward(self, x):
        residual = x

        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu(out)

        out = self.conv2(out)
        out = self.bn2(out)

        out += self.shortcut(residual)
        out = self.relu(out)

        return out

class ECGResNet(nn.Module):
    def __init__(self, input_length=3000, num_classes=1):
        super(ECGResNet, self).__init__()

        # Initial convolutional layer
        self.conv1 = nn.Conv1d(1, 64, kernel_size=15, stride=2, padding=7, bias=False)
        self.bn1 = nn.BatchNorm1d(64)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(kernel_size=3, stride=2, padding=1)

        # Residual blocks with increasing channels
        self.layer1 = self._make_layer(64, 64, 2, stride=1)
        self.layer2 = self._make_layer(64, 128, 2, stride=2)
        self.layer3 = self._make_layer(128, 256, 2, stride=2)
        self.layer4 = self._make_layer(256, 512, 2, stride=2)

        # Adaptive pooling and fully connected layers
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Sequential(
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

        # Initialize weights
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def _make_layer(self, in_channels, out_channels, blocks, stride=1):
        layers = []
        layers.append(ResidualBlock(in_channels, out_channels, stride))
        for _ in range(1, blocks):
            layers.append(ResidualBlock(out_channels, out_channels))
        return nn.Sequential(*layers)

    def forward(self, x):
        # Reshape if needed: [batch_size, sequence_length] -> [batch_size, 1, sequence_length]
        if len(x.shape) == 2:
            x = x.unsqueeze(1)

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)

        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)

        return x

class HybridModel(nn.Module):
    def __init__(self, input_length=3000, feature_dim=20, num_classes=1):
        super(HybridModel, self).__init__()

        # Raw signal pathway (ResNet)
        self.signal_model = ECGResNet(input_length, 64)  # Output 64 features

        # Hand-crafted feature pathway
        self.feature_layers = nn.Sequential(
            nn.Linear(feature_dim, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(0.3)
        )

        # Combining both pathways
        self.combined_layers = nn.Sequential(
            nn.Linear(64 + 32, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Linear(32, num_classes)
        )

    def forward(self, signal, features):
        # Process raw signal
        signal_features = self.signal_model(signal)

        # Process hand-crafted features
        extracted_features = self.feature_layers(features)

        # Combine both sets of features
        combined = torch.cat((signal_features, extracted_features), dim=1)

        # Final prediction
        output = self.combined_layers(combined)

        return output

# -------------------- 5. LOSS + TRAINING --------------------
class FocalLoss(nn.Module):
    def __init__(self, gamma=2.0, pos_weight=None):
        super().__init__()
        self.gamma = gamma
        self.pos_weight = pos_weight

    def forward(self, inputs, targets):
        # Clip inputs to prevent extreme values
        inputs = torch.clamp(inputs, -50, 50)

        # Using BCE with logits for numerical stability
        bce = nn.functional.binary_cross_entropy_with_logits(
            inputs, targets,
            pos_weight=self.pos_weight,
            reduction='none'
        )

        # Safe exponential calculation
        pt = torch.exp(-torch.clamp(bce, max=50))

        # Calculate focal term with safety checks
        focal_term = (1 - pt + 1e-7) ** self.gamma

        # Return mean of the loss
        loss = focal_term * bce

        # Check for NaN and replace with zero
        loss = torch.where(torch.isnan(loss), torch.zeros_like(loss), loss)

        return loss.mean()

def train_epoch(model, train_loader, criterion, optimizer, device, hybrid=False):
    model.train()
    total_loss = 0
    all_preds, all_labels = [], []

    for data in train_loader:
        if hybrid:
            # Hybrid model with signal and features
            (X_signal, X_features), y = data
            X_signal, X_features = X_signal.to(device), X_features.to(device)
            y = y.to(device).unsqueeze(1)

            optimizer.zero_grad()
            outputs = model(X_signal, X_features)
        else:
            # Regular model with just one input
            X, y = data
            X, y = X.to(device), y.to(device).unsqueeze(1)

            optimizer.zero_grad()
            outputs = model(X)

        loss = criterion(outputs, y)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

        optimizer.step()
        total_loss += loss.item()

        # Collect predictions for metrics
        probs = torch.sigmoid(outputs)
        preds = (probs > 0.5).float()
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(y.cpu().numpy())

    # Calculate training metrics
    acc = np.mean((np.array(all_preds) == np.array(all_labels)).astype(float))

    return total_loss / len(train_loader), acc

def evaluate(model, val_loader, criterion, device, threshold=0.5, hybrid=False):
    model.eval()
    all_preds, all_probs, all_labels = [], [], []
    total_loss = 0

    with torch.no_grad():
        for data in val_loader:
            if hybrid:
                # Hybrid model with signal and features
                (X_signal, X_features), y = data
                X_signal, X_features = X_signal.to(device), X_features.to(device)
                y = y.to(device).unsqueeze(1)

                outputs = model(X_signal, X_features)
            else:
                # Regular model with just one input
                X, y = data
                X, y = X.to(device), y.to(device).unsqueeze(1)

                outputs = model(X)

            loss = criterion(outputs, y)
            probs = torch.sigmoid(outputs)
            preds = (probs > threshold).cpu().numpy()

            all_preds.extend(preds)
            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(y.cpu().numpy())
            total_loss += loss.item()

    # Convert to numpy arrays
    all_preds = np.array(all_preds)
    all_probs = np.array(all_probs)
    all_labels = np.array(all_labels)

    # Calculate metrics
    auc = roc_auc_score(all_labels, all_probs)

    return total_loss / len(val_loader), all_preds, all_labels, auc, all_probs

class HybridDataset(Dataset):
    def __init__(self, X_signal, X_features, y):
        self.X_signal = X_signal
        self.X_features = X_features
        self.y = y

    def __len__(self):
        return len(self.y)

    def __getitem__(self, idx):
        signal = torch.tensor(self.X_signal[idx], dtype=torch.float32)
        features = torch.tensor(self.X_features[idx], dtype=torch.float32)
        label = torch.tensor(self.y[idx], dtype=torch.float32)

        return (signal, features), label

def find_best_threshold(probs, labels):
    """Find threshold that optimizes F1 score"""
    precisions, recalls, thresholds = precision_recall_curve(labels, probs)

    # IMPORTANT: precision_recall_curve returns len(thresholds) = len(precisions) - 1
    # So we need to handle this difference carefully

    # Calculate F1 scores for different thresholds
    f1_scores = []
    for i in range(len(thresholds)):
        if precisions[i] + recalls[i] > 0:  # Avoid division by zero
            f1 = 2 * precisions[i] * recalls[i] / (precisions[i] + recalls[i])
            f1_scores.append((thresholds[i], f1, precisions[i], recalls[i]))

    if not f1_scores:
        print("Warning: No valid F1 scores calculated. Using default threshold of 0.5")
        return 0.5

    # Find threshold with maximum F1 score
    best_threshold, best_f1, best_precision, best_recall = max(f1_scores, key=lambda x: x[1])

    print(f"✅ Best threshold: {best_threshold:.4f} with F1 score: {best_f1:.4f}")
    print(f"✅ At this threshold - Precision: {best_precision:.4f}, Recall: {best_recall:.4f}")

    return best_threshold

def plot_results(val_probs, val_labels, model_name, fold):
    """Plot and save ROC and Precision-Recall curves"""
    plt.figure(figsize=(12, 5))

    # ROC Curve
    plt.subplot(1, 2, 1)
    from sklearn.metrics import roc_curve
    fpr, tpr, _ = roc_curve(val_labels, val_probs)
    plt.plot(fpr, tpr)
    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title(f'ROC Curve (AUC = {roc_auc_score(val_labels, val_probs):.4f})')

    # PR Curve
    plt.subplot(1, 2, 2)
    precision, recall, _ = precision_recall_curve(val_labels, val_probs)
    plt.plot(recall, precision)
    plt.xlabel('Recall')
    plt.ylabel('Precision')
    plt.title(f'PR Curve (AP = {average_precision_score(val_labels, val_probs):.4f})')

    plt.tight_layout()
    plt.savefig(f'{model_name}_fold{fold}_curves.png')
    plt.close()

# -------------------- 6. MAIN TRAINING FUNCTION --------------------
def run_training(X_signal, X_features, y, use_hybrid=True, n_splits=5, epochs=30, batch_size=64):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    results = []
    # Create sampling strategy - balanced for stroke/non-stroke
    sampling_strategy = Pipeline([
        ('oversample', SMOTE(sampling_strategy=0.5, random_state=42)),
        ('undersample', RandomUnderSampler(sampling_strategy=0.8, random_state=42))
    ])

    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)

    all_val_probs = []
    all_val_labels = []

    for fold, (train_idx, test_idx) in enumerate(skf.split(X_signal, y)):
        print(f"\n{'='*50}")
        print(f"🔄 FOLD {fold+1}/{n_splits}")
        print(f"{'='*50}")

        # Split data
        X_train_signal, X_test_signal = X_signal[train_idx], X_signal[test_idx]
        y_train, y_test = y[train_idx], y[test_idx]

        if use_hybrid and X_features is not None:
            X_train_features, X_test_features = X_features[train_idx], X_features[test_idx]

            # Apply sampling strategy
            print(f"Original training set — Class balance: {Counter(y_train)}")
            combined_train = np.hstack([X_train_signal, X_train_features])
            combined_train, y_train = sampling_strategy.fit_resample(combined_train, y_train)
            X_train_signal = combined_train[:, :X_train_signal.shape[1]]
            X_train_features = combined_train[:, X_train_signal.shape[1]:]
            print(f"✅ Resampled training set — Class balance: {Counter(y_train)}")

            # Create hybrid dataloaders
            train_dataset = HybridDataset(X_train_signal, X_train_features, y_train)
            test_dataset = HybridDataset(X_test_signal, X_test_features, y_test)

            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            test_loader = DataLoader(test_dataset, batch_size=batch_size)

            # Create hybrid model
            feature_dim = X_train_features.shape[1]
            model = HybridModel(X_train_signal.shape[1], feature_dim).to(device)
            model_name = "hybrid_model"
        else:
            # Apply sampling strategy to signal data only
            print(f"Original training set — Class balance: {Counter(y_train)}")
            X_train_signal, y_train = sampling_strategy.fit_resample(X_train_signal, y_train)
            print(f"✅ Resampled training set — Class balance: {Counter(y_train)}")

            # Create standard dataloaders
            train_dataset = ECGDataset(X_train_signal, y_train)
            test_dataset = ECGDataset(X_test_signal, y_test)

            train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
            test_loader = DataLoader(test_dataset, batch_size=batch_size)

            # Create ResNet model
            model = ECGResNet(X_train_signal.shape[1]).to(device)
            model_name = "resnet_model"

        # Set up loss function with class weighting
        pos_count = np.sum(y_train)
        neg_count = len(y_train) - pos_count
        pos_weight = torch.tensor([neg_count / pos_count]).to(device)

        criterion = FocalLoss(gamma=2.0, pos_weight=pos_weight)
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-5)
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=10, T_mult=1, eta_min=1e-6)

        # Training tracking
        best_auc = 0
        best_model = None
        best_val_probs = None
        best_val_labels = None
        patience = 5
        patience_counter = 0

        # Training loop
        for epoch in range(epochs):
            # Train
            train_loss, train_acc = train_epoch(
                model, train_loader, criterion, optimizer, device, hybrid=use_hybrid
            )

            # Evaluate
            val_loss, val_preds, val_labels, val_auc, val_probs = evaluate(
                model, test_loader, criterion, device, threshold=0.5, hybrid=use_hybrid
            )

            # Update scheduler
            scheduler.step()

            print(f"Epoch {epoch+1}/{epochs} | "
                  f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | "
                  f"Val Loss: {val_loss:.4f} | Val AUC: {val_auc:.4f}")

            # Save best model
            if val_auc > best_auc:
                best_auc = val_auc
                best_model = copy.deepcopy(model)
                best_val_probs = val_probs
                best_val_labels = val_labels
                patience_counter = 0
                print(f"✅ New best model saved! AUC: {val_auc:.4f}")
            else:
                patience_counter += 1
                if patience_counter >= patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break

        # Store validation results for threshold optimization
        all_val_probs.extend(best_val_probs)
        all_val_labels.extend(best_val_labels)

        # Plot validation results
        plot_results(best_val_probs, best_val_labels, model_name, fold+1)

        # Find best threshold
        best_threshold = find_best_threshold(np.array(best_val_probs), np.array(best_val_labels))

        # Final evaluation with best model and threshold
        _, final_preds, final_labels, final_auc, _ = evaluate(
            best_model, test_loader, criterion, device, threshold=best_threshold, hybrid=use_hybrid
        )

        print("\n📊 Final Evaluation Results:")
        print(f"→ ROC-AUC: {final_auc:.4f}")
        print(f"→ Predicted Positives: {int(final_preds.sum())} / {len(final_preds)}")
        print(f"→ Actual Positives: {int(final_labels.sum())} / {len(final_labels)}")

        # Classification report
        report = classification_report(
            final_labels, final_preds,
            target_names=['Non-Stroke', 'Stroke'],
            output_dict=True,
            zero_division=0
        )

        # Save results
        results.append({
            'fold': fold+1,
            'auc': final_auc,
            'threshold': best_threshold,
            'report': report,
            'model': best_model
        })

        # Save model checkpoint
        torch.save({
            'model_state_dict': best_model.state_dict(),
            'threshold': best_threshold,
            'auc': final_auc,
            'hybrid': use_hybrid,
        }, f'{model_name}_fold{fold+1}.pt')
        print(f"✅ Model saved to {model_name}_fold{fold+1}.pt")

    return results

# -------------------- 7. PREDICTION WITH ENSEMBLE (CONTINUED) --------------------
def predict_with_ensemble(X_signal, X_features=None, model_paths=None, threshold=None):
    """
    Make predictions using an ensemble of models

    Args:
        X_signal: ECG signal data
        X_features: Optional extracted features if using hybrid models
        model_paths: List of paths to model checkpoints
        threshold: Optional prediction threshold (if None, uses saved threshold)

    Returns:
        Average probabilities and binary predictions
    """
    if not model_paths:
        # Default - try to find all models
        model_paths = [f for f in os.listdir('.') if f.endswith('.pt')]
        if not model_paths:
            raise ValueError("No model checkpoints found!")
        print(f"Found {len(model_paths)} model checkpoints")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    all_probs = []

    for model_path in model_paths:
        try:
            checkpoint = torch.load(model_path, map_location=device)
            is_hybrid = checkpoint.get('hybrid', False)

            if is_hybrid and X_features is None:
                print(f"Warning: {model_path} is a hybrid model but no features provided. Skipping.")
                continue

            # Create appropriate model
            if is_hybrid:
                model = HybridModel(X_signal.shape[1], X_features.shape[1]).to(device)
            else:
                model = ECGResNet(X_signal.shape[1]).to(device)

            # Load weights
            model.load_state_dict(checkpoint['model_state_dict'])
            model.eval()

            # Use the model's saved threshold if none provided
            model_threshold = threshold if threshold is not None else checkpoint.get('threshold', 0.5)

            # Prepare data
            if is_hybrid:
                dataset = HybridDataset(X_signal, X_features, np.zeros(len(X_signal)))  # Dummy labels
            else:
                dataset = ECGDataset(X_signal, np.zeros(len(X_signal)))  # Dummy labels

            loader = DataLoader(dataset, batch_size=64)

            # Get predictions
            probs = []
            with torch.no_grad():
                for data in loader:
                    if is_hybrid:
                        (X_sig, X_feat), _ = data
                        X_sig, X_feat = X_sig.to(device), X_feat.to(device)
                        outputs = model(X_sig, X_feat)
                    else:
                        X, _ = data
                        X = X.to(device)
                        outputs = model(X)

                    batch_probs = torch.sigmoid(outputs).cpu().numpy()
                    probs.extend(batch_probs)

            all_probs.append(np.array(probs))
            print(f"✅ Generated predictions with model: {model_path}")

        except Exception as e:
            print(f"Error loading model {model_path}: {e}")
            continue

    if not all_probs:
        raise ValueError("No valid predictions were generated from any model!")

    # Average probabilities from all models
    all_probs = [p.reshape(-1) for p in all_probs]  # Ensure all are 1D arrays
    avg_probs = np.mean(np.array(all_probs), axis=0)

    # Use threshold to get binary predictions
    final_threshold = threshold if threshold is not None else 0.5
    binary_preds = (avg_probs > final_threshold).astype(int)

    print(f"✅ Ensemble predictions complete: {sum(binary_preds)} positives out of {len(binary_preds)}")
    return avg_probs, binary_preds

# -------------------- 8. VISUALIZATION UTILITIES --------------------
def visualize_ecg_with_prediction(ecg_signal, prediction, probability, true_label=None, idx=0):
    """Visualize an ECG signal with its prediction"""
    plt.figure(figsize=(12, 4))
    plt.plot(ecg_signal)

    # Add prediction information
    title = f"ECG #{idx} - Prediction: {'Stroke' if prediction else 'Non-Stroke'} (Prob: {probability:.4f})"
    if true_label is not None:
        title += f" - True: {'Stroke' if true_label else 'Non-Stroke'}"
    plt.title(title)

    # Highlight based on prediction
    if prediction:
        plt.axhspan(min(ecg_signal), max(ecg_signal), alpha=0.2, color='red')

    plt.xlabel("Time")
    plt.ylabel("Amplitude")
    plt.grid(True, alpha=0.3)

    plt.tight_layout()
    return plt.gcf()  # Return the figure

def visualize_feature_importance(model, feature_names):
    """Analyze feature importance for the hybrid model"""
    # Only works for the hybrid model linear layers
    if not hasattr(model, 'feature_layers'):
        print("Feature importance visualization only works with hybrid models")
        return None

    # Extract weights from the first linear layer
    weights = model.feature_layers[0].weight.detach().cpu().numpy()

    # Calculate absolute importance
    importance = np.mean(np.abs(weights), axis=0)

    # Sort features by importance
    sorted_idx = np.argsort(importance)

    # Plot
    plt.figure(figsize=(10, 8))
    plt.barh(range(len(sorted_idx)), importance[sorted_idx])
    plt.yticks(range(len(sorted_idx)), [feature_names[i] for i in sorted_idx])
    plt.xlabel('Feature Importance')
    plt.title('ECG Feature Importance for Stroke Prediction')
    plt.tight_layout()

    plt.savefig('feature_importance.png')
    print("✅ Feature importance plot saved as 'feature_importance.png'")
    return plt.gcf()

# -------------------- 9. EXECUTION --------------------
if __name__ == "__main__":
    print("🚀 Starting Enhanced ECG Stroke Prediction System")

    # 1. Load raw data
    ecg_data = load_ecg_data()
    diagnosis_df = load_diagnosis_data("records_w_diag_icd10 (1).csv", len(ecg_data))
    diagnosis_df = create_stroke_labels(diagnosis_df)

    # 2. Print class distribution
    print("\nClass Distribution:")
    class_counts = diagnosis_df["stroke_label"].value_counts()
    print(class_counts)
    print(f"Stroke prevalence: {class_counts[1]/len(diagnosis_df):.4%}")

    # 3. Preprocess ECG data
    ecg_data_processed = preprocess_ecg(ecg_data, wavelet_transform=True)

    # 4. Extract features
    ecg_features = extract_ecg_features(ecg_data_processed)

    # 5. Run cross-validation with the hybrid model
    print("\n🔄 Starting cross-validation training with hybrid model...")
    hybrid_results = run_training(
        ecg_data_processed,
        ecg_features,
        diagnosis_df["stroke_label"].values,
        use_hybrid=True,
        n_splits=5,
        epochs=10,
        batch_size=64
    )

    # 6. Run cross-validation with the ResNet model only for comparison
    print("\n🔄 Starting cross-validation training with ResNet model only...")
    resnet_results = run_training(
        ecg_data_processed,
        None,  # No features
        diagnosis_df["stroke_label"].values,
        use_hybrid=False,
        n_splits=5,
        epochs=10,
        batch_size=64
    )

    # 7. Compare results
    print("\n📊 Comparing Model Performance:")
    print("\nHybrid Model Results:")
    hybrid_aucs = []
    for res in hybrid_results:
        print(f"Fold {res['fold']} - AUC: {res['auc']:.4f}, Best Threshold: {res['threshold']:.4f}")
        hybrid_aucs.append(res['auc'])

        if 'Stroke' in res['report']:
            print(f"Stroke Precision: {res['report']['Stroke']['precision']:.4f}")
            print(f"Stroke Recall: {res['report']['Stroke']['recall']:.4f}")
            print(f"Stroke F1 Score: {res['report']['Stroke']['f1-score']:.4f}")
        elif '1' in res['report']:
            print(f"Stroke Precision: {res['report']['1']['precision']:.4f}")
            print(f"Stroke Recall: {res['report']['1']['recall']:.4f}")
            print(f"Stroke F1 Score: {res['report']['1']['f1-score']:.4f}")

    print(f"\nHybrid Model Average AUC: {np.mean(hybrid_aucs):.4f}")

    print("\nResNet Model Results:")
    resnet_aucs = []
    for res in resnet_results:
        print(f"Fold {res['fold']} - AUC: {res['auc']:.4f}, Best Threshold: {res['threshold']:.4f}")
        resnet_aucs.append(res['auc'])

        if 'Stroke' in res['report']:
            print(f"Stroke Precision: {res['report']['Stroke']['precision']:.4f}")
            print(f"Stroke Recall: {res['report']['Stroke']['recall']:.4f}")
            print(f"Stroke F1 Score: {res['report']['Stroke']['f1-score']:.4f}")
        elif '1' in res['report']:
            print(f"Stroke Precision: {res['report']['1']['precision']:.4f}")
            print(f"Stroke Recall: {res['report']['1']['recall']:.4f}")
            print(f"Stroke F1 Score: {res['report']['1']['f1-score']:.4f}")

    print(f"\nResNet Model Average AUC: {np.mean(resnet_aucs):.4f}")

    # 8. Create and save final ensemble model prediction
    hybrid_models = [f for f in os.listdir('.') if f.startswith('hybrid_model') and f.endswith('.pt')]
    resnet_models = [f for f in os.listdir('.') if f.startswith('resnet_model') and f.endswith('.pt')]

    # Use the model with better performance for final predictions
    if np.mean(hybrid_aucs) > np.mean(resnet_aucs):
        print("\n🔍 Using Hybrid models for final ensemble prediction")
        ensemble_probs, ensemble_preds = predict_with_ensemble(
            ecg_data_processed,
            X_features=ecg_features,
            model_paths=hybrid_models
        )
        best_model_type = "hybrid"
    else:
        print("\n🔍 Using ResNet models for final ensemble prediction")
        ensemble_probs, ensemble_preds = predict_with_ensemble(
            ecg_data_processed,
            model_paths=resnet_models
        )
        best_model_type = "resnet"

    # 9. Save final predictions to CSV
    predictions_df = pd.DataFrame({
        'ecg_id': range(len(ecg_data_processed)),
        'stroke_probability': ensemble_probs,
        'stroke_prediction': ensemble_preds
    })

    if 'stroke_label' in diagnosis_df.columns:
        predictions_df['true_label'] = diagnosis_df['stroke_label'].values

    predictions_df.to_csv('stroke_predictions.csv', index=False)
    print("✅ Final predictions saved to 'stroke_predictions.csv'")

    # 10. Visualize some examples
    print("\n🖼️ Generating example visualizations...")
    if 'true_label' in predictions_df.columns:
        # Plot some true positives, false positives, true negatives, false negatives
        tp_idx = predictions_df[(predictions_df['stroke_prediction'] == 1) &
                               (predictions_df['true_label'] == 1)].index[:3]
        fp_idx = predictions_df[(predictions_df['stroke_prediction'] == 1) &
                               (predictions_df['true_label'] == 0)].index[:3]
        tn_idx = predictions_df[(predictions_df['stroke_prediction'] == 0) &
                               (predictions_df['true_label'] == 0)].index[:3]
        fn_idx = predictions_df[(predictions_df['stroke_prediction'] == 0) &
                               (predictions_df['true_label'] == 1)].index[:3]

        for i, idx in enumerate(tp_idx):
            fig = visualize_ecg_with_prediction(
                ecg_data_processed[idx], 1,
                predictions_df.loc[idx, 'stroke_probability'],
                true_label=1, idx=idx
            )
            fig.savefig(f'example_tp_{i}.png')

        for i, idx in enumerate(fp_idx):
            fig = visualize_ecg_with_prediction(
                ecg_data_processed[idx], 1,
                predictions_df.loc[idx, 'stroke_probability'],
                true_label=0, idx=idx
            )
            fig.savefig(f'example_fp_{i}.png')

        for i, idx in enumerate(fn_idx):
            fig = visualize_ecg_with_prediction(
                ecg_data_processed[idx], 0,
                predictions_df.loc[idx, 'stroke_probability'],
                true_label=1, idx=idx
            )
            fig.savefig(f'example_fn_{i}.png')

    # 11. If hybrid model was better, visualize feature importance
    if best_model_type == "hybrid":
        print("\n📊 Visualizing feature importance...")

        # Define feature names
        feature_names = [
            "Mean", "Std Dev", "Min Value", "Max Value", "Peak-to-Peak",
            "RMS", "Heart Rate", "HR Variability", "LF Power", "HF Power",
            "LF/HF Ratio", "Dominant Freq", "Wavelet Energy L1", "Wavelet Energy L2",
            "Wavelet Energy L3", "Wavelet Energy L4", "Wavelet Energy L5", "Wavelet Energy L6"
        ]

        # Pad with additional features if needed
        if len(feature_names) < ecg_features.shape[1]:
            for i in range(len(feature_names), ecg_features.shape[1]):
                feature_names.append(f"Feature_{i+1}")

        # Visualize feature importance for the first hybrid model
        best_model = torch.load(hybrid_models[0], map_location=device)['model']
        visualize_feature_importance(best_model, feature_names)

    print("\n✅ ECG Stroke Prediction System Completed Successfully")

🚀 Starting Enhanced ECG Stroke Prediction System
✅ Loaded ECG data from .npy file
Shape: (86596, 3000)

🔍 Loading diagnosis data from records_w_diag_icd10 (1).csv
Aligned to 86596 rows to match ECG data

Class Distribution:
stroke_label
0    83366
1     3230
Name: count, dtype: int64
Stroke prevalence: 3.7300%
Preprocessing ECG data...
Applying wavelet denoising...
Standardizing signals...
✅ ECG preprocessing complete
Extracting ECG features...


100%|██████████| 86596/86596 [00:43<00:00, 2013.00it/s]


✅ Extracted 18 features for each ECG

🔄 Starting cross-validation training with hybrid model...
Using device: cuda

🔄 FOLD 1/5
Original training set — Class balance: Counter({np.int64(0): 66692, np.int64(1): 2584})
✅ Resampled training set — Class balance: Counter({np.int64(0): 41682, np.int64(1): 33346})
Epoch 1/10 | Train Loss: 0.1223 | Train Acc: 0.8474 | Val Loss: 0.1527 | Val AUC: 0.5416
✅ New best model saved! AUC: 0.5416
Epoch 2/10 | Train Loss: 0.1050 | Train Acc: 0.8773 | Val Loss: 0.0678 | Val AUC: 0.5440
✅ New best model saved! AUC: 0.5440
Epoch 3/10 | Train Loss: 0.0989 | Train Acc: 0.8866 | Val Loss: 0.3799 | Val AUC: 0.5407
Epoch 4/10 | Train Loss: 0.0947 | Train Acc: 0.8932 | Val Loss: 0.0651 | Val AUC: 0.5632
✅ New best model saved! AUC: 0.5632
Epoch 5/10 | Train Loss: 0.0903 | Train Acc: 0.9009 | Val Loss: 0.0991 | Val AUC: 0.5567
Epoch 6/10 | Train Loss: 0.0863 | Train Acc: 0.9059 | Val Loss: 0.0899 | Val AUC: 0.5744
✅ New best model saved! AUC: 0.5744
Epoch 7/10 | Tr