In [None]:
import os
import tarfile
import nibabel as nib
import numpy as np
from tqdm import tqdm
import time

# -------------------------------------------------
# Define Constants and Parameters
# -------------------------------------------------
IMG_SIZE = 128
VOLUME_SLICES = 100
VOLUME_START_AT = 22
SEGMENT_CLASSES = {
    0: 'NOT tumor',
    1: 'NECROTIC/CORE',
    2: 'EDEMA',
    3: 'ENHANCING'
}

# -------------------------------------------------
# Function Definitions
# -------------------------------------------------
def extract_tar_files(tar_path, extract_to):
    """
    Extracts tar.gz files with progress tracking.
    """
    print(f"\n{'='*50}")
    print(f"Starting extraction of {os.path.basename(tar_path)}")
    print(f"{'='*50}")
    
    with tarfile.open(tar_path, 'r') as tar:
        members = tar.getmembers()
        print(f"Total files in archive: {len(members)}")
        
        for member in tqdm(members, desc="Extracting files"):
            tar.extract(member, extract_to)
            
    print(f"✓ Extraction complete: {tar_path}")
    print(f"✓ Files extracted to: {extract_to}\n")

def list_files_in_directory(directory):
    """
    Lists all files with detailed information.
    """
    print(f"\n{'='*50}")
    print(f"Scanning directory: {directory}")
    print(f"{'='*50}")
    
    files = os.listdir(directory)
    
    print(f"Found {len(files)} files/directories:")
    for i, file in enumerate(files, 1):
        file_path = os.path.join(directory, file)
        size = os.path.getsize(file_path) / (1024 * 1024)  # Convert to MB
        print(f"{i}. {file} ({size:.2f} MB)")
    
    return files

def load_nifti_image(file_path):
    """
    Loads a NIfTI file with timing and memory information.
    """
    start_time = time.time()
    print(f"\nLoading: {os.path.basename(file_path)}")
    
    img = nib.load(file_path)
    data = img.get_fdata()
    
    elapsed_time = time.time() - start_time
    memory_usage = data.nbytes / (1024 * 1024)  # Convert to MB
    
    print(f"✓ Load time: {elapsed_time:.2f} seconds")
    print(f"✓ Data shape: {data.shape}")
    print(f"✓ Memory usage: {memory_usage:.2f} MB")
    print(f"✓ Data type: {data.dtype}")
    
    return data

def load_training_data(dataset_path, sample_id):
    """
    Loads all imaging modalities with progress tracking.
    """
    print(f"\n{'='*50}")
    print(f"Loading complete dataset for sample: {sample_id}")
    print(f"{'='*50}")
    
    sample_path = os.path.join(dataset_path, sample_id)
    modalities = ['flair', 't1', 't1ce', 't2', 'seg']
    data = {}
    
    for modality in tqdm(modalities, desc="Loading modalities"):
        file_path = os.path.join(sample_path, f"{sample_id}_{modality}.nii.gz")
        data[modality] = load_nifti_image(file_path)
        
    print("\n✓ All modalities loaded successfully")
    return data

# -------------------------------------------------
# Usage
# -------------------------------------------------
def main():
    print("\n🧠 Starting Brain Tumor Image Processing Pipeline 🧠\n")
    
    # Define paths
    train_tar_path = '../input/brats-2021-task1/BraTS2021_Training_Data.tar'
    sample_tar_path = '../input/brats-2021-task1/BraTS2021_00621.tar'
    training_data_path = './BraTS2021_Training_Data'
    sample_data_path = './sample_img'
    
    # Extract files with progress tracking
    extract_tar_files(train_tar_path, training_data_path)
    extract_tar_files(sample_tar_path, sample_data_path)
    
    # List extracted files
    files = list_files_in_directory(sample_data_path)
    
    # Load sample data
    sample_id = 'BraTS2021_01261'
    print(f"\nProcessing sample: {sample_id}")
    
    sample_data = load_training_data(training_data_path, sample_id)
    
    # Access modalities
    modalities = {
        'FLAIR': sample_data['flair'],
        'T1': sample_data['t1'],
        'T1CE': sample_data['t1ce'],
        'T2': sample_data['t2'],
        'Segmentation': sample_data['seg']
    }
    
    print("\n📊 Final Data Summary:")
    for name, data in modalities.items():
        print(f"✓ {name:<12} Shape: {data.shape}, Range: [{data.min():.2f}, {data.max():.2f}]")
    
    print("\n✨ Processing pipeline step 1 completed successfully! ✨")

if __name__ == "__main__":
    main()

In [None]:
import os
import tarfile
import nibabel as nib
import numpy as np
from skimage import exposure
from skimage.transform import resize
from scipy.ndimage import rotate, gaussian_filter, map_coordinates
import matplotlib.pyplot as plt
from tqdm import tqdm
import time
from datetime import datetime
import nilearn as nl
import nilearn.plotting as nlplt

# -------------------------------------------------
# Constants and Parameters
# -------------------------------------------------
IMG_SIZE = 128
VOLUME_SLICES = 100
VOLUME_START_AT = 22
SEGMENT_CLASSES = {
    0: 'NOT tumor',
    1: 'NECROTIC/CORE',
    2: 'EDEMA',
    3: 'ENHANCING'
}

