In [1]:
import os
import torch
import monai
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged, CropForegroundd,
    RandCropByPosNegLabeld, RandFlipd, RandRotate90d, RandShiftIntensityd, EnsureTyped, DivisiblePadd
)
from monai.data import DataLoader, CacheDataset
from monai.networks.nets import SwinUNETR
from monai.utils import set_determinism
from monai.data.image_reader import NibabelReader
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np
import pandas as pd
import lightgbm as lgb

In [None]:
# Define function to extract features from the SwinUNETR model
def extract_features(modality_keys, data_path, model_path, layer_name='encoder1'):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    # Load the pre-trained model
    model = SwinUNETR(
        img_size=(64, 64, 64),
        in_channels=len(modality_keys),
        out_channels=len(modality_keys),
        feature_size=48,
        use_checkpoint=True,
    ).to(device)
    model.load_state_dict(torch.load(model_path))
    
    # Set model to evaluation mode
    model.eval()
    
    # Modify the forward function to get intermediate features
    def forward_hook(module, input, output):
        return output

    handle = getattr(model, layer_name).register_forward_hook(forward_hook)
    
    transforms = Compose(
        [
            LoadImaged(keys=modality_keys, reader=NibabelReader()),
            EnsureChannelFirstd(keys=modality_keys),
            Spacingd(keys=modality_keys, pixdim=(2.0, 2.0, 2.0), mode=("bilinear")),
            Orientationd(keys=modality_keys, axcodes="RAS"),
            ScaleIntensityRanged(keys=modality_keys, a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
            CropForegroundd(keys=modality_keys, source_key=modality_keys[0], allow_smaller=True),
            DivisiblePadd(keys=modality_keys, k=32),
            EnsureTyped(keys=modality_keys),
        ]
    )
    
    data_list = create_data_list(data_path, modality_keys)
    dataset = CacheDataset(data=data_list, transform=transforms, cache_rate=0.5, num_workers=4)
    loader = DataLoader(dataset, batch_size=1, shuffle=False, num_workers=4)
    
    features = []
    labels = []  # Assuming labels are available
    
    with torch.no_grad():
        for batch_data in loader:
            inputs = torch.cat([batch_data[key] for key in modality_keys], dim=1).to(device)
            outputs = model(inputs)
            feature_vector = handle.outputs.cpu().numpy().flatten()
            features.append(feature_vector)
            # Append corresponding label here
            # labels.append(batch_data['label'])
    
    handle.remove()
    return np.array(features), np.array(labels)


In [None]:
 

# Extract features for each modality
modality_keys = ["flair"]
model_path = "model_saved/swin_unetr_flair.pth"
train_path = 'dataset/MICCAI_BraTS2020_TrainingData/'

all_labels, all_data_ids = preprocess_labels(f'{train_path}/survival_info.csv')
train_id, val_id, train_labels, val_labels = train_test_split(all_data_ids, all_labels, test_size=0.2, random_state=42)

features, labels = extract_features(modality_keys, train_path, model_path)
