In [3]:
import os
import torch
import monai
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged, CropForegroundd,
    RandCropByPosNegLabeld, RandFlipd, RandRotate90d, RandShiftIntensityd, EnsureTyped, DivisiblePadd
)
from monai.data import DataLoader, CacheDataset
from monai.networks.nets import SwinUNETR
from monai.utils import set_determinism
from monai.data.image_reader import NibabelReader
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from sklearn.model_selection import train_test_split
import numpy as np
import pandas as pd
import torch.nn.functional as F
from xgboost import XGBRegressor
from sklearn.model_selection import RandomizedSearchCV
import re

In [4]:
 # Define directories
train_path = 'dataset/MICCAI_BraTS2020_TrainingData/'
# val_path = 'dataset/MICCAI_BraTS2020_ValidationData/'
modality_keys = ["flair"]
in_channels = len(modality_keys)
out_channels = len(modality_keys) 

In [5]:
def create_data_list(data_dir, patient_ids, lables, modality_keys):
    data_list = []
    for idx, patient in enumerate(patient_ids):
        patient_dir = f"{data_dir}{patient}/"
        if os.path.isdir(patient_dir):
            data_dict = {key: os.path.join(patient_dir, f"{patient}_{key}.nii") for key in modality_keys}
            data_dict['lable'] = lables[idx]
            data_list.append(data_dict)

    return data_list

In [6]:
def preprocess_labels(csv_file_path):
    df = pd.read_csv(csv_file_path)
    
    def extract_number(value):
        if isinstance(value, str):
            match = re.search(r'\d+', value)
            return int(match.group()) if match else None
        return value

    df['Survival_days'] = df['Survival_days'].apply(extract_number)
    df['Survival_days'] = pd.to_numeric(df['Survival_days'], errors='coerce')
    df = df.dropna(subset=['Survival_days'])
    df['Survival_days'] = df['Survival_days'].astype(int)
    
    return df['Survival_days'].values, df['Brats20ID'].values

In [7]:
all_labels, all_data_ids = preprocess_labels(f'{train_path}/survival_info.csv')
train_id, val_id, train_labels, val_labels = train_test_split(all_data_ids, all_labels, test_size=0.2, random_state=42)

train_data_list = create_data_list(train_path, train_id, train_labels, modality_keys)
val_data_list = create_data_list(train_path, val_id, val_labels, modality_keys)


# since we dont have lables for validation dataset not using it for validation
# val_data_list = create_data_list(val_path, modality_keys)
# val_labels = preprocess_labels(f'{val_path}/survival_evaluation.csv')



In [8]:
train_data_list[0]

{'flair': 'dataset/MICCAI_BraTS2020_TrainingData/BraTS20_Training_168/BraTS20_Training_168_flair.nii',
 'lable': 291}

In [9]:
def get_transforms(modality_keys, pixdim=(2.0, 2.0, 2.0), spatial_size=(64, 64, 64)):
    transforms = Compose(
        [
            LoadImaged(keys=modality_keys, reader=NibabelReader()),
            EnsureChannelFirstd(keys=modality_keys),
            Spacingd(keys=modality_keys, pixdim=pixdim, mode=("bilinear")),
            Orientationd(keys=modality_keys, axcodes="RAS"),
            ScaleIntensityRanged(keys=modality_keys, a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
            CropForegroundd(keys=modality_keys, source_key=modality_keys[0], allow_smaller=True),
            DivisiblePadd(keys=modality_keys, k=32),
            RandCropByPosNegLabeld(keys=modality_keys, label_key=modality_keys[0], spatial_size=spatial_size, pos=1, neg=1, num_samples=1, image_key=modality_keys[0], image_threshold=0),
            RandFlipd(keys=modality_keys, prob=0.5, spatial_axis=0),
            RandFlipd(keys=modality_keys, prob=0.5, spatial_axis=1),
            RandFlipd(keys=modality_keys, prob=0.5, spatial_axis=2),
            RandRotate90d(keys=modality_keys, prob=0.5, max_k=3),
            RandShiftIntensityd(keys=modality_keys, offsets=0.10, prob=0.5),
            EnsureTyped(keys=modality_keys),
        ]
    )
    return transforms

In [10]:
def get_val_transforms(modality_keys, pixdim=(2.0, 2.0, 2.0)):
    transforms = Compose(
        [
            LoadImaged(keys=modality_keys, reader=NibabelReader()),
            EnsureChannelFirstd(keys=modality_keys),
            Spacingd(keys=modality_keys, pixdim=pixdim, mode=("bilinear")),
            Orientationd(keys=modality_keys, axcodes="RAS"),
            ScaleIntensityRanged(keys=modality_keys, a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
            CropForegroundd(keys=modality_keys, source_key=modality_keys[0], allow_smaller=True),
            DivisiblePadd(keys=modality_keys, k=32),
            EnsureTyped(keys=modality_keys),
        ]
    )
    return transforms

In [11]:
train_transforms = get_transforms(modality_keys)
val_transforms = get_val_transforms(modality_keys)


In [12]:
train_ds = CacheDataset(
    data=train_data_list,
    transform=train_transforms,
    cache_rate=0.5,
    num_workers=4,
)
train_loader = DataLoader(train_ds,  batch_size=2, shuffle=True, num_workers=4)


Loading dataset: 100%|██████████| 94/94 [00:16<00:00,  5.69it/s]


In [13]:
# for data in train_loader:
#     print(data['flair'].shape)

In [14]:
val_ds = CacheDataset(
    data=val_data_list,
    transform=val_transforms,
    cache_rate=0.5,
    num_workers=4,
)
val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, num_workers=2)

Loading dataset: 100%|██████████| 24/24 [00:04<00:00,  5.24it/s]


In [15]:
# Load the pretrained model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = SwinUNETR(
    img_size=(64, 64, 64),
    in_channels=in_channels,
    out_channels=out_channels,
    feature_size=48,
    use_checkpoint=True,
).to(device)