# -------------------------------------------------
# Utility Functions
# -------------------------------------------------
def extract_tar_files(tar_path, extract_to):
    """Extracts .tar or .tar.gz files with progress tracking."""
    print(f"\n{'='*60}")
    print(f"📦 Extracting: {os.path.basename(tar_path)}")
    print(f"{'='*60}")
    
    with tarfile.open(tar_path, 'r') as file:
        members = file.getmembers()
        print(f"Total files to extract: {len(members)}")
        
        for member in tqdm(members, desc="Extracting files", unit="file"):
            file.extract(member, extract_to)
    
    print(f"✅ Extraction complete: {len(members)} files extracted to {extract_to}\n")

def check_nifti(file_path):
    """Checks if a NIfTI file can be loaded successfully."""
    try:
        start_time = time.time()
        nib.load(file_path)
        load_time = time.time() - start_time
        print(f"✓ Validated {os.path.basename(file_path)} ({load_time:.2f}s)")
        return True
    except Exception as e:
        print(f"❌ Error loading {os.path.basename(file_path)}: {str(e)}")
        return False

def load_nifti(file_path):
    """Loads a NIfTI file with timing information."""
    start_time = time.time()
    img = nib.load(file_path)
    load_time = time.time() - start_time
    print(f"📂 Loaded {os.path.basename(file_path)} ({load_time:.2f}s)")
    return img

def save_nifti(img, save_path):
    """Saves a NIfTI image with timing information."""
    start_time = time.time()
    nib.save(img, save_path)
    save_time = time.time() - start_time
    print(f"💾 Saved {os.path.basename(save_path)} ({save_time:.2f}s)")

# -------------------------------------------------
# Preprocessing Functions
# -------------------------------------------------
def normalize_nifti(img, desc=""):
    """Normalizes intensity values in the NIfTI image."""
    start_time = time.time()
    data = img.get_fdata()
    data = (data - np.min(data)) / (np.max(data) - np.min(data))
    normalized = nib.Nifti1Image(data, img.affine)
    process_time = time.time() - start_time
    print(f"⚖️  Normalized{' '+desc if desc else ''} ({process_time:.2f}s)")
    return normalized

def resize_nifti(img, target_shape=(256, 256, 128), desc=""):
    """Resizes a NIfTI image to the specified shape."""
    start_time = time.time()
    data = img.get_fdata()
    data_resized = resize(data, target_shape, anti_aliasing=True)
    resized = nib.Nifti1Image(data_resized, img.affine)
    process_time = time.time() - start_time
    print(f"📐 Resized{' '+desc if desc else ''} to {target_shape} ({process_time:.2f}s)")
    return resized

def adjust_brightness(img, brightness_factor=2, desc=""):
    """Adjusts the brightness of a NIfTI image."""
    start_time = time.time()
    data = img.get_fdata()
    data_adjusted = exposure.adjust_gamma(data, gamma=brightness_factor)
    adjusted = nib.Nifti1Image(data_adjusted, img.affine)
    process_time = time.time() - start_time
    print(f"🔆 Adjusted brightness{' '+desc if desc else ''} ({process_time:.2f}s)")
    return adjusted

# -------------------------------------------------
# Visualization Functions
# -------------------------------------------------
def plot_modality_comparison(images, slice_w=25):
    """Plots comparison of different MRI modalities."""
    print("\n📊 Generating modality comparison plot...")
    
    fig, (ax1, ax2, ax3, ax4, ax5) = plt.subplots(1, 5, figsize=(20, 10))
    
    modalities = [
        ('flair', ax1, 'Image FLAIR'),
        ('t1', ax2, 'Image T1'),
        ('t1ce', ax3, 'Image T1CE'),
        ('t2', ax4, 'Image T2'),
        ('seg', ax5, 'Mask')
    ]
    
    for modality, ax, title in tqdm(modalities, desc="Plotting modalities"):
        img = images[modality]
        slice_idx = img.shape[0]//2 - slice_w
        
        if modality != 'seg':
            ax.imshow(img[:,:,slice_idx], cmap='gray')
        else:
            ax.imshow(img[:,:,slice_idx])
        
        ax.set_title(title)
        ax.axis('off')
    
    plt.tight_layout()
    # Display the plot
    plt.show()
    print("✅ Modality comparison plot generated")
    return fig

def plot_detailed_visualization(flair_path, mask_path):
    """Creates detailed visualizations using nilearn."""
    print("\n🎨 Generating detailed nilearn visualizations...")
    
    try:
        with tqdm(total=2, desc="Loading images") as pbar:
            niimg = nl.image.load_img(flair_path)
            pbar.update(1)
            nimask = nl.image.load_img(mask_path)
            pbar.update(1)
        
        fig, axes = plt.subplots(nrows=4, figsize=(30, 40))
        
        views = [
            (nlplt.plot_anat, 'Anatomical View'),
            (nlplt.plot_epi, 'EPI View'),
            (nlplt.plot_img, 'Standard View'),
            (None, 'ROI with Mask')  # Special case for ROI plot
        ]
        
        for i, (plot_func, title) in enumerate(tqdm(views, desc="Generating views")):
            if i == 3:  # ROI plot
                nlplt.plot_roi(
                    nimask,
                    title=f'FLAIR with mask {title}',
                    bg_img=niimg,
                    axes=axes[i],
                    cmap='Paired'
                )
            else:
                plot_func(
                    niimg,
                    title=f'FLAIR {title}',
                    axes=axes[i]
                )
                
        # Display the plot
        plt.show()
        print("✅ Detailed visualization completed")
        return fig
    except Exception as e:
        print(f"❌ Error generating detailed visualization: {str(e)}")
        return None

