PH Detection

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

# Define file paths for the 12-lead ECG data and labels
lead_file_paths = {
    "LEAD_I": "../Data_processing/data_aspire_PAP_1/LEAD_I.pt",
    "LEAD_II": "../Data_processing/data_aspire_PAP_1/LEAD_II.pt",
    "LEAD_III": "../Data_processing/data_aspire_PAP_1/LEAD_III.pt",
    "LEAD_aVR": "../Data_processing/data_aspire_PAP_1/LEAD_aVR.pt",
    "LEAD_aVL": "../Data_processing/data_aspire_PAP_1/LEAD_aVL.pt",
    "LEAD_aVF": "../Data_processing/data_aspire_PAP_1/LEAD_aVF.pt",
    "LEAD_V1": "../Data_processing/data_aspire_PAP_1/LEAD_V1.pt",
    "LEAD_V2": "../Data_processing/data_aspire_PAP_1/LEAD_V2.pt",
    "LEAD_V3": "../Data_processing/data_aspire_PAP_1/LEAD_V3.pt",
    "LEAD_V4": "../Data_processing/data_aspire_PAP_1/LEAD_V4.pt",
    "LEAD_V5": "../Data_processing/data_aspire_PAP_1/LEAD_V5.pt",
    "LEAD_V6": "../Data_processing/data_aspire_PAP_1/LEAD_V6.pt"
}
labels_file_path = "../Data_processing/data_aspire_PAP_1/labels.pt"


# Load all lead tensors and labels
ecg_lead_tensors = {lead: torch.load(path) for lead, path in lead_file_paths.items()}
labels = torch.load(labels_file_path)

# Ensure all leads have the same number of samples as the labels
sample_count = len(next(iter(ecg_lead_tensors.values())))
assert len(labels) == sample_count, "Mismatch between number of labels and samples."
for tensor in ecg_lead_tensors.values():
    assert len(tensor) == sample_count, "All leads must have the same number of samples."

# Define the dataset class
class ECGMultiLeadDatasetWithLabels(Dataset):
    def __init__(self, ecg_leads, labels):
        self.ecg_leads = ecg_leads
        self.labels = labels

    def __len__(self):
        return len(next(iter(self.ecg_leads.values())))

    def __getitem__(self, idx):
        # Return each lead sample with the correct input shape and corresponding label
        lead_data = {lead: self.ecg_leads[lead][idx].unsqueeze(0) for lead in self.ecg_leads}
        label = self.labels[idx]
        return lead_data, label

# Initialize the dataset and dataloader
dataset = ECGMultiLeadDatasetWithLabels(ecg_lead_tensors, labels)


In [ ]:
# **Classifier Model Class**
class ECGLeadClassifier(nn.Module):
    def __init__(self, pretrained_mopoe, num_classes, use_12_leads=True):
        super(ECGLeadClassifier, self).__init__()

        # Define lead names based on the mode
        self.use_12_leads = use_12_leads
        self.lead_names = (
            ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL",
             "LEAD_V1", "LEAD_V2", "LEAD_V3", "LEAD_V4", "LEAD_V5", "LEAD_V6"]
            if use_12_leads
            else ["LEAD_I", "LEAD_II", "LEAD_III", "LEAD_aVR", "LEAD_aVF", "LEAD_aVL"]
        )

        # Select encoders based on the leads to use
        encoder_indices = range(12) if use_12_leads else range(6)
        self.lead_encoders = nn.ModuleList([pretrained_mopoe.encoders[i] for i in encoder_indices])

        self.feature_dim = pretrained_mopoe.latent_dim

        # Freeze encoder weights
        for encoder in self.lead_encoders:
            for param in encoder.parameters():
                param.requires_grad = False

        # Define classifier network
        self.classifier = nn.Sequential(
            nn.Linear(len(self.lead_names) * self.feature_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes)
        )

    def forward(self, lead_data):
        lead_features = []
        for lead_name, encoder in zip(self.lead_names, self.lead_encoders):
            mu, _ = encoder(lead_data[lead_name])
            lead_features.append(mu)

        combined_features = torch.cat(lead_features, dim=1)
        logits = self.classifier(combined_features)
        return logits

In [None]:
import random
import numpy as np

# **Set device configuration**
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# **Set seed for reproducibility**
def set_seed(seed_value):
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(123)


# Define parameters

params = {'latent_dim': 256, 'input_dim_per_lead': 5000, 'num_leads': 12}

# Instantiate the prior distribution
prior_dist = prior_expert(params['latent_dim'])

