In [26]:
import os
import torch
import monai
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged, CropForegroundd,
    RandFlipd, RandRotate90d, RandShiftIntensityd, EnsureTyped,  ResizeWithPadOrCropd
)
from monai.data import DataLoader, CacheDataset
from monai.networks.nets import SwinUNETR
from monai.utils import set_determinism
from monai.data.image_reader import NibabelReader
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np
import pandas as pd
import torch.nn.functional as F
from sklearn.model_selection import RandomizedSearchCV
from sklearn.ensemble import RandomForestRegressor
import re
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report

In [27]:
 # Define directories
train_path = 'dataset/MICCAI_BraTS2020_TrainingData/'
# val_path = 'dataset/MICCAI_BraTS2020_ValidationData/'
modality_keys = ["flair", "t1ce", "t2"]
# modality_keys = ["flair"]
in_channels = len(modality_keys)
out_channels = len(modality_keys) 

In [28]:
def create_data_list(data_dir, patient_ids, lables, modality_keys):
    data_list = []
    for idx, patient in enumerate(patient_ids):
        patient_dir = f"{data_dir}{patient}/"
        if os.path.isdir(patient_dir):
            data_dict = {key: os.path.join(patient_dir, f"{patient}_{key}.nii") for key in modality_keys}
            data_dict['lable'] = lables[idx]
            data_list.append(data_dict)

    return data_list

In [29]:
def categorize_days(days):
    categorized_days = []
    for day in days:
        if day > 455:
            categorized_days.append(2)
        elif day < 304:
            categorized_days.append(0)
        else:
            categorized_days.append(1)
    return categorized_days

In [30]:

def evaluate_predictions(y_true, y_pred):
    # Calculate metrics
    accuracy = accuracy_score(y_true, y_pred)
    precision = precision_score(y_true, y_pred, average='weighted')
    recall = recall_score(y_true, y_pred, average='weighted')
    f1 = f1_score(y_true, y_pred, average='weighted')
    conf_matrix = confusion_matrix(y_true, y_pred)
    class_report = classification_report(y_true, y_pred, output_dict=True)
    
    # Print metrics
    print(f'Accuracy: {accuracy:.2f}')
    print(f'Precision: {precision:.2f}')
    print(f'Recall: {recall:.2f}')
    print(f'F1 Score: {f1:.2f}')
    print('Confusion Matrix:')
    print(conf_matrix)
    print('Classification Report:')
    print(classification_report(y_true, y_pred))
    
    # Return metrics in a dictionary
    return {
        'accuracy': accuracy,
        'precision': precision,
        'recall': recall,
        'f1_score': f1,
        'confusion_matrix': conf_matrix,
        'classification_report': class_report
    }


In [31]:
def preprocess_labels(csv_file_path):
    df = pd.read_csv(csv_file_path)
    
    def extract_number(value):
        if isinstance(value, str):
            match = re.search(r'\d+', value)
            return int(match.group()) if match else None
        return value

    df['Survival_days'] = df['Survival_days'].apply(extract_number)
    df['Survival_days'] = pd.to_numeric(df['Survival_days'], errors='coerce')
    df = df.dropna(subset=['Survival_days'])
    df['Survival_days'] = df['Survival_days'].astype(int)
    
    return df['Survival_days'].values, df['Brats20ID'].values

In [32]:
all_labels, all_data_ids = preprocess_labels(f'{train_path}/survival_info.csv')
train_id, val_id, train_labels, val_labels = train_test_split(all_data_ids, all_labels, test_size=0.2, random_state=42)

train_data_list = create_data_list(train_path, train_id, train_labels, modality_keys)
val_data_list = create_data_list(train_path, val_id, val_labels, modality_keys)


# since we dont have lables for validation dataset not using it for validation
# val_data_list = create_data_list(val_path, modality_keys)
# val_labels = preprocess_labels(f'{val_path}/survival_evaluation.csv')



