# A Novel Approach for Three-Way Classification of Lumbar Spine Degeneration Using Pseudo-Modality Learning to Handle Missing MRI Data

## Libs

In [7]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import random
import torch.nn.functional as F

## Loading MedicalNet50 Embeddings

In [2]:
csv1 = pd.read_csv('/kaggle/input/medicalnet-attention-layers-for-rsna/AT2_attention_embeddings_hist.csv')
csv2 = pd.read_csv('/kaggle/input/medicalnet-attention-layers-for-rsna/ST1_attention_embeddings_hist.csv')
csv3 = pd.read_csv('/kaggle/input/medicalnet-attention-layers-for-rsna/ST2_attention_embeddings_hist.csv')

merged_data = pd.concat([csv1, csv2, csv3], ignore_index=True)

unique_study_ids = merged_data['study_id'].unique()
train_ids, test_ids = train_test_split(unique_study_ids, test_size=0.2, random_state=42)

train_data = merged_data[merged_data['study_id'].isin(train_ids)]
test_data = merged_data[merged_data['study_id'].isin(test_ids)]

def simulate_missing_modalities(data, missing_rate=0.5):
    grouped = data.groupby('study_id')
    simulated_data = []
    
    for study_id, group in grouped:
        num_embeddings = len(group)
        num_to_keep = max(1, int(num_embeddings * (1 - missing_rate)))  
        keep_indices = random.sample(range(num_embeddings), num_to_keep)
        for idx in keep_indices:
            simulated_data.append(group.iloc[idx])
    
    return pd.DataFrame(simulated_data)

simulated_test_data = simulate_missing_modalities(test_data, missing_rate=0.5)

## Defining Architecture

In [3]:
class DenoisingAutoencoder(nn.Module):
    def __init__(self, input_dim=512, hidden_dim=256, output_dim=512):
        super(DenoisingAutoencoder, self).__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.decoder = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, output_dim),
            nn.Tanh()
        )

    def forward(self, x):
        encoded = self.encoder(x)
        decoded = self.decoder(encoded)
        return decoded
    
class VAE(nn.Module):
    def __init__(self, input_dim=512, latent_dim=128):
        super(VAE, self).__init__()
        
        self.fc1 = nn.Linear(input_dim, 256)
        self.fc21 = nn.Linear(256, latent_dim)
        self.fc22 = nn.Linear(256, latent_dim)
        
        self.fc3 = nn.Linear(latent_dim, 256)
        self.fc4 = nn.Linear(256, input_dim)

    def encode(self, x):
        h1 = F.relu(self.fc1(x))
        return self.fc21(h1), self.fc22(h1)
    
    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std
    
    def decode(self, z):
        h3 = F.relu(self.fc3(z))
        return torch.tanh(self.fc4(h3))

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        return self.decode(z), mu, logvar

def vae_loss_function(reconstructed_x, x, mu, logvar):
    recon_loss = F.mse_loss(reconstructed_x, x, reduction='sum')
    kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
    
    return recon_loss + kl_loss

class AttentionFusion(nn.Module):
    def __init__(self, embedding_dim=512):
        super(AttentionFusion, self).__init__()
        self.attention_weights = nn.Parameter(torch.ones(1, embedding_dim), requires_grad=True)

    def forward(self, embeddings):
        attention_scores = torch.matmul(embeddings, self.attention_weights.T)
        attention_weights = torch.softmax(attention_scores, dim=0)

        fused_embedding = torch.sum(attention_weights * embeddings, dim=0)
        return fused_embedding

## Training Functions