# Create the MoPoE model (specific for 12-lead ECG data)
pretrained_mopoe = LSEMVAE(
    prior_dist=prior_dist,
    latent_dim=params['latent_dim'],
    num_leads=params['num_leads'],
    input_dim_per_lead=params['input_dim_per_lead']
)

# Load the pretrained weights
# Load the saved state_dict
state_dict = torch.load("../Main/pretrain/LS_EMVAE_with_reg_12_lead.pth", map_location=device)


# Fix key mismatch by removing '_orig_mod.' prefix
new_state_dict = {k.replace("_orig_mod.", ""): v for k, v in state_dict.items()}

# Load the updated state dictionary
pretrained_mopoe.load_state_dict(new_state_dict, strict=False)

# Move the model to the device (GPU/CPU)
pretrained_mopoe.to(device)

model = ECGLeadClassifier(pretrained_mopoe=pretrained_mopoe, num_classes=2, use_12_leads=False).to(device)

# Load the saved model state dictionary
model.load_state_dict(torch.load('../Main/HPC/multimodal_classifier.pth', map_location=device))

In [None]:
from captum.attr import IntegratedGradients
import torch
import matplotlib.pyplot as plt
import numpy as np
import neurokit2 as nk 


# Select a sample from the dataset
sample_index = 609
lead_data_sample, label = dataset[sample_index]


for lead in lead_data_sample:
    if lead_data_sample[lead].dim() == 2:
        lead_data_sample[lead] = lead_data_sample[lead].unsqueeze(0)
    lead_data_sample[lead] = lead_data_sample[lead].to(device)
    lead_data_sample[lead].requires_grad_(True)

# Define a wrapper function for Integrated Gradients
def model_wrapper(*inputs):
    lead_data = {lead: tensor for lead, tensor in zip(model.lead_names, inputs)}
    return model(lead_data)

# Prepare inputs as a tuple in the order of model.lead_names
inputs_tuple = tuple(lead_data_sample[lead] for lead in model.lead_names)

# Get model prediction
model.eval()
with torch.no_grad():
    logits = model_wrapper(*inputs_tuple)
predicted_label = torch.argmax(logits, dim=1).item()
print(f"Actual Label: {label}")
print(f"Predicted Label: {predicted_label}")

# Initialize Integrated Gradients
integrated_gradients = IntegratedGradients(model_wrapper)

# Compute attributions
attributions, _ = integrated_gradients.attribute(inputs=inputs_tuple,
                                                 target=predicted_label,
                                                 return_convergence_delta=True)

# Create a plot with one subplot per lead
num_leads = len(model.lead_names)
fig, axes = plt.subplots(num_leads, 1, figsize=(12, num_leads * 2), sharex=True)
if num_leads == 1:
    axes = [axes]

sampling_rate = 500  # adjust if necessary

for idx, lead in enumerate(model.lead_names):
    attr = attributions[idx].cpu().detach().numpy().squeeze()
    lead_waveform = lead_data_sample[lead].cpu().detach().numpy().squeeze()
    lead_waveform_clean = nk.ecg_clean(lead_waveform, sampling_rate=sampling_rate)
    
    norm_attr = (attr - attr.min()) / (attr.max() - attr.min() + 1e-10)
    threshold = 0.75
    important_indices = np.where(norm_attr >= threshold)[0]
    
    signal_length = lead_waveform_clean.shape[0]
    time_axis = np.linspace(0, signal_length / sampling_rate, signal_length)
    
    ax = axes[idx]
    ax.plot(time_axis, lead_waveform_clean, color='black', label='ECG Waveform')
    
    stretch_window = 6  
    for imp_idx in important_indices:
        start = max(0, imp_idx - stretch_window)
        end = min(signal_length, imp_idx + stretch_window + 1)
        ax.plot(time_axis[start:end], lead_waveform_clean[start:end], color='red', linewidth=2)
    
    ax.set_title(f"{lead}", fontsize=16, fontweight='bold')  # Bigger, bold title
    ax.tick_params(axis='both', which='major', labelsize=20)

# Set a single bold y-axis label in the middle
fig.text(0.002, 0.5, 'Amplitude (mm)', fontsize=18, fontweight='bold', va='center', rotation='vertical')

axes[-1].set_xlabel('Time (seconds)', fontsize=20, fontweight='bold')
plt.tight_layout()
plt.show()
fig.savefig("plot_q_1.svg", format="svg", dpi=600, bbox_inches='tight')

In [None]:
from captum.attr import IntegratedGradients
import torch
import matplotlib.pyplot as plt
import numpy as np
import neurokit2 as nk  # for optional ECG cleaning

