In [1]:
# ---- Nested function ---- #
import os
import pandas as pd
import wfdb
import ast
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import gc
from pprint import pprint
from collections import Counter
import math
from copy import deepcopy
import random

# # ---- BWR ---- #
# import bwr
# import emd
# import pywt
# ---- Scipy ---- #
from scipy import signal
from scipy.signal import butter, lfilter, freqz, filtfilt
from scipy.fftpack import fft
from scipy.signal import find_peaks
from scipy.interpolate import interp1d


# ---- PyTorch ---- #
import torch
import torchvision
from torch import nn
from torch import optim
from torch import functional as F
from torch.utils.data import TensorDataset, DataLoader, Dataset
from torchvision.transforms import ToTensor
from torch.nn.functional import softmax
from torch.nn.parallel import DistributedDataParallel
from pytorchtools import EarlyStopping
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange, Reduce
import torchvision.ops as ops

# ---- Scikit Learn ---- #
from sklearn.preprocessing import MinMaxScaler, RobustScaler
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import Normalizer, StandardScaler
from sklearn.model_selection import KFold


# ---- Matplotlib ---- #
import matplotlib.pyplot as plt
import seaborn as sns

# ---- Summary ---- #
import pytorch_model_summary



In [2]:
x_train = np.load("/data/graduate/MI_Detection_Transformer/npy_data/x_train.npy")
y_train = np.load("/data/graduate/MI_Detection_Transformer/npy_data/y_train.npy")
x_valid = np.load("/data/graduate/MI_Detection_Transformer/npy_data/x_valid.npy")
y_valid = np.load("/data/graduate/MI_Detection_Transformer/npy_data/y_valid.npy")
x_test = np.load("/data/graduate/MI_Detection_Transformer/npy_data/x_test.npy")
y_test = np.load("/data/graduate/MI_Detection_Transformer/npy_data/y_test.npy")

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [9]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
import seaborn as sns

class ViTModel(nn.Module):
    def __init__(self, input_size, num_classes, patch_size, hidden_size, num_heads, num_layers):
        super(ViTModel, self).__init__()

        self.patch_embedding = nn.Conv1d(input_size, hidden_size, kernel_size=patch_size, stride=patch_size)
        num_patches = (500 * 10) // patch_size

        self.positional_embedding = nn.Parameter(torch.randn(1, num_patches, hidden_size))
        self.norm = nn.LayerNorm(hidden_size)
        transformer_layer = nn.TransformerEncoderLayer(d_model=hidden_size, nhead=num_heads, dim_feedforward=256, dropout=0.2, activation="gelu",norm_first=True)
        self.transformer_encoder = nn.TransformerEncoder(transformer_layer, num_layers=num_layers, norm=self.norm)

        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        # x: (batch_size, num_channels, sequence_length)

        # Patch Embedding
        x = self.patch_embedding(x)
        x = x.permute(0, 2, 1)  # (batch_size, num_patches, hidden_size)
        
        # Add positional encoding
        x += self.positional_embedding

        # Transformer Encoder
        x = self.transformer_encoder(x,)

        # Global Average Pooling
        x = x.mean(dim=1)  # (batch_size, hidden_size)

        # Classifier Head
        x = self.fc(x)  # (batch_size, num_classes)

        return x, x   # For visualization purposes, returning x twice

# Hyperparameters
input_size = 12
num_classes = 5
patch_size = 20
hidden_size = 768
num_heads = 6
num_layers = 6
learning_rate = 1e-4
num_epochs = 100

# Create ViT model
vit_model = ViTModel(input_size, num_classes, patch_size, hidden_size, num_heads, num_layers)
vit_model.to(device)
# Define loss function and optimizer
criterion = nn.MultiLabelSoftMarginLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=learning_rate)


def get_DataLoader(x, y, batch, num_workers, shuffle=False):
    x_tensor = torch.FloatTensor(x)
    y_tensor = torch.FloatTensor(y)
    dataset = TensorDataset(x_tensor, y_tensor)
    dataloader = DataLoader(dataset, batch_size=batch, num_workers=num_workers, shuffle=shuffle)
    return dataloader

