# **A Novel Approach for Three-Way Classification of Lumbar Spine Degeneration Using Pseudo-Modality Learning to Handle Missing MRI Data**

## **Embeddings Architecture**

![MRI Processor Architecture](https://github.com/ahmedembeddedxx/lumbar-spine-degenerative-classification/blob/main/architecture/classifiers-architecture/mri-processor.png?raw=true)


## **Modules**

In [1]:
import torch.nn as nn
import os
import numpy as np
import pandas as pd
import torch
from torchvision import models, transforms
from tqdm import tqdm
import matplotlib.pyplot as plt

#### **Attention Layers Architecture**

In [2]:
class SliceAttention(nn.Module):
    def __init__(self, embedding_dim):
        super(SliceAttention, self).__init__()
        self.attention_fc = nn.Linear(embedding_dim, 1)

    def forward(self, slice_embeddings):
        attention_scores = self.attention_fc(slice_embeddings)
        attention_weights = torch.softmax(attention_scores, dim=1)
        
        weighted_embeddings = slice_embeddings * attention_weights
        mri_embedding = torch.sum(weighted_embeddings, dim=1)
        return mri_embedding, attention_weights

class MRIEmbeddingModel(nn.Module):
    def __init__(self, base_model, embedding_dim):
        super(MRIEmbeddingModel, self).__init__()
        self.base_model = base_model
        self.attention = SliceAttention(embedding_dim)
    
    def forward(self, slice_embeddings):
        mri_embedding, attention_weights = self.attention(slice_embeddings)
        return mri_embedding, attention_weights

#### **Embeddings Generator using ResNet50**

In [3]:
def attention_embeddings(csv_path, img_path, result_path_csv, result_path_pth):
    model = models.resnet50(pretrained=True)
    model.fc = torch.nn.Linear(model.fc.in_features, 512)
    embedding_model = MRIEmbeddingModel(model, embedding_dim=512)

    model = model.to('cuda')
    embedding_model = embedding_model.to('cuda')

    df = pd.read_csv(csv_path)

    results = []
    for index, row in tqdm(df.iterrows()):
        patient_id = str(row['study_id'])
        series_id = str(row['series_id'])

        series_path = os.path.join(img_path, patient_id, series_id)

        embeddings = []

        for slice_file in os.listdir(series_path):
            if slice_file.endswith('.npy'):
                slice_path = os.path.join(series_path, slice_file)
                slice_data = np.load(slice_path)

                if slice_data.ndim == 2:
                    slice_data = np.stack([slice_data] * 3, axis=0)
                elif slice_data.ndim == 3 and slice_data.shape[0] == 1:
                    slice_data = np.repeat(slice_data, 3, axis=0)

                input_tensor = torch.from_numpy(slice_data).float().to('cuda')
                input_tensor = transforms.Resize((224, 224))(input_tensor)
                input_tensor = (input_tensor - torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to('cuda')) / \
                               torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to('cuda')
                input_tensor = input_tensor.unsqueeze(0)

                with torch.no_grad():
                    embedding = model(input_tensor)
                    embeddings.append(embedding)

        if embeddings:
            slice_embeddings = torch.stack(embeddings, dim=1)
            slice_embeddings = slice_embeddings.to('cuda') 
            with torch.no_grad():
                final_embedding, attention_weights = embedding_model(slice_embeddings)

                final_embedding = final_embedding.squeeze().cpu()
                embedding_dict = {f'{i}': final_embedding[i].item() for i in range(final_embedding.shape[0])}

            embedding_dict.update({'study_id': patient_id, 'series_id': series_id})
            results.append(embedding_dict)

    results_df = pd.DataFrame(results)
    results_df.to_csv(result_path_csv, index=False)

    torch.save(embedding_model.state_dict(), result_path_pth)
    print(f"Embeddings with attention completed and saved to {result_path_csv}")

## **Running Inference**

In [4]:
attention_embeddings('/kaggle/input/preprocessed-dataset/train_data_AT2.csv', '/kaggle/input/preprocessed-dataset/grey_scale_train', 'AT2_attention_embeddings_gsl.csv', 'AT2_attention_model_gsl.pth')
attention_embeddings('/kaggle/input/preprocessed-dataset/train_data_AT2.csv', '/kaggle/input/preprocessed-dataset/hist_norm_train', 'AT2_attention_embeddings_hist.csv', 'AT2_attention_model_hist.pth')
attention_embeddings('/kaggle/input/preprocessed-dataset/train_data_ST2.csv', '/kaggle/input/preprocessed-dataset/grey_scale_train', 'ST2_attention_embeddings_gsl.csv', 'ST2_attention_model_gsl.pth')
attention_embeddings('/kaggle/input/preprocessed-dataset/train_data_ST2.csv', '/kaggle/input/preprocessed-dataset/hist_norm_train', 'ST2_attention_embeddings_hist.csv', 'ST2_attention_model_hist.pth')
attention_embeddings('/kaggle/input/preprocessed-dataset/train_data_ST1.csv', '/kaggle/input/preprocessed-dataset/grey_scale_train', 'ST1_attention_embeddings_gsl.csv', 'ST1_attention_model_gsl.pth')
attention_embeddings('/kaggle/input/preprocessed-dataset/train_data_ST1.csv', '/kaggle/input/preprocessed-dataset/hist_norm_train', 'ST1_attention_embeddings_hist.csv', 'ST1_attention_model_hist.pth')

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 129MB/s]
2226it [24:33,  1.51it/s]


Embeddings with attention completed and saved to AT2_attention_embeddings_gsl.csv


2226it [24:21,  1.52it/s]


Embeddings with attention completed and saved to AT2_attention_embeddings_hist.csv


1876it [10:24,  3.00it/s]


Embeddings with attention completed and saved to ST2_attention_embeddings_gsl.csv


1876it [10:23,  3.01it/s]


Embeddings with attention completed and saved to ST2_attention_embeddings_hist.csv


1881it [10:34,  2.97it/s]


Embeddings with attention completed and saved to ST1_attention_embeddings_gsl.csv


1881it [10:30,  2.98it/s]


Embeddings with attention completed and saved to ST1_attention_embeddings_hist.csv


In [5]:
#embedding_model = MRIEmbeddingModel(model, embedding_dim=512)
#embedding_model.load_state_dict(torch.load(result_path_pth))
#embedding_model.eval()