In [33]:
def get_transforms(modality_keys, pixdim=(1.0, 1.0, 1.0), is_train=True):
    transform_list = [
        LoadImaged(keys=modality_keys, reader=NibabelReader()),
        EnsureChannelFirstd(keys=modality_keys),
        Spacingd(keys=modality_keys, pixdim=pixdim, mode=("bilinear")),
        Orientationd(keys=modality_keys, axcodes="RAS"),
        ScaleIntensityRanged(keys=modality_keys, a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=modality_keys, source_key=modality_keys[0], allow_smaller=True),
        ResizeWithPadOrCropd(keys=modality_keys, spatial_size=(256, 256, 160)),  # Kept original size
        EnsureTyped(keys=modality_keys),
    ]
    
    if is_train:
        transform_list.extend([
            RandFlipd(keys=modality_keys, prob=0.5, spatial_axis=0),
            RandFlipd(keys=modality_keys, prob=0.5, spatial_axis=1),
            RandFlipd(keys=modality_keys, prob=0.5, spatial_axis=2),
            RandRotate90d(keys=modality_keys, prob=0.5, max_k=3),
            RandShiftIntensityd(keys=modality_keys, offsets=0.10, prob=0.5),
        ])
    
    return Compose(transform_list)

In [34]:
train_transforms = get_transforms(modality_keys, is_train=True)
val_transforms = get_transforms(modality_keys, is_train=False)


In [35]:
train_ds = CacheDataset(
    data=train_data_list,
    transform=train_transforms,
    cache_rate=0.5,
    num_workers=4,
)
train_loader = DataLoader(train_ds,  batch_size=2, shuffle=True, num_workers=4)


Loading dataset: 100%|██████████| 94/94 [05:03<00:00,  3.23s/it]


In [36]:
val_ds = CacheDataset(
    data=val_data_list,
    transform=val_transforms,
    cache_rate=0.5,
    num_workers=4,
)
val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, num_workers=2)

Loading dataset: 100%|██████████| 24/24 [01:08<00:00,  2.87s/it]


In [37]:
# Load the pretrained model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 
model = SwinUNETR(
    img_size=(256, 256, 160),  # Kept original size
    in_channels=len(modality_keys),
    out_channels=len(modality_keys),
    feature_size=24,  # Reduced from 48 to save memory
    use_checkpoint=True,
).to(device)

# Load the saved weights
modality_used = "_".join(modality_keys)
model_save_path = f"model_saved/swin_unetr_{modality_used}_best.pth"

model.load_state_dict(torch.load(model_save_path))
model.eval()