batch_size = 64
num_workers = 2

train_loader = get_DataLoader(x_train, y_train, batch=batch_size, num_workers=num_workers, shuffle=False)
val_loader = get_DataLoader(x_valid, y_valid, batch=16, num_workers=num_workers, shuffle=False)
test_loader = get_DataLoader(x_test, y_test, batch=16, num_workers=num_workers, shuffle=False)



In [15]:
seed = 0
deterministic = True
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
if deterministic:
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
# Create ViT model
vit_model = ViTModel(input_size, num_classes, patch_size, hidden_size, num_heads, num_layers)
vit_model.to(device)
# Define loss function and optimizer
criterion = nn.MultiLabelSoftMarginLoss()
optimizer = optim.Adam(vit_model.parameters(), lr=learning_rate)

compiled_model =torch.compile(vit_model)
# Training loop with attention weights visualization
val_loss_list = []
best_loss = np.inf
for epoch in range(num_epochs):
    vit_model.train()
    total_loss = 0.0
    train_bar = tqdm(train_loader)
    for step, (inputs, labels) in enumerate(train_bar):
        optimizer.zero_grad()
        inputs, labels = inputs.to(device), labels.to(device)
        # Forward pass
        outputs, attention_weights = vit_model(inputs)
        
        # Compute loss
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        train_bar.desc = "Train Epoch[{}/{}] loss: {:.3f}".format(epoch+1, num_epochs, loss.mean().item())
        # # Visualization of attention weights (for the first batch in each epoch)
        # if epoch == 0 and inputs.size(0) == batch_size:
        #     # Get the attention weights for the first batch
        #     attention_weights = attention_weights[0, 0, 0].detach().numpy()
        #     # Use seaborn to create a heatmap
        #     sns.heatmap(attention_weights, cmap="viridis")
        #     plt.xlabel("To Patch")
        #     plt.ylabel("From Patch")
        #     plt.title("Attention Weights - First Attention Head, First Transformer Layer")
        #     plt.show()
    with torch.no_grad():
        vit_model.eval()
        val_loss = 0.0
        val_bar = tqdm(val_loader)
        for v_step, (val_x, val_y) in enumerate(val_bar):
            val_x, val_y = val_x.to(device), val_y.to(device)
            val_logits, at_wei = vit_model(val_x)
            loss_v = criterion(val_logits, val_y)
            val_loss_list.append(loss_v.mean().item())
            val_loss += loss_v.mean().item()

        print("Validation loss :",val_loss/len(val_loader))
        if best_loss > val_loss:
            print("Validation Loss Decrease.. Best Model, Best Loss update")
            lossv = val_loss/len(val_loader)
            best_loss = val_loss
            # torch.save(model.state_dict(), f"Saved_ViT_210model_{round(lossv,3) * 100}.pth")
            
    average_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {average_loss}")

# Save the trained model if needed
# torch.save(vit_model.state_dict(), 'vit_model.pth')

Train Epoch[1/100] loss: 0.507: 100%|████████████████████████████████| 219/219 [00:32<00:00,  6.66it/s]
100%|████████████████████████████████████████████████████████████████| 219/219 [00:03<00:00, 61.38it/s]


Validation Loss Decrease.. Best Model, Best Loss update
Validation loss : 0.5359410975621716
Epoch 1/100, Loss: 0.5294206598305811


Train Epoch[2/100] loss: 0.483: 100%|████████████████████████████████| 219/219 [00:33<00:00,  6.60it/s]
100%|████████████████████████████████████████████████████████████████| 219/219 [00:03<00:00, 59.56it/s]


Validation Loss Decrease.. Best Model, Best Loss update
Validation loss : 0.48839754410530334
Epoch 2/100, Loss: 0.4774655375850799


Train Epoch[3/100] loss: 0.482: 100%|████████████████████████████████| 219/219 [00:33<00:00,  6.46it/s]
100%|████████████████████████████████████████████████████████████████| 219/219 [00:03<00:00, 57.54it/s]


