# CardioXplainAI - ECG Abnormality Detection Demo
This notebook demonstrates ECG preprocessing, model training, evaluation, and attention-based explainability.

In [1]:

import numpy as np
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
from preprocess import load_and_segment_ecg_data
from model import RhythmTCN_GRUAttNet
from train import train_model
from evaluate import evaluate_model, plot_confusion, print_summary
from explainability import extract_attention_weights, plot_ecg_with_attention


## Step 1: Load and Preprocess ECG Data

In [2]:

data_dir = r'C:\Users\krake\Downloads\mizoram\mit-bih-arrhythmia-database-1.0.0'  # directory containing .mat ECG files
segments, labels = load_and_segment_ecg_data(data_dir)


 10%|█         | 5/49 [00:00<00:02, 20.84it/s]

Error processing 102-0: [Errno 2] No such file or directory: 'C:/Users/krake/Downloads/mizoram/mit-bih-arrhythmia-database-1.0.0/102-0.hea'


100%|██████████| 49/49 [00:03<00:00, 15.74it/s]


Extracted 100814 segments.


## Step 2: Prepare Dataset and DataLoader

In [3]:

X_tensor = torch.tensor(segments, dtype=torch.float32)
y_tensor = torch.tensor(labels, dtype=torch.long)
dataset = TensorDataset(X_tensor, y_tensor)

train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)


## Step 3: Initialize RhythmTCN-GRUAttNet Model

In [4]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = RhythmTCN_GRUAttNet(input_size=360, num_classes=len(set(labels))).to(device)


  WeightNorm.apply(module, name, dim)


In [10]:
import torch
import torch.nn as nn
import torch.optim as optim


In [11]:
import torch
import torch.nn as nn
import torch.optim as optim

def train_model(model, train_loader, val_loader, num_epochs=10, lr=0.001):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            inputs, labels = inputs.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{num_epochs}, Loss: {avg_loss:.4f}")


## Step 4: Train the Model

In [12]:

train_model(model, train_loader, val_loader, num_epochs=10, lr=0.001)



KeyboardInterrupt: 

## Step 5: Evaluate the Model

In [None]:

class_names = [f"Class {i}" for i in sorted(set(labels))]
report, conf_matrix, roc_auc = evaluate_model(model, val_loader, class_names)
print_summary(report, roc_auc)
plot_confusion(conf_matrix, class_names)


## Step 6: Visualize Explainability via Attention

In [None]:

sample_index = 0
sample_ecg = X_tensor[sample_index]
att_weights = extract_attention_weights(model, sample_ecg.to(device))
plot_ecg_with_attention(sample_ecg.numpy(), att_weights, title="Attention Map on ECG Signal")
