<a href="https://colab.research.google.com/github/abdulkader902017/Brain-Hemorrhage-Dataset/blob/main/ecgbert.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
import wfdb
import numpy as np
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertModel
from sklearn.metrics import classification_report, accuracy_score

# Path to the data directory
from google.colab import files

uploaded = files.upload()
#uploaded = files.upload()
!unzip /content/mit-bih-arrhythmia-database-1.0.0.zip
data_dir = '/content/mit-bih-arrhythmia-database-1.0.0'  # Update path if necessary# Mapping of beat types to classes
beat_type_to_class = {
    'N': 0,  # Normal
    'L': 4,  # Unknown
    'R': 4,  # Unknown
    'A': 1,  # SVEB
    'a': 1,  # SVEB
    'J': 1,  # SVEB
    'S': 1,  # SVEB
    'V': 2,  # VEB
    'E': 2,  # VEB
    'F': 3,  # Fusion
    '/': 4,  # Unknown (paced beat)
    'Q': 4,  # Unknown
    'f': 4,  # Fusion of paced and normal
}



Saving mit-bih-arrhythmia-database-1.0.0.zip to mit-bih-arrhythmia-database-1.0.0.zip
Archive:  /content/mit-bih-arrhythmia-database-1.0.0.zip
   creating: mit-bih-arrhythmia-database-1.0.0/
  inflating: mit-bih-arrhythmia-database-1.0.0/100.atr  
  inflating: mit-bih-arrhythmia-database-1.0.0/100.dat  
  inflating: mit-bih-arrhythmia-database-1.0.0/100.hea  
  inflating: mit-bih-arrhythmia-database-1.0.0/100.xws  
  inflating: mit-bih-arrhythmia-database-1.0.0/101.atr  
  inflating: mit-bih-arrhythmia-database-1.0.0/101.dat  
  inflating: mit-bih-arrhythmia-database-1.0.0/101.hea  
  inflating: mit-bih-arrhythmia-database-1.0.0/101.xws  
  inflating: mit-bih-arrhythmia-database-1.0.0/102-0.atr  
  inflating: mit-bih-arrhythmia-database-1.0.0/102.atr  
  inflating: mit-bih-arrhythmia-database-1.0.0/102.dat  
  inflating: mit-bih-arrhythmia-database-1.0.0/102.hea  
  inflating: mit-bih-arrhythmia-database-1.0.0/102.xws  
  inflating: mit-bih-arrhythmia-database-1.0.0/103.atr  
  inflati

In [2]:
f#rom transformers import DistilBertTokenizer, BertModel
#!pip install wfdb

Collecting wfdb
  Downloading wfdb-4.1.2-py3-none-any.whl.metadata (4.3 kB)
