MVAE model

In [1]:
import numpy as np

import torch
import torch.nn as nn
from torch.autograd import Variable
from torch.nn import functional as F
from torch.nn.parameter import Parameter


# Image Modality Encoder
class ImageVAEEncoder(nn.Module):
    def __init__(self, input_channels=1, latent_dim=256):
        super(ImageVAEEncoder, self).__init__()
        self.conv1 = nn.Conv2d(input_channels, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1)
        self.flatten = nn.Flatten()
        self.fc_mu = nn.Linear(in_features=64 * 28 * 28, out_features=latent_dim)
        self.fc_logvar = nn.Linear(in_features=64 * 28 * 28, out_features=latent_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.flatten(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar
    
    

# Image Modality Decoder
class ImageVAEDecoder(nn.Module):
    def __init__(self, latent_dim=256, output_channels=1):
        super(ImageVAEDecoder, self).__init__()
        self.fc = nn.Linear(in_features=latent_dim, out_features=64 * 28 * 28)
        self.convtrans1 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.convtrans2 = nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.convtrans3 = nn.ConvTranspose2d(16, output_channels, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.relu = nn.ReLU()
        self.output_activation = nn.Sigmoid()

    def forward(self, z):
        z = self.fc(z)
        z = z.view(-1, 64, 28, 28)
        z = self.relu(self.convtrans1(z))
        z = self.relu(self.convtrans2(z))
        z = self.output_activation(self.convtrans3(z))
        return z

# ECG Modality Encoder
class ECGVAEEncoder(nn.Module):
    def __init__(self, input_dim=60000, latent_dim=256):
        super(ECGVAEEncoder, self).__init__()
        self.conv1 = nn.Conv1d(1, 16, kernel_size=3, stride=2, padding=1)
        self.conv2 = nn.Conv1d(16, 32, kernel_size=3, stride=2, padding=1)
        self.conv3 = nn.Conv1d(32, 64, kernel_size=3, stride=2, padding=1)
        self.flatten = nn.Flatten()
        self.fc_mu = nn.Linear(in_features=64 * (input_dim // 8), out_features=latent_dim)  # Adjusted for stride=2, 3 layers
        self.fc_logvar = nn.Linear(in_features=64 * (input_dim // 8), out_features=latent_dim)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.relu(self.conv2(x))
        x = self.relu(self.conv3(x))
        x = self.flatten(x)
        mu = self.fc_mu(x)
        logvar = self.fc_logvar(x)
        return mu, logvar

# ECG Modality Decoder
class ECGVAEDecoder(nn.Module):
    def __init__(self, latent_dim=256, output_dim=60000):
        super(ECGVAEDecoder, self).__init__()
        self.fc = nn.Linear(in_features=latent_dim, out_features=64 * (output_dim // 8))
        self.convtrans1 = nn.ConvTranspose1d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.convtrans2 = nn.ConvTranspose1d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.convtrans3 = nn.ConvTranspose1d(16, 1, kernel_size=3, stride=2, padding=1, output_padding=1)
        self.relu = nn.ReLU()
        self.output_activation = nn.Identity()  # Suitable for standardized data

    def forward(self, z):
        z = self.fc(z)
        z = z.view(-1, 64, z.size(1) // 64)  # Adjust the reshape for proper dimensions
        z = self.relu(self.convtrans1(z))
        z = self.relu(self.convtrans2(z))
        z = self.output_activation(self.convtrans3(z))
        return z

class MultimodalVAE(nn.Module):

    def __init__(self, image_input_channels=1, ecg_input_dim=60000, latent_dim=256):
        super(MultimodalVAE, self).__init__()
        self.image_encoder = ImageVAEEncoder(image_input_channels, latent_dim)
        self.ecg_encoder = ECGVAEEncoder(ecg_input_dim, latent_dim)
        self.image_decoder = ImageVAEDecoder(latent_dim, image_input_channels)
        self.ecg_decoder = ECGVAEDecoder(latent_dim, ecg_input_dim)

        self.experts       = ProductOfExperts()
        self.n_latents     = latent_dim

    def reparametrize(self, mu, logvar):
        if self.training:
            std = torch.exp(0.5 * logvar)
            eps = torch.randn_like(std)
            return mu + eps * std
        else:
          return mu

    def forward(self, image=None, ecg=None):
        mu, logvar = self.infer(image, ecg)
        # reparametrization trick to sample
        z          = self.reparametrize(mu, logvar)
        # reconstruct inputs based on that gaussian
        img_recon  = self.image_decoder(z)
        ecg_recon  = self.ecg_decoder(z)
        return img_recon, ecg_recon, mu, logvar

    def infer(self, image=None, ecg=None): 
        batch_size = image.size(0) if image is not None else ecg.size(0)
        use_cuda   = next(self.parameters()).is_cuda  # check if CUDA
        # initialize the universal prior expert
        mu, logvar = prior_expert((1, batch_size, self.n_latents), 
                                  use_cuda=use_cuda)
        if image is not None:
            img_mu, img_logvar = self.image_encoder(image)
            mu     = torch.cat((mu, img_mu.unsqueeze(0)), dim=0)
            logvar = torch.cat((logvar, img_logvar.unsqueeze(0)), dim=0)

        if ecg is not None:
            ecg_mu, ecg_logvar = self.ecg_encoder(ecg)
            mu     = torch.cat((mu, ecg_mu.unsqueeze(0)), dim=0)
            logvar = torch.cat((logvar, ecg_logvar.unsqueeze(0)), dim=0)

        # product of experts to combine gaussians
        mu, logvar = self.experts(mu, logvar)
        return mu, logvar


class ProductOfExperts(nn.Module):
    """Return parameters for product of independent experts.
    See https://arxiv.org/pdf/1410.7827.pdf for equations.

    @param mu: M x D for M experts
    @param logvar: M x D for M experts
    """
    def forward(self, mu, logvar, eps=1e-8):
        var       = torch.exp(logvar) + eps
        # precision of i-th Gaussian expert at point x
        T         = 1. / (var + eps)
        pd_mu     = torch.sum(mu * T, dim=0) / torch.sum(T, dim=0)
        pd_var    = 1. / torch.sum(T, dim=0)
        pd_logvar = torch.log(pd_var + eps)
        return pd_mu, pd_logvar



def prior_expert(size, use_cuda=False):
    """Universal prior expert. Here we use a spherical
    Gaussian: N(0, 1).

    @param size: integer
                 dimensionality of Gaussian
    @param use_cuda: boolean [default: False]
                     cast CUDA on variables
    """
    mu     = Variable(torch.zeros(size))
    logvar = Variable(torch.zeros(size))
    if use_cuda:
        mu, logvar = mu.cuda(), logvar.cuda()
    return mu, logvar

MIMIC Dataloader

In [2]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn.functional as nnf

X_ecg_tensor = torch.load('data_feature/ecg_features_tensor.pt')
X_image_tensor = torch.load('data_feature/encoder_image_tensor.pt')


class ECGImageDataset(Dataset):
    def __init__(self, ecg_features, image_features):
        self.ecg_features = ecg_features
        self.image_features = image_features

    def __len__(self):
        return len(self.ecg_features)

    def __getitem__(self, idx):
        return self.ecg_features[idx], self.image_features[idx]
    
dataset = ECGImageDataset(X_ecg_tensor, X_image_tensor)

dataloader = DataLoader(dataset, batch_size=128, shuffle=True)
        

Pre-Training

In [4]:
# Device configuration
import torch.optim as optim

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


def elbo_loss(recon_xray, xray, recon_ecg, ecg, mu, logvar,
              lambda_xray=1.0, lambda_ecg=1.0, annealing_factor=1):

    xray_mse, ecg_mse = 0, 0
    if recon_xray is not None and xray is not None:
        # Reshape to the original image size
        xray_mse = nnf.mse_loss(recon_xray, xray, reduction='sum')

    if recon_ecg is not None and ecg is not None:
        # Reshape to the original image size
        ecg_mse = nnf.mse_loss(recon_ecg, ecg, reduction='sum')

    KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1)
    
    ELBO = torch.mean(lambda_xray * xray_mse + lambda_ecg * ecg_mse + annealing_factor * KLD)
    
    return ELBO

# Utility Functions
class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

# Training function
def train(epoch):
    model.train()
    train_loss_meter = AverageMeter()
    N_mini_batches = len(dataloader)

    for batch_idx, (ecg, xray) in enumerate(dataloader):
        annealing_factor = min(epoch / annealing_epochs, 1) if epoch < annealing_epochs else 1.0
        ecg, xray = ecg.to(device), xray.to(device)
    

        optimizer.zero_grad()
        recon_xray_joint, recon_ecg_joint, mu_joint, logvar_joint = model(xray, ecg)

        joint_loss = elbo_loss(recon_xray_joint, xray, recon_ecg_joint, ecg, mu_joint, logvar_joint, lambda_xray, lambda_ecg, annealing_factor)

        train_loss = joint_loss
        train_loss.backward()
        optimizer.step()

        train_loss_meter.update(train_loss.item(), len(ecg))
        if batch_idx % log_interval == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(ecg)}/{len(dataloader.dataset)} ({100. * batch_idx / N_mini_batches:.0f}%)]\tLoss: {train_loss_meter.avg:.6f}')

    return train_loss_meter.avg
   
# Hyperparameters
n_latents = 256
epochs = 100
annealing_epochs = 50
lr = 1e-3
log_interval = 10
lambda_xray = 1.0
lambda_ecg = 10.0

# Model and optimizer setup
model = MultimodalVAE(image_input_channels=1, ecg_input_dim=60000,latent_dim=256).to(device)
optimizer = optim.Adam(model.parameters(), lr=lr)

# Main training and validation loop
best_loss = float('inf')
for epoch in range(1, epochs + 1):
    train_loss = train(epoch)


# After training is complete
torch.save(model.state_dict(), 'pretrained_models/mvae_only_joint.pth')
print("Saved model state dictionary to 'mvae_only_joint.pth'")

Using device: cuda
Saved model state dictionary to 'mvae_only_joint.pth'


Fine-tuning on Aspire Dataset

Dataloader

In [2]:
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from PIL import Image
import os
import pydicom
import torch
from torch.utils.data import Dataset
import numpy as np

class XRayDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []

        # Loop through each label directory
        for label in [0, 1]:
            label_dir = os.path.join(root_dir, f'processed_label_{label}')
            for folder_name in os.listdir(label_dir):
                folder_path = os.path.join(label_dir, folder_name)
                image_name = os.listdir(folder_path)[0]  # Assuming only one image per folder
                if image_name.endswith('.jpg'):
                    self.images.append(os.path.join(folder_path, image_name))
                    self.labels.append(label)

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image_path = self.images[idx]
        image = Image.open(image_path)
        label = self.labels[idx]

        if self.transform:
            image = self.transform(image)

        return image, label
    
 
# Initialize your dataset
xray_dataset = XRayDataset(root_dir='D:/Aspire_xray/xray', transform=transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=1),  # Convert to grayscale
    transforms.ToTensor(),
]))


def process_dicom(file_path, sampling_rate=500):
    desired_length = 10 * sampling_rate  # 10 seconds of data
    try:
        dicom_data = pydicom.dcmread(file_path)
        if "WaveformSequence" in dicom_data:
            rhythm_waveform = dicom_data.WaveformSequence[1]
            wave_data = rhythm_waveform.get("WaveformData")
            num_channels = rhythm_waveform.NumberOfWaveformChannels
            wave_array = np.frombuffer(wave_data, dtype=np.int16)
            num_samples_per_channel = wave_array.size // num_channels
            
            if wave_array.size % num_channels == 0:
                wave_array = wave_array.reshape(num_samples_per_channel, num_channels)
                

                # Trim or Pad the array to 10 seconds
                if wave_array.shape[0] > desired_length:
                    wave_array = wave_array[:desired_length, :]
                elif wave_array.shape[0] < desired_length:
                    padding = np.zeros((desired_length - wave_array.shape[0], num_channels), dtype=wave_array.dtype)
                    wave_array = np.vstack((wave_array, padding))
                
                # Normalize the array
                wave_array = (wave_array - np.mean(wave_array, axis=0)) / np.std(wave_array, axis=0)

                return wave_array
            else:
                print(f"Unexpected data size in {file_path}. Skipping file.")
                return None
        else:
            print(f"No Waveform data found in {file_path}. Skipping file.")
            return None
    except Exception as e:
        print(f"Error processing file {file_path}: {e}")
        return None


class ECGDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.ecg_data = []
        self.labels = []

        # Loop through each label directory
        for label in [0, 1]:
            label_dir = os.path.join(root_dir, f'processed_label_{label}')
            for folder_name in os.listdir(label_dir):
                folder_path = os.path.join(label_dir, folder_name)
                for file_name in os.listdir(folder_path):
                    if file_name.endswith('.dcm'):
                        file_path = os.path.join(folder_path, file_name)
                        ecg_waveform = process_dicom(file_path)
                        if ecg_waveform is not None:
                            self.ecg_data.append(ecg_waveform)
                            self.labels.append(label)

    def __len__(self):
        return len(self.ecg_data)

    def __getitem__(self, idx):
        ecg_waveform = self.ecg_data[idx]
        label = self.labels[idx]
        # Reshape waveform to [1, signal_length]
        ecg_waveform = ecg_waveform.reshape(1, -1)  
        return torch.tensor(ecg_waveform, dtype=torch.float32), label

# Usage example
ecg_dataset = ECGDataset(root_dir='D:/Aspire_ecg/ecg')

class CombinedDataset(Dataset):
    def __init__(self, xray_dataset, ecg_dataset):
        self.xray_dataset = xray_dataset
        self.ecg_dataset = ecg_dataset
        assert len(xray_dataset) == len(ecg_dataset), "Datasets must be of the same length."
        
        # Assuming the labels are the same for both datasets and can be directly accessed
        self.labels = [label for _, label in xray_dataset]

    def __len__(self):
        return len(self.xray_dataset)

    def __getitem__(self, idx):
        xray_image, xray_label = self.xray_dataset[idx]
        ecg_waveform, ecg_label = self.ecg_dataset[idx]
        
        # Ensure the labels match if they are supposed to be the same
        assert xray_label == ecg_label, "Labels do not match for the same index."
        
        return xray_image, ecg_waveform, xray_label  # Use either xray_label or ecg_label

    def get_labels(self):
        return self.labels

# Instantiate the combined dataset
combined_dataset = CombinedDataset(xray_dataset, ecg_dataset)
labels = combined_dataset.get_labels()

Model 

In [3]:

class MultimodalClassifier(nn.Module):
    def __init__(self, pretrained_mvae, num_classes):
        super(MultimodalClassifier, self).__init__()
        self.image_encoder = pretrained_mvae.image_encoder
        self.ecg_encoder = pretrained_mvae.ecg_encoder
        
        # Assuming you want to concatenate the encoded features
        combined_feature_dim = pretrained_mvae.n_latents * 2  # Since you're likely concatenating
        
        # Freeze the encoder weights
        for param in self.image_encoder.parameters():
            param.requires_grad = False
        for param in self.ecg_encoder.parameters():
            param.requires_grad = False

        self.classifier = nn.Sequential(
            nn.Linear(combined_feature_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(128, num_classes),
        )

    def forward(self, xray, ecg):
        mu_xray, _ = self.image_encoder(xray)
        mu_ecg, _ = self.ecg_encoder(ecg)
        combined_features = torch.cat((mu_xray, mu_ecg), dim=1)
        logits = self.classifier(combined_features)
        return logits

Training

In [4]:
from sklearn.metrics import accuracy_score, roc_auc_score

def train_classifier(model, train_loader, criterion, optimizer, epochs=50):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for xray, ecg, labels in train_loader:
            xray, ecg, labels = xray.to(device), ecg.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(xray, ecg)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item() * labels.size(0)
        
        avg_loss = total_loss / len(train_loader.dataset)
        print(f'Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}')

def evaluate_model(model, data_loader):
    model.eval()
    all_labels = []
    all_probs = []
    all_preds = []
    
    with torch.no_grad():
        for xray, ecg, labels in data_loader:
            xray, ecg, labels = xray.to(device), ecg.to(device), labels.to(device)
            logits = model(xray, ecg)
            probs = torch.softmax(logits, dim=1)
            preds = torch.argmax(probs, dim=1)
            
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(probs[:, 1].cpu().numpy())
            all_preds.extend(preds.cpu().numpy())
    
    auc_score = roc_auc_score(all_labels, all_probs)
    accuracy = accuracy_score(all_labels, all_preds)
    
    return accuracy, auc_score

In [5]:
import torch
import random
from sklearn.model_selection import StratifiedKFold
import numpy as np
from torch.utils.data import DataLoader, Subset
from sklearn.metrics import accuracy_score, roc_auc_score


device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

def set_seed(seed_value):
    """Set seed for reproducibility."""
    random.seed(seed_value)
    np.random.seed(seed_value)
    torch.manual_seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed_value)
        torch.cuda.manual_seed_all(seed_value)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    
# Set a seed value
seed = 42
set_seed(seed)


skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=42)
fold_results = []

for fold, (train_ids, test_ids) in enumerate(skf.split(np.zeros(len(labels)), labels)):
    print(f'FOLD {fold}')
    print('--------------------------------')
    
    train_subset = Subset(combined_dataset, train_ids)
    test_subset = Subset(combined_dataset, test_ids)
    
    train_loader = DataLoader(train_subset, batch_size=32, shuffle=True)
    test_loader = DataLoader(test_subset, batch_size=32, shuffle=False)
    
    # Load the pre-trained Multimodal VAE
    pretrained_mvae = MultimodalVAE(image_input_channels=1, ecg_input_dim=60000, latent_dim=256)
    pretrained_mvae.load_state_dict(torch.load('pretrained_models/multimodal_vae_only_joint_100.pth', map_location=torch.device('cuda' if torch.cuda.is_available() else 'cpu')))
    pretrained_mvae.to(device)
    
    # Initialize the classifier for this fold
    model = MultimodalClassifier(pretrained_mvae=pretrained_mvae, num_classes=2).to(device)
    
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
    criterion = nn.CrossEntropyLoss()
    
    # Train the model
    train_classifier(model, train_loader, criterion, optimizer, epochs=50)
    
    # Evaluate the model on the test set
    accuracy, auc_score = evaluate_model(model, test_loader)
    fold_results.append((accuracy, auc_score))
    print(f'Fold {fold} Results: Accuracy: {accuracy:.4f}, AUC: {auc_score:.4f}\n')
    
# Calculate and print the mean and STD for each metric across folds
accuracies, aucs = zip(*fold_results)
mean_accuracy = np.mean(accuracies)
std_accuracy = np.std(accuracies)
mean_auc = np.mean(aucs)
std_auc = np.std(aucs)

print(f'Mean Accuracy: {mean_accuracy:.4f}, STD: {std_accuracy:.4f}')
print(f'Mean AUC: {mean_auc:.4f}, STD: {std_auc:.4f}')

Using device: cuda
FOLD 0
--------------------------------
Epoch 1/50, Loss: 0.7196
Epoch 2/50, Loss: 0.5867
Epoch 3/50, Loss: 0.5327
Epoch 4/50, Loss: 0.5167
Epoch 5/50, Loss: 0.5111
Epoch 6/50, Loss: 0.4896
Epoch 7/50, Loss: 0.4861
Epoch 8/50, Loss: 0.4643
Epoch 9/50, Loss: 0.4556
Epoch 10/50, Loss: 0.4406
Epoch 11/50, Loss: 0.4290
Epoch 12/50, Loss: 0.4026
Epoch 13/50, Loss: 0.4136
Epoch 14/50, Loss: 0.3868
Epoch 15/50, Loss: 0.4108
Epoch 16/50, Loss: 0.3913
Epoch 17/50, Loss: 0.3852
Epoch 18/50, Loss: 0.3664
Epoch 19/50, Loss: 0.3649
Epoch 20/50, Loss: 0.3725
Epoch 21/50, Loss: 0.3504
Epoch 22/50, Loss: 0.3615
Epoch 23/50, Loss: 0.3364
Epoch 24/50, Loss: 0.3365
Epoch 25/50, Loss: 0.2899
Epoch 26/50, Loss: 0.3257
Epoch 27/50, Loss: 0.3001
Epoch 28/50, Loss: 0.2791
Epoch 29/50, Loss: 0.2882
Epoch 30/50, Loss: 0.2637
Epoch 31/50, Loss: 0.2806
Epoch 32/50, Loss: 0.2791
Epoch 33/50, Loss: 0.2777
Epoch 34/50, Loss: 0.2405
Epoch 35/50, Loss: 0.2677
Epoch 36/50, Loss: 0.2626
Epoch 37/50, L