# -------------------------------------------------
# Main Processing Pipeline
# -------------------------------------------------
def process_single_patient(sample_id, input_path, output_path):
    """Processes and visualizes data for a single patient."""
    print(f"\n{'='*60}")
    print(f"🏥 Processing patient: {sample_id}")
    print(f"{'='*60}")
    
    try:
        # Load all modalities
        modalities = ['flair', 't1', 't1ce', 't2', 'seg']
        images = {}
        
        for modality in tqdm(modalities, desc="Loading modalities"):
            file_path = os.path.join(input_path, sample_id, f"{sample_id}_{modality}.nii.gz")
            images[modality] = load_nifti(file_path).get_fdata()
        
        # Create visualizations
        os.makedirs(output_path, exist_ok=True)
        
        # Modality comparison plot
        comparison_fig = plot_modality_comparison(images)
        comparison_fig.savefig(os.path.join(output_path, f"{sample_id}_modality_comparison.png"))
        plt.close(comparison_fig)
        
        # Detailed visualization
        flair_path = os.path.join(input_path, sample_id, f"{sample_id}_flair.nii.gz")
        mask_path = os.path.join(input_path, sample_id, f"{sample_id}_seg.nii.gz")
        detailed_fig = plot_detailed_visualization(flair_path, mask_path)
        
        if detailed_fig:
            detailed_fig.savefig(os.path.join(output_path, f"{sample_id}_detailed_visualization.png"))
            plt.close(detailed_fig)
        
        print(f"✅ Processing complete for patient {sample_id}")
        return images
        
    except Exception as e:
        print(f"❌ Error processing patient {sample_id}: {str(e)}")
        return None

# -------------------------------------------------
# Main Execution
# -------------------------------------------------
if __name__ == "__main__":
    print("\n🧠 BraTS Image Processing Pipeline 🧠")
    
    # Define paths
    TRAIN_TAR_PATH = '../input/brats-2021-task1/BraTS2021_Training_Data.tar'
    SAMPLE_TAR_PATH = '../input/brats-2021-task1/BraTS2021_00621.tar'
    TRAINING_PATH = './BraTS2021_Training_Data'
    OUTPUT_PATH = './processed_data'
    
    # Create output directory
    os.makedirs(OUTPUT_PATH, exist_ok=True)
    
    # Extract data
    extract_tar_files(TRAIN_TAR_PATH, TRAINING_PATH)
    
    # Process sample patient
    sample_id = 'BraTS2021_01261'
    sample_data = process_single_patient(sample_id, TRAINING_PATH, OUTPUT_PATH)
    
    if sample_data:
        print("\n✨ Processing pipeline completed successfully!")
    else:
        print("\n❌ Processing pipeline encountered errors.")

In [None]:
# Using PIL (Python Imaging Library)
from PIL import Image
import matplotlib.pyplot as plt

# Method 1: Direct loading with PIL
image = Image.open('/kaggle/working/processed_data/BraTS2021_01261_detailed_visualization.png')

# To display the image
plt.imshow(image)
plt.axis('off')  # Hide axes
plt.show()

# Method 2: Using OpenCV (if you need image processing)
import cv2
import numpy as np

# Read image (OpenCV loads in BGR format)
image = cv2.imread('/kaggle/working/processed_data/BraTS2021_01261_modality_comparison.png')
# Convert BGR to RGB for display
image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

plt.imshow(image_rgb)
plt.show()

In [4]:
import os
import cv2
import glob 
# used to find all the pathnames matching a specified pattern
import PIL
import shutil
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from skimage import data
from skimage.util import montage 
import skimage.transform as skTrans
from skimage.transform import rotate
from skimage.transform import resize
from PIL import Image, ImageOps  

# neural imaging
import nilearn as nl
import nibabel as nib
import nilearn.plotting as nlplt


# ml libs
import keras
import keras.backend as K
from keras.callbacks import CSVLogger
import tensorflow as tf
from tensorflow.keras.utils import plot_model
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping, TensorBoard
# from tensorflow.keras.layers.experimental import preprocessing


# Make numpy printouts easier to read.
np.set_printoptions(precision=3, suppress=True)

In [5]:
# # dice loss as defined above for 4 classes
# def dice_coef(y_true, y_pred, smooth=1.0):
#     class_num = 4
#     for i in range(class_num):
#         y_true_f = K.flatten(y_true[:,:,:,i])
#         y_pred_f = K.flatten(y_pred[:,:,:,i])
#         intersection = K.sum(y_true_f * y_pred_f)
#         loss = ((2. * intersection + smooth) / (K.sum(y_true_f) + K.sum(y_pred_f) + smooth))
#    #     K.print_tensor(loss, message='loss value for class {} : '.format(SEGMENT_CLASSES[i]))
#         if i == 0:
#             total_loss = loss
#         else:
#             total_loss = total_loss + loss
            