In [10]:
def train_denoising_autoencoder(autoencoder, optimizer, train_data, num_epochs=50, noise_factor=0.2):
    autoencoder.train()
    grouped = train_data.groupby('study_id')
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    autoencoder.to(device)

    mse_loss_fn = nn.MSELoss()

    for epoch in tqdm(range(num_epochs)):
        total_loss = 0
        for study_id, group in (grouped):

            embeddings = []
            for _, row in group.iterrows():
                embedding = row[0:512].to_numpy(dtype=float)  
                embeddings.append(embedding)

            embeddings = np.array(embeddings)  

            num_embeddings = embeddings.shape[0]

            if num_embeddings > 1:
                input_embeddings = embeddings[0]
                noisy_embeddings = input_embeddings + noise_factor * np.random.randn(*input_embeddings.shape)  # Add noise
                noisy_embeddings = np.clip(noisy_embeddings, 0., 1.)
                
                input_embeddings = torch.tensor(input_embeddings, dtype=torch.float32).to(device)
                noisy_embeddings = torch.tensor(noisy_embeddings, dtype=torch.float32).to(device)

                reconstructed_embeddings = autoencoder(noisy_embeddings)

                loss = mse_loss_fn(reconstructed_embeddings, input_embeddings)
                total_loss += loss.item()

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

    print(f'Total Loss: {total_loss/len(grouped)}')
        
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def train_vae(vae, optimizer, train_data, num_epochs=50):
    vae.train()
    grouped = train_data.groupby('study_id')
    
    for epoch in tqdm(range(num_epochs)):
        total_loss = 0
        for study_id, group in (grouped):

            embeddings = []
            for _, row in group.iterrows():
                embedding = row[0:512].to_numpy(dtype=float)  
                embeddings.append(embedding)

            embeddings = np.array(embeddings)

            if len(embeddings) > 1:
                input_embedding = embeddings[0]
                noisy_embedding = input_embedding + 0.3 * np.random.randn(*input_embedding.shape)  # Add noise

                input_embedding = torch.tensor(input_embedding, dtype=torch.float32).to(device)
                noisy_embedding = torch.tensor(noisy_embedding, dtype=torch.float32).to(device)

                optimizer.zero_grad()

                reconstructed_embedding, mu, logvar = vae(noisy_embedding)

                loss = vae_loss_function(reconstructed_embedding, input_embedding, mu, logvar)

                total_loss += loss.item()

                loss.backward()
                optimizer.step()

    print(f'Total Loss: {total_loss/len(grouped)}')

fusion_layer = AttentionFusion(embedding_dim=512).to(device)

def train_with_attention_fusion(vae, optimizer, train_data, num_epochs=50):
    vae.train()
    fusion_layer.train()
    grouped = train_data.groupby('study_id')
    
    for epoch in tqdm(range(num_epochs)):
        total_loss = 0
        for study_id, group in (grouped):
            embeddings = []
            for _, row in group.iterrows():
                embedding = row[0:512].to_numpy(dtype=float)  
                embeddings.append(embedding)

            embeddings = np.array(embeddings)

            if len(embeddings) > 1:
                embeddings_tensor = torch.tensor(embeddings, dtype=torch.float32).to(device)
                
                fused_embedding = fusion_layer(embeddings_tensor)

                noisy_embedding = fused_embedding + 0.3 * torch.randn(*fused_embedding.shape).to(device)

                optimizer.zero_grad()
                reconstructed_embedding, mu, logvar = vae(noisy_embedding)

                loss = vae_loss_function(reconstructed_embedding, fused_embedding, mu, logvar)

                total_loss += loss.item()

                loss.backward()
                optimizer.step()

    print(f'Total Loss: {total_loss/len(grouped)}')

## Training

In [12]:
autoencoder = DenoisingAutoencoder(input_dim=512, hidden_dim=256, output_dim=512)
optimizer = optim.Adam(autoencoder.parameters(), lr=0.0001)

train_denoising_autoencoder(autoencoder, optimizer, train_data, num_epochs=20)

vae = VAE(input_dim=512, latent_dim=256).to(device)
optimizer = optim.Adam(vae.parameters(), lr=0.001)

train_vae(vae, optimizer, train_data, num_epochs=20)

attention_fusion = VAE(input_dim=512, latent_dim=128).to(device)
optimizer = optim.Adam(vae.parameters(), lr=0.001)

train_with_attention_fusion(attention_fusion, optimizer, train_data, num_epochs=20)

100%|██████████| 20/20 [01:10<00:00,  3.51s/it]


Total Loss: 52.91559053136459


100%|██████████| 20/20 [01:28<00:00,  4.44s/it]


Total Loss: 27095.971876089767


100%|██████████| 20/20 [01:29<00:00,  4.46s/it]

Total Loss: 13440.086152139622





## Storing Models

In [13]:
torch.save(autoencoder.state_dict(), 'autoencoder.pth')
torch.save(vae.state_dict(), 'vae.pth')
torch.save(attention_fusion.state_dict(), 'attention_fusion.pth')