Validation Loss Decrease.. Best Model, Best Loss update
Validation loss : 0.46701368750774697
Epoch 3/100, Loss: 0.4427006707343881


Train Epoch[4/100] loss: 0.460: 100%|████████████████████████████████| 219/219 [00:38<00:00,  5.70it/s]
100%|████████████████████████████████████████████████████████████████| 219/219 [00:04<00:00, 53.95it/s]


Validation Loss Decrease.. Best Model, Best Loss update
Validation loss : 0.46277126030290505
Epoch 4/100, Loss: 0.4321748083584929


Train Epoch[5/100] loss: 0.442: 100%|████████████████████████████████| 219/219 [00:38<00:00,  5.74it/s]
100%|████████████████████████████████████████████████████████████████| 219/219 [00:03<00:00, 55.00it/s]


Validation Loss Decrease.. Best Model, Best Loss update
Validation loss : 0.4572925434264963
Epoch 5/100, Loss: 0.4274387616817265


Train Epoch[6/100] loss: 0.442: 100%|████████████████████████████████| 219/219 [00:39<00:00,  5.61it/s]
100%|████████████████████████████████████████████████████████████████| 219/219 [00:04<00:00, 50.06it/s]


Validation Loss Decrease.. Best Model, Best Loss update
Validation loss : 0.45163294652553454
Epoch 6/100, Loss: 0.42103462554004095


Train Epoch[7/100] loss: 0.425: 100%|████████████████████████████████| 219/219 [00:38<00:00,  5.62it/s]
100%|████████████████████████████████████████████████████████████████| 219/219 [00:03<00:00, 55.47it/s]


Validation Loss Decrease.. Best Model, Best Loss update
Validation loss : 0.4508969108127568
Epoch 7/100, Loss: 0.4125208631498084


Train Epoch[8/100] loss: 0.422: 100%|████████████████████████████████| 219/219 [00:37<00:00,  5.78it/s]
100%|████████████████████████████████████████████████████████████████| 219/219 [00:04<00:00, 51.12it/s]


Validation Loss Decrease.. Best Model, Best Loss update
Validation loss : 0.44787279873678126
Epoch 8/100, Loss: 0.4053074460323543


Train Epoch[9/100] loss: 0.397: 100%|████████████████████████████████| 219/219 [00:38<00:00,  5.73it/s]
100%|████████████████████████████████████████████████████████████████| 219/219 [00:03<00:00, 55.01it/s]


Epoch 9/100, Loss: 0.3991586858551252


Train Epoch[10/100] loss: 0.390: 100%|███████████████████████████████| 219/219 [00:37<00:00,  5.86it/s]
100%|████████████████████████████████████████████████████████████████| 219/219 [00:03<00:00, 55.49it/s]


Epoch 10/100, Loss: 0.3956008931137111


Train Epoch[11/100] loss: 0.391: 100%|███████████████████████████████| 219/219 [00:36<00:00,  5.93it/s]
100%|████████████████████████████████████████████████████████████████| 219/219 [00:03<00:00, 55.48it/s]


Epoch 11/100, Loss: 0.39100658356054735


Train Epoch[12/100] loss: 0.363: 100%|███████████████████████████████| 219/219 [00:38<00:00,  5.70it/s]
100%|████████████████████████████████████████████████████████████████| 219/219 [00:04<00:00, 53.19it/s]


Epoch 12/100, Loss: 0.3864682842063033


Train Epoch[13/100] loss: 0.356: 100%|███████████████████████████████| 219/219 [00:39<00:00,  5.57it/s]
100%|████████████████████████████████████████████████████████████████| 219/219 [00:04<00:00, 54.02it/s]


Epoch 13/100, Loss: 0.38249374742377296


Train Epoch[14/100] loss: 0.458:  36%|███████████▍                    | 78/219 [00:13<00:25,  5.61it/s]


KeyboardInterrupt: 