In [458]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.utils.data import Dataset,DataLoader
from data import load_traindata
device = 'mps' if torch.backends.mps.is_available() else ('cuda' if torch.cuda.is_available() else 'cpu')
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import random_split
from torch.optim.lr_scheduler import LambdaLR

In [459]:
torch.manual_seed(42)

<torch._C.Generator at 0x110f8a470>

In [460]:
split_size = 0.8
batch_size = 32
num_subclasses = 500
learning_rate = 3e-4
epochs = 1000
test_epochs = 10
seq_size = 500
n_block = 4
channel_size = 12
num_aug = 5

In [461]:
# Encoder function (One-hot encoding)
def one_hot_encode(labels, unique_labels):
    # Create a mapping from unique labels to indices
    label_to_index = {label.item(): idx for idx, label in enumerate(unique_labels)}
    # Convert the labels to indices based on the mapping
    indices = torch.tensor([label_to_index[label.item()] for label in labels])
    # Create the one-hot encoded tensor
    return torch.eye(len(unique_labels))[indices]

# Decoder function (Converts one-hot back to original labels)
def one_hot_decode(one_hot, unique_labels):
    # Get the index of the '1' in the one-hot vector
    index = torch.argmax(one_hot)
    return torch.tensor(unique_labels[index])

In [462]:
def time_masking(ecg_data, w=0.05):
    if not (0 <= w <= 1):
        raise ValueError("w must be between 0 and 1.")
    T = ecg_data.shape[0]
    mask_length = int(w * T)
    ts = np.random.randint(0, T - mask_length + 1)
    ecg_data[ts:ts + mask_length, :] = 0
    return ecg_data

In [463]:
def time_shifting(ecg_data, w=0.08):
    if not (0 <= w <= 1):
        raise ValueError("w must be between 0 and 1.")
    T = ecg_data.shape[0]
    shift_length = int(w * T)
    # Perform a circular shift (rotation)
    ecg_data = np.roll(ecg_data, shift=shift_length, axis=0)
    return ecg_data

In [482]:
def augment(X, Y, num_masks=num_aug):
    augmented_X = []
    augmented_Y = []
    
    for i in range(len(X)):
        ecg_data = X[i].clone()  # To avoid modifying the original data
        label = Y[i]
        
        # Append the original data
        augmented_X.append(ecg_data)
        augmented_Y.append(label)
        
        # Apply rotation num_masks times and append to augmented dataset
        for _ in range(num_masks):
            augmented_data = time_shifting(ecg_data.numpy())  # Apply time masking to numpy array
            masked_data = time_masking(ecg_data.numpy())
            augmented_X.append(torch.from_numpy(augmented_data))  # Convert back to torch tensor
            augmented_Y.append(label)
            augmented_X.append(torch.from_numpy(masked_data))  # Convert back to torch tensor
            augmented_Y.append(label)
    
    # Convert augmented lists back to tensors
    augmented_X = torch.stack(augmented_X)
    augmented_Y = torch.tensor(augmented_Y)
    
    return augmented_X, augmented_Y

In [483]:
#load dataset
X,Y = load_traindata(num_subclasses)
X = np.array(X)
X = torch.from_numpy(X)
unique_labels = Y
#X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.3, random_state=42)

X_train = X[:, :int(5000*split_size), :]
X_test = X[:, int(5000*split_size):, :]
Y_train = Y
Y_test = Y
X_aug, Y_aug = augment(X_train,Y_train)

X_test = X_test.reshape(num_subclasses, int(5000 * round(1 - split_size, 2) / seq_size), seq_size, 12)
X_test = X_test.reshape(num_subclasses * int(5000 * round(1 - split_size, 2) / seq_size), seq_size, 12)
Y_test = one_hot_encode(Y_test, unique_labels)  # One-hot encode Y
Y_test = Y_test.unsqueeze(1).repeat(1, int(5000 * round(1 - split_size, 2) / seq_size), 1).view(-1, num_subclasses)


X_aug = X_aug.view(num_subclasses*(2*num_aug + 1), int(5000*split_size/seq_size),seq_size,12) #reshape after split
X_aug = X_aug.view(num_subclasses*((2*num_aug + 1))*(int(5000*split_size/seq_size)),seq_size,12)
Y_aug = one_hot_encode(Y_aug, unique_labels)  # One-hot encode Y
Y_aug = Y_aug.unsqueeze(1).repeat(1, (int(5000*split_size/seq_size)), 1).view(-1, num_subclasses)

In [484]:
assert(X_aug.shape[0]==Y_aug.shape[0])
assert(X_test.shape[0]==Y_test.shape[0])

In [485]:
class ECGDataset(Dataset):
    def __init__(self,X,Y):  
        self.X = X
        self.Y = Y

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return self.X[:,:,1][idx], self.Y[idx] #second lead
        return self.X[idx], self.Y[idx]

In [486]:
def plotWave(X):
   # Create a common time axis for all leads
    time = torch.arange(0, X.size(0))

    # Create subplots (12 rows, 1 column)
    fig, axes = plt.subplots(12, 1, figsize=(10, 24), sharex=True, sharey=True)
    fig.suptitle('12 Lead ECG Report', fontsize=16)

    # Plot each lead in a separate subplot
    for i in range(12):
        axes[i].plot(time, X[:, i].numpy(), color='b')
        axes[i].set_ylabel(f'Lead {i+1}')
        axes[i].grid(True)

    # Set common X-axis label for time
    axes[11].set_xlabel('Time')

    # Adjust layout
    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()

In [487]:
testing_data = ECGDataset(X_test,Y_test)
dataset = ECGDataset(X_aug,Y_aug)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(testing_data, batch_size=batch_size, shuffle=True)