#     total_loss = total_loss / class_num
# #    K.print_tensor(total_loss, message=' total dice coef: ')
#     return total_loss


 
# # These functions are used for evaluating the performance of a segmentation model on three different classes
# # in medical imaging (presumably related to brain tumor segmentation).
# # Input Parameters:
# # y_true: The ground truth segmentation mask for the edema class.
# # y_pred: The predicted segmentation mask for the edema class.
# # epsilon: A small constant to avoid division by zero.

# def dice_coef_necrotic(y_true, y_pred, epsilon=1e-6):
#     intersection = K.sum(K.abs(y_true[:,:,:,1] * y_pred[:,:,:,1]))
#     return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,1])) + K.sum(K.square(y_pred[:,:,:,1])) + epsilon)

# def dice_coef_edema(y_true, y_pred, epsilon=1e-6):
#     intersection = K.sum(K.abs(y_true[:,:,:,2] * y_pred[:,:,:,2]))
#     return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,2])) + K.sum(K.square(y_pred[:,:,:,2])) + epsilon)

# def dice_coef_enhancing(y_true, y_pred, epsilon=1e-6):
#     intersection = K.sum(K.abs(y_true[:,:,:,3] * y_pred[:,:,:,3]))
#     return (2. * intersection) / (K.sum(K.square(y_true[:,:,:,3])) + K.sum(K.square(y_pred[:,:,:,3])) + epsilon)



# # Computing Precision 
# def precision(y_true, y_pred):
#         true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
#         predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
#         precision = true_positives / (predicted_positives + K.epsilon())
#         return precision

    
# # Computing Sensitivity      
# def sensitivity(y_true, y_pred):
#     true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
#     possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
#     return true_positives / (possible_positives + K.epsilon())


# # Computing Specificity
# def specificity(y_true, y_pred):
#     true_negatives = K.sum(K.round(K.clip((1-y_true) * (1-y_pred), 0, 1)))
#     possible_negatives = K.sum(K.round(K.clip(1-y_true, 0, 1)))
#     return true_negatives / (possible_negatives + K.epsilon())

In [None]:
import tensorflow.keras.backend as K
import tensorflow as tf

def dice_coef(y_true, y_pred, smooth=1.0):
    class_num = 4  # Number of classes
    dice_list = []
    for i in range(class_num):
        # Reshape the tensors into 1D arrays for each class
        y_true_f = tf.reshape(y_true[:, :, :, i], [-1])
        y_pred_f = tf.reshape(y_pred[:, :, :, i], [-1])
        
        # Compute intersection and Dice coefficient
        intersection = K.sum(y_true_f * y_pred_f)
        union = K.sum(y_true_f) + K.sum(y_pred_f)
        dice = (2. * intersection + smooth) / (union + smooth)
        dice_list.append(dice)

    # Convert the list of Dice scores into a tensor and average across all classes
    return K.mean(tf.stack(dice_list))


def dice_coef_necrotic(y_true, y_pred, epsilon=1e-6):
    intersection = K.sum(K.abs(y_true[:,:,:,1] * y_pred[:,:,:,1]))
    union = K.sum(K.square(y_true[:,:,:,1])) + K.sum(K.square(y_pred[:,:,:,1]))
    return (2. * intersection) / (union + epsilon)

def dice_coef_edema(y_true, y_pred, epsilon=1e-6):
    intersection = K.sum(K.abs(y_true[:,:,:,2] * y_pred[:,:,:,2]))
    union = K.sum(K.square(y_true[:,:,:,2])) + K.sum(K.square(y_pred[:,:,:,2]))
    return (2. * intersection) / (union + epsilon)

def dice_coef_enhancing(y_true, y_pred, epsilon=1e-6):
    intersection = K.sum(K.abs(y_true[:,:,:,3] * y_pred[:,:,:,3]))
    union = K.sum(K.square(y_true[:,:,:,3])) + K.sum(K.square(y_pred[:,:,:,3]))
    return (2. * intersection) / (union + epsilon)


