In [11]:
import os
import torch
import monai
from monai.transforms import (
    Compose, LoadImaged, EnsureChannelFirstd, Spacingd, Orientationd, ScaleIntensityRanged, CropForegroundd,
    RandCropByPosNegLabeld, RandFlipd, RandRotate90d, RandShiftIntensityd, EnsureTyped, DivisiblePadd,ResizeWithPadOrCropd
)
from monai.data import DataLoader, CacheDataset
from monai.networks.nets import SwinUNETR
from monai.utils import set_determinism
from monai.data.image_reader import NibabelReader
from sklearn.metrics import mean_absolute_error, mean_squared_error,accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import numpy as np
import pandas as pd
import lightgbm as lgb
from sklearn.preprocessing import StandardScaler
import torch.nn.functional as F
import re
from sklearn.model_selection import RandomizedSearchCV
import csv

from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report

In [12]:
 # Define directories
train_path = 'dataset/MICCAI_BraTS2020_TrainingData/'
val_path = 'dataset/MICCAI_BraTS2020_ValidationData/'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


In [13]:
def create_data_list(data_dir, patient_ids, lables, modality_keys):
    data_list = []
    for idx, patient in enumerate(patient_ids):
        patient_dir = f"{data_dir}{patient}/"
        if os.path.isdir(patient_dir):
            data_dict = {key: os.path.join(patient_dir, f"{patient}_{key}.nii") for key in modality_keys}
            data_dict['lable'] = lables[idx]
            data_list.append(data_dict)

    return data_list

In [14]:
def create_data_list_val(data_dir,  modality_keys):
    df = pd.read_csv(f'{val_path}/survival_evaluation.csv')
    patient_ids = df['BraTS20ID'].values
    data_list = []
    for idx, patient in enumerate(patient_ids):
        patient_dir = f"{data_dir}{patient}/"
        if os.path.isdir(patient_dir):
            data_dict = {key: os.path.join(patient_dir, f"{patient}_{key}.nii") for key in modality_keys}
            # data_dict['lable'] = lables[idx]
            data_list.append(data_dict)

    return data_list

In [15]:
def preprocess_labels(csv_file_path):
    df = pd.read_csv(csv_file_path)
    
    def extract_number(value):
        if isinstance(value, str):
            match = re.search(r'\d+', value)
            return int(match.group()) if match else None
        return value

    df['Survival_days'] = df['Survival_days'].apply(extract_number)
    df['Survival_days'] = pd.to_numeric(df['Survival_days'], errors='coerce')
    df = df.dropna(subset=['Survival_days'])
    df['Survival_days'] = df['Survival_days'].astype(int)
    
    return df['Survival_days'].values, df['Brats20ID'].values

In [16]:

# since we dont have lables for validation dataset not using it for validation
# valdate_data_list = create_data_list_val(val_path, modality_keys)
# val_labels = preprocess_labels(f'{val_path}/survival_evaluation.csv')



In [17]:
# import nibabel as nib
# import torch

# # Load the NIfTI image
# file_path = train_data_list[0]['flair']
# img = nib.load(file_path)

# # Get the image data as a numpy array   
# img_data = img.get_fdata()

# # Convert the numpy array to a PyTorch tensor
# img_tensor = torch.tensor(img_data, dtype=torch.float32)

# # Print the shape of the tensor
# print(f'Tensor shape: {img_tensor.shape}')


In [18]:
def get_transforms(modality_keys, pixdim=(1.0, 1.0, 1.0), is_train=True):
    transform_list = [
        LoadImaged(keys=modality_keys, reader=NibabelReader()),
        EnsureChannelFirstd(keys=modality_keys),
        Spacingd(keys=modality_keys, pixdim=pixdim, mode=("bilinear")),
        Orientationd(keys=modality_keys, axcodes="RAS"),
        ScaleIntensityRanged(keys=modality_keys, a_min=-175, a_max=250, b_min=0.0, b_max=1.0, clip=True),
        CropForegroundd(keys=modality_keys, source_key=modality_keys[0], allow_smaller=True),
        ResizeWithPadOrCropd(keys=modality_keys, spatial_size=(256, 256, 160)),  # Kept original size
        EnsureTyped(keys=modality_keys),
    ]
    
    if is_train:
        transform_list.extend([
            RandFlipd(keys=modality_keys, prob=0.5, spatial_axis=0),
            RandFlipd(keys=modality_keys, prob=0.5, spatial_axis=1),
            RandFlipd(keys=modality_keys, prob=0.5, spatial_axis=2),
            RandRotate90d(keys=modality_keys, prob=0.5, max_k=3),
            RandShiftIntensityd(keys=modality_keys, offsets=0.10, prob=0.5),
        ])
    
    return Compose(transform_list)