In [488]:
#CONV1D model for 2nd Lead
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv1d(1, 4, 8, padding=3)
        self.bn1 = nn.BatchNorm1d(4)

        self.conv2 = nn.Conv1d(4, 8, 4, padding=1)
        self.bn2 = nn.BatchNorm1d(8)

        self.conv3 = nn.Conv1d(8, 4, 1)
        self.bn3 = nn.BatchNorm1d(4)
        self.maxPool = nn.MaxPool1d(2) # (16, 45, 0)
        self.flat = nn.Flatten()
        #self.linearLayer = nn.Linear(16 * 45 * 3, 128)
        transformerEncoderLayer = nn.TransformerEncoderLayer(d_model=196, nhead=7)
        self.transformer = nn.TransformerEncoder(transformerEncoderLayer, num_layers=4)
        self.fnn1 = nn.Linear(196, 128)
        self.lreul = nn.LeakyReLU(0.2)
        self.xnorm1 = nn.LayerNorm(128)
        self.fnn2 = nn.Linear(128, num_subclasses)

    def forward(self, x):
        print(x.shape)
        x = x.unsqueeze(1)  # Add channel dimension at index 1
        x = self.conv1(x)
        x = self.bn1(x)
        x = torch.relu(x)


        x = self.conv2(x)
        x = self.bn2(x)
        x = torch.relu(x)


        x = self.conv3(x)
        x = self.bn3(x)
        x = torch.relu(x)

        x = self.maxPool(x)
        x = self.flat(x)

        x = x.unsqueeze(0)
        x = self.transformer(x)
        
        x = x.squeeze(0) 
        x = self.fnn1(x)
        x = self.lreul(x)
        x = self.xnorm1(x)
        x = self.fnn2(x)
        return x


#CONV2D model for 12 lead
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 4, 4) # (8, 96, 8)
        self.bn1 = nn.BatchNorm2d(4)
        
        self.conv2 = nn.Conv2d(4, 4, 3) # (16, 93, 5)
        self.bn2 = nn.BatchNorm2d(4)

        self.conv3 = nn.Conv2d(4, 2, 1)  # (16, 91, 3)
        self.bn3 = nn.BatchNorm2d(2)
        self.maxPool = nn.MaxPool2d(2) # (16, 45, 0)
        self.flat = nn.Flatten() 
        #self.linearLayer = nn.Linear(16 * 45 * 3, 128)
        transformerEncoderLayer = nn.TransformerEncoderLayer(d_model=246, nhead=6)
        self.transformer = nn.TransformerEncoder(transformerEncoderLayer, num_layers=4)
        self.fnn1 = nn.Linear(246, 128)
        self.lreul = nn.LeakyReLU(0.2)
        self.xnorm1 = nn.LayerNorm(128)
        self.fnn2 = nn.Linear(128, num_subclasses)

    def forward(self, x):
        x = x.unsqueeze(1)  # Add channel dimension at index 1
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.maxPool(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.maxPool(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.flat(x)
        #x = self.linearLayer(x)
        x = x.unsqueeze(0) 
        x = self.transformer(x)
        x = x.squeeze(0) 
        x = self.fnn1(x)
        x = self.lreul(x)
        x = self.xnorm1(x)
        x = self.fnn2(x)
        return x

In [592]:
model = Model()
m = model.to(device)
print(sum(p.numel() for p in m.parameters())/1e6, 'M parameters')
losses = []

1.267532 M parameters


In [594]:
def warmup_then_decay(epoch, warmup_epochs=10, total_epochs=epochs):
    if epoch < warmup_epochs:
        return epoch / warmup_epochs  # Linear warm-up
    else:
        decay_factor = 0.1 ** ((epoch - warmup_epochs) / (total_epochs - warmup_epochs))
        return decay_factor
    

def getaccuracy(loader):
    # Assuming the model and test_loader have been defined
    # model.eval() switches the model to evaluation mode
    model.eval()
    # Initialize variables to track correct predictions and total predictions
    correct = 0
    total = 0
    # Disable gradient computation during evaluation
    with torch.no_grad():
        # Loop over the test dataset
        for data, labels in loader:
            # Move data to the appropriate device (if using CUDA)
            data,labels = data.to(torch.float32), labels.to(torch.float32)
            data, labels = data.to(device), labels.to(device)
            # Get model predictions
            outputs = model(data)
            _, true_labels = torch.max(labels, 1)
            # Get the predicted class by taking the argmax (class with highest score)
            _, predicted = torch.max(outputs, 1)
            # Update the total number of samples and correct predictions
            total += labels.size(0)
            correct += (predicted == true_labels).sum().item()

    # Calculate accuracy
    accuracy = 100 * correct / total
    if(loader==dataloader):
        print(f'Accuracy on training dataset: {accuracy:.2f}%')
    else:
        print(f'Accuracy on test dataset: {accuracy:.2f}%')


In [595]:
optimizer = torch.optim.AdamW(m.parameters(), lr=learning_rate,weight_decay=1e-2)
scheduler = LambdaLR(optimizer, lr_lambda=lambda epoch: warmup_then_decay(epoch))
for epoch in range(epochs):
    if epoch % test_epochs == 0:
        print(f'Epoch {epoch}, loss: {loss.item()}')
        getaccuracy(dataloader)
        getaccuracy(test_loader)
    for x, y in dataloader:
        x,y = x.to(torch.float32), y.to(torch.float32)
        x,y = x.to(device), y.to(device)
        optimizer.zero_grad(set_to_none=True)
        out = m(x)
        loss = F.cross_entropy(out, y)
        loss.backward()
        optimizer.step()
    losses.append(loss.item())
    scheduler.step()

Epoch 0, loss: 0.28317350149154663


AssertionError: was expecting embedding dimension of 128, but got 490