# Computing Precision 
def precision(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
    return true_positives / (predicted_positives + K.epsilon())

# Computing Sensitivity      
def sensitivity(y_true, y_pred):
    true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
    possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
    return true_positives / (possible_positives + K.epsilon())

# Computing Specificity
def specificity(y_true, y_pred):
    true_negatives = K.sum(K.round(K.clip((1-y_true) * (1-y_pred), 0, 1)))
    possible_negatives = K.sum(K.round(K.clip(1-y_true, 0, 1)))
    return true_negatives / (possible_negatives + K.epsilon())

In [None]:
from tensorflow.keras.applications import ResNet50
from tensorflow.keras.layers import (
    Conv2D, Input, Lambda, UpSampling2D, concatenate, Dropout, LayerNormalization, Dense, Layer
)
from tensorflow.keras.models import Model
from transformers import TFAutoModel
import tensorflow as tf


class TransformerBlock(Layer):
    def __init__(self, transformer_name="huawei-noah/TinyBERT_General_4L_312D", embedding_dim=256, **kwargs):
        super().__init__(**kwargs)
        self.transformer_name = transformer_name
        self.embedding_dim = embedding_dim
        # Load transformer configuration first
        self.transformer_config = TFAutoModel.from_pretrained(transformer_name, from_pt=True).config
        # Then create the model
        self.transformer = TFAutoModel.from_pretrained(transformer_name, from_pt=True, config=self.transformer_config)
        self.transformer.trainable = False
        self.input_dense = None
        self.output_dense = None
        self.layer_norm = LayerNormalization(epsilon=1e-6)
    
    def build(self, input_shape):
        self.h, self.w = input_shape[1], input_shape[2]
        bert_hidden_size = self.transformer_config.hidden_size
        
        self.input_dense = Dense(bert_hidden_size)
        self.output_dense = Dense(self.embedding_dim, activation='relu')
        
        self.pos_embedding = self.add_weight(
            shape=(self.h * self.w, bert_hidden_size),
            initializer="random_normal",
            trainable=True,
            name="position_embeddings",
        )
        super().build(input_shape)
    
    def call(self, inputs, training=None):
        # Convert inputs to tensor if needed
        x = tf.convert_to_tensor(inputs)
        batch_size = tf.shape(x)[0]
        
        # Reshape to sequence: (batch, sequence_length, channels)
        x = tf.reshape(x, (batch_size, self.h * self.w, x.shape[-1]))
        
        # Project to BERT's hidden size
        x = self.input_dense(x)
        
        # Add position embeddings
        x = x + self.pos_embedding
        
        # Layer normalization
        x = self.layer_norm(x)
        
        # Create attention mask
        attention_mask = tf.ones((batch_size, self.h * self.w), dtype=tf.int32)
        
        # Call transformer using the functional API
        transformer_outputs = self.transformer.call(
            input_ids=None,
            attention_mask=attention_mask,
            token_type_ids=None,
            position_ids=None,
            inputs_embeds=x,
            training=training
        )
        
        # Get the sequence output
        sequence_output = transformer_outputs[0]
        
        # Project to desired dimension
        x = self.output_dense(sequence_output)
        
        # Reshape back to spatial dimensions
        x = tf.reshape(x, (batch_size, self.h, self.w, self.embedding_dim))
        
        return x

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1], input_shape[2], self.embedding_dim)
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "transformer_name": self.transformer_name,
            "embedding_dim": self.embedding_dim,
        })
        return config


def build_enhanced_unetpp_with_transformer(input_shape, ker_init='he_normal', dropout_rate=0.2):
    
    # Fix for 2-channel input (convert to 3 channels)
    inputs = Input(input_shape)  # Input with 2 channels
    inputs_3_channel = Lambda(lambda x: tf.concat([x, x[:, :, :, :1]], axis=-1))(inputs)  # Convert to 3 channels

    # Load ResNet50 as the encoder backbone with 3-channel input
    resnet = ResNet50(include_top=False, weights='imagenet', input_tensor=inputs_3_channel)
    resnet.trainable = False  # Freeze the ResNet backbone

    # inputs = resnet.input
    conv1 = resnet.get_layer("conv1_relu").output  # 1/2 resolution
    conv2 = resnet.get_layer("conv2_block3_out").output  # 1/4 resolution
    conv3 = resnet.get_layer("conv3_block4_out").output  # 1/8 resolution
    conv4 = resnet.get_layer("conv4_block6_out").output  # 1/16 resolution
    bottleneck = resnet.get_layer("conv5_block3_out").output  # 1/32 resolution

    # Transformer enhancement applied at bottleneck
    transformer_output = TransformerBlock()(bottleneck)
    enhanced_features = concatenate([bottleneck, transformer_output])
    enhanced_features = Conv2D(256, 1, activation='relu', padding='same')(enhanced_features)

    # Decoder Path with Nested Skip Connections
    # Decoder Level 4 (Up-sampling bottleneck to conv4 resolution)
    up4 = UpSampling2D(size=(2, 2))(enhanced_features)
    merge4 = concatenate([conv4, up4])
    merge4 = Dropout(dropout_rate)(merge4)
    conv4_1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer=ker_init)(merge4)
    conv4_1 = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer=ker_init)(conv4_1)

    # Decoder Level 3 (Up-sampling conv4_1 to conv3 resolution)
    up3 = UpSampling2D(size=(2, 2))(conv4_1)
    merge3 = concatenate([conv3, up3])
    merge3_nested = concatenate([merge3, UpSampling2D(size=(2, 2))(conv4)])  # Nested connection
    merge3_nested = Dropout(dropout_rate)(merge3_nested)
    conv3_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer=ker_init)(merge3_nested)
    conv3_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer=ker_init)(conv3_1)

    # Decoder Level 2 (Up-sampling conv3_1 to conv2 resolution)
    up2 = UpSampling2D(size=(2, 2))(conv3_1)
    merge2 = concatenate([conv2, up2])
    merge2_nested = concatenate([merge2, UpSampling2D(size=(2, 2))(conv3)])  # Nested connection
    merge2_nested = Dropout(dropout_rate)(merge2_nested)
    conv2_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer=ker_init)(merge2_nested)
    conv2_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer=ker_init)(conv2_1)

    # Decoder Level 1 (Up-sampling conv2_1 to conv1 resolution)
    up1 = UpSampling2D(size=(2, 2))(conv2_1)
    merge1 = concatenate([conv1, up1])
    merge1_nested = concatenate([merge1, UpSampling2D(size=(2, 2))(conv2)])  # Nested connection
    merge1_nested = Dropout(dropout_rate)(merge1_nested)
    conv1_1 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer=ker_init)(merge1_nested)
    conv1_1 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer=ker_init)(conv1_1)

    # Output Layer
    outputs = Conv2D(4, 1, activation='softmax')(conv1_1)

    return Model(inputs, outputs)

