In [1]:
import os
from fastai.vision.all import *
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
import torchvision.models as models
import torch.nn.functional as F
import torch.nn as nn
import json
from fastai.learner import Learner
from fastai.data.core import DataLoaders
from fastai.metrics import accuracy
from fastai.losses import CrossEntropyLossFlat
from fastai.callback.all import SaveModelCallback, EarlyStoppingCallback
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np
import torch
from PIL import Image  
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import pydicom
%matplotlib inline

In [2]:
# # Load datasets
# df = pd.read_csv('Final_Datasets/train_resnet_dxa.csv')
# df_test = pd.read_csv('Final_Datasets/test_data_incidence.csv')
# controls = pd.read_csv('Final_Datasets/imbalanced_control_iid.csv')

# # Remove test cases and controls
# test_iids = set(df_test['IID'])

# # Get all cases (CAD = 1) directly from df, excluding test data
# cases_train = df[(df['CAD'] == 1) & (~df['IID'].isin(test_iids))]

# # Get all controls specifically from controls.csv, excluding test data
# controls_train = controls[(~controls['IID'].isin(test_iids))]

# # Combine the imbalanced dataset
# df_imbalanced_train = pd.concat([cases_train, controls_train]).reset_index(drop=True)

# # Save the imbalanced dataset
# df_imbalanced_train.to_csv('Final_Datasets/train_imbalanced_dxa.csv', sep=',', index=False)

# # Report the dataset statistics
# cases_count = len(cases_train)
# controls_count = len(controls_train)
# print(f"Number of cases: {cases_count}")
# print(f"Number of controls: {controls_count}")
# print(f"Imbalance ratio (controls to cases): {controls_count / cases_count:.2f}")

In [3]:
# Backbone Class for Feature Extraction
class Backbone(nn.Module):
    def __init__(self):
        super().__init__()
        base_model = models.resnet50(pretrained=False)
        encoder_layers = list(base_model.children())
        self.backbone = nn.Sequential(*encoder_layers[:9])  # Use the first 9 layers of ResNet50

    def forward(self, x):
        return self.backbone(x)


# Classifier Class (not used in embeddings extraction but included for completeness)
class Classifier(nn.Module):
    def __init__(self, num_classes):
        super().__init__()
        self.drop_out = nn.Dropout()
        self.linear = nn.Linear(2048, num_classes)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.drop_out(x)
        x = self.linear(x)
        return x


# DXA Dataset Class
class DXADataset(Dataset):
    def __init__(self, dataframe, image_column_name, label_column_name, transform=None):
        self.dataframe = dataframe
        self.image_column_name = image_column_name
        self.label_column_name = label_column_name
        self.transform = transform

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        img_path = self.dataframe[self.image_column_name].iloc[idx]
        label = self.dataframe[self.label_column_name].iloc[idx]

        # Load the DICOM file
        dicom = pydicom.dcmread(img_path)
        image = dicom.pixel_array.astype(np.float32)

        # Normalize the image
        image -= np.min(image)
        if np.max(image) != 0:
            image /= np.max(image)

        # Convert to PIL Image and ensure grayscale
        image = Image.fromarray((image * 255).astype(np.uint8))

        # Convert grayscale to 3-channel
        transform_to_3_channel = transforms.Compose([
            transforms.Grayscale(num_output_channels=3)
        ])
        image = transform_to_3_channel(image)

        # Apply transformations if specified
        if self.transform:
            image = self.transform(image)

        label = torch.tensor(label, dtype=torch.long)
        return image, label, img_path


