# A Novel Approach for Three-Way Classification of Lumbar Spine Degeneration Using Pseudo-Modality Learning to Handle Missing MRI Data

## Libs

In [1]:
import torch.nn as nn
import os
import numpy as np
import pandas as pd
import torch
from tqdm import tqdm
import matplotlib.pyplot as plt
from torchvision import transforms
from transformers import AutoModel
import torchvision.models as models

### Downloading Model

In [2]:
import requests

url = "https://huggingface.co/TencentMedicalNet/MedicalNet-Resnet152/resolve/main/resnet_152.pth"

local_filename = "resnet_152.pth"

response = requests.get(url)

response.raise_for_status()

with open(local_filename, 'wb') as f:
    f.write(response.content)

print(f"Downloaded {local_filename}")

Downloaded resnet_152.pth


### Loading Model

In [3]:
resnet152 = models.resnet152(pretrained=False)

weights_path = 'resnet_152.pth'
try:
    resnet152.load_state_dict(torch.load(weights_path), strict=False)
    print("ResNet-152 model loaded with custom weights (missing keys ignored).")
except RuntimeError as e:
    print(f"Failed to load weights: {e}")

  resnet152.load_state_dict(torch.load(weights_path), strict=False)


ResNet-152 model loaded with custom weights (missing keys ignored).


## Attention Mechanism for MedicalNet152

In [4]:
class MRIEmbeddingModel(torch.nn.Module):
    def __init__(self, base_model, embedding_dim):
        super(MRIEmbeddingModel, self).__init__()
        self.base_model = base_model
        self.attention_layer = torch.nn.Linear(embedding_dim, 1) 
        self.embedding_dim = embedding_dim

    def forward(self, x):
        attention_weights = self.attention_layer(x)
        final_embedding = torch.sum(x * attention_weights, dim=1)
        return final_embedding, attention_weights

In [5]:
def attention_embeddings(csv_path, img_path, result_path_csv, result_path_pth):
    model = resnet152
    model.fc = torch.nn.Linear(model.fc.in_features, 512)
    embedding_model = MRIEmbeddingModel(model, embedding_dim=512)

    model = model.to('cuda')
    embedding_model = embedding_model.to('cuda')

    df = pd.read_csv(csv_path)

    results = []
    for index, row in tqdm(df.iterrows()):
        patient_id = str(row['study_id'])
        series_id = str(row['series_id'])
    
        series_path = os.path.join(img_path, patient_id, series_id)

        embeddings = []

        for slice_file in os.listdir(series_path):
            if slice_file.endswith('.npy'):
                slice_path = os.path.join(series_path, slice_file)
                slice_data = np.load(slice_path)

                if slice_data.ndim == 2:
                    slice_data = np.stack([slice_data] * 3, axis=0)
                elif slice_data.ndim == 3 and slice_data.shape[0] == 1:
                    slice_data = np.repeat(slice_data, 3, axis=0)

                input_tensor = torch.from_numpy(slice_data).float().to('cuda')
                input_tensor = transforms.Resize((224, 224))(input_tensor)
                input_tensor = (input_tensor - torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1).to('cuda')) / \
                               torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1).to('cuda')
                input_tensor = input_tensor.unsqueeze(0)

                with torch.no_grad():
                    embedding = model(input_tensor)
                    embeddings.append(embedding)

        if embeddings:
            slice_embeddings = torch.stack(embeddings, dim=1)
            slice_embeddings = slice_embeddings.to('cuda') 
            with torch.no_grad():
                final_embedding, attention_weights = embedding_model(slice_embeddings)

                final_embedding = final_embedding.squeeze().cpu()
                embedding_dict = {f'{i}': final_embedding[i].item() for i in range(final_embedding.shape[0])}

            embedding_dict.update({'study_id': patient_id, 'series_id': series_id})
            results.append(embedding_dict)

    results_df = pd.DataFrame(results)
    results_df.to_csv(result_path_csv, index=False)

    torch.save(embedding_model.state_dict(), result_path_pth)
    print(f"Embeddings with attention completed and saved to {result_path_csv}")

### Generating Embddings

In [6]:
attention_embeddings('/kaggle/input/preprocessed-dataset/train_data_AT2.csv', '/kaggle/input/preprocessed-dataset/grey_scale_train', 'AT2_attention_embeddings_gsl.csv', 'AT2_attention_model_gsl.pth')
attention_embeddings('/kaggle/input/preprocessed-dataset/train_data_AT2.csv', '/kaggle/input/preprocessed-dataset/hist_norm_train', 'AT2_attention_embeddings_hist.csv', 'AT2_attention_model_hist.pth')
attention_embeddings('/kaggle/input/preprocessed-dataset/train_data_ST2.csv', '/kaggle/input/preprocessed-dataset/grey_scale_train', 'ST2_attention_embeddings_gsl.csv', 'ST2_attention_model_gsl.pth')
attention_embeddings('/kaggle/input/preprocessed-dataset/train_data_ST2.csv', '/kaggle/input/preprocessed-dataset/hist_norm_train', 'ST2_attention_embeddings_hist.csv', 'ST2_attention_model_hist.pth')
attention_embeddings('/kaggle/input/preprocessed-dataset/train_data_ST1.csv', '/kaggle/input/preprocessed-dataset/grey_scale_train', 'ST1_attention_embeddings_gsl.csv', 'ST1_attention_model_gsl.pth')
attention_embeddings('/kaggle/input/preprocessed-dataset/train_data_ST1.csv', '/kaggle/input/preprocessed-dataset/hist_norm_train', 'ST1_attention_embeddings_hist.csv', 'ST1_attention_model_hist.pth')

2226it [43:50,  1.18s/it]


Embeddings with attention completed and saved to AT2_attention_embeddings_gsl.csv


2226it [43:54,  1.18s/it]


Embeddings with attention completed and saved to AT2_attention_embeddings_hist.csv


1876it [18:31,  1.69it/s]


Embeddings with attention completed and saved to ST2_attention_embeddings_gsl.csv


1876it [18:33,  1.68it/s]


Embeddings with attention completed and saved to ST2_attention_embeddings_hist.csv


1881it [19:11,  1.63it/s]


Embeddings with attention completed and saved to ST1_attention_embeddings_gsl.csv


1881it [18:52,  1.66it/s]


Embeddings with attention completed and saved to ST1_attention_embeddings_hist.csv