SwinUNETR(
  (swinViT): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv3d(3, 24, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers1): ModuleList(
      (0): BasicLayer(
        (blocks): ModuleList(
          (0-1): 2 x SwinTransformerBlock(
            (norm1): LayerNorm((24,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=24, out_features=72, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=24, out_features=24, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((24,), eps=1e-05, elementwise_affine=True)
            (mlp): MLPBlock(
              (linear1): Linear(in_features=24, out_features=96, bias=True)
              (linear2): Linear(in_features=

In [38]:
class FeatureExtractorSwinUNETR(SwinUNETR):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def extract_features(self, x):
        hidden_states = self.swinViT(x, self.normalize)
        # print("starting ------")
        # Adaptive pooling to resize all hidden states to a common size
        pooled_states = []
        for state in hidden_states:
            # Adaptive average pooling to 4x4x4
       
            pooled = F.adaptive_avg_pool3d(state, (4, 4, 4))
            pooled_states.append(pooled)
        # print("end ====\n")
        # Concatenate the pooled states
        return torch.cat(pooled_states, dim=1)


In [39]:
#  Create an instance of the feature extractor
feature_extractor = FeatureExtractorSwinUNETR(
     img_size=(256, 256, 160),  # Kept original size
    in_channels=len(modality_keys),
    out_channels=len(modality_keys),
    feature_size=24,  # Reduced from 48 to save memory
    use_checkpoint=True,
).to(device)
feature_extractor.load_state_dict(model.state_dict())
feature_extractor.eval()

FeatureExtractorSwinUNETR(
  (swinViT): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv3d(3, 24, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers1): ModuleList(
      (0): BasicLayer(
        (blocks): ModuleList(
          (0-1): 2 x SwinTransformerBlock(
            (norm1): LayerNorm((24,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=24, out_features=72, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=24, out_features=24, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((24,), eps=1e-05, elementwise_affine=True)
            (mlp): MLPBlock(
              (linear1): Linear(in_features=24, out_features=96, bias=True)
              (linear2): Lin

In [40]:
def extract_features_from_dataset(dataloader, feature_extractor):
    features = []
    
    for batch in dataloader:
        inputs = torch.cat([batch[key] for key in modality_keys], dim=1).to(device)
        with torch.no_grad():
            feature = feature_extractor.extract_features(inputs)
        
        # Global average pooling to reduce spatial dimensions
        feature = torch.mean(feature, dim=[2, 3, 4])
        
        features.append(feature.cpu().numpy())
    
    return np.concatenate(features)


In [41]:

# Extract features for training and validation sets
train_features = extract_features_from_dataset(train_loader, feature_extractor)
val_features = extract_features_from_dataset(val_loader, feature_extractor)


In [42]:
# Define Random Forest Regressor model
rf_model = RandomForestRegressor(random_state=42)


In [43]:
# Define hyperparameter search space for regression
param_dist = {
    'n_estimators': [100, 200, 300, 400, 500],
    'max_depth': [None, 10, 20, 30, 40, 50],
    'min_samples_split': [2, 5, 10],
    'min_samples_leaf': [1, 2, 4],
    'max_features': ['auto', 'sqrt', 'log2']
}

In [44]:
# Perform RandomizedSearchCV
random_search = RandomizedSearchCV(rf_model, param_distributions=param_dist, n_iter=25, cv=5, 
                                   scoring='neg_mean_squared_error', random_state=42, n_jobs=-1)


In [45]:
# Fit the model
random_search.fit(train_features, train_labels)


35 fits failed out of a total of 125.
The score on these train-test partitions for these parameters will be set to nan.
If these failures are not expected, you can try to debug them by setting error_score='raise'.

Below are more details about the failures:
--------------------------------------------------------------------------------
6 fits failed with the following error:
Traceback (most recent call last):
  File "/home/m1/23CS60R48/anaconda3/envs/gpu/lib/python3.9/site-packages/sklearn/model_selection/_validation.py", line 888, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
  File "/home/m1/23CS60R48/anaconda3/envs/gpu/lib/python3.9/site-packages/sklearn/base.py", line 1466, in wrapper
    estimator._validate_params()
  File "/home/m1/23CS60R48/anaconda3/envs/gpu/lib/python3.9/site-packages/sklearn/base.py", line 666, in _validate_params
    validate_parameter_constraints(
  File "/home/m1/23CS60R48/anaconda3/envs/gpu/lib/python3.9/site-packages/sklearn/utils/

In [46]:
# Get the best model
best_rf_model = random_search.best_estimator_
print("Best parameters found: ", random_search.best_params_)


Best parameters found:  {'n_estimators': 200, 'min_samples_split': 5, 'min_samples_leaf': 2, 'max_features': 'sqrt', 'max_depth': None}


In [47]:
# Predict on validation set
val_pred = best_rf_model.predict(val_features)


In [48]:
# Calculate regression metrics
mse = mean_squared_error(val_labels, val_pred)
rmse = np.sqrt(mse)
r2 = r2_score(val_labels, val_pred)
mae = mean_absolute_error(val_labels, val_pred)


In [49]:
print(f"Validation Mean Squared Error: {mse}")
print(f"Validation Root Mean Squared Error: {rmse}")
print(f"Validation R-squared Score: {r2}")
print(f"Validation Mean Absolute Error: {mae}")

Validation Mean Squared Error: 193447.61485996775
Validation Root Mean Squared Error: 439.8268009796217
Validation R-squared Score: -0.17036217351057448
Validation Mean Absolute Error: 363.55815714782904


In [50]:
pred_lable = categorize_days(val_pred)
real_lable = categorize_days(val_labels)
results_RF = evaluate_predictions(real_lable, pred_lable)


Accuracy: 0.23
Precision: 0.12
Recall: 0.23
F1 Score: 0.16
Confusion Matrix:
[[ 0  5 17]
 [ 0  3  8]
 [ 0  7  8]]
Classification Report:
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        22
           1       0.20      0.27      0.23        11
           2       0.24      0.53      0.33        15

    accuracy                           0.23        48
   macro avg       0.15      0.27      0.19        48
weighted avg       0.12      0.23      0.16        48



  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [51]:
# Print feature shapes for reference
print("Train features shape:", train_features.shape)
print("Validation features shape:", val_features.shape)

Train features shape: (188, 744)
Validation features shape: (48, 744)