In [18]:
from tensorflow.keras.layers import (
    Conv2D, Input, MaxPooling2D, Dropout, UpSampling2D, 
    concatenate, Reshape, Dense, Flatten, LayerNormalization,
    Layer
)

from tensorflow.keras.models import Model
from transformers import TFAutoModel
import tensorflow as tf

# google-bert/bert-base-uncased

class TransformerBlock(Layer):
    def __init__(self, transformer_name="huawei-noah/TinyBERT_General_4L_312D", embedding_dim=256, **kwargs):
        super().__init__(**kwargs)
        self.transformer_name = transformer_name
        self.embedding_dim = embedding_dim
        # Load transformer configuration first
        self.transformer_config = TFAutoModel.from_pretrained(transformer_name,from_pt=True).config
        # Then create the model
        self.transformer = TFAutoModel.from_pretrained(transformer_name,from_pt=True,config=self.transformer_config)
        self.transformer.trainable = False
        self.input_dense = None
        self.output_dense = None
        self.layer_norm = LayerNormalization(epsilon=1e-6)
    
    def build(self, input_shape):
        self.h, self.w = input_shape[1], input_shape[2]
        bert_hidden_size = self.transformer_config.hidden_size
        
        self.input_dense = Dense(bert_hidden_size)
        self.output_dense = Dense(self.embedding_dim, activation='relu')
        
        self.pos_embedding = self.add_weight(
            shape=(self.h * self.w, bert_hidden_size),
            initializer="random_normal",
            trainable=True,
            name="position_embeddings",
        )
        super().build(input_shape)
    
    def call(self, inputs, training=None):
        # Convert inputs to tensor if needed
        x = tf.convert_to_tensor(inputs)
        batch_size = tf.shape(x)[0]
        
        # Reshape to sequence: (batch, sequence_length, channels)
        x = tf.reshape(x, (batch_size, self.h * self.w, x.shape[-1]))
        
        # Project to BERT's hidden size
        x = self.input_dense(x)
        
        # Add position embeddings
        x = x + self.pos_embedding
        
        # Layer normalization
        x = self.layer_norm(x)
        
        # Create attention mask
        attention_mask = tf.ones((batch_size, self.h * self.w), dtype=tf.int32)
        
        # Call transformer using the functional API
        transformer_outputs = self.transformer.call(
            input_ids=None,
            attention_mask=attention_mask,
            token_type_ids=None,
            position_ids=None,
            inputs_embeds=x,
            training=training
        )
        
        # Get the sequence output
        sequence_output = transformer_outputs[0]
        
        # Project to desired dimension
        x = self.output_dense(sequence_output)
        
        # Reshape back to spatial dimensions
        x = tf.reshape(x, (batch_size, self.h, self.w, self.embedding_dim))
        
        return x

    def compute_output_shape(self, input_shape):
        return (input_shape[0], input_shape[1], input_shape[2], self.embedding_dim)
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "transformer_name": self.transformer_name,
            "embedding_dim": self.embedding_dim,
        })
        return config
        
    