In [19]:
class FeatureExtractorSwinUNETR(SwinUNETR):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
    
    def extract_features(self, x):
        hidden_states = self.swinViT(x, self.normalize)
        # print("starting ------")
        # Adaptive pooling to resize all hidden states to a common size
        i = 0
        pooled_states = []
        for state in hidden_states:
            # Adaptive average pooling to 4x4x4
            # print('At',i,"state", state.shape)
            i = i + 1
            pooled = F.adaptive_avg_pool3d(state, (4, 4, 4))
            pooled_states.append(pooled)
        # print("end ====\n")
        # Concatenate the pooled states
        return torch.cat(pooled_states, dim=1)


In [20]:
def make_csv(y_pred_validation, modality_used):
    df = pd.read_csv(f'{val_path}/survival_evaluation.csv')
    validation_ids = df['BraTS20ID'].values
    
    filename = f"./global_predictons/Light_GBM/{modality_used}Light_GBM.csv"

    # Writing to csv file
    with open(filename, mode='w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(["ID", "Days"])  # Writing the header
        for id, day in zip(validation_ids, y_pred_validation):
            writer.writerow([id, day])

    print(f"CSV file '{filename}' created successfully.")

In [21]:
def extract_features_from_dataset(modality_keys, dataloader, feature_extractor):
    features = []
    
    for batch in dataloader:
        inputs = torch.cat([batch[key] for key in modality_keys], dim=1).to(device)
        with torch.no_grad():
            feature = feature_extractor.extract_features(inputs)
        
        # Global average pooling to reduce spatial dimensions
        feature = torch.mean(feature, dim=[2, 3, 4])
        
        features.append(feature.cpu().numpy())
    
    return np.concatenate(features)


In [22]:
def build_model(modality_keys, train_data_list, valdate_data_list):

    train_transforms = get_transforms(modality_keys, is_train=True)
    val_transforms = get_transforms(modality_keys, is_train=False)
        
    train_ds = CacheDataset(
        data=train_data_list,
        transform=train_transforms,
        cache_rate=0.5,
        num_workers=4,
    )
    train_loader = DataLoader(train_ds,  batch_size=2, shuffle=True, num_workers=4)

    validate_ds = CacheDataset(
        data=valdate_data_list,
        transform=val_transforms,
        cache_rate=0.5,
        num_workers=4,
    )
    validate_loader = DataLoader(validate_ds, batch_size=2, shuffle=False, num_workers=2)

    # Load the pretrained model
    
    model = SwinUNETR(
        img_size=(256, 256, 160),  # Kept original size
        in_channels=len(modality_keys),
        out_channels=len(modality_keys),
        feature_size=24,  # Reduced from 48 to save memory
        use_checkpoint=True,
    ).to(device)

    # Load the saved weights
    modality_used = "_".join(modality_keys)
    model_save_path = f"model_saved/swin_unetr_{modality_used}_best.pth"

    model.load_state_dict(torch.load(model_save_path))
    model.eval()

    #  Create an instance of the feature extractor
    feature_extractor = FeatureExtractorSwinUNETR(
        img_size=(256, 256, 160),  # Kept original size
        in_channels=len(modality_keys),
        out_channels=len(modality_keys),
        feature_size=24,  # Reduced from 48 to save memory
        use_checkpoint=True,
    ).to(device)
    feature_extractor.load_state_dict(model.state_dict())
    feature_extractor.eval()


    # Extract features for training and validation sets
    train_features = extract_features_from_dataset(modality_keys, train_loader, feature_extractor)

    validate_features = extract_features_from_dataset(validate_loader, feature_extractor)


    # Preprocess the data
    scaler = StandardScaler()
    train_features_scaled = scaler.fit_transform(train_features)
    validate_features_scaled = scaler.transform(validate_features)

    # Hyperparameter tuning for LightGBM
    param_dist = {
        'num_leaves': [31, 63, 127],
        'max_depth': [-1, 5, 10, 20],
        'learning_rate': [0.01, 0.05, 0.1],
        'n_estimators': [100, 200, 300],
        'min_child_samples': [10, 20, 30]
    }

    # Train LightGBM regression model
    lgb_model = lgb.LGBMRegressor(random_state=42)
    random_search = RandomizedSearchCV(lgb_model, param_distributions=param_dist, n_iter=20, cv=5, random_state=42, n_jobs=-1)
    random_search.fit(train_features_scaled, train_labels)

    # making prediction for validation data
    y_pred_validation = random_search.predict(validate_features_scaled)

    make_csv(y_pred_validation, modality_used)



In [23]:
# val_ds = CacheDataset(
#     data=val_data_list,
#     transform=val_transforms,
#     cache_rate=0.5,
#     num_workers=4,
# )
# val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, num_workers=2)

In [24]:

# def evaluate_predictions(y_true, y_pred):
#     # Calculate metrics
#     accuracy = accuracy_score(y_true, y_pred)
#     precision = precision_score(y_true, y_pred, average='weighted')
#     recall = recall_score(y_true, y_pred, average='weighted')
#     f1 = f1_score(y_true, y_pred, average='weighted')
#     conf_matrix = confusion_matrix(y_true, y_pred)
#     class_report = classification_report(y_true, y_pred, output_dict=True)
    
#     # Print metrics
#     print(f'Accuracy: {accuracy:.2f}')
#     print(f'Precision: {precision:.2f}')
#     print(f'Recall: {recall:.2f}')
#     print(f'F1 Score: {f1:.2f}')
#     print('Confusion Matrix:')
#     print(conf_matrix)
#     print('Classification Report:')
#     print(classification_report(y_true, y_pred))
    
#     # Return metrics in a dictionary
#     return {
#         'accuracy': accuracy,
#         'precision': precision,
#         'recall': recall,
#         'f1_score': f1,
#         'confusion_matrix': conf_matrix,
#         'classification_report': class_report
#     }


In [25]:
# def categorize_days(days):
#     categorized_days = []
#     for day in days:
#         if day > 455:
#             categorized_days.append(2)
#         elif day < 304:
#             categorized_days.append(0)
#         else:
#             categorized_days.append(1)
#     return categorized_days

In [26]:
# Make predictions
# y_pred_1 = random_search.predict(val_features)


In [27]:
# y_pred_validation = random_search.predict(validate_features_scaled)
# validation_res = categorize_days(y_pred_validation)
# real_lable = categorize_days(val_labels)
# results_1 = evaluate_predictions(real_lable, pred_lable)

In [28]:
# pred_lable = categorize_days(y_pred_1)
# real_lable = categorize_days(val_labels)

In [29]:
# results_1 = evaluate_predictions(real_lable, pred_lable)


In [30]:
# # Evaluate the model
# mae = mean_absolute_error(val_labels, y_pred_1)
# rmse = mean_squared_error(val_labels, y_pred_1, squared=False)
# print(f"Mean Absolute Error: {mae}")
# print(f"Root Mean Squared Error: {rmse}")

In [31]:
modality_keys_list = [
        ["flair"],
        ["t1ce"],
        ["flair", "t1ce"],
        ["flair", "t1ce", "t2"],
        ["flair", "t1", "t1ce", "t2"]
]

In [32]:
train_labels, train_id = preprocess_labels(f'{train_path}/survival_info.csv')

In [33]:
for modality_keys in modality_keys_list:
    print("now working on", modality_keys)
    in_channels = len(modality_keys)
    out_channels = len(modality_keys)
    train_data_list = create_data_list(train_path, train_id, train_labels, modality_keys)
    valdate_data_list = create_data_list_val(val_path, modality_keys)
     
    build_model(modality_keys, train_data_list, valdate_data_list)


Loading dataset: 100%|██████████| 118/118 [02:12<00:00,  1.12s/it]
Loading dataset: 100%|██████████| 14/14 [00:22<00:00,  1.62s/it]


KeyboardInterrupt: 

In [None]:
# modality_keys = ["flair", "t1ce", "t2"]
# in_channels = len(modality_keys)
# out_channels = len(modality_keys) 

In [None]:
# train_labels, train_id = preprocess_labels(f'{train_path}/survival_info.csv')


In [None]:

# train_data_list = create_data_list(train_path, train_id, train_labels, modality_keys)
# valdate_data_list = create_data_list_val(val_path, modality_keys)