# DXA Embeddings Extraction Class
class DXADiseaseModelEmbeddings:
    def __init__(self, train_df_path, test_df_path, image_column_name, label_column_name, batch_size=32, model_name='dxa_radIM_resnet50_model_nov'):
        self.train_df_path = train_df_path
        self.test_df_path = test_df_path
        self.image_column_name = image_column_name
        self.label_column_name = label_column_name
        self.batch_size = batch_size
        self.model_name = model_name
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        self._prepare_data()
        self._prepare_model()

    def _prepare_data(self):
        train_df = pd.read_csv(self.train_df_path)
        test_df = pd.read_csv(self.test_df_path)

        # DXA Dataset
        self.train_dataset = DXADataset(train_df, self.image_column_name, self.label_column_name, transform=self._get_transforms())
        self.test_dataset = DXADataset(test_df, self.image_column_name, self.label_column_name, transform=self._get_transforms())

        self.train_loader = DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=8)
        self.test_loader = DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=8)

    def _get_transforms(self):
        return transforms.Compose([
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    def _prepare_model(self):
        backbone = Backbone()
        classifier = Classifier(num_classes=2)
        model = nn.Sequential(backbone, classifier)
        model.to(self.device)
        self.model = model

        # Load the fine-tuned model
        self.model.load_state_dict(torch.load(f'{self.model_name}.pth'))
        self.model.eval()

    def extract_embeddings(self, loader):
        embeddings, labels, paths = [], [], []
        with torch.no_grad():
            for images, label_batch, path_batch in loader:  # Add path_batch
                images = images.to(self.device)
            
                # Pass through the backbone only
                x = self.model[0](images)  # Extract features from the Backbone
                x = torch.flatten(x, 1)   # Flatten after global pooling
            
                embeddings.append(x.cpu().numpy())
                labels.append(label_batch.cpu().numpy())
                paths.extend(path_batch)  # Collect the image paths
        embeddings = np.concatenate(embeddings)
        labels = np.concatenate(labels)
        return embeddings, labels, paths

    # Updated generate_embeddings_dataframe function
    def generate_embeddings_dataframe(self, embeddings, labels, paths):
        """
        Creates a Pandas DataFrame from embeddings, labels, and image paths.

        Args:
            embeddings (numpy.ndarray): The extracted embeddings.
            labels (numpy.ndarray): The labels corresponding to the embeddings.
            paths (list of str): The image paths.

        Returns:
            pd.DataFrame: A DataFrame with serialized embeddings and metadata.
        """
        # Serialize embeddings as JSON strings for safe CSV storage
        df = pd.DataFrame({
            'image_path': paths,
            'embedding': [json.dumps(emb.tolist()) for emb in embeddings],
            'label': labels
        })
        return df

    # Updated extract_and_save_embeddings function
    def extract_and_save_embeddings(self):
        """
        Extracts embeddings for train and test datasets and saves them as CSV files.

        The embeddings are serialized as JSON strings for robust CSV storage.
        """
        # Extract training embeddings
        train_embeddings, train_labels, train_paths = self.extract_embeddings(self.train_loader)
        train_df = self.generate_embeddings_dataframe(train_embeddings, train_labels, train_paths)

        # Extract test embeddings
        test_embeddings, test_labels, test_paths = self.extract_embeddings(self.test_loader)
        test_df = self.generate_embeddings_dataframe(test_embeddings, test_labels, test_paths)

        # Save DataFrames
        train_df.to_csv('train_embeddings_dxa_nov_imbalanced.csv', index=False)
        test_df.to_csv('test_embeddings_dxa_nov.csv', index=False)


In [4]:
if __name__ == "__main__":
    model = DXADiseaseModelEmbeddings(
        train_df_path='Final_Datasets/train_imbalanced_dxa.csv',  # Path to DXA train dataset
        test_df_path='Final_Datasets/test_data_incidence.csv',    # Path to DXA test dataset
        image_column_name='FilePath_dxa',                        # DXA image file paths
        label_column_name='CAD',                                 # CAD labels
        model_name='models/dxa_radIM_resnet50_model_nov'         # Fine-tuned DXA model name
    )
    
    # Extract embeddings and save to CSV
    model.extract_and_save_embeddings()

  self.model.load_state_dict(torch.load(f'{self.model_name}.pth'))