def build_unetpp_with_transformer(
    input_shape, 
    ker_init='he_normal', 
    dropout_rate=0.2
):
    # Input Layer
    inputs = Input(input_shape)
    
    # Encoder Path
    conv1_1 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer=ker_init)(inputs)
    conv1_1 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer=ker_init)(conv1_1)
    pool1 = MaxPooling2D(pool_size=(2, 2))(conv1_1)
    pool1 = Dropout(dropout_rate)(pool1)

    conv2_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer=ker_init)(pool1)
    conv2_1 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer=ker_init)(conv2_1)
    pool2 = MaxPooling2D(pool_size=(2, 2))(conv2_1)
    pool2 = Dropout(dropout_rate)(pool2)

    # U-Net++ nested skip connection
    conv2_2 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer=ker_init)(
        concatenate([
            conv2_1,
            UpSampling2D(size=(2, 2))(pool2)
        ])
    )

    conv3_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer=ker_init)(pool2)
    conv3_1 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer=ker_init)(conv3_1)
    pool3 = MaxPooling2D(pool_size=(2, 2))(conv3_1)
    pool3 = Dropout(dropout_rate)(pool3)

    # Bottleneck
    bottleneck = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer=ker_init)(pool3)
    bottleneck = Conv2D(256, 3, activation='relu', padding='same', kernel_initializer=ker_init)(bottleneck)
    
    # Transformer Feature Enhancement
    transformer_output = TransformerBlock()(bottleneck)
    
    # Combine CNN and transformer features
    enhanced_features = concatenate([bottleneck, transformer_output])
    enhanced_features = Conv2D(256, 1, activation='relu', padding='same')(enhanced_features)

    # Decoder Path
    up3 = UpSampling2D(size=(2, 2))(enhanced_features)
    up3 = Conv2D(128, 2, activation='relu', padding='same', kernel_initializer=ker_init)(up3)
    merge3 = concatenate([conv3_1, up3])
    merge3 = Dropout(dropout_rate)(merge3)
    conv7 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer=ker_init)(merge3)
    conv7 = Conv2D(128, 3, activation='relu', padding='same', kernel_initializer=ker_init)(conv7)

    up2 = UpSampling2D(size=(2, 2))(conv7)
    up2 = Conv2D(64, 2, activation='relu', padding='same', kernel_initializer=ker_init)(up2)
    merge2 = concatenate([conv2_1, conv2_2, up2])
    merge2 = Dropout(dropout_rate)(merge2)
    conv8 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer=ker_init)(merge2)
    conv8 = Conv2D(64, 3, activation='relu', padding='same', kernel_initializer=ker_init)(conv8)

    up1 = UpSampling2D(size=(2, 2))(conv8)
    up1 = Conv2D(32, 2, activation='relu', padding='same', kernel_initializer=ker_init)(up1)
    merge1 = concatenate([conv1_1, up1])
    merge1 = Dropout(dropout_rate)(merge1)
    conv9 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer=ker_init)(merge1)
    conv9 = Conv2D(32, 3, activation='relu', padding='same', kernel_initializer=ker_init)(conv9)

    outputs = Conv2D(4, 1, activation='softmax')(conv9)

    return Model(inputs=inputs, outputs=outputs)

In [None]:
# Create and compile the model
input_shape = (128, 128, 2)
model = build_unetpp_with_transformer(input_shape)

model.compile(
    loss="categorical_crossentropy",
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    metrics=[
        'accuracy',
        tf.keras.metrics.MeanIoU(num_classes=4),
        dice_coef,
        precision,
        sensitivity,
        specificity,
        dice_coef_necrotic,
        dice_coef_edema,
        dice_coef_enhancing
    ]
)

In [None]:
# Test the model with a sample input
test_input = np.random.random((100, 128, 128, 2))  # Add batch dimension
print("Input shape:", test_input.shape)
test_output = model.predict(test_input)
print("Output shape:", test_output.shape)
# Test model summary
model.summary()

In [None]:
from tensorflow.keras.utils import plot_model

plot_model(model, 
           show_shapes = True,
           show_dtype=False,
           show_layer_names = True, 
           rankdir = 'TB', 
           expand_nested = False, 
           dpi = 70)

In [10]:
# lists of directories with studies
train_and_val_directories = [f.path for f in os.scandir(TRAINING_PATH) if f.is_dir()]

# file BraTS20_Training_355 has ill formatted name for for seg.nii file
#train_and_val_directories.remove(TRAIN_DATASET_PATH+'BraTS20_Training_355')


def pathListIntoIds(dirList):
    x = []
    for i in range(0,len(dirList)):
        x.append(dirList[i][dirList[i].rfind('/')+1:])
    return x

train_and_test_ids = pathListIntoIds(train_and_val_directories); 

    
train_test_ids, val_ids = train_test_split(train_and_test_ids,test_size=0.2) 
train_ids, test_ids = train_test_split(train_test_ids,test_size=0.15) 

In [None]:
train_and_test_ids[0] 

In [None]:
train_ids[0]

In [None]:
test_ids[0]

In [14]:
class DataGenerator(tf.keras.utils.Sequence):
    'Generates data for Keras'
    def __init__(self, list_IDs, dim=(IMG_SIZE,IMG_SIZE), batch_size = 1, n_channels = 2, shuffle=True):
        'Initialization'
        self.dim = dim
        self.batch_size = batch_size
        self.list_IDs = list_IDs
        self.n_channels = n_channels
        self.shuffle = shuffle
        self.on_epoch_end()

    def __len__(self):
        'Denotes the number of batches per epoch'
        return int(np.floor(len(self.list_IDs) / self.batch_size))

    def __getitem__(self, index):
        'Generate one batch of data'
        # Generate indexes of the batch
        indexes = self.indexes[index*self.batch_size:(index+1)*self.batch_size]

        # Find list of IDs
        Batch_ids = [self.list_IDs[k] for k in indexes]

        # Generate data
        X, y = self.__data_generation(Batch_ids)

        return X, y

    def on_epoch_end(self):
        'Updates indexes after each epoch'
        self.indexes = np.arange(len(self.list_IDs))
        if self.shuffle == True:
            np.random.shuffle(self.indexes)

    def __data_generation(self, Batch_ids):
        'Generates data containing batch_size samples' # X : (n_samples, *dim, n_channels)
        # Initialization
        X = np.zeros((self.batch_size*VOLUME_SLICES, *self.dim, self.n_channels))
        y = np.zeros((self.batch_size*VOLUME_SLICES, 240, 240))
        Y = np.zeros((self.batch_size*VOLUME_SLICES, *self.dim, 4))

        
        # Generate data
        for c, i in enumerate(Batch_ids):
            case_path = os.path.join(TRAINING_PATH, i)

            data_path = os.path.join(case_path, f'{i}_flair.nii.gz');
            flair = nib.load(data_path).get_fdata()    

            data_path = os.path.join(case_path, f'{i}_t1ce.nii.gz');
            ce = nib.load(data_path).get_fdata()
            
            data_path = os.path.join(case_path, f'{i}_seg.nii.gz');
            seg = nib.load(data_path).get_fdata()
        
            for j in range(VOLUME_SLICES):
                X[j +VOLUME_SLICES*c,:,:,0] = cv2.resize(flair[:,:,j+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE));
                X[j +VOLUME_SLICES*c,:,:,1] = cv2.resize(ce[:,:,j+VOLUME_START_AT], (IMG_SIZE, IMG_SIZE));

                y[j +VOLUME_SLICES*c] = seg[:,:,j+VOLUME_START_AT];
                    
        # Generate masks
        y[y==4] = 3;
        mask = tf.one_hot(y, 4);
        Y = tf.image.resize(mask, (IMG_SIZE, IMG_SIZE));
        return X/np.max(X), Y
        
