In [1]:
import os
import torch
import monai
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged, CropForegroundd,
    RandCropByPosNegLabeld, RandFlipd, RandRotate90d, RandShiftIntensityd, EnsureTyped, DivisiblePadd
)
from monai.data import DataLoader, CacheDataset
from monai.networks.nets import SwinUNETR
from monai.utils import set_determinism
from monai.data.image_reader import NibabelReader
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np
import pandas as pd
import lightgbm as lgb
from sklearn.preprocessing import StandardScaler
import torch.nn.functional as F

In [2]:
 # Define directories
train_path = 'dataset/MICCAI_BraTS2020_TrainingData/'
# val_path = 'dataset/MICCAI_BraTS2020_ValidationData/'
modality_keys = ["flair"]
in_channels = len(modality_keys)
out_channels = len(modality_keys) 

In [3]:
def create_data_list(data_dir, patient_ids, lables, modality_keys):
    data_list = []
    for idx, patient in enumerate(patient_ids):
        patient_dir = f"{data_dir}{patient}/"
        if os.path.isdir(patient_dir):
            data_dict = {key: os.path.join(patient_dir, f"{patient}_{key}.nii") for key in modality_keys}
            data_dict['lable'] = lables[idx]
            data_list.append(data_dict)

    return data_list

In [4]:
def preprocess_labels(csv_file_path):
    df = pd.read_csv(csv_file_path)
    df['Survival_days'] = pd.to_numeric(df['Survival_days'], errors='coerce')
    df = df.dropna(subset=['Survival_days'])
    df['Survival_days'] = df['Survival_days'].astype(int)
    df['Survival_class'] = df['Survival_days'].apply(lambda x: 0 if x < 300 else (1 if x < 450 else 2))
    return df['Survival_class'].values, df['Brats20ID'].values

In [5]:
all_labels, all_data_ids = preprocess_labels(f'{train_path}/survival_info.csv')
train_id, val_id, train_labels, val_labels = train_test_split(all_data_ids, all_labels, test_size=0.2, random_state=42)

train_data_list = create_data_list(train_path, train_id, train_labels, modality_keys)
val_data_list = create_data_list(train_path, val_id, val_labels, modality_keys)


# since we dont have lables for validation dataset not using it for validation
# val_data_list = create_data_list(val_path, modality_keys)
# val_labels = preprocess_labels(f'{val_path}/survival_evaluation.csv')



In [6]:
train_data_list[0]

{'flair': 'dataset/MICCAI_BraTS2020_TrainingData/BraTS20_Training_119/BraTS20_Training_119_flair.nii',
 'lable': 2}

In [7]:
import nibabel as nib
import torch

# Load the NIfTI image
file_path = train_data_list[0]['flair']
img = nib.load(file_path)

# Get the image data as a numpy array
img_data = img.get_fdata()

# Convert the numpy array to a PyTorch tensor
img_tensor = torch.tensor(img_data, dtype=torch.float32)

# Print the shape of the tensor
print(f'Tensor shape: {img_tensor.shape}')


Tensor shape: torch.Size([240, 240, 155])