In [24]:
autoencoder.load_state_dict(torch.load('autoencoder.pth', weights_only=True))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
autoencoder.to(device)

vae.load_state_dict(torch.load('vae.pth', weights_only=True))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
vae.to(device)

attention_fusion.load_state_dict(torch.load('attention_fusion.pth', weights_only=True))
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
attention_fusion.to(device)

VAE(
  (fc1): Linear(in_features=512, out_features=256, bias=True)
  (fc21): Linear(in_features=256, out_features=128, bias=True)
  (fc22): Linear(in_features=256, out_features=128, bias=True)
  (fc3): Linear(in_features=128, out_features=256, bias=True)
  (fc4): Linear(in_features=256, out_features=512, bias=True)
)

## Evaluation Pipeline

In [29]:
def evaluate_with_cosine_distance(autoencoder, test_data):
    autoencoder.eval()
    grouped = test_data.groupby('study_id')
    
    mse_scores = []
    cosine_distances = []
    
    with torch.no_grad():
        for study_id, group in tqdm(grouped):
            embeddings = [row[:512].to_numpy(dtype=float) for _, row in group.iterrows()]
            embeddings = np.array(embeddings)

            if len(embeddings) > 1:
                input_embedding = embeddings[0] 
                noisy_embedding = input_embedding + 0.3 * np.random.randn(*input_embedding.shape)
                
                input_embedding = torch.tensor(input_embedding, dtype=torch.float32).to(device)
                noisy_embedding = torch.tensor(noisy_embedding, dtype=torch.float32).to(device)

                generated_embedding = autoencoder(noisy_embedding)

                mse = F.mse_loss(input_embedding, generated_embedding).item()
                mse_scores.append(mse)

                cosine_similarity = F.cosine_similarity(input_embedding.unsqueeze(0), generated_embedding.unsqueeze(0)).item()
                cosine_distance = 1 - cosine_similarity
                cosine_distances.append(cosine_distance)

    print(f'Average MSE on Unseen Data: {np.mean(mse_scores)}')
    print(f'Average Cosine Distance on Unseen Data: {np.mean(cosine_distances)}')
    
    
def evaluate_vae_with_cosine_distance(vae, test_data):
    vae.eval()
    grouped = test_data.groupby('study_id')
    
    mse_scores = []
    cosine_distances = []

    with torch.no_grad():
        for study_id, group in tqdm(grouped):
            embeddings = [row[:512].to_numpy(dtype=float) for _, row in group.iterrows()]
            embeddings = np.array(embeddings)

            if len(embeddings) > 1:
                input_embedding = embeddings[0]
                noisy_embedding = input_embedding + 0.3 * np.random.randn(*input_embedding.shape)  # Add noise

                input_embedding = torch.tensor(input_embedding, dtype=torch.float32).to(device)
                noisy_embedding = torch.tensor(noisy_embedding, dtype=torch.float32).to(device)

                reconstructed_embedding, _, _ = vae(noisy_embedding)

                mse = F.mse_loss(input_embedding, reconstructed_embedding).item()
                mse_scores.append(mse)

                cosine_similarity = F.cosine_similarity(input_embedding.unsqueeze(0), reconstructed_embedding.unsqueeze(0)).item()
                cosine_distance = 1 - cosine_similarity
                cosine_distances.append(cosine_distance)

    print(f'Average MSE on Unseen Data: {np.mean(mse_scores)}')
    print(f'Average Cosine Distance on Unseen Data: {np.mean(cosine_distances)}')


## Unseen Data Metrics

In [31]:
evaluate_with_cosine_distance(autoencoder, simulated_test_data)    
evaluate_vae_with_cosine_distance(vae, simulated_test_data)
evaluate_vae_with_cosine_distance(attention_fusion, simulated_test_data)

100%|██████████| 376/376 [00:00<00:00, 3875.90it/s]


Average MSE on Unseen Data: 24.237207994378846
Average Cosine Distance on Unseen Data: 0.6346563293502249


100%|██████████| 376/376 [00:00<00:00, 3578.11it/s]


Average MSE on Unseen Data: 24.343631791657415
Average Cosine Distance on Unseen Data: 0.6387393695848256


100%|██████████| 376/376 [00:00<00:00, 3699.71it/s]

Average MSE on Unseen Data: 27.50167006254196
Average Cosine Distance on Unseen Data: 0.9744737282363248