# Assume device, dataset, and model (an instance of ECGLeadClassifier) are defined and loaded.
sample_index = 609
lead_data_sample, label = dataset[sample_index]

# Ensure each lead tensor has shape (1, 1, signal_length) and move to device.
for lead in lead_data_sample:
    if lead_data_sample[lead].dim() == 2:  # e.g. (1, signal_length)
        lead_data_sample[lead] = lead_data_sample[lead].unsqueeze(0)  # becomes (1,1,signal_length)
    lead_data_sample[lead] = lead_data_sample[lead].to(device)
    lead_data_sample[lead].requires_grad_(True)

# Define a wrapper to convert a tuple of tensors (in model.lead_names order) into a dictionary.
def model_wrapper(*inputs):
    lead_data = {lead: tensor for lead, tensor in zip(model.lead_names, inputs)}
    return model(lead_data)

# Prepare inputs as a tuple.
inputs_tuple = tuple(lead_data_sample[lead] for lead in model.lead_names)

# Get the model prediction via the wrapper.
model.eval()
with torch.no_grad():
    logits = model_wrapper(*inputs_tuple)
predicted_label = torch.argmax(logits, dim=1).item()
print(f"Actual Label: {label}")
print(f"Predicted Label: {predicted_label}")

# Initialize Integrated Gradients with the wrapper.
integrated_gradients = IntegratedGradients(model_wrapper)

# Compute attributions for the predicted label.
attributions, _ = integrated_gradients.attribute(
    inputs=inputs_tuple,
    target=predicted_label,
    return_convergence_delta=True
)

# Define the zoom segment.
# For instance, with a sampling rate of 500 Hz, 0.5 sec to 1.5 sec corresponds to indices 250 to 750.
sampling_rate = 500
start_time = 2 # seconds
end_time = 3   # seconds

#start_time = 3.7 # seconds
#end_time =  4.4    # seconds
start_index = int(start_time * sampling_rate)
end_index = int(end_time * sampling_rate)

# Visualization parameters: thicker lines and a threshold to mark "important" regions.
waveform_linewidth = 5
important_linewidth = 5
threshold = 0.75
stretch_window = 6

# Create a subplot for each lead.
num_leads = len(model.lead_names)
fig, axes = plt.subplots(num_leads, 1, figsize=(12, num_leads * 2), sharex=True)
if num_leads == 1:
    axes = [axes]

for idx, lead in enumerate(model.lead_names):
    # Get the attribution for the current lead and convert to NumPy.
    attr = attributions[idx]
    attr_np = attr.cpu().detach().numpy().squeeze()
    
    # Retrieve the original ECG waveform for this lead.
    lead_waveform = lead_data_sample[lead].cpu().detach().numpy().squeeze()
    
    # Optionally, clean the ECG waveform using neurokit2.
    lead_waveform_clean = nk.ecg_clean(lead_waveform, sampling_rate=sampling_rate)
    
    # Normalize the attribution scores to the [0, 1] range.
    norm_attr = (attr_np - attr_np.min()) / (attr_np.max() - attr_np.min() + 1e-10)
    
    # Extract the zoom segment from both the cleaned waveform and attributions.
    segment_waveform = lead_waveform_clean[start_index:end_index]
    segment_attr = norm_attr[start_index:end_index]
    segment_indices = np.arange(start_index, end_index)
    
    # Identify important indices within the segment.
    important_indices = np.where(segment_attr >= threshold)[0]
    
    ax = axes[idx]
    # Plot the base ECG segment with a thick black line.
    ax.plot(segment_indices, segment_waveform, color='black', alpha=1, linewidth=waveform_linewidth)
    
    # For each important index, highlight a small window in red.
    for imp_idx in important_indices:
        stretch_start = max(0, imp_idx - stretch_window)
        stretch_end = min(len(segment_waveform), imp_idx + stretch_window + 1)
        ax.plot(segment_indices[stretch_start:stretch_end],
                segment_waveform[stretch_start:stretch_end],
                color='red', linewidth=important_linewidth, zorder=3)
    
    # Remove axis ticks and labels for a cleaner view.
    ax.set_xticks([])
    ax.set_yticks([])
    ax.set_title(f"{lead}", fontsize=16, fontweight='bold')

# Remove x and y labels on the bottom subplot.
axes[-1].set_xlabel("")
axes[-1].set_ylabel("")

plt.tight_layout()
plt.show()
fig.savefig("plot_q_2.svg", format="svg", dpi=600, bbox_inches='tight')