In [8]:
def get_transforms(modality_keys, pixdim=(2.0, 2.0, 2.0), spatial_size=(64, 64, 64)):
    transforms = Compose(
        [
            LoadImaged(keys=modality_keys, reader=NibabelReader()),
            EnsureChannelFirstd(keys=modality_keys),
            Spacingd(keys=modality_keys, pixdim=pixdim, mode=("bilinear")),
            Orientationd(keys=modality_keys, axcodes="RAS"),
            ScaleIntensityRanged(keys=modality_keys, a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
            CropForegroundd(keys=modality_keys, source_key=modality_keys[0], allow_smaller=True),
            DivisiblePadd(keys=modality_keys, k=32),
            RandCropByPosNegLabeld(keys=modality_keys, label_key=modality_keys[0], spatial_size=spatial_size, pos=1, neg=1, num_samples=1, image_key=modality_keys[0], image_threshold=0),
            RandFlipd(keys=modality_keys, prob=0.5, spatial_axis=0),
            RandFlipd(keys=modality_keys, prob=0.5, spatial_axis=1),
            RandFlipd(keys=modality_keys, prob=0.5, spatial_axis=2),
            RandRotate90d(keys=modality_keys, prob=0.5, max_k=3),
            RandShiftIntensityd(keys=modality_keys, offsets=0.10, prob=0.5),
            EnsureTyped(keys=modality_keys),
        ]
    )
    return transforms

In [9]:
def get_val_transforms(modality_keys, pixdim=(2.0, 2.0, 2.0)):
    transforms = Compose(
        [
            LoadImaged(keys=modality_keys, reader=NibabelReader()),
            EnsureChannelFirstd(keys=modality_keys),
            Spacingd(keys=modality_keys, pixdim=pixdim, mode=("bilinear")),
            Orientationd(keys=modality_keys, axcodes="RAS"),
            ScaleIntensityRanged(keys=modality_keys, a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
            CropForegroundd(keys=modality_keys, source_key=modality_keys[0], allow_smaller=True),
            DivisiblePadd(keys=modality_keys, k=32),
            EnsureTyped(keys=modality_keys),
        ]
    )
    return transforms

In [10]:
train_transforms = get_transforms(modality_keys)
val_transforms = get_val_transforms(modality_keys)


In [11]:
train_ds = CacheDataset(
    data=train_data_list,
    transform=train_transforms,
    cache_rate=0.5,
    num_workers=4,
)
train_loader = DataLoader(train_ds,  batch_size=2, shuffle=True, num_workers=4)


Loading dataset:   0%|          | 0/94 [00:00<?, ?it/s]

Loading dataset: 100%|██████████| 94/94 [00:17<00:00,  5.40it/s]


In [12]:
# for data in train_loader:
#     print(data['flair'].shape)

In [13]:
val_ds = CacheDataset(
    data=val_data_list,
    transform=val_transforms,
    cache_rate=0.5,
    num_workers=4,
)
val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, num_workers=2)

Loading dataset: 100%|██████████| 23/23 [00:03<00:00,  5.88it/s]


In [14]:
# Load the pretrained model
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = SwinUNETR(
    img_size=(64, 64, 64),
    in_channels=in_channels,
    out_channels=out_channels,
    feature_size=48,
    use_checkpoint=True,
).to(device)


# Load the saved weights
modality_used = "_".join(modality_keys)
model_save_path = f"model_saved/swin_unetr_{modality_used}_temp.pth"

model.load_state_dict(torch.load(model_save_path))
model.eval()

  return torch._C._cuda_getDeviceCount() > 0


SwinUNETR(
  (swinViT): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv3d(1, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers1): ModuleList(
      (0): BasicLayer(
        (blocks): ModuleList(
          (0-1): 2 x SwinTransformerBlock(
            (norm1): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=48, out_features=144, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=48, out_features=48, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (mlp): MLPBlock(
              (linear1): Linear(in_features=48, out_features=192, bias=True)
              (linear2): Linear(in_feature

In [15]:
# # 2. Modify the SwinUNETR class to add a method for feature extraction
# class FeatureExtractorSwinUNETR(SwinUNETR):
#     def __init__(self, *args, **kwargs):
#         super().__init__(*args, **kwargs)
    
#     def extract_features(self, x):
#         print(" at start",x.size())
#         hidden_states = self.swinViT(x, self.normalize)
#         enc0 = self.encoder1(x)
#         print(" after enco 1",x.size())
#         enc1 = self.encoder2(hidden_states[0])
#         print(" after enco 2",enc1.size())
#         enc2 = self.encoder3(hidden_states[1])
#         print(" after enco 3",enc2.size())
#         enc3 = self.encoder4(hidden_states[2])
#         print(" after enco 4",enc3.size())
#         # dec3 = self.encoder10(hidden_states[3])
#         # print(" after dec3",dec3.size())
#         # return torch.cat([enc0, enc1, enc2, enc3, dec3], dim=1)
#         return torch.cat([enc0, enc1, enc2, enc3], dim=1)

In [16]:
class FeatureExtractorSwinUNETR(SwinUNETR):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def extract_features(self, x):
        hidden_states = self.swinViT(x, self.normalize)
        
        # Adaptive pooling to resize all hidden states to a common size
        pooled_states = []
        for state in hidden_states:
            # Adaptive average pooling to 4x4x4
            pooled = F.adaptive_avg_pool3d(state, (4, 4, 4))
            pooled_states.append(pooled)
        
        # Concatenate the pooled states
        return torch.cat(pooled_states, dim=1)


In [17]:
#  Create an instance of the feature extractor
feature_extractor = FeatureExtractorSwinUNETR(
    img_size=(64, 64, 64),
    in_channels=in_channels,
    out_channels=out_channels,
    feature_size=48,
    use_checkpoint=True,
).to(device)
feature_extractor.load_state_dict(model.state_dict())
feature_extractor.eval()

FeatureExtractorSwinUNETR(
  (swinViT): SwinTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv3d(1, 48, kernel_size=(2, 2, 2), stride=(2, 2, 2))
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (layers1): ModuleList(
      (0): BasicLayer(
        (blocks): ModuleList(
          (0-1): 2 x SwinTransformerBlock(
            (norm1): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (attn): WindowAttention(
              (qkv): Linear(in_features=48, out_features=144, bias=True)
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=48, out_features=48, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
              (softmax): Softmax(dim=-1)
            )
            (drop_path): Identity()
            (norm2): LayerNorm((48,), eps=1e-05, elementwise_affine=True)
            (mlp): MLPBlock(
              (linear1): Linear(in_features=48, out_features=192, bias=True)
              (linear2): L

In [18]:

# # 3. Extract features from your dataset
# def extract_features_from_dataset(dataset, feature_extractor):
#     features = []
#     labels = []
    
#     for data in dataset:
#         inputs = torch.cat([data[key] for key in modality_keys], dim=1).to(device)
#         with torch.no_grad():
#             feature = feature_extractor.extract_features(inputs)
        
#         # Assuming the feature is 5D (batch, channels, depth, height, width)
#         # We'll flatten it to 2D (batch, features)
#         feature = feature.view(feature.size(0), -1)
        
#         features.append(feature.cpu().numpy())
#         labels.append(data['label'].cpu().numpy())  # Assuming you have labels in your dataset
    
#     return np.concatenate(features), np.concatenate(labels)


In [19]:
def extract_features_from_dataset(dataloader, feature_extractor):
    features = []
    labels = []
    
    for batch in dataloader:
        inputs = torch.cat([batch[key] for key in modality_keys], dim=1).to(device)
        with torch.no_grad():
            feature = feature_extractor.extract_features(inputs)
        
        # Global average pooling to reduce spatial dimensions
        feature = torch.mean(feature, dim=[2, 3, 4])
        
        features.append(feature.cpu().numpy())
        # labels.append(batch['label'].cpu().numpy())  # Assuming you have labels in your dataset
    
    return np.concatenate(features)


In [20]:

# Extract features for training and validation sets
train_features = extract_features_from_dataset(train_loader, feature_extractor)
val_features = extract_features_from_dataset(val_loader, feature_extractor)




In [21]:
# Preprocess the data
scaler = StandardScaler()
train_features_scaled = scaler.fit_transform(train_features)
val_features_scaled = scaler.transform(val_features)


In [22]:
# Hyperparameter tuning for LightGBM
param_dist = {
    'num_leaves': [31, 63, 127],
    'max_depth': [-1, 5, 10, 20],
    'learning_rate': [0.01, 0.05, 0.1],
    'n_estimators': [100, 200, 300],
    'min_child_samples': [10, 20, 30]
}

In [23]:
from sklearn.model_selection import RandomizedSearchCV

In [24]:
# Train LightGBM model
lgb_model = lgb.LGBMClassifier(random_state=42)
random_search = RandomizedSearchCV(lgb_model, param_distributions=param_dist, n_iter=20, cv=5, random_state=42, n_jobs=-1)
random_search.fit(train_features, train_labels)

[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.258346 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.265861 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.262320 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.247358 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 76625
[LightGBM] [Info] Total Bins 76620
[LightGBM] [Info] Total Bins 76620
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.268522 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was

In [25]:
from sklearn.ensemble import RandomForestClassifier, VotingClassifier
best_lgb_model = random_search.best_estimator_

# Create an ensemble
rf_model = RandomForestClassifier(n_estimators=100, random_state=42)

ensemble = VotingClassifier(
    estimators=[('lgb', best_lgb_model), ('rf', rf_model)],
    voting='soft'
)

In [26]:
# Make predictions
y_pred = ensemble.predict(val_features)


NotFittedError: This VotingClassifier instance is not fitted yet. Call 'fit' with appropriate arguments before using this estimator.

In [None]:
# Evaluate the model
accuracy = accuracy_score(val_labels, y_pred)
print(f"Accuracy: {accuracy}")
print(classification_report(val_labels, y_pred))

Accuracy: 0.2765957446808511
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        19
           1       0.00      0.00      0.00        15
           2       0.28      1.00      0.43        13

    accuracy                           0.28        47
   macro avg       0.09      0.33      0.14        47
weighted avg       0.08      0.28      0.12        47



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


In [None]:
train_features.shape


(188, 1488)

In [None]:
val_features.shape

(47, 1488)

In [None]:
# Create LightGBM datasets
train_data = lgb.Dataset(train_features, label=train_labels)
val_data = lgb.Dataset(val_features, label=val_labels)

In [None]:
# Set parameters for LightGBM
params = {
    'objective': 'multiclass',
    'num_class': 3,
    'metric': 'multi_logloss',
    'boosting_type': 'gbdt',
    'num_leaves': 31,
    'learning_rate': 0.05,
    'feature_fraction': 0.9
}

In [None]:
# Train the model
num_round = 100
bst = lgb.train(params, train_data, num_round, valid_sets=[val_data], early_stopping_rounds=10)


TypeError: train() got an unexpected keyword argument 'early_stopping_rounds'

In [None]:
# Make predictions
val_preds = bst.predict(val_features)
val_preds_classes = np.argmax(val_preds, axis=1)


In [None]:
# Evaluate the model
accuracy = accuracy_score(val_labels, val_preds_classes)
print(f"Validation Accuracy: {accuracy:.4f}")

print("\nClassification Report:")
print(classification_report(val_labels, val_preds_classes, target_names=le.classes_))