Downloading wfdb-4.1.2-py3-none-any.whl (159 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m160.0/160.0 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: wfdb
Successfully installed wfdb-4.1.2


In [4]:
import os
# Function to process a single record
def process_record(record_id):
    record = wfdb.rdrecord(os.path.join(data_dir, record_id))
    annotation = wfdb.rdann(os.path.join(data_dir, record_id), 'atr')

    ecg_signal = record.p_signal[:, 0]  # Access the first channel (MLII lead)
    beat_indices = annotation.sample
    beat_types = annotation.symbol

    # Map beat types to classes
    labels = [beat_type_to_class.get(bt, 4) for bt in beat_types]

    return ecg_signal, beat_indices, labels

# Process all records in the directory
all_ecg_signals = []
all_labels = []

records = [f.split('.')[0] for f in os.listdir(data_dir) if f.endswith('.dat')]

for record_id in records:
    ecg_signal, beat_indices, labels = process_record(record_id)
    all_ecg_signals.append((ecg_signal, beat_indices))
    all_labels.append(labels)
    print(f"Processed record {record_id}")

class ECGTextDataset(Dataset):
    def __init__(self, ecg_signals, labels, tokenizer, max_len=128, segment_length=100):
        self.ecg_signals = ecg_signals
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len
        self.segment_length = segment_length

    def __len__(self):
        return sum(len(lbls) for lbls in self.labels)

    def __getitem__(self, idx):
        cumulative_length = 0
        for ecg_signal, label_set in zip(self.ecg_signals, self.labels):
            if idx < cumulative_length + len(label_set):
                start_idx = ecg_signal[1][idx - cumulative_length] - 100
                end_idx = ecg_signal[1][idx - cumulative_length] + 100
                signal_segment = ecg_signal[0][start_idx:end_idx]

                if len(signal_segment) < self.segment_length:
                    signal_segment = np.pad(signal_segment, (0, self.segment_length - len(signal_segment)), 'constant')
                elif len(signal_segment) > self.segment_length:
                    signal_segment = signal_segment[:self.segment_length]

                signal_segment = torch.tensor(signal_segment, dtype=torch.float32).unsqueeze(0)
                label = label_set[idx - cumulative_length]
                break
            cumulative_length += len(label_set)

        text_description = f"Class {label}"

        encoding = self.tokenizer.encode_plus(
            text_description,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=False,
            return_attention_mask=True,
            return_tensors='pt',
        )

        input_ids = encoding['input_ids'].flatten()
        attention_mask = encoding['attention_mask'].flatten()

        return signal_segment, input_ids, attention_mask, torch.tensor(label, dtype=torch.long)

# Instantiate tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
# Create dataset
dataset = ECGTextDataset(ecg_signals=all_ecg_signals, labels=all_labels, tokenizer=tokenizer)

# Create dataloader
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

class ECGModel(nn.Module):
    def __init__(self):
        super(ECGModel, self).__init__()
        self.conv1 = nn.Conv1d(1, 32, kernel_size=5, stride=1, padding=2)
        self.conv2 = nn.Conv1d(32, 64, kernel_size=5, stride=1, padding=2)
        self.pool = nn.MaxPool1d(2, 2)
        self.fc1 = nn.Linear(64 * (200 // 4), 128)  # Adjusted based on pooling

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 64 * (200 // 4))  # Adjusted based on pooling
        x = F.relu(self.fc1(x))
        return x


Processed record 108
Processed record 102
Processed record 214
Processed record 205
Processed record 113
Processed record 203
Processed record 228
Processed record 202
Processed record 123
Processed record 231
Processed record 233
Processed record 118
Processed record 230
Processed record 115
Processed record 101
Processed record 121
Processed record 222
Processed record 111
Processed record 217
Processed record 112
Processed record 200
Processed record 117
Processed record 116
Processed record 107
Processed record 232
Processed record 124
Processed record 119
Processed record 220
Processed record 215
Processed record 104
Processed record 103
Processed record 234
Processed record 201
Processed record 213
Processed record 207
Processed record 209
Processed record 122
Processed record 210
Processed record 221
Processed record 212
Processed record 109
Processed record 208
Processed record 100
Processed record 106
Processed record 105
Processed record 114
Processed record 219
Processed rec

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]



In [5]:
class BERTEmbedder(nn.Module):
    def __init__(self):
        super(BERTEmbedder, self).__init__()
        self.bert = DistilBertModel.from_pretrained('distilbert-base-uncased')  # Use DistilBERT

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.last_hidden_state[:, 0, :]  # Use the first token's embedding

In [6]:
class TextModel(nn.Module):
    def __init__(self):
        super(TextModel, self).__init__()
        self.bert_embedder = BERTEmbedder()  # Using the updated BERTEmbedder
        self.fc = nn.Linear(self.bert_embedder.bert.config.dim, 128) #Updated to dim

    def forward(self, input_ids, attention_mask):
        embeddings = self.bert_embedder(input_ids, attention_mask)
        output = self.fc(embeddings)
        return output

class MultimodalModel(nn.Module):
    def __init__(self, ecg_model, text_model):
        super(MultimodalModel, self).__init__()
        self.ecg_model = ecg_model
        self.text_model = text_model
        # Accessing hidden_size from text_model.bert.config
        self.fc = nn.Linear(128 + self.text_model.bert.config.hidden_size, 5)

    def forward(self, ecg_signal, input_ids, attention_mask):
        ecg_features = self.ecg_model(ecg_signal)
        text_embeddings = self.text_model(input_ids, attention_mask)

        # Pad the shorter sequence to match the longer one
        batch_size_ecg = ecg_features.shape[0]
        batch_size_text = text_embeddings.shape[0]

        if batch_size_ecg < batch_size_text:
            # Pad ECG features
            padding_size = (0, 0, 0, batch_size_text - batch_size_ecg)  # Pad along the batch dimension
            ecg_features = F.pad(ecg_features, padding_size)
        elif batch_size_text < batch_size_ecg:
            # Pad text embeddings
            padding_size = (0, 0, 0, batch_size_ecg - batch_size_text)  # Pad along the batch dimension
            text_embeddings = F.pad(text_embeddings, padding_size)

        combined_features = torch.cat((ecg_features, text_embeddings), dim=1)
        output = self.fc(combined_features)
        return output

In [7]:
def train_model(dataloader, model, criterion, optimizer, epochs=10):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for ecg_signal, input_ids, attention_mask, label in dataloader:
            # Forward pass
            output = model(ecg_signal, input_ids, attention_mask)
            loss = criterion(output, label)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(dataloader)
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss}')

In [1]:
# Instantiate models
from transformers import DistilBertTokenizer, DistilBertModel

ecg_model = ECGModel()
text_model = BERTEmbedder()
fc_layer = nn.Linear(text_model.bert.config.hidden_size, 128)

model = MultimodalModel(ecg_model, text_model)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
train_model(dataloader, model, criterion, optimizer, epochs=10)

# Evaluate the model
evaluate_model(dataloader, model)

NameError: name 'ECGModel' is not defined