# Load the saved weights
modality_used = "_".join(modality_keys)
model_save_path = f"model_saved/swin_unetr_{modality_used}_temp.pth"

model.load_state_dict(torch.load(model_save_path))
model.eval()

SwinUNETR(
  (swinViT): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv3d(1, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers1): ModuleList(
      (0): BasicLayer(
        (blocks): ModuleList(
          (0-1): 2 x SwinTransformerBlock(
            (norm1): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=48, out_features=144, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=48, out_features=48, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (mlp): MLPBlock(
              (linear1): Linear(in_features=48, out_features=192, bias=True)
              (linear2): Linear(in_feature

In [16]:
# # 2. Modify the SwinUNETR class to add a method for feature extraction
# class FeatureExtractorSwinUNETR(SwinUNETR):
#     def __init__(self, *args, **kwargs):
#         super().__init__(*args, **kwargs)
    
#     def extract_features(self, x):
#         print(" at start",x.size())
#         hidden_states = self.swinViT(x, self.normalize)
#         enc0 = self.encoder1(x)
#         print(" after enco 1",x.size())
#         enc1 = self.encoder2(hidden_states[0])
#         print(" after enco 2",enc1.size())
#         enc2 = self.encoder3(hidden_states[1])
#         print(" after enco 3",enc2.size())
#         enc3 = self.encoder4(hidden_states[2])
#         print(" after enco 4",enc3.size())
#         # dec3 = self.encoder10(hidden_states[3])
#         # print(" after dec3",dec3.size())
#         # return torch.cat([enc0, enc1, enc2, enc3, dec3], dim=1)
#         return torch.cat([enc0, enc1, enc2, enc3], dim=1)

In [17]:
class FeatureExtractorSwinUNETR(SwinUNETR):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def extract_features(self, x):
        hidden_states = self.swinViT(x, self.normalize)
        
        # Adaptive pooling to resize all hidden states to a common size
        pooled_states = []
        for state in hidden_states:
            # Adaptive average pooling to 4x4x4
            pooled = F.adaptive_avg_pool3d(state, (4, 4, 4))
            pooled_states.append(pooled)
        
        # Concatenate the pooled states
        return torch.cat(pooled_states, dim=1)


In [18]:
#  Create an instance of the feature extractor
feature_extractor = FeatureExtractorSwinUNETR(
    img_size=(64, 64, 64),
    in_channels=in_channels,
    out_channels=out_channels,
    feature_size=48,
    use_checkpoint=True,
).to(device)
feature_extractor.load_state_dict(model.state_dict())
feature_extractor.eval()

FeatureExtractorSwinUNETR(
  (swinViT): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv3d(1, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers1): ModuleList(
      (0): BasicLayer(
        (blocks): ModuleList(
          (0-1): 2 x SwinTransformerBlock(
            (norm1): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=48, out_features=144, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=48, out_features=48, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (mlp): MLPBlock(
              (linear1): Linear(in_features=48, out_features=192, bias=True)
              (linear2): L

In [19]:

# # 3. Extract features from your dataset
# def extract_features_from_dataset(dataset, feature_extractor):
#     features = []
#     labels = []
    
#     for data in dataset:
#         inputs = torch.cat([data[key] for key in modality_keys], dim=1).to(device)
#         with torch.no_grad():
#             feature = feature_extractor.extract_features(inputs)
        
#         # Assuming the feature is 5D (batch, channels, depth, height, width)
#         # We'll flatten it to 2D (batch, features)
#         feature = feature.view(feature.size(0), -1)
        
#         features.append(feature.cpu().numpy())
#         labels.append(data['label'].cpu().numpy())  # Assuming you have labels in your dataset
    
#     return np.concatenate(features), np.concatenate(labels)


In [20]:
def extract_features_from_dataset(dataloader, feature_extractor):
    features = []
    labels = []
    
    for batch in dataloader:
        inputs = torch.cat([batch[key] for key in modality_keys], dim=1).to(device)
        with torch.no_grad():
            feature = feature_extractor.extract_features(inputs)
        
        # Global average pooling to reduce spatial dimensions
        feature = torch.mean(feature, dim=[2, 3, 4])
        
        features.append(feature.cpu().numpy())
        # labels.append(batch['label'].cpu().numpy())  # Assuming you have labels in your dataset
    
    return np.concatenate(features)


In [21]:

# Extract features for training and validation sets
train_features = extract_features_from_dataset(train_loader, feature_extractor)
val_features = extract_features_from_dataset(val_loader, feature_extractor)




In [22]:
# Define XGBoost regression model
xgb_model = XGBRegressor(random_state=42)


In [23]:

# Define hyperparameter search space for regression
param_dist = {
    'n_estimators': [100, 200, 300, 400, 500],
    'max_depth': [3, 4, 5, 6, 7, 8],
    'learning_rate': [0.01, 0.05, 0.1, 0.2],
    'subsample': [0.6, 0.7, 0.8, 0.9, 1.0],
    'colsample_bytree': [0.6, 0.7, 0.8, 0.9, 1.0],
    'min_child_weight': [1, 2, 3, 4, 5]
}

In [24]:
# Perform RandomizedSearchCV
random_search = RandomizedSearchCV(xgb_model, param_distributions=param_dist, n_iter=25, cv=5, 
                                   scoring='neg_mean_squared_error', random_state=42, n_jobs=-1)


In [25]:
# Fit the model
random_search.fit(train_features, train_labels)


In [26]:
# Get the best model
best_xgb_model = random_search.best_estimator_
print("Best parameters found: ", random_search.best_params_)
 

Best parameters found:  {'subsample': 0.7, 'n_estimators': 400, 'min_child_weight': 4, 'max_depth': 6, 'learning_rate': 0.01, 'colsample_bytree': 0.6}


In [27]:
# Make predictions on validation set
val_pred = best_xgb_model.predict(val_features)


In [29]:
# Evaluate the model using regression metrics
mse = mean_squared_error(val_labels, val_pred)
rmse = np.sqrt(mse)
r2 = r2_score(val_labels, val_pred)
mae = mean_absolute_error(val_labels, val_pred)

In [30]:

print(f"Mean Squared Error: {mse}")
print(f"Root Mean Squared Error: {rmse}")
print(f"R-squared Score: {r2}")
print(f"Mean Absolute Error: {mae}")

Mean Squared Error: 171330.5080733764
Root Mean Squared Error: 413.92089591294666
R-squared Score: -0.036553144454956055
Mean Absolute Error: 297.25440470377606


In [31]:
# Feature importance
feature_importance = best_xgb_model.feature_importances_
sorted_idx = np.argsort(feature_importance)
sorted_features = [f"Feature {i}" for i in sorted_idx]
sorted_importance = feature_importance[sorted_idx]

In [32]:
print("Feature Importances:")
for feat, imp in zip(sorted_features[-10:], sorted_importance[-10:]):
    print(f"{feat}: {imp}")

Feature Importances:
Feature 245: 0.0036692579742521048
Feature 683: 0.00410101655870676
Feature 696: 0.004411330446600914
Feature 83: 0.004494997672736645
Feature 1136: 0.004573773127049208
Feature 318: 0.0047732265666127205
Feature 946: 0.005056614521890879
Feature 1267: 0.005556359887123108
Feature 260: 0.005830122157931328
Feature 279: 0.007581779733300209


In [33]:
print("Train features shape:", train_features.shape)
print("Validation features shape:", val_features.shape)

Train features shape: (188, 1488)
Validation features shape: (48, 1488)