training_generator = DataGenerator(train_ids)
valid_generator = DataGenerator(val_ids)
test_generator = DataGenerator(test_ids)

In [None]:
print(len(train_ids))
print(len(val_ids))
print(len(test_ids))

In [None]:
# show number of data for each dir 
def showDataLayout():
    plt.bar(["Train","Valid","Test"],
    [len(train_ids), len(val_ids), len(test_ids)], align='center',color=[ 'green','red', 'blue'])
    plt.legend()

    plt.ylabel('Number of images')
    plt.title('Data distribution')
    plt.savefig('data2018.png')
    plt.show()
    
showDataLayout()

In [None]:
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, LambdaCallback, CSVLogger
from tensorflow.keras import mixed_precision
import tensorflow as tf

filepath="3D-MedVision++-weights-improvement-{epoch:02d}-{val_loss:.3f}.keras" 

checkpoint = ModelCheckpoint(filepath, monitor='val_loss', verbose=1, save_best_only=True, mode='min')

early_stop = EarlyStopping(monitor='val_loss', patience=300, verbose=1, restore_best_weights=True)
log_callback = LambdaCallback(on_epoch_end=lambda epoch, logs: print(f"Epoch {epoch} logs: {logs}"))
csv_logger = CSVLogger('training_UNet++.log')

history =  model.fit(training_generator,
                    epochs=100,
                    batch_size =2,
                    steps_per_epoch=len(train_ids) // 2,
                    callbacks= [checkpoint, csv_logger, early_stop],
                    validation_data = valid_generator,
                    validation_steps = len(val_ids) // 2,
                    )  

# Metrics: Various metrics are reported during training and validation to assess the model's performance. These include:

# Loss: A measure of how well the model is performing. It represents an error value that the model is trying to minimize during training.
# Accuracy: The proportion of correctly classified samples.
# Mean Intersection over Union (mean_io_u): A metric used for image segmentation tasks, measuring the overlap between predicted and true segmentation masks.
# Dice Coefficient (dice_coef): Another metric for segmentation tasks, measuring the similarity between predicted and true masks.
# Precision, Sensitivity, Specificity: These are commonly used metrics in binary classification tasks. Precision is the ratio of true positives to the sum of true positives and false positives. Sensitivity (recall) is the ratio of true positives to the sum of true positives and false negatives. Specificity is the ratio of true negatives to the sum of true negatives and false positives.
# Dice Coefficients for Necrotic, Edema, Enhancing: Specific dice coefficients for different classes in the segmentation task.
# This is common in medical image segmentation where different regions of interest are segmented separately.
# Validation Metrics: These are metrics evaluated on a separate dataset not used for training. 
# They give an indication of how well the model generalizes to new, unseen data.

In [None]:
import tensorflow as tf
physical_devices = tf.config.list_physical_devices('GPU')
print(f"Available GPUs: {physical_devices}")

In [None]:
!nvidia-smi


In [None]:
import tensorflow as tf
from tensorflow.keras import mixed_precision

# Set the policy to mixed precision
# policy = mixed_precision.Policy('mixed_float16')
# mixed_precision.set_global_policy(policy)

# Define a MirroredStrategy
strategy = tf.distribute.MirroredStrategy()

print('Number of devices: ', strategy.num_replicas_in_sync)


with strategy.scope():
    # Build and compile the model inside the strategy scope
    model = build_unetpp_with_transformer(input_shape=(128, 128, 2))
    model.compile(
    loss="categorical_crossentropy",
    optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
    metrics=[
        'accuracy',
        tf.keras.metrics.MeanIoU(num_classes=4),
        dice_coef,
        precision,
        sensitivity,
        specificity,
        dice_coef_necrotic,
        dice_coef_edema,
        dice_coef_enhancing
    ]
)

    # Train the model
    history = model.fit(training_generator,
                        epochs=100,
                        batch_size=2,
                        steps_per_epoch=len(train_ids) // 2,
                        validation_data=valid_generator,
                        validation_steps=len(val_ids) // 2,
                        callbacks=[checkpoint, csv_logger, early_stop